Browse Source

use channels instead of mutexes

pull/31/head
aler9 6 years ago
parent
commit
4df4bbba6a
  1. 350
      main.go
  2. 348
      server-client.go
  3. 71
      server-tcpl.go
  4. 60
      server-udpl.go

350
main.go

@ -9,7 +9,9 @@ import (
"strings" "strings"
"time" "time"
"github.com/aler9/gortsplib"
"gopkg.in/alecthomas/kingpin.v2" "gopkg.in/alecthomas/kingpin.v2"
"gortc.io/sdp"
) )
var Version = "v0.0.0" var Version = "v0.0.0"
@ -38,10 +40,10 @@ func parseIpCidrList(in string) ([]interface{}, error) {
return ret, nil return ret, nil
} }
type trackFlow int type trackFlowType int
const ( const (
_TRACK_FLOW_RTP trackFlow = iota _TRACK_FLOW_RTP trackFlowType = iota
_TRACK_FLOW_RTCP _TRACK_FLOW_RTCP
) )
@ -64,6 +66,110 @@ func (s streamProtocol) String() string {
return "tcp" return "tcp"
} }
type programEvent interface {
isProgramEvent()
}
type programEventClientNew struct {
nconn net.Conn
}
func (programEventClientNew) isProgramEvent() {}
type programEventClientClose struct {
done chan struct{}
client *serverClient
}
func (programEventClientClose) isProgramEvent() {}
type programEventClientGetStreamSdp struct {
path string
res chan []byte
}
func (programEventClientGetStreamSdp) isProgramEvent() {}
type programEventClientAnnounce struct {
res chan error
client *serverClient
path string
sdpText []byte
sdpParsed *sdp.Message
}
func (programEventClientAnnounce) isProgramEvent() {}
type programEventClientSetupPlay struct {
res chan error
client *serverClient
path string
protocol streamProtocol
rtpPort int
rtcpPort int
}
func (programEventClientSetupPlay) isProgramEvent() {}
type programEventClientSetupRecord struct {
res chan error
client *serverClient
protocol streamProtocol
rtpPort int
rtcpPort int
}
func (programEventClientSetupRecord) isProgramEvent() {}
type programEventClientPlay1 struct {
res chan error
client *serverClient
}
func (programEventClientPlay1) isProgramEvent() {}
type programEventClientPlay2 struct {
res chan error
client *serverClient
}
func (programEventClientPlay2) isProgramEvent() {}
type programEventClientPause struct {
res chan error
client *serverClient
}
func (programEventClientPause) isProgramEvent() {}
type programEventClientRecord struct {
res chan error
client *serverClient
}
func (programEventClientRecord) isProgramEvent() {}
type programEventFrameUdp struct {
trackFlowType trackFlowType
addr *net.UDPAddr
buf []byte
}
func (programEventFrameUdp) isProgramEvent() {}
type programEventFrameTcp struct {
path string
trackId int
trackFlowType trackFlowType
buf []byte
}
func (programEventFrameTcp) isProgramEvent() {}
type programEventTerminate struct{}
func (programEventTerminate) isProgramEvent() {}
type args struct { type args struct {
version bool version bool
protocolsStr string protocolsStr string
@ -90,6 +196,11 @@ type program struct {
tcpl *serverTcpListener tcpl *serverTcpListener
udplRtp *serverUdpListener udplRtp *serverUdpListener
udplRtcp *serverUdpListener udplRtcp *serverUdpListener
clients map[*serverClient]struct{}
publishers map[string]*serverClient
events chan programEvent
done chan struct{}
} }
func newProgram(sargs []string) (*program, error) { func newProgram(sargs []string) (*program, error) {
@ -204,6 +315,10 @@ func newProgram(sargs []string) (*program, error) {
protocols: protocols, protocols: protocols,
publishIps: publishIps, publishIps: publishIps,
readIps: readIps, readIps: readIps,
clients: make(map[*serverClient]struct{}),
publishers: make(map[string]*serverClient),
events: make(chan programEvent),
done: make(chan struct{}),
} }
p.udplRtp, err = newServerUdpListener(p, args.rtpPort, _TRACK_FLOW_RTP) p.udplRtp, err = newServerUdpListener(p, args.rtpPort, _TRACK_FLOW_RTP)
@ -224,14 +339,243 @@ func newProgram(sargs []string) (*program, error) {
go p.udplRtp.run() go p.udplRtp.run()
go p.udplRtcp.run() go p.udplRtcp.run()
go p.tcpl.run() go p.tcpl.run()
go p.run()
return p, nil return p, nil
} }
func (p *program) close() { func (p *program) run() {
outer:
for rawEvt := range p.events {
switch evt := rawEvt.(type) {
case programEventClientNew:
c := newServerClient(p, evt.nconn)
p.clients[c] = struct{}{}
case programEventClientClose:
// already deleted
if _, ok := p.clients[evt.client]; !ok {
close(evt.done)
continue
}
delete(p.clients, evt.client)
if evt.client.path != "" {
if pub, ok := p.publishers[evt.client.path]; ok && pub == evt.client {
delete(p.publishers, evt.client.path)
// if the publisher has disconnected
// close all other connections that share the same path
for oc := range p.clients {
if oc.path == evt.client.path {
go oc.close()
}
}
}
}
close(evt.done)
case programEventClientGetStreamSdp:
pub, ok := p.publishers[evt.path]
if !ok {
evt.res <- nil
continue
}
evt.res <- pub.streamSdpText
case programEventClientAnnounce:
_, ok := p.publishers[evt.path]
if ok {
evt.res <- fmt.Errorf("another client is already publishing on path '%s'", evt.path)
continue
}
evt.client.path = evt.path
evt.client.streamSdpText = evt.sdpText
evt.client.streamSdpParsed = evt.sdpParsed
evt.client.state = _CLIENT_STATE_ANNOUNCE
p.publishers[evt.path] = evt.client
evt.res <- nil
case programEventClientSetupPlay:
pub, ok := p.publishers[evt.path]
if !ok {
evt.res <- fmt.Errorf("no one is streaming on path '%s'", evt.path)
continue
}
if len(evt.client.streamTracks) >= len(pub.streamSdpParsed.Medias) {
evt.res <- fmt.Errorf("all the tracks have already been setup")
continue
}
evt.client.path = evt.path
evt.client.streamProtocol = evt.protocol
evt.client.streamTracks = append(evt.client.streamTracks, &track{
rtpPort: evt.rtpPort,
rtcpPort: evt.rtcpPort,
})
evt.client.state = _CLIENT_STATE_PRE_PLAY
evt.res <- nil
case programEventClientSetupRecord:
evt.client.streamProtocol = evt.protocol
evt.client.streamTracks = append(evt.client.streamTracks, &track{
rtpPort: evt.rtpPort,
rtcpPort: evt.rtcpPort,
})
evt.client.state = _CLIENT_STATE_PRE_RECORD
evt.res <- nil
case programEventClientPlay1:
pub, ok := p.publishers[evt.client.path]
if !ok {
evt.res <- fmt.Errorf("no one is streaming on path '%s'", evt.client.path)
continue
}
if len(evt.client.streamTracks) != len(pub.streamSdpParsed.Medias) {
evt.res <- fmt.Errorf("not all tracks have been setup")
continue
}
evt.res <- nil
case programEventClientPlay2:
evt.client.state = _CLIENT_STATE_PLAY
evt.res <- nil
case programEventClientPause:
evt.client.state = _CLIENT_STATE_PRE_PLAY
evt.res <- nil
case programEventClientRecord:
evt.client.state = _CLIENT_STATE_RECORD
evt.res <- nil
case programEventFrameUdp:
// find publisher and track id from ip and port
pub, trackId := func() (*serverClient, int) {
for _, pub := range p.publishers {
if pub.streamProtocol != _STREAM_PROTOCOL_UDP ||
pub.state != _CLIENT_STATE_RECORD ||
!pub.ip().Equal(evt.addr.IP) {
continue
}
for i, t := range pub.streamTracks {
if evt.trackFlowType == _TRACK_FLOW_RTP {
if t.rtpPort == evt.addr.Port {
return pub, i
}
} else {
if t.rtcpPort == evt.addr.Port {
return pub, i
}
}
}
}
return nil, -1
}()
if pub == nil {
continue
}
pub.udpLastFrameTime = time.Now()
p.forwardTrack(pub.path, trackId, evt.trackFlowType, evt.buf)
case programEventFrameTcp:
p.forwardTrack(evt.path, evt.trackId, evt.trackFlowType, evt.buf)
case programEventTerminate:
break outer
}
}
go func() {
for rawEvt := range p.events {
switch evt := rawEvt.(type) {
case programEventClientClose:
close(evt.done)
case programEventClientGetStreamSdp:
evt.res <- nil
case programEventClientAnnounce:
evt.res <- fmt.Errorf("terminated")
case programEventClientSetupPlay:
evt.res <- fmt.Errorf("terminated")
case programEventClientSetupRecord:
evt.res <- fmt.Errorf("terminated")
case programEventClientPlay1:
evt.res <- fmt.Errorf("terminated")
case programEventClientPlay2:
evt.res <- fmt.Errorf("terminated")
case programEventClientPause:
evt.res <- fmt.Errorf("terminated")
case programEventClientRecord:
evt.res <- fmt.Errorf("terminated")
}
}
}()
p.tcpl.close() p.tcpl.close()
p.udplRtcp.close() p.udplRtcp.close()
p.udplRtp.close() p.udplRtp.close()
for c := range p.clients {
c.close()
}
close(p.events)
close(p.done)
}
func (p *program) close() {
p.events <- programEventTerminate{}
<-p.done
}
func (p *program) forwardTrack(path string, id int, trackFlowType trackFlowType, frame []byte) {
for c := range p.clients {
if c.path == path && c.state == _CLIENT_STATE_PLAY {
if c.streamProtocol == _STREAM_PROTOCOL_UDP {
if trackFlowType == _TRACK_FLOW_RTP {
p.udplRtp.write <- &udpWrite{
addr: &net.UDPAddr{
IP: c.ip(),
Zone: c.zone(),
Port: c.streamTracks[id].rtpPort,
},
buf: frame,
}
} else {
p.udplRtcp.write <- &udpWrite{
addr: &net.UDPAddr{
IP: c.ip(),
Zone: c.zone(),
Port: c.streamTracks[id].rtcpPort,
},
buf: frame,
}
}
} else {
c.write <- &gortsplib.InterleavedFrame{
Channel: trackToInterleavedChannel(id, trackFlowType),
Content: frame,
}
}
}
}
} }
func main() { func main() {

348
server-client.go

@ -19,15 +19,15 @@ const (
_UDP_STREAM_DEAD_AFTER = 10 * time.Second _UDP_STREAM_DEAD_AFTER = 10 * time.Second
) )
func interleavedChannelToTrack(channel uint8) (int, trackFlow) { func interleavedChannelToTrack(channel uint8) (int, trackFlowType) {
if (channel % 2) == 0 { if (channel % 2) == 0 {
return int(channel / 2), _TRACK_FLOW_RTP return int(channel / 2), _TRACK_FLOW_RTP
} }
return int((channel - 1) / 2), _TRACK_FLOW_RTCP return int((channel - 1) / 2), _TRACK_FLOW_RTCP
} }
func trackToInterleavedChannel(id int, flow trackFlow) uint8 { func trackToInterleavedChannel(id int, trackFlowType trackFlowType) uint8 {
if flow == _TRACK_FLOW_RTP { if trackFlowType == _TRACK_FLOW_RTP {
return uint8(id * 2) return uint8(id * 2)
} }
return uint8((id * 2) + 1) return uint8((id * 2) + 1)
@ -80,8 +80,9 @@ type serverClient struct {
streamTracks []*track streamTracks []*track
udpLastFrameTime time.Time udpLastFrameTime time.Time
udpCheckStreamTicker *time.Ticker udpCheckStreamTicker *time.Ticker
write chan *gortsplib.InterleavedFrame
done chan struct{} write chan *gortsplib.InterleavedFrame
done chan struct{}
} }
func newServerClient(p *program, nconn net.Conn) *serverClient { func newServerClient(p *program, nconn net.Conn) *serverClient {
@ -97,39 +98,13 @@ func newServerClient(p *program, nconn net.Conn) *serverClient {
done: make(chan struct{}), done: make(chan struct{}),
} }
c.p.tcpl.mutex.Lock()
c.p.tcpl.clients[c] = struct{}{}
c.p.tcpl.mutex.Unlock()
go c.run() go c.run()
return c return c
} }
func (c *serverClient) close() error { func (c *serverClient) close() {
// already deleted
if _, ok := c.p.tcpl.clients[c]; !ok {
return nil
}
delete(c.p.tcpl.clients, c)
c.conn.NetConn().Close() c.conn.NetConn().Close()
close(c.write) <-c.done
if c.path != "" {
if pub, ok := c.p.tcpl.publishers[c.path]; ok && pub == c {
delete(c.p.tcpl.publishers, c.path)
// if the publisher has disconnected
// close all other connections that share the same path
for oc := range c.p.tcpl.clients {
if oc.path == c.path {
oc.close()
}
}
}
}
return nil
} }
func (c *serverClient) log(format string, args ...interface{}) { func (c *serverClient) log(format string, args ...interface{}) {
@ -172,18 +147,12 @@ func (c *serverClient) run() {
} }
} }
func() { c.log("disconnected")
c.p.tcpl.mutex.Lock()
defer c.p.tcpl.mutex.Unlock()
c.close()
}()
if c.udpCheckStreamTicker != nil { if c.udpCheckStreamTicker != nil {
c.udpCheckStreamTicker.Stop() c.udpCheckStreamTicker.Stop()
} }
c.log("disconnected")
func() { func() {
if c.p.args.postScript != "" { if c.p.args.postScript != "" {
postScript := exec.Command(c.p.args.postScript) postScript := exec.Command(c.p.args.postScript)
@ -194,6 +163,12 @@ func (c *serverClient) run() {
} }
}() }()
done := make(chan struct{})
c.p.events <- programEventClientClose{done, c}
<-done
close(c.write)
close(c.done) close(c.done)
} }
@ -202,7 +177,7 @@ func (c *serverClient) writeResError(req *gortsplib.Request, code gortsplib.Stat
header := gortsplib.Header{} header := gortsplib.Header{}
if cseq, ok := req.Header["CSeq"]; ok && len(cseq) == 1 { if cseq, ok := req.Header["CSeq"]; ok && len(cseq) == 1 {
header["CSeq"] = []string{cseq[0]} header["CSeq"] = cseq
} }
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
@ -317,7 +292,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusOK, StatusCode: gortsplib.StatusOK,
Header: gortsplib.Header{ Header: gortsplib.Header{
"CSeq": []string{cseq[0]}, "CSeq": cseq,
"Public": []string{strings.Join([]string{ "Public": []string{strings.Join([]string{
string(gortsplib.DESCRIBE), string(gortsplib.DESCRIBE),
string(gortsplib.ANNOUNCE), string(gortsplib.ANNOUNCE),
@ -346,26 +321,18 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return true return true
} }
sdp, err := func() ([]byte, error) { res := make(chan []byte)
c.p.tcpl.mutex.RLock() c.p.events <- programEventClientGetStreamSdp{path, res}
defer c.p.tcpl.mutex.RUnlock() sdp := <-res
if sdp == nil {
pub, ok := c.p.tcpl.publishers[path] c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("no one is streaming on path '%s'", path))
if !ok {
return nil, fmt.Errorf("no one is streaming on path '%s'", path)
}
return pub.streamSdpText, nil
}()
if err != nil {
c.writeResError(req, gortsplib.StatusBadRequest, err)
return false return false
} }
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusOK, StatusCode: gortsplib.StatusOK,
Header: gortsplib.Header{ Header: gortsplib.Header{
"CSeq": []string{cseq[0]}, "CSeq": cseq,
"Content-Base": []string{req.Url.String()}, "Content-Base": []string{req.Url.String()},
"Content-Type": []string{"application/sdp"}, "Content-Type": []string{"application/sdp"},
}, },
@ -404,25 +371,16 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("invalid SDP: %s", err)) c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("invalid SDP: %s", err))
return false return false
} }
sdpParsed, req.Content = gortsplib.SDPFilter(sdpParsed, req.Content) sdpParsed, req.Content = gortsplib.SDPFilter(sdpParsed, req.Content)
err = func() error { if len(path) == 0 {
c.p.tcpl.mutex.Lock() c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path can't be empty"))
defer c.p.tcpl.mutex.Unlock() return false
}
_, ok := c.p.tcpl.publishers[path]
if ok {
return fmt.Errorf("another client is already publishing on path '%s'", path)
}
c.path = path res := make(chan error)
c.p.tcpl.publishers[path] = c c.p.events <- programEventClientAnnounce{res, c, path, req.Content, sdpParsed}
c.streamSdpText = req.Content err = <-res
c.streamSdpParsed = sdpParsed
c.state = _CLIENT_STATE_ANNOUNCE
return nil
}()
if err != nil { if err != nil {
c.writeResError(req, gortsplib.StatusBadRequest, err) c.writeResError(req, gortsplib.StatusBadRequest, err)
return false return false
@ -431,7 +389,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusOK, StatusCode: gortsplib.StatusOK,
Header: gortsplib.Header{ Header: gortsplib.Header{
"CSeq": []string{cseq[0]}, "CSeq": cseq,
}, },
}) })
return true return true
@ -488,33 +446,14 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
err := func() error { if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP {
c.p.tcpl.mutex.Lock() c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't receive tracks with different protocols"))
defer c.p.tcpl.mutex.Unlock() return false
}
pub, ok := c.p.tcpl.publishers[path]
if !ok {
return fmt.Errorf("no one is streaming on path '%s'", path)
}
if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP {
return fmt.Errorf("client wants to read tracks with different protocols")
}
if len(c.streamTracks) >= len(pub.streamSdpParsed.Medias) {
return fmt.Errorf("all the tracks have already been setup")
}
c.path = path
c.streamProtocol = _STREAM_PROTOCOL_UDP
c.streamTracks = append(c.streamTracks, &track{
rtpPort: rtpPort,
rtcpPort: rtcpPort,
})
c.state = _CLIENT_STATE_PRE_PLAY res := make(chan error)
return nil c.p.events <- programEventClientSetupPlay{res, c, path, _STREAM_PROTOCOL_UDP, rtpPort, rtcpPort}
}() err = <-res
if err != nil { if err != nil {
c.writeResError(req, gortsplib.StatusBadRequest, err) c.writeResError(req, gortsplib.StatusBadRequest, err)
return false return false
@ -523,7 +462,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusOK, StatusCode: gortsplib.StatusOK,
Header: gortsplib.Header{ Header: gortsplib.Header{
"CSeq": []string{cseq[0]}, "CSeq": cseq,
"Transport": []string{strings.Join([]string{ "Transport": []string{strings.Join([]string{
"RTP/AVP/UDP", "RTP/AVP/UDP",
"unicast", "unicast",
@ -547,33 +486,14 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
err := func() error { if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP {
c.p.tcpl.mutex.Lock() c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't receive tracks with different protocols"))
defer c.p.tcpl.mutex.Unlock() return false
}
pub, ok := c.p.tcpl.publishers[path]
if !ok {
return fmt.Errorf("no one is streaming on path '%s'", path)
}
if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP {
return fmt.Errorf("client wants to read tracks with different protocols")
}
if len(c.streamTracks) >= len(pub.streamSdpParsed.Medias) {
return fmt.Errorf("all the tracks have already been setup")
}
c.path = path
c.streamProtocol = _STREAM_PROTOCOL_TCP
c.streamTracks = append(c.streamTracks, &track{
rtpPort: 0,
rtcpPort: 0,
})
c.state = _CLIENT_STATE_PRE_PLAY res := make(chan error)
return nil c.p.events <- programEventClientSetupPlay{res, c, path, _STREAM_PROTOCOL_TCP, 0, 0}
}() err = <-res
if err != nil { if err != nil {
c.writeResError(req, gortsplib.StatusBadRequest, err) c.writeResError(req, gortsplib.StatusBadRequest, err)
return false return false
@ -584,7 +504,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusOK, StatusCode: gortsplib.StatusOK,
Header: gortsplib.Header{ Header: gortsplib.Header{
"CSeq": []string{cseq[0]}, "CSeq": cseq,
"Transport": []string{strings.Join([]string{ "Transport": []string{strings.Join([]string{
"RTP/AVP/TCP", "RTP/AVP/TCP",
"unicast", "unicast",
@ -607,6 +527,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
// after ANNOUNCE, c.path is already set
if path != c.path { if path != c.path {
c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path has changed")) c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path has changed"))
return false return false
@ -635,27 +556,19 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
err := func() error { if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP {
c.p.tcpl.mutex.Lock() c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't publish tracks with different protocols"))
defer c.p.tcpl.mutex.Unlock() return false
}
if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP {
return fmt.Errorf("client wants to publish tracks with different protocols")
}
if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) {
return fmt.Errorf("all the tracks have already been setup")
}
c.streamProtocol = _STREAM_PROTOCOL_UDP if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) {
c.streamTracks = append(c.streamTracks, &track{ c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("all the tracks have already been setup"))
rtpPort: rtpPort, return false
rtcpPort: rtcpPort, }
})
c.state = _CLIENT_STATE_PRE_RECORD res := make(chan error)
return nil c.p.events <- programEventClientSetupRecord{res, c, _STREAM_PROTOCOL_UDP, rtpPort, rtcpPort}
}() err := <-res
if err != nil { if err != nil {
c.writeResError(req, gortsplib.StatusBadRequest, err) c.writeResError(req, gortsplib.StatusBadRequest, err)
return false return false
@ -664,7 +577,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusOK, StatusCode: gortsplib.StatusOK,
Header: gortsplib.Header{ Header: gortsplib.Header{
"CSeq": []string{cseq[0]}, "CSeq": cseq,
"Transport": []string{strings.Join([]string{ "Transport": []string{strings.Join([]string{
"RTP/AVP/UDP", "RTP/AVP/UDP",
"unicast", "unicast",
@ -683,38 +596,31 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
var interleaved string if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP {
err := func() error { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't publish tracks with different protocols"))
c.p.tcpl.mutex.Lock() return false
defer c.p.tcpl.mutex.Unlock() }
if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP {
return fmt.Errorf("client wants to publish tracks with different protocols")
}
if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) {
return fmt.Errorf("all the tracks have already been setup")
}
interleaved = th.GetValue("interleaved") interleaved := th.GetValue("interleaved")
if interleaved == "" { if interleaved == "" {
return fmt.Errorf("transport header does not contain interleaved field") c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("transport header does not contain the interleaved field"))
} return false
}
expInterleaved := fmt.Sprintf("%d-%d", 0+len(c.streamTracks)*2, 1+len(c.streamTracks)*2) expInterleaved := fmt.Sprintf("%d-%d", 0+len(c.streamTracks)*2, 1+len(c.streamTracks)*2)
if interleaved != expInterleaved { if interleaved != expInterleaved {
return fmt.Errorf("wrong interleaved value, expected '%s', got '%s'", expInterleaved, interleaved) c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("wrong interleaved value, expected '%s', got '%s'", expInterleaved, interleaved))
} return false
}
c.streamProtocol = _STREAM_PROTOCOL_TCP if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) {
c.streamTracks = append(c.streamTracks, &track{ c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("all the tracks have already been setup"))
rtpPort: 0, return false
rtcpPort: 0, }
})
c.state = _CLIENT_STATE_PRE_RECORD res := make(chan error)
return nil c.p.events <- programEventClientSetupRecord{res, c, _STREAM_PROTOCOL_TCP, 0, 0}
}() err := <-res
if err != nil { if err != nil {
c.writeResError(req, gortsplib.StatusBadRequest, err) c.writeResError(req, gortsplib.StatusBadRequest, err)
return false return false
@ -723,7 +629,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusOK, StatusCode: gortsplib.StatusOK,
Header: gortsplib.Header{ Header: gortsplib.Header{
"CSeq": []string{cseq[0]}, "CSeq": cseq,
"Transport": []string{strings.Join([]string{ "Transport": []string{strings.Join([]string{
"RTP/AVP/TCP", "RTP/AVP/TCP",
"unicast", "unicast",
@ -756,33 +662,22 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
err := func() error { // check publisher existence
c.p.tcpl.mutex.Lock() res := make(chan error)
defer c.p.tcpl.mutex.Unlock() c.p.events <- programEventClientPlay1{res, c}
err := <-res
pub, ok := c.p.tcpl.publishers[c.path]
if !ok {
return fmt.Errorf("no one is streaming on path '%s'", c.path)
}
if len(c.streamTracks) != len(pub.streamSdpParsed.Medias) {
return fmt.Errorf("not all tracks have been setup")
}
return nil
}()
if err != nil { if err != nil {
c.writeResError(req, gortsplib.StatusBadRequest, err) c.writeResError(req, gortsplib.StatusBadRequest, err)
return false return false
} }
// first write response, then set state // write response before setting state
// otherwise, in case of TCP connections, RTP packets could be written // otherwise, in case of TCP connections, RTP packets could be sent
// before the response // before the response
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusOK, StatusCode: gortsplib.StatusOK,
Header: gortsplib.Header{ Header: gortsplib.Header{
"CSeq": []string{cseq[0]}, "CSeq": cseq,
"Session": []string{"12345678"}, "Session": []string{"12345678"},
}, },
}) })
@ -794,9 +689,10 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return "tracks" return "tracks"
}(), c.streamProtocol) }(), c.streamProtocol)
c.p.tcpl.mutex.Lock() // set state
c.state = _CLIENT_STATE_PLAY res = make(chan error)
c.p.tcpl.mutex.Unlock() c.p.events <- programEventClientPlay2{res, c}
<-res
// when protocol is TCP, the RTSP connection becomes a RTP connection // when protocol is TCP, the RTSP connection becomes a RTP connection
if c.streamProtocol == _STREAM_PROTOCOL_TCP { if c.streamProtocol == _STREAM_PROTOCOL_TCP {
@ -836,14 +732,14 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
c.log("paused") c.log("paused")
c.p.tcpl.mutex.Lock() res := make(chan error)
c.state = _CLIENT_STATE_PRE_PLAY c.p.events <- programEventClientPause{res, c}
c.p.tcpl.mutex.Unlock() <-res
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusOK, StatusCode: gortsplib.StatusOK,
Header: gortsplib.Header{ Header: gortsplib.Header{
"CSeq": []string{cseq[0]}, "CSeq": cseq,
"Session": []string{"12345678"}, "Session": []string{"12345678"},
}, },
}) })
@ -861,25 +757,15 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
err := func() error { if len(c.streamTracks) != len(c.streamSdpParsed.Medias) {
c.p.tcpl.mutex.Lock() c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("not all tracks have been setup"))
defer c.p.tcpl.mutex.Unlock()
if len(c.streamTracks) != len(c.streamSdpParsed.Medias) {
return fmt.Errorf("not all tracks have been setup")
}
return nil
}()
if err != nil {
c.writeResError(req, gortsplib.StatusBadRequest, err)
return false return false
} }
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusOK, StatusCode: gortsplib.StatusOK,
Header: gortsplib.Header{ Header: gortsplib.Header{
"CSeq": []string{cseq[0]}, "CSeq": cseq,
"Session": []string{"12345678"}, "Session": []string{"12345678"},
}, },
}) })
@ -891,13 +777,13 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return "tracks" return "tracks"
}(), c.streamProtocol) }(), c.streamProtocol)
res := make(chan error)
c.p.events <- programEventClientRecord{res, c}
<-res
// when protocol is TCP, the RTSP connection becomes a RTP connection // when protocol is TCP, the RTSP connection becomes a RTP connection
// receive RTP data and parse it // receive RTP data and parse it
if c.streamProtocol == _STREAM_PROTOCOL_TCP { if c.streamProtocol == _STREAM_PROTOCOL_TCP {
c.p.tcpl.mutex.Lock()
c.state = _CLIENT_STATE_RECORD
c.p.tcpl.mutex.Unlock()
for { for {
frame, err := c.conn.ReadInterleavedFrame() frame, err := c.conn.ReadInterleavedFrame()
if err != nil { if err != nil {
@ -907,37 +793,27 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
trackId, trackFlow := interleavedChannelToTrack(frame.Channel) trackId, trackFlowType := interleavedChannelToTrack(frame.Channel)
if trackId >= len(c.streamTracks) { if trackId >= len(c.streamTracks) {
c.log("ERR: invalid track id '%d'", trackId) c.log("ERR: invalid track id '%d'", trackId)
return false return false
} }
c.p.tcpl.mutex.RLock() c.p.events <- programEventFrameTcp{
c.p.tcpl.forwardTrack(c.path, trackId, trackFlow, frame.Content) c.path,
c.p.tcpl.mutex.RUnlock() trackId,
trackFlowType,
frame.Content,
}
} }
} else { } else {
c.p.tcpl.mutex.Lock()
c.state = _CLIENT_STATE_RECORD
c.udpLastFrameTime = time.Now() c.udpLastFrameTime = time.Now()
c.udpCheckStreamTicker = time.NewTicker(_UDP_CHECK_STREAM_INTERVAL) c.udpCheckStreamTicker = time.NewTicker(_UDP_CHECK_STREAM_INTERVAL)
c.p.tcpl.mutex.Unlock()
go func() { go func() {
for range c.udpCheckStreamTicker.C { for range c.udpCheckStreamTicker.C {
ok := func() bool { if time.Since(c.udpLastFrameTime) >= _UDP_STREAM_DEAD_AFTER {
c.p.tcpl.mutex.Lock()
defer c.p.tcpl.mutex.Unlock()
if time.Since(c.udpLastFrameTime) >= _UDP_STREAM_DEAD_AFTER {
return false
}
return true
}()
if !ok {
c.log("ERR: stream is dead") c.log("ERR: stream is dead")
c.conn.NetConn().Close() c.conn.NetConn().Close()
break break

71
server-tcpl.go

@ -3,18 +3,13 @@ package main
import ( import (
"log" "log"
"net" "net"
"sync"
"github.com/aler9/gortsplib"
) )
type serverTcpListener struct { type serverTcpListener struct {
p *program p *program
nconn *net.TCPListener nconn *net.TCPListener
mutex sync.RWMutex
clients map[*serverClient]struct{} done chan struct{}
publishers map[string]*serverClient
done chan struct{}
} }
func newServerTcpListener(p *program) (*serverTcpListener, error) { func newServerTcpListener(p *program) (*serverTcpListener, error) {
@ -26,11 +21,9 @@ func newServerTcpListener(p *program) (*serverTcpListener, error) {
} }
l := &serverTcpListener{ l := &serverTcpListener{
p: p, p: p,
nconn: nconn, nconn: nconn,
clients: make(map[*serverClient]struct{}), done: make(chan struct{}),
publishers: make(map[string]*serverClient),
done: make(chan struct{}),
} }
l.log("opened on :%d", p.args.rtspPort) l.log("opened on :%d", p.args.rtspPort)
@ -48,21 +41,7 @@ func (l *serverTcpListener) run() {
break break
} }
newServerClient(l.p, nconn) l.p.events <- programEventClientNew{nconn}
}
// close clients
var doneChans []chan struct{}
func() {
l.mutex.Lock()
defer l.mutex.Unlock()
for c := range l.clients {
c.close()
doneChans = append(doneChans, c.done)
}
}()
for _, c := range doneChans {
<-c
} }
close(l.done) close(l.done)
@ -72,37 +51,3 @@ func (l *serverTcpListener) close() {
l.nconn.Close() l.nconn.Close()
<-l.done <-l.done
} }
func (l *serverTcpListener) forwardTrack(path string, id int, flow trackFlow, frame []byte) {
for c := range l.clients {
if c.path == path && c.state == _CLIENT_STATE_PLAY {
if c.streamProtocol == _STREAM_PROTOCOL_UDP {
if flow == _TRACK_FLOW_RTP {
l.p.udplRtp.write <- &udpWrite{
addr: &net.UDPAddr{
IP: c.ip(),
Zone: c.zone(),
Port: c.streamTracks[id].rtpPort,
},
buf: frame,
}
} else {
l.p.udplRtcp.write <- &udpWrite{
addr: &net.UDPAddr{
IP: c.ip(),
Zone: c.zone(),
Port: c.streamTracks[id].rtcpPort,
},
buf: frame,
}
}
} else {
c.write <- &gortsplib.InterleavedFrame{
Channel: trackToInterleavedChannel(id, flow),
Content: frame,
}
}
}
}
}

60
server-udpl.go

@ -12,14 +12,15 @@ type udpWrite struct {
} }
type serverUdpListener struct { type serverUdpListener struct {
p *program p *program
nconn *net.UDPConn nconn *net.UDPConn
flow trackFlow trackFlowType trackFlowType
write chan *udpWrite write chan *udpWrite
done chan struct{} done chan struct{}
} }
func newServerUdpListener(p *program, port int, flow trackFlow) (*serverUdpListener, error) { func newServerUdpListener(p *program, port int, trackFlowType trackFlowType) (*serverUdpListener, error) {
nconn, err := net.ListenUDP("udp", &net.UDPAddr{ nconn, err := net.ListenUDP("udp", &net.UDPAddr{
Port: port, Port: port,
}) })
@ -28,11 +29,11 @@ func newServerUdpListener(p *program, port int, flow trackFlow) (*serverUdpListe
} }
l := &serverUdpListener{ l := &serverUdpListener{
p: p, p: p,
nconn: nconn, nconn: nconn,
flow: flow, trackFlowType: trackFlowType,
write: make(chan *udpWrite), write: make(chan *udpWrite),
done: make(chan struct{}), done: make(chan struct{}),
} }
l.log("opened on :%d", port) l.log("opened on :%d", port)
@ -41,7 +42,7 @@ func newServerUdpListener(p *program, port int, flow trackFlow) (*serverUdpListe
func (l *serverUdpListener) log(format string, args ...interface{}) { func (l *serverUdpListener) log(format string, args ...interface{}) {
var label string var label string
if l.flow == _TRACK_FLOW_RTP { if l.trackFlowType == _TRACK_FLOW_RTP {
label = "RTP" label = "RTP"
} else { } else {
label = "RTCP" label = "RTCP"
@ -67,40 +68,11 @@ func (l *serverUdpListener) run() {
break break
} }
func() { l.p.events <- programEventFrameUdp{
l.p.tcpl.mutex.Lock() l.trackFlowType,
defer l.p.tcpl.mutex.Unlock() addr,
buf[:n],
// find publisher and track id from ip and port }
pub, trackId := func() (*serverClient, int) {
for _, pub := range l.p.tcpl.publishers {
if pub.streamProtocol != _STREAM_PROTOCOL_UDP ||
pub.state != _CLIENT_STATE_RECORD ||
!pub.ip().Equal(addr.IP) {
continue
}
for i, t := range pub.streamTracks {
if l.flow == _TRACK_FLOW_RTP {
if t.rtpPort == addr.Port {
return pub, i
}
} else {
if t.rtcpPort == addr.Port {
return pub, i
}
}
}
}
return nil, -1
}()
if pub == nil {
return
}
pub.udpLastFrameTime = time.Now()
l.p.tcpl.forwardTrack(pub.path, trackId, l.flow, buf[:n])
}()
} }
close(l.write) close(l.write)

Loading…
Cancel
Save