diff --git a/internal/clientman/clientman.go b/internal/clientman/clientman.go index 5573e797..8ab87dde 100644 --- a/internal/clientman/clientman.go +++ b/internal/clientman/clientman.go @@ -1,7 +1,6 @@ package clientman import ( - "net" "sync" "time" @@ -9,10 +8,8 @@ import ( "github.com/aler9/rtsp-simple-server/internal/client" "github.com/aler9/rtsp-simple-server/internal/clienthls" - "github.com/aler9/rtsp-simple-server/internal/clientrtmp" "github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/serverhls" - "github.com/aler9/rtsp-simple-server/internal/serverrtmp" "github.com/aler9/rtsp-simple-server/internal/stats" ) @@ -41,7 +38,6 @@ type ClientManager struct { protocols map[base.StreamProtocol]struct{} stats *stats.Stats pathMan PathManager - serverRTMP *serverrtmp.Server serverHLS *serverhls.Server parent Parent @@ -70,7 +66,6 @@ func New( protocols map[base.StreamProtocol]struct{}, stats *stats.Stats, pathMan PathManager, - serverRTMP *serverrtmp.Server, serverHLS *serverhls.Server, parent Parent) *ClientManager { @@ -86,7 +81,6 @@ func New( protocols: protocols, stats: stats, pathMan: pathMan, - serverRTMP: serverRTMP, serverHLS: serverHLS, parent: parent, clients: make(map[client.Client]struct{}), @@ -115,13 +109,6 @@ func (cm *ClientManager) Log(level logger.Level, format string, args ...interfac func (cm *ClientManager) run() { defer close(cm.done) - rtmpAccept := func() chan net.Conn { - if cm.serverRTMP != nil { - return cm.serverRTMP.Accept() - } - return make(chan net.Conn) - }() - hlsRequest := func() chan serverhls.Request { if cm.serverHLS != nil { return cm.serverHLS.Request() @@ -132,21 +119,6 @@ func (cm *ClientManager) run() { outer: for { select { - case nconn := <-rtmpAccept: - c := clientrtmp.New( - cm.rtspAddress, - cm.readTimeout, - cm.writeTimeout, - cm.readBufferCount, - cm.runOnConnect, - cm.runOnConnectRestart, - &cm.wg, - cm.stats, - nconn, - cm.pathMan, - cm) - cm.clients[c] = struct{}{} - case req := <-hlsRequest: c, ok := cm.clientsByHLSPath[req.Path] if !ok { diff --git a/internal/clientrtmp/client.go b/internal/clientrtmp/client.go index 781a377a..5096173e 100644 --- a/internal/clientrtmp/client.go +++ b/internal/clientrtmp/client.go @@ -75,7 +75,7 @@ type PathMan interface { // Parent is implemented by clientman.ClientMan. type Parent interface { Log(logger.Level, string, ...interface{}) - OnClientClose(client.Client) + OnClientClose(*Client) } // Client is a RTMP client. diff --git a/internal/serverrtmp/server.go b/internal/serverrtmp/server.go index ebeb3e5b..a23bdceb 100644 --- a/internal/serverrtmp/server.go +++ b/internal/serverrtmp/server.go @@ -2,9 +2,13 @@ package serverrtmp import ( "net" - "sync/atomic" + "sync" + "time" + "github.com/aler9/rtsp-simple-server/internal/clientrtmp" "github.com/aler9/rtsp-simple-server/internal/logger" + "github.com/aler9/rtsp-simple-server/internal/pathman" + "github.com/aler9/rtsp-simple-server/internal/stats" ) // Parent is implemented by program. @@ -14,19 +18,39 @@ type Parent interface { // Server is a RTMP listener. type Server struct { - parent Parent - - l net.Listener - closed uint32 + readTimeout time.Duration + writeTimeout time.Duration + readBufferCount int + rtspAddress string + runOnConnect string + runOnConnectRestart bool + stats *stats.Stats + pathMan *pathman.PathManager + parent Parent + + l net.Listener + wg sync.WaitGroup + clients map[*clientrtmp.Client]struct{} + + // in + clientClose chan *clientrtmp.Client + terminate chan struct{} // out - accept chan net.Conn - done chan struct{} + done chan struct{} } // New allocates a Server. func New( address string, + readTimeout time.Duration, + writeTimeout time.Duration, + readBufferCount int, + rtspAddress string, + runOnConnect string, + runOnConnectRestart bool, + stats *stats.Stats, + pathMan *pathman.PathManager, parent Parent) (*Server, error) { l, err := net.Listen("tcp", address) @@ -35,55 +59,134 @@ func New( } s := &Server{ - parent: parent, - l: l, - accept: make(chan net.Conn), - done: make(chan struct{}), + readTimeout: readTimeout, + writeTimeout: writeTimeout, + readBufferCount: readBufferCount, + rtspAddress: rtspAddress, + runOnConnect: runOnConnect, + runOnConnectRestart: runOnConnectRestart, + stats: stats, + pathMan: pathMan, + parent: parent, + l: l, + clients: make(map[*clientrtmp.Client]struct{}), + clientClose: make(chan *clientrtmp.Client), + terminate: make(chan struct{}), + done: make(chan struct{}), } - s.log(logger.Info, "opened on %s", address) + s.Log(logger.Info, "listener opened on %s", address) go s.run() return s, nil } -func (s *Server) log(level logger.Level, format string, args ...interface{}) { - s.parent.Log(level, "[RTMP listener] "+format, append([]interface{}{}, args...)...) +// Log is the main logging function. +func (s *Server) Log(level logger.Level, format string, args ...interface{}) { + s.parent.Log(level, "[RTMP] "+format, append([]interface{}{}, args...)...) } // Close closes a Server. func (s *Server) Close() { - go func() { - for co := range s.accept { - co.Close() - } - }() - atomic.StoreUint32(&s.closed, 1) - s.l.Close() + close(s.terminate) <-s.done } func (s *Server) run() { defer close(s.done) + s.wg.Add(1) + clientNew := make(chan net.Conn) + acceptErr := make(chan error) + go func() { + defer s.wg.Done() + acceptErr <- func() error { + for { + conn, err := s.l.Accept() + if err != nil { + return err + } + + clientNew <- conn + } + }() + }() + +outer: for { - nconn, err := s.l.Accept() - if err != nil { - if atomic.LoadUint32(&s.closed) == 1 { - break + select { + case err := <-acceptErr: + s.Log(logger.Warn, "ERR: %s", err) + break outer + + case nconn := <-clientNew: + c := clientrtmp.New( + s.rtspAddress, + s.readTimeout, + s.writeTimeout, + s.readBufferCount, + s.runOnConnect, + s.runOnConnectRestart, + &s.wg, + s.stats, + nconn, + s.pathMan, + s) + s.clients[c] = struct{}{} + + case c := <-s.clientClose: + if _, ok := s.clients[c]; !ok { + continue } - s.log(logger.Warn, "ERR: %s", err) - continue + s.doClientClose(c) + + case <-s.terminate: + break outer } + } - s.accept <- nconn + go func() { + for { + select { + case _, ok := <-acceptErr: + if !ok { + return + } + + case conn, ok := <-clientNew: + if !ok { + return + } + conn.Close() + + case _, ok := <-s.clientClose: + if !ok { + return + } + } + } + }() + + s.l.Close() + + for c := range s.clients { + s.doClientClose(c) } - close(s.accept) + s.wg.Wait() + + close(acceptErr) + close(clientNew) + close(s.clientClose) +} + +func (s *Server) doClientClose(c *clientrtmp.Client) { + delete(s.clients, c) + c.Close() } -// Accept returns a channel to accept incoming connections. -func (s *Server) Accept() chan net.Conn { - return s.accept +// OnClientClose is called by a client. +func (s *Server) OnClientClose(c *clientrtmp.Client) { + s.clientClose <- c } diff --git a/main.go b/main.go index 94ec17f4..3c642ebd 100644 --- a/main.go +++ b/main.go @@ -191,17 +191,6 @@ func (p *program) createResources(initial bool) error { } } - if !p.conf.RTMPDisable { - if p.serverRTMP == nil { - p.serverRTMP, err = serverrtmp.New( - p.conf.RTMPAddress, - p) - if err != nil { - return err - } - } - } - if !p.conf.HLSDisable { if p.serverHLS == nil { p.serverHLS, err = serverhls.New( @@ -239,7 +228,6 @@ func (p *program) createResources(initial bool) error { p.conf.ProtocolsParsed, p.stats, p.pathMan, - p.serverRTMP, p.serverHLS, p) } @@ -303,6 +291,25 @@ func (p *program) createResources(initial bool) error { } } + if !p.conf.RTMPDisable { + if p.serverRTMP == nil { + p.serverRTMP, err = serverrtmp.New( + p.conf.RTMPAddress, + p.conf.ReadTimeout, + p.conf.WriteTimeout, + p.conf.ReadBufferCount, + p.conf.RTSPAddress, + p.conf.RunOnConnect, + p.conf.RunOnConnectRestart, + p.stats, + p.pathMan, + p) + if err != nil { + return err + } + } + } + return nil } @@ -335,15 +342,6 @@ func (p *program) closeResources(newConf *conf.Conf) { closePPROF = true } - closeServerRTMP := false - if newConf == nil || - newConf.RTMPDisable != p.conf.RTMPDisable || - newConf.RTMPAddress != p.conf.RTMPAddress || - newConf.ReadTimeout != p.conf.ReadTimeout || - closeStats { - closeServerRTMP = true - } - closeServerHLS := false if newConf == nil || newConf.HLSDisable != p.conf.HLSDisable || @@ -368,7 +366,6 @@ func (p *program) closeResources(newConf *conf.Conf) { closeClientMan := false if newConf == nil || - closeServerRTMP || closeServerHLS || closePathMan || newConf.HLSSegmentCount != p.conf.HLSSegmentCount || @@ -423,6 +420,21 @@ func (p *program) closeResources(newConf *conf.Conf) { closeServerTLS = true } + closeServerRTMP := false + if newConf == nil || + newConf.RTMPDisable != p.conf.RTMPDisable || + newConf.RTMPAddress != p.conf.RTMPAddress || + newConf.ReadTimeout != p.conf.ReadTimeout || + newConf.WriteTimeout != p.conf.WriteTimeout || + newConf.ReadBufferCount != p.conf.ReadBufferCount || + newConf.RTSPAddress != p.conf.RTSPAddress || + newConf.RunOnConnect != p.conf.RunOnConnect || + newConf.RunOnConnectRestart != p.conf.RunOnConnectRestart || + closeStats || + closePathMan { + closeServerRTMP = true + } + if closeServerTLS && p.serverRTSPTLS != nil { p.serverRTSPTLS.Close() p.serverRTSPTLS = nil