diff --git a/main.go b/main.go index 37714b88..43db5770 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "os" "regexp" "strings" - "sync" "time" "github.com/aler9/gortsplib" @@ -60,14 +59,11 @@ type args struct { } type program struct { - args args - protocols map[streamProtocol]struct{} - mutex sync.RWMutex - rtspl *serverTcpListener - rtpl *serverUdpListener - rtcpl *serverUdpListener - clients map[*serverClient]struct{} - publishers map[string]*serverClient + args args + protocols map[streamProtocol]struct{} + rtspl *serverTcpListener + rtpl *serverUdpListener + rtcpl *serverUdpListener } func newProgram(args args) (*program, error) { @@ -132,10 +128,8 @@ func newProgram(args args) (*program, error) { log.Printf("rtsp-simple-server %s", Version) p := &program{ - args: args, - protocols: protocols, - clients: make(map[*serverClient]struct{}), - publishers: make(map[string]*serverClient), + args: args, + protocols: protocols, } var err error @@ -163,7 +157,7 @@ func newProgram(args args) (*program, error) { } func (p *program) forwardTrack(path string, id int, flow trackFlow, frame []byte) { - for c := range p.clients { + for c := range p.rtspl.clients { if c.path == path && c.state == _CLIENT_STATE_PLAY { if c.streamProtocol == _STREAM_PROTOCOL_UDP { if flow == _TRACK_FLOW_RTP { diff --git a/server-client.go b/server-client.go index df24645a..36ad8d22 100644 --- a/server-client.go +++ b/server-client.go @@ -128,30 +128,30 @@ func newServerClient(p *program, nconn net.Conn) *serverClient { write: make(chan *gortsplib.InterleavedFrame), } - c.p.mutex.Lock() - c.p.clients[c] = struct{}{} - c.p.mutex.Unlock() + c.p.rtspl.mutex.Lock() + c.p.rtspl.clients[c] = struct{}{} + c.p.rtspl.mutex.Unlock() return c } func (c *serverClient) close() error { // already deleted - if _, ok := c.p.clients[c]; !ok { + if _, ok := c.p.rtspl.clients[c]; !ok { return nil } - delete(c.p.clients, c) + delete(c.p.rtspl.clients, c) c.conn.NetConn().Close() close(c.write) if c.path != "" { - if pub, ok := c.p.publishers[c.path]; ok && pub == c { - delete(c.p.publishers, c.path) + if pub, ok := c.p.rtspl.publishers[c.path]; ok && pub == c { + delete(c.p.rtspl.publishers, c.path) // if the publisher has disconnected // close all other connections that share the same path - for oc := range c.p.clients { + for oc := range c.p.rtspl.clients { if oc.path == c.path { oc.close() } @@ -189,8 +189,8 @@ func (c *serverClient) run() { defer c.log("disconnected") defer func() { - c.p.mutex.Lock() - defer c.p.mutex.Unlock() + c.p.rtspl.mutex.Lock() + defer c.p.rtspl.mutex.Unlock() c.close() }() @@ -288,10 +288,10 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { } sdp, err := func() ([]byte, error) { - c.p.mutex.RLock() - defer c.p.mutex.RUnlock() + c.p.rtspl.mutex.RLock() + defer c.p.rtspl.mutex.RUnlock() - pub, ok := c.p.publishers[path] + pub, ok := c.p.rtspl.publishers[path] if !ok { return nil, fmt.Errorf("no one is streaming on path '%s'", path) } @@ -369,16 +369,16 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { sdpParsed, req.Content = sdpFilter(sdpParsed, req.Content) err = func() error { - c.p.mutex.Lock() - defer c.p.mutex.Unlock() + c.p.rtspl.mutex.Lock() + defer c.p.rtspl.mutex.Unlock() - _, ok := c.p.publishers[path] + _, ok := c.p.rtspl.publishers[path] if ok { return fmt.Errorf("another client is already publishing on path '%s'", path) } c.path = path - c.p.publishers[path] = c + c.p.rtspl.publishers[path] = c c.streamSdpText = req.Content c.streamSdpParsed = sdpParsed c.state = _CLIENT_STATE_ANNOUNCE @@ -443,10 +443,10 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { } err := func() error { - c.p.mutex.Lock() - defer c.p.mutex.Unlock() + c.p.rtspl.mutex.Lock() + defer c.p.rtspl.mutex.Unlock() - pub, ok := c.p.publishers[path] + pub, ok := c.p.rtspl.publishers[path] if !ok { return fmt.Errorf("no one is streaming on path '%s'", path) } @@ -502,10 +502,10 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { } err := func() error { - c.p.mutex.Lock() - defer c.p.mutex.Unlock() + c.p.rtspl.mutex.Lock() + defer c.p.rtspl.mutex.Unlock() - pub, ok := c.p.publishers[path] + pub, ok := c.p.rtspl.publishers[path] if !ok { return fmt.Errorf("no one is streaming on path '%s'", path) } @@ -590,8 +590,8 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { } err := func() error { - c.p.mutex.Lock() - defer c.p.mutex.Unlock() + c.p.rtspl.mutex.Lock() + defer c.p.rtspl.mutex.Unlock() if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP { return fmt.Errorf("client wants to publish tracks with different protocols") @@ -639,8 +639,8 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { var interleaved string err := func() error { - c.p.mutex.Lock() - defer c.p.mutex.Unlock() + c.p.rtspl.mutex.Lock() + defer c.p.rtspl.mutex.Unlock() if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP { return fmt.Errorf("client wants to publish tracks with different protocols") @@ -710,10 +710,10 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { } err := func() error { - c.p.mutex.Lock() - defer c.p.mutex.Unlock() + c.p.rtspl.mutex.Lock() + defer c.p.rtspl.mutex.Unlock() - pub, ok := c.p.publishers[c.path] + pub, ok := c.p.rtspl.publishers[c.path] if !ok { return fmt.Errorf("no one is streaming on path '%s'", c.path) } @@ -747,9 +747,9 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return "tracks" }(), c.streamProtocol) - c.p.mutex.Lock() + c.p.rtspl.mutex.Lock() c.state = _CLIENT_STATE_PLAY - c.p.mutex.Unlock() + c.p.rtspl.mutex.Unlock() // when protocol is TCP, the RTSP connection becomes a RTP connection if c.streamProtocol == _STREAM_PROTOCOL_TCP { @@ -788,9 +788,9 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { c.log("paused") - c.p.mutex.Lock() + c.p.rtspl.mutex.Lock() c.state = _CLIENT_STATE_PRE_PLAY - c.p.mutex.Unlock() + c.p.rtspl.mutex.Unlock() c.conn.WriteResponse(&gortsplib.Response{ StatusCode: gortsplib.StatusOK, @@ -813,8 +813,8 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { } err := func() error { - c.p.mutex.Lock() - defer c.p.mutex.Unlock() + c.p.rtspl.mutex.Lock() + defer c.p.rtspl.mutex.Unlock() if len(c.streamTracks) != len(c.streamSdpParsed.Medias) { return fmt.Errorf("not all tracks have been setup") @@ -835,9 +835,9 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { }, }) - c.p.mutex.Lock() + c.p.rtspl.mutex.Lock() c.state = _CLIENT_STATE_RECORD - c.p.mutex.Unlock() + c.p.rtspl.mutex.Unlock() c.log("is publishing on path '%s', %d %s via %s", c.path, len(c.streamTracks), func() string { if len(c.streamTracks) == 1 { @@ -863,9 +863,9 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return false } - c.p.mutex.RLock() + c.p.rtspl.mutex.RLock() c.p.forwardTrack(c.path, trackId, trackFlow, frame.Content) - c.p.mutex.RUnlock() + c.p.rtspl.mutex.RUnlock() } } diff --git a/server-tcpl.go b/server-tcpl.go index bede6ded..9a6eca2f 100644 --- a/server-tcpl.go +++ b/server-tcpl.go @@ -3,11 +3,15 @@ package main import ( "log" "net" + "sync" ) type serverTcpListener struct { - p *program - netl *net.TCPListener + p *program + netl *net.TCPListener + mutex sync.RWMutex + clients map[*serverClient]struct{} + publishers map[string]*serverClient } func newServerTcpListener(p *program) (*serverTcpListener, error) { @@ -19,8 +23,10 @@ func newServerTcpListener(p *program) (*serverTcpListener, error) { } s := &serverTcpListener{ - p: p, - netl: netl, + p: p, + netl: netl, + clients: make(map[*serverClient]struct{}), + publishers: make(map[string]*serverClient), } s.log("opened on :%d", p.args.rtspPort) diff --git a/server-udpl.go b/server-udpl.go index 1f1d470d..6a5da3f4 100644 --- a/server-udpl.go +++ b/server-udpl.go @@ -68,12 +68,12 @@ func (l *serverUdpListener) run() { } func() { - l.p.mutex.RLock() - defer l.p.mutex.RUnlock() + l.p.rtspl.mutex.RLock() + defer l.p.rtspl.mutex.RUnlock() // find path and track id from ip and port path, trackId := func() (string, int) { - for _, pub := range l.p.publishers { + for _, pub := range l.p.rtspl.publishers { for i, t := range pub.streamTracks { if !pub.ip().Equal(addr.IP) { continue