From 3d98bede4a27767e51bd9acd21107663244ee588 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sat, 24 Oct 2020 19:55:47 +0200 Subject: [PATCH] implement configuration dynamic update / hot reloading (#64) --- README.md | 3 + client/client.go | 95 ++++++------- clientman/clientman.go | 19 +-- conf/conf.go | 229 ++++++------------------------- conf/pathconf.go | 175 +++++++++++++++++++++++ confwatcher/confwatcher.go | 74 ++++++++++ go.mod | 1 + go.sum | 4 + main.go | 274 +++++++++++++++++++++++++++++-------- main_test.go | 67 +++++++++ path/path.go | 244 +++++++++++++++++++++------------ pathman/pathman.go | 132 ++++++++++++------ serverudp/server.go | 34 ++--- 13 files changed, 905 insertions(+), 446 deletions(-) create mode 100644 conf/pathconf.go create mode 100644 confwatcher/confwatcher.go diff --git a/README.md b/README.md index cb58368b..88c4c004 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,9 @@ Parameters in maps can be overridden by using underscores, in the following way: RTSP_PATHS_TEST_SOURCE=rtsp://myurl ./rtsp-simple-server ``` +The configuration file can be changed dinamically when the server is running (hot reloading): changes are detected and written over the previous configuration, clients are disconnected only if necessary. + + ### RTSP proxy mode `rtsp-simple-server` is also an RTSP proxy, that is usually deployed in one of these scenarios: diff --git a/client/client.go b/client/client.go index 1836311f..f68350b0 100644 --- a/client/client.go +++ b/client/client.go @@ -448,6 +448,7 @@ func (c *Client) handleRequest(req *base.Request) error { return errRunTerminate } } + c.path = path c.state = statePreRecord @@ -873,7 +874,9 @@ func (c *Client) runInitial() bool { func (c *Client) runWaitingDescribe() bool { select { case res := <-c.describeData: + c.path.OnClientRemove(c) c.path = nil + c.state = stateInitial if res.err != nil { @@ -899,6 +902,7 @@ func (c *Client) runWaitingDescribe() bool { }() c.path.OnClientRemove(c) + c.path = nil c.conn.Close() return false @@ -935,6 +939,9 @@ func (c *Client) runPlay() bool { onReadCmd.Close() } + c.path.OnClientRemove(c) + c.path = nil + return false } @@ -963,15 +970,11 @@ func (c *Client) runPlayUDP() { c.log("ERR: %s", err) } - c.path.OnClientRemove(c) - c.parent.OnClientClose(c) <-c.terminate return case <-c.terminate: - c.path.OnClientRemove(c) - c.conn.Close() <-readDone return @@ -1024,8 +1027,6 @@ func (c *Client) runPlayTCP() { } }() - c.path.OnClientRemove(c) - c.parent.OnClientClose(c) <-c.terminate return @@ -1040,7 +1041,10 @@ func (c *Client) runPlayTCP() { } }() - c.path.OnClientRemove(c) + go func() { + for range c.tcpFrame { + } + }() c.conn.Close() <-readDone @@ -1050,6 +1054,15 @@ func (c *Client) runPlayTCP() { } func (c *Client) runRecord() bool { + c.path.OnClientRecord(c) + + c.log("is publishing to path '%s', %d %s with %s", c.path.Name(), len(c.streamTracks), func() string { + if len(c.streamTracks) == 1 { + return "track" + } + return "tracks" + }(), c.streamProtocol) + c.rtcpReceivers = make([]*rtcpreceiver.RtcpReceiver, len(c.streamTracks)) for trackId := range c.streamTracks { c.rtcpReceivers[trackId] = rtcpreceiver.New() @@ -1061,22 +1074,30 @@ func (c *Client) runRecord() bool { v := time.Now().Unix() c.udpLastFrameTimes[trackId] = &v } - } - - c.path.OnClientRecord(c) - c.log("is publishing to path '%s', %d %s with %s", c.path.Name(), len(c.streamTracks), func() string { - if len(c.streamTracks) == 1 { - return "track" - } - return "tracks" - }(), c.streamProtocol) - - if c.streamProtocol == gortsplib.StreamProtocolUDP { for trackId, track := range c.streamTracks { c.serverUdpRtp.AddPublisher(c.ip(), track.rtpPort, c, trackId) c.serverUdpRtcp.AddPublisher(c.ip(), track.rtcpPort, c, trackId) } + + // open the firewall by sending packets to the counterpart + for _, track := range c.streamTracks { + c.serverUdpRtp.Write( + []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + &net.UDPAddr{ + IP: c.ip(), + Zone: c.zone(), + Port: track.rtpPort, + }) + + c.serverUdpRtcp.Write( + []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, + &net.UDPAddr{ + IP: c.ip(), + Zone: c.zone(), + Port: track.rtcpPort, + }) + } } var onPublishCmd *externalcmd.ExternalCmd @@ -1094,6 +1115,10 @@ func (c *Client) runRecord() bool { c.runRecordTCP() } + if onPublishCmd != nil { + onPublishCmd.Close() + } + if c.streamProtocol == gortsplib.StreamProtocolUDP { for _, track := range c.streamTracks { c.serverUdpRtp.RemovePublisher(c.ip(), track.rtpPort, c) @@ -1101,33 +1126,13 @@ func (c *Client) runRecord() bool { } } - if onPublishCmd != nil { - onPublishCmd.Close() - } + c.path.OnClientRemove(c) + c.path = nil return false } func (c *Client) runRecordUDP() { - // open the firewall by sending packets to the counterpart - for _, track := range c.streamTracks { - c.serverUdpRtp.Write( - []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - &net.UDPAddr{ - IP: c.ip(), - Zone: c.zone(), - Port: track.rtpPort, - }) - - c.serverUdpRtcp.Write( - []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, - &net.UDPAddr{ - IP: c.ip(), - Zone: c.zone(), - Port: track.rtcpPort, - }) - } - readDone := make(chan error) go func() { for { @@ -1159,8 +1164,6 @@ func (c *Client) runRecordUDP() { c.log("ERR: %s", err) } - c.path.OnClientRemove(c) - c.parent.OnClientClose(c) <-c.terminate return @@ -1176,8 +1179,6 @@ func (c *Client) runRecordUDP() { c.conn.Close() <-readDone - c.path.OnClientRemove(c) - c.parent.OnClientClose(c) <-c.terminate return @@ -1195,8 +1196,6 @@ func (c *Client) runRecordUDP() { } case <-c.terminate: - c.path.OnClientRemove(c) - c.conn.Close() <-readDone return @@ -1252,8 +1251,6 @@ func (c *Client) runRecordTCP() { c.log("ERR: %s", err) } - c.path.OnClientRemove(c) - c.parent.OnClientClose(c) <-c.terminate return @@ -1271,8 +1268,6 @@ func (c *Client) runRecordTCP() { } }() - c.path.OnClientRemove(c) - c.conn.Close() <-readDone return diff --git a/clientman/clientman.go b/clientman/clientman.go index 6f398a1a..014cb0c7 100644 --- a/clientman/clientman.go +++ b/clientman/clientman.go @@ -20,13 +20,13 @@ type Parent interface { } type ClientManager struct { - stats *stats.Stats - serverUdpRtp *serverudp.Server - serverUdpRtcp *serverudp.Server readTimeout time.Duration writeTimeout time.Duration runOnConnect string protocols map[headers.StreamProtocol]struct{} + stats *stats.Stats + serverUdpRtp *serverudp.Server + serverUdpRtcp *serverudp.Server pathMan *pathman.PathManager serverTcp *servertcp.Server parent Parent @@ -42,25 +42,26 @@ type ClientManager struct { done chan struct{} } -func New(stats *stats.Stats, - serverUdpRtp *serverudp.Server, - serverUdpRtcp *serverudp.Server, +func New( readTimeout time.Duration, writeTimeout time.Duration, runOnConnect string, protocols map[headers.StreamProtocol]struct{}, + stats *stats.Stats, + serverUdpRtp *serverudp.Server, + serverUdpRtcp *serverudp.Server, pathMan *pathman.PathManager, serverTcp *servertcp.Server, parent Parent) *ClientManager { cm := &ClientManager{ - stats: stats, - serverUdpRtp: serverUdpRtp, - serverUdpRtcp: serverUdpRtcp, readTimeout: readTimeout, writeTimeout: writeTimeout, runOnConnect: runOnConnect, protocols: protocols, + stats: stats, + serverUdpRtp: serverUdpRtp, + serverUdpRtcp: serverUdpRtcp, pathMan: pathMan, serverTcp: serverTcp, parent: parent, diff --git a/conf/conf.go b/conf/conf.go index 46e470f4..a67e1c30 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -2,10 +2,7 @@ package conf import ( "fmt" - "net/url" "os" - "regexp" - "strings" "time" "github.com/aler9/gortsplib" @@ -16,29 +13,9 @@ import ( "github.com/aler9/rtsp-simple-server/loghandler" ) -type PathConf struct { - Regexp *regexp.Regexp `yaml:"-"` - Source string `yaml:"source"` - SourceProtocol string `yaml:"sourceProtocol"` - SourceProtocolParsed gortsplib.StreamProtocol `yaml:"-"` - SourceOnDemand bool `yaml:"sourceOnDemand"` - RunOnInit string `yaml:"runOnInit"` - RunOnDemand string `yaml:"runOnDemand"` - RunOnPublish string `yaml:"runOnPublish"` - RunOnRead string `yaml:"runOnRead"` - PublishUser string `yaml:"publishUser"` - PublishPass string `yaml:"publishPass"` - PublishIps []string `yaml:"publishIps"` - PublishIpsParsed []interface{} `yaml:"-"` - ReadUser string `yaml:"readUser"` - ReadPass string `yaml:"readPass"` - ReadIps []string `yaml:"readIps"` - ReadIpsParsed []interface{} `yaml:"-"` -} - type Conf struct { Protocols []string `yaml:"protocols"` - ProtocolsParsed map[gortsplib.StreamProtocol]struct{} `yaml:"-"` + ProtocolsParsed map[gortsplib.StreamProtocol]struct{} `yaml:"-" json:"-"` RtspPort int `yaml:"rtspPort"` RtpPort int `yaml:"rtpPort"` RtcpPort int `yaml:"rtcpPort"` @@ -46,50 +23,16 @@ type Conf struct { ReadTimeout time.Duration `yaml:"readTimeout"` WriteTimeout time.Duration `yaml:"writeTimeout"` AuthMethods []string `yaml:"authMethods"` - AuthMethodsParsed []headers.AuthMethod `yaml:"-"` + AuthMethodsParsed []headers.AuthMethod `yaml:"-" json:"-"` Metrics bool `yaml:"metrics"` Pprof bool `yaml:"pprof"` LogDestinations []string `yaml:"logDestinations"` - LogDestinationsParsed map[loghandler.Destination]struct{} `yaml:"-"` + LogDestinationsParsed map[loghandler.Destination]struct{} `yaml:"-" json:"-"` LogFile string `yaml:"logFile"` Paths map[string]*PathConf `yaml:"paths"` } -func Load(fpath string) (*Conf, error) { - conf := &Conf{} - - // read from file - err := func() error { - // rtsp-simple-server.yml is optional - if fpath == "rtsp-simple-server.yml" { - if _, err := os.Stat(fpath); err != nil { - return nil - } - } - - f, err := os.Open(fpath) - if err != nil { - return err - } - defer f.Close() - - err = yaml.NewDecoder(f).Decode(conf) - if err != nil { - return err - } - - return nil - }() - if err != nil { - return nil, err - } - - // read from environment - err = confenv.Load("RTSP", conf) - if err != nil { - return nil, err - } - +func (conf *Conf) fillAndCheck() error { if len(conf.Protocols) == 0 { conf.Protocols = []string{"udp", "tcp"} } @@ -103,11 +46,11 @@ func Load(fpath string) (*Conf, error) { conf.ProtocolsParsed[gortsplib.StreamProtocolTCP] = struct{}{} default: - return nil, fmt.Errorf("unsupported protocol: %s", proto) + return fmt.Errorf("unsupported protocol: %s", proto) } } if len(conf.ProtocolsParsed) == 0 { - return nil, fmt.Errorf("no protocols provided") + return fmt.Errorf("no protocols provided") } if conf.RtspPort == 0 { @@ -117,13 +60,13 @@ func Load(fpath string) (*Conf, error) { conf.RtpPort = 8000 } if (conf.RtpPort % 2) != 0 { - return nil, fmt.Errorf("rtp port must be even") + return fmt.Errorf("rtp port must be even") } if conf.RtcpPort == 0 { conf.RtcpPort = 8001 } if conf.RtcpPort != (conf.RtpPort + 1) { - return nil, fmt.Errorf("rtcp and rtp ports must be consecutive") + return fmt.Errorf("rtcp and rtp ports must be consecutive") } if conf.ReadTimeout == 0 { @@ -145,7 +88,7 @@ func Load(fpath string) (*Conf, error) { conf.AuthMethodsParsed = append(conf.AuthMethodsParsed, headers.AuthDigest) default: - return nil, fmt.Errorf("unsupported authentication method: %s", method) + return fmt.Errorf("unsupported authentication method: %s", method) } } @@ -165,7 +108,7 @@ func Load(fpath string) (*Conf, error) { conf.LogDestinationsParsed[loghandler.DestinationSyslog] = struct{}{} default: - return nil, fmt.Errorf("unsupported log destination: %s", dest) + return fmt.Errorf("unsupported log destination: %s", dest) } } if conf.LogFile == "" { @@ -190,137 +133,53 @@ func Load(fpath string) (*Conf, error) { pconf = conf.Paths[name] } - if name == "" { - return nil, fmt.Errorf("path name can not be empty") - } - - // normal path - if name[0] != '~' { - err := CheckPathName(name) - if err != nil { - return nil, fmt.Errorf("invalid path name: %s (%s)", err, name) - } - - // regular expression path - } else { - pathRegexp, err := regexp.Compile(name[1:]) - if err != nil { - return nil, fmt.Errorf("invalid regular expression: %s", name[1:]) - } - pconf.Regexp = pathRegexp - } - - if pconf.Source == "" { - pconf.Source = "record" + err := pconf.fillAndCheck(name) + if err != nil { + return err } + } - if strings.HasPrefix(pconf.Source, "rtsp://") { - if pconf.Regexp != nil { - return nil, fmt.Errorf("a path with a regular expression (or path 'all') cannot have a RTSP source; use another path") - } - - u, err := url.Parse(pconf.Source) - if err != nil { - return nil, fmt.Errorf("'%s' is not a valid url", pconf.Source) - } - if u.User != nil { - pass, _ := u.User.Password() - user := u.User.Username() - if user != "" && pass == "" || - user == "" && pass != "" { - fmt.Errorf("username and password must be both provided") - } - } - - if pconf.SourceProtocol == "" { - pconf.SourceProtocol = "udp" - } - switch pconf.SourceProtocol { - case "udp": - pconf.SourceProtocolParsed = gortsplib.StreamProtocolUDP - - case "tcp": - pconf.SourceProtocolParsed = gortsplib.StreamProtocolTCP - - default: - return nil, fmt.Errorf("unsupported protocol '%s'", pconf.SourceProtocol) - } + return nil +} - } else if strings.HasPrefix(pconf.Source, "rtmp://") { - if pconf.Regexp != nil { - return nil, fmt.Errorf("a path with a regular expression (or path 'all') cannot have a RTMP source; use another path") - } +func Load(fpath string) (*Conf, error) { + conf := &Conf{} - u, err := url.Parse(pconf.Source) - if err != nil { - return nil, fmt.Errorf("'%s' is not a valid url", pconf.Source) - } - if u.User != nil { - pass, _ := u.User.Password() - user := u.User.Username() - if user != "" && pass == "" || - user == "" && pass != "" { - fmt.Errorf("username and password must be both provided") - } + // read from file + err := func() error { + // rtsp-simple-server.yml is optional + if fpath == "rtsp-simple-server.yml" { + if _, err := os.Stat(fpath); err != nil { + return nil } - - } else if pconf.Source == "record" { - - } else { - return nil, fmt.Errorf("unsupported source: '%s'", pconf.Source) } - if pconf.PublishUser != "" { - if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.PublishUser) { - return nil, fmt.Errorf("publish username must be alphanumeric") - } - } - if pconf.PublishPass != "" { - if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.PublishPass) { - return nil, fmt.Errorf("publish password must be alphanumeric") - } + f, err := os.Open(fpath) + if err != nil { + return err } + defer f.Close() - if len(pconf.PublishIps) > 0 { - pconf.PublishIpsParsed, err = parseIpCidrList(pconf.PublishIps) - if err != nil { - return nil, err - } - } else { - // the configuration file doesn't use nil dicts - avoid test fails by using nil - pconf.PublishIps = nil + err = yaml.NewDecoder(f).Decode(conf) + if err != nil { + return err } - if pconf.ReadUser != "" && pconf.ReadPass == "" || pconf.ReadUser == "" && pconf.ReadPass != "" { - return nil, fmt.Errorf("read username and password must be both filled") - } - if pconf.ReadUser != "" { - if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.ReadUser) { - return nil, fmt.Errorf("read username must be alphanumeric") - } - } - if pconf.ReadPass != "" { - if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.ReadPass) { - return nil, fmt.Errorf("read password must be alphanumeric") - } - } - if pconf.ReadUser != "" && pconf.ReadPass == "" || pconf.ReadUser == "" && pconf.ReadPass != "" { - return nil, fmt.Errorf("read username and password must be both filled") - } + return nil + }() + if err != nil { + return nil, err + } - if len(pconf.ReadIps) > 0 { - pconf.ReadIpsParsed, err = parseIpCidrList(pconf.ReadIps) - if err != nil { - return nil, err - } - } else { - // the configuration file doesn't use nil dicts - avoid test fails by using nil - pconf.ReadIps = nil - } + // read from environment + err = confenv.Load("RTSP", conf) + if err != nil { + return nil, err + } - if pconf.Regexp != nil && pconf.RunOnInit != "" { - return nil, fmt.Errorf("a path with a regular expression does not support option 'runOnInit'; use another path") - } + err = conf.fillAndCheck() + if err != nil { + return nil, err } return conf, nil diff --git a/conf/pathconf.go b/conf/pathconf.go new file mode 100644 index 00000000..d82ea9a9 --- /dev/null +++ b/conf/pathconf.go @@ -0,0 +1,175 @@ +package conf + +import ( + "encoding/json" + "fmt" + "net/url" + "regexp" + "strings" + + "github.com/aler9/gortsplib" +) + +type PathConf struct { + Regexp *regexp.Regexp `yaml:"-" json:"-"` + Source string `yaml:"source"` + SourceProtocol string `yaml:"sourceProtocol"` + SourceProtocolParsed gortsplib.StreamProtocol `yaml:"-" json:"-"` + SourceOnDemand bool `yaml:"sourceOnDemand"` + RunOnInit string `yaml:"runOnInit"` + RunOnDemand string `yaml:"runOnDemand"` + RunOnPublish string `yaml:"runOnPublish"` + RunOnRead string `yaml:"runOnRead"` + PublishUser string `yaml:"publishUser"` + PublishPass string `yaml:"publishPass"` + PublishIps []string `yaml:"publishIps"` + PublishIpsParsed []interface{} `yaml:"-" json:"-"` + ReadUser string `yaml:"readUser"` + ReadPass string `yaml:"readPass"` + ReadIps []string `yaml:"readIps"` + ReadIpsParsed []interface{} `yaml:"-" json:"-"` +} + +func (pconf *PathConf) fillAndCheck(name string) error { + if name == "" { + return fmt.Errorf("path name can not be empty") + } + + // normal path + if name[0] != '~' { + err := CheckPathName(name) + if err != nil { + return fmt.Errorf("invalid path name: %s (%s)", err, name) + } + + // regular expression path + } else { + pathRegexp, err := regexp.Compile(name[1:]) + if err != nil { + return fmt.Errorf("invalid regular expression: %s", name[1:]) + } + pconf.Regexp = pathRegexp + } + + if pconf.Source == "" { + pconf.Source = "record" + } + + if strings.HasPrefix(pconf.Source, "rtsp://") { + if pconf.Regexp != nil { + return fmt.Errorf("a path with a regular expression (or path 'all') cannot have a RTSP source; use another path") + } + + u, err := url.Parse(pconf.Source) + if err != nil { + return fmt.Errorf("'%s' is not a valid url", pconf.Source) + } + if u.User != nil { + pass, _ := u.User.Password() + user := u.User.Username() + if user != "" && pass == "" || + user == "" && pass != "" { + fmt.Errorf("username and password must be both provided") + } + } + + if pconf.SourceProtocol == "" { + pconf.SourceProtocol = "udp" + } + switch pconf.SourceProtocol { + case "udp": + pconf.SourceProtocolParsed = gortsplib.StreamProtocolUDP + + case "tcp": + pconf.SourceProtocolParsed = gortsplib.StreamProtocolTCP + + default: + return fmt.Errorf("unsupported protocol '%s'", pconf.SourceProtocol) + } + + } else if strings.HasPrefix(pconf.Source, "rtmp://") { + if pconf.Regexp != nil { + return fmt.Errorf("a path with a regular expression (or path 'all') cannot have a RTMP source; use another path") + } + + u, err := url.Parse(pconf.Source) + if err != nil { + return fmt.Errorf("'%s' is not a valid url", pconf.Source) + } + if u.User != nil { + pass, _ := u.User.Password() + user := u.User.Username() + if user != "" && pass == "" || + user == "" && pass != "" { + fmt.Errorf("username and password must be both provided") + } + } + + } else if pconf.Source == "record" { + + } else { + return fmt.Errorf("unsupported source: '%s'", pconf.Source) + } + + if pconf.PublishUser != "" { + if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.PublishUser) { + return fmt.Errorf("publish username must be alphanumeric") + } + } + if pconf.PublishPass != "" { + if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.PublishPass) { + return fmt.Errorf("publish password must be alphanumeric") + } + } + + if len(pconf.PublishIps) > 0 { + var err error + pconf.PublishIpsParsed, err = parseIpCidrList(pconf.PublishIps) + if err != nil { + return err + } + } else { + // the configuration file doesn't use nil dicts - avoid test fails by using nil + pconf.PublishIps = nil + } + + if pconf.ReadUser != "" && pconf.ReadPass == "" || pconf.ReadUser == "" && pconf.ReadPass != "" { + return fmt.Errorf("read username and password must be both filled") + } + if pconf.ReadUser != "" { + if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.ReadUser) { + return fmt.Errorf("read username must be alphanumeric") + } + } + if pconf.ReadPass != "" { + if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(pconf.ReadPass) { + return fmt.Errorf("read password must be alphanumeric") + } + } + if pconf.ReadUser != "" && pconf.ReadPass == "" || pconf.ReadUser == "" && pconf.ReadPass != "" { + return fmt.Errorf("read username and password must be both filled") + } + + if len(pconf.ReadIps) > 0 { + var err error + pconf.ReadIpsParsed, err = parseIpCidrList(pconf.ReadIps) + if err != nil { + return err + } + } else { + // the configuration file doesn't use nil dicts - avoid test fails by using nil + pconf.ReadIps = nil + } + + if pconf.Regexp != nil && pconf.RunOnInit != "" { + return fmt.Errorf("a path with a regular expression does not support option 'runOnInit'; use another path") + } + + return nil +} + +func (pconf *PathConf) Equal(other *PathConf) bool { + a, _ := json.Marshal(pconf) + b, _ := json.Marshal(pconf) + return string(a) == string(b) +} diff --git a/confwatcher/confwatcher.go b/confwatcher/confwatcher.go new file mode 100644 index 00000000..f9c2f189 --- /dev/null +++ b/confwatcher/confwatcher.go @@ -0,0 +1,74 @@ +package confwatcher + +import ( + "os" + "time" + + "github.com/fsnotify/fsnotify" +) + +type ConfWatcher struct { + inner *fsnotify.Watcher + + // out + signal chan struct{} + done chan struct{} +} + +func New(confPath string) (*ConfWatcher, error) { + inner, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + if _, err := os.Stat(confPath); err == nil { + err := inner.Add(confPath) + if err != nil { + inner.Close() + return nil, err + } + } + + w := &ConfWatcher{ + inner: inner, + signal: make(chan struct{}), + done: make(chan struct{}), + } + + go w.run() + return w, nil +} + +func (w *ConfWatcher) Close() { + go func() { + for range w.signal { + } + }() + w.inner.Close() + <-w.done +} + +func (w *ConfWatcher) run() { + defer close(w.done) + +outer: + for { + select { + case event := <-w.inner.Events: + if (event.Op & fsnotify.Write) == fsnotify.Write { + // wait some additional time to avoid EOF + time.Sleep(10 * time.Millisecond) + w.signal <- struct{}{} + } + + case <-w.inner.Errors: + break outer + } + } + + close(w.signal) +} + +func (w *ConfWatcher) Watch() chan struct{} { + return w.signal +} diff --git a/go.mod b/go.mod index b29e52a7..7e33b089 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect github.com/aler9/gortsplib v0.0.0-20201017143703-0b7201de6890 github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fsnotify/fsnotify v1.4.9 github.com/notedit/rtmp v0.0.2 github.com/pion/rtp v1.6.1 // indirect github.com/stretchr/testify v1.6.1 diff --git a/go.sum b/go.sum index 16a8b5dc..05410dd4 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/notedit/rtmp v0.0.2 h1:5+to4yezKATiJgnrcETu9LbV5G/QsWkOV9Ts2M/p33w= github.com/notedit/rtmp v0.0.2/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc= github.com/pion/randutil v0.0.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= @@ -30,6 +32,8 @@ github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9 h1:L2auWcuQIvxz9xSEqzESnV/QN/gNRXNApHi3fYwl2w0= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/main.go b/main.go index f3953979..77370415 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "os" + "reflect" "sync/atomic" "github.com/aler9/gortsplib" @@ -11,6 +12,7 @@ import ( "github.com/aler9/rtsp-simple-server/clientman" "github.com/aler9/rtsp-simple-server/conf" + "github.com/aler9/rtsp-simple-server/confwatcher" "github.com/aler9/rtsp-simple-server/loghandler" "github.com/aler9/rtsp-simple-server/metrics" "github.com/aler9/rtsp-simple-server/pathman" @@ -23,6 +25,7 @@ import ( var Version = "v0.0.0" type program struct { + confPath string conf *conf.Conf stats *stats.Stats logHandler *loghandler.LogHandler @@ -33,6 +36,7 @@ type program struct { serverTcp *servertcp.Server pathMan *pathman.PathManager clientMan *clientman.ClientManager + confWatcher *confwatcher.ConfWatcher terminate chan struct{} done chan struct{} @@ -52,77 +56,33 @@ func newProgram(args []string) (*program, error) { os.Exit(0) } - conf, err := conf.Load(*argConfPath) - if err != nil { - return nil, err - } - p := &program{ - conf: conf, + confPath: *argConfPath, terminate: make(chan struct{}), done: make(chan struct{}), } - p.stats = stats.New() - - p.logHandler, err = loghandler.New(conf.LogDestinationsParsed, conf.LogFile) + var err error + p.conf, err = conf.Load(p.confPath) if err != nil { - p.closeResources() return nil, err } - p.Log("rtsp-simple-server %s", Version) - - if conf.Metrics { - p.metrics, err = metrics.New(p.stats, p) - if err != nil { - p.closeResources() - return nil, err - } - } - - if conf.Pprof { - p.pprof, err = pprof.New(p) - if err != nil { - p.closeResources() - return nil, err - } - } - - if _, ok := conf.ProtocolsParsed[gortsplib.StreamProtocolUDP]; ok { - p.serverUdpRtp, err = serverudp.New(p.conf.WriteTimeout, - conf.RtpPort, gortsplib.StreamTypeRtp, p) - if err != nil { - p.closeResources() - return nil, err - } - - p.serverUdpRtcp, err = serverudp.New(p.conf.WriteTimeout, - conf.RtcpPort, gortsplib.StreamTypeRtcp, p) - if err != nil { - p.closeResources() - return nil, err - } - } - - p.serverTcp, err = servertcp.New(conf.RtspPort, p) + err = p.createResources(true) if err != nil { p.closeResources() return nil, err } - p.pathMan = pathman.New(p.stats, p.serverUdpRtp, p.serverUdpRtcp, - p.conf.ReadTimeout, p.conf.WriteTimeout, p.conf.AuthMethodsParsed, - conf.Paths, p) - - p.clientMan = clientman.New(p.stats, p.serverUdpRtp, p.serverUdpRtcp, - p.conf.ReadTimeout, p.conf.WriteTimeout, p.conf.RunOnConnect, - p.conf.ProtocolsParsed, p.pathMan, p.serverTcp, p) - go p.run() return p, nil } +func (p *program) close() { + close(p.terminate) + <-p.done +} + func (p *program) Log(format string, args ...interface{}) { countClients := atomic.LoadInt64(p.stats.CountClients) countPublishers := atomic.LoadInt64(p.stats.CountPublishers) @@ -138,6 +98,13 @@ func (p *program) run() { outer: for { select { + case <-p.confWatcher.Watch(): + err := p.reloadConf() + if err != nil { + p.Log("ERR: %s", err) + break outer + } + case <-p.terminate: break outer } @@ -146,7 +113,93 @@ outer: p.closeResources() } +func (p *program) createResources(initial bool) error { + var err error + + if p.stats == nil { + p.stats = stats.New() + } + + if p.logHandler == nil { + p.logHandler, err = loghandler.New(p.conf.LogDestinationsParsed, p.conf.LogFile) + if err != nil { + return err + } + } + + if initial { + p.Log("rtsp-simple-server %s", Version) + } + + if p.conf.Metrics { + if p.metrics == nil { + p.metrics, err = metrics.New(p.stats, p) + if err != nil { + return err + } + } + } + + if p.conf.Pprof { + if p.pprof == nil { + p.pprof, err = pprof.New(p) + if err != nil { + return err + } + } + } + + if _, ok := p.conf.ProtocolsParsed[gortsplib.StreamProtocolUDP]; ok { + if p.serverUdpRtp == nil { + p.serverUdpRtp, err = serverudp.New(p.conf.WriteTimeout, + p.conf.RtpPort, gortsplib.StreamTypeRtp, p) + if err != nil { + return err + } + } + + if p.serverUdpRtcp == nil { + p.serverUdpRtcp, err = serverudp.New(p.conf.WriteTimeout, + p.conf.RtcpPort, gortsplib.StreamTypeRtcp, p) + if err != nil { + return err + } + } + } + + if p.serverTcp == nil { + p.serverTcp, err = servertcp.New(p.conf.RtspPort, p) + if err != nil { + return err + } + } + + if p.pathMan == nil { + p.pathMan = pathman.New(p.conf.ReadTimeout, p.conf.WriteTimeout, + p.conf.AuthMethodsParsed, p.conf.Paths, p.stats, p) + } + + if p.clientMan == nil { + p.clientMan = clientman.New(p.conf.ReadTimeout, p.conf.WriteTimeout, + p.conf.RunOnConnect, p.conf.ProtocolsParsed, p.stats, + p.serverUdpRtp, p.serverUdpRtcp, p.pathMan, p.serverTcp, p) + } + + if p.confWatcher == nil { + p.confWatcher, err = confwatcher.New(p.confPath) + if err != nil { + return err + } + } + + return nil +} + func (p *program) closeResources() { + if p.confWatcher != nil { + p.confWatcher.Close() + } + if p.clientMan != nil { p.clientMan.Close() } @@ -180,16 +233,121 @@ func (p *program) closeResources() { } } -func (p *program) close() { - close(p.terminate) - <-p.done +func (p *program) reloadConf() error { + p.Log("reloading configuration") + + conf, err := conf.Load(p.confPath) + if err != nil { + return err + } + + // always recreate confWatcher to avoid reloading twice + p.confWatcher.Close() + p.confWatcher = nil + + closeLogHandler := false + if !reflect.DeepEqual(conf.LogDestinationsParsed, p.conf.LogDestinationsParsed) || + conf.LogFile != p.conf.LogFile { + closeLogHandler = true + } + + closeMetrics := false + if conf.Metrics != p.conf.Metrics { + closeMetrics = true + } + + closePprof := false + if conf.Pprof != p.conf.Pprof { + closePprof = true + } + + closeServerUdpRtp := false + if conf.WriteTimeout != p.conf.WriteTimeout || + conf.RtpPort != p.conf.RtpPort { + closeServerUdpRtp = true + } + + closeServerUdpRtcp := false + if conf.WriteTimeout != p.conf.WriteTimeout || + conf.RtcpPort != p.conf.RtcpPort { + closeServerUdpRtcp = true + } + + closeServerTcp := false + if conf.RtspPort != p.conf.RtspPort { + closeServerTcp = true + } + + closePathMan := false + if conf.ReadTimeout != p.conf.ReadTimeout || + conf.WriteTimeout != p.conf.WriteTimeout || + !reflect.DeepEqual(conf.AuthMethodsParsed, p.conf.AuthMethodsParsed) { + closePathMan = true + } else if !reflect.DeepEqual(conf.Paths, p.conf.Paths) { + p.pathMan.OnProgramConfReload(conf.Paths) + } + + closeClientMan := false + if closeServerUdpRtp || + closeServerUdpRtcp || + closeServerTcp || + closePathMan || + conf.ReadTimeout != p.conf.ReadTimeout || + conf.WriteTimeout != p.conf.WriteTimeout || + conf.RunOnConnect != p.conf.RunOnConnect || + !reflect.DeepEqual(conf.ProtocolsParsed, p.conf.ProtocolsParsed) { + closeClientMan = true + } + + if closeClientMan { + p.clientMan.Close() + p.clientMan = nil + } + + if closePathMan { + p.pathMan.Close() + p.pathMan = nil + } + + if closeServerTcp { + p.serverTcp.Close() + p.serverTcp = nil + } + + if closeServerUdpRtcp && p.serverUdpRtcp != nil { + p.serverUdpRtcp.Close() + p.serverUdpRtcp = nil + } + + if closeServerUdpRtp && p.serverUdpRtp != nil { + p.serverUdpRtp.Close() + p.serverUdpRtp = nil + } + + if closePprof && p.pprof != nil { + p.pprof.Close() + p.pprof = nil + } + + if closeMetrics && p.metrics != nil { + p.metrics.Close() + p.metrics = nil + } + + if closeLogHandler { + p.logHandler.Close() + p.logHandler = nil + } + + p.conf = conf + return p.createResources(false) } func main() { - _, err := newProgram(os.Args[1:]) + p, err := newProgram(os.Args[1:]) if err != nil { log.Fatal("ERR: ", err) } - select {} + <-p.done } diff --git a/main_test.go b/main_test.go index 441c25ea..fa32c107 100644 --- a/main_test.go +++ b/main_test.go @@ -5,6 +5,7 @@ import ( "net" "os" "os/exec" + "path/filepath" "regexp" "strconv" "testing" @@ -660,3 +661,69 @@ func TestRunOnDemand(t *testing.T) { code := cnt1.wait() require.Equal(t, 0, code) } + +func TestHotReloading(t *testing.T) { + confPath := filepath.Join(os.TempDir(), "rtsp-conf") + + err := ioutil.WriteFile(confPath, []byte("paths:\n"+ + " test1:\n"+ + " runOnDemand: ffmpeg -hide_banner -loglevel error -re -i testimages/ffmpeg/emptyvideo.ts -c copy -f rtsp rtsp://localhost:8554/$RTSP_SERVER_PATH\n"), + 0644) + require.NoError(t, err) + + p, err := newProgram([]string{confPath}) + require.NoError(t, err) + defer p.close() + + time.Sleep(1 * time.Second) + + func() { + cnt1, err := newContainer("ffmpeg", "dest", []string{ + "-i", "rtsp://" + ownDockerIp + ":8554/test1", + "-vframes", "1", + "-f", "image2", + "-y", "/dev/null", + }) + require.NoError(t, err) + defer cnt1.close() + + code := cnt1.wait() + require.Equal(t, 0, code) + }() + + err = ioutil.WriteFile(confPath, []byte("paths:\n"+ + " test2:\n"+ + " runOnDemand: ffmpeg -hide_banner -loglevel error -re -i testimages/ffmpeg/emptyvideo.ts -c copy -f rtsp rtsp://localhost:8554/$RTSP_SERVER_PATH\n"), + 0644) + require.NoError(t, err) + + time.Sleep(1 * time.Second) + + func() { + cnt1, err := newContainer("ffmpeg", "dest", []string{ + "-i", "rtsp://" + ownDockerIp + ":8554/test1", + "-vframes", "1", + "-f", "image2", + "-y", "/dev/null", + }) + require.NoError(t, err) + defer cnt1.close() + + code := cnt1.wait() + require.Equal(t, 1, code) + }() + + func() { + cnt1, err := newContainer("ffmpeg", "dest", []string{ + "-i", "rtsp://" + ownDockerIp + ":8554/test2", + "-vframes", "1", + "-f", "image2", + "-y", "/dev/null", + }) + require.NoError(t, err) + defer cnt1.close() + + code := cnt1.wait() + require.Equal(t, 0, code) + }() +} diff --git a/path/path.go b/path/path.go index 2ece79f4..af246bb5 100644 --- a/path/path.go +++ b/path/path.go @@ -13,7 +13,6 @@ import ( "github.com/aler9/rtsp-simple-server/client" "github.com/aler9/rtsp-simple-server/conf" "github.com/aler9/rtsp-simple-server/externalcmd" - "github.com/aler9/rtsp-simple-server/serverudp" "github.com/aler9/rtsp-simple-server/sourcertmp" "github.com/aler9/rtsp-simple-server/sourcertsp" "github.com/aler9/rtsp-simple-server/stats" @@ -98,20 +97,21 @@ const ( clientStatePlay clientStatePreRecord clientStateRecord + clientStatePreRemove ) type Path struct { - wg *sync.WaitGroup - stats *stats.Stats - serverUdpRtp *serverudp.Server - serverUdpRtcp *serverudp.Server - readTimeout time.Duration - writeTimeout time.Duration - name string - conf *conf.PathConf - parent Parent + readTimeout time.Duration + writeTimeout time.Duration + confName string + conf *conf.PathConf + name string + wg *sync.WaitGroup + stats *stats.Stats + parent Parent clients map[*client.Client]clientState + clientsWg sync.WaitGroup source source sourceReady bool sourceTrackCount int @@ -135,25 +135,23 @@ type Path struct { } func New( - wg *sync.WaitGroup, - stats *stats.Stats, - serverUdpRtp *serverudp.Server, - serverUdpRtcp *serverudp.Server, readTimeout time.Duration, writeTimeout time.Duration, - name string, + confName string, conf *conf.PathConf, + name string, + wg *sync.WaitGroup, + stats *stats.Stats, parent Parent) *Path { pa := &Path{ - wg: wg, - stats: stats, - serverUdpRtp: serverUdpRtp, - serverUdpRtcp: serverUdpRtcp, readTimeout: readTimeout, writeTimeout: writeTimeout, - name: name, + confName: confName, conf: conf, + name: name, + wg: wg, + stats: stats, parent: parent, clients: make(map[*client.Client]clientState), readers: newReadersMap(), @@ -241,6 +239,7 @@ outer: case <-tickerCheck.C: ok := pa.onCheck() if !ok { + pa.exhaustChannels() pa.parent.OnPathClose(pa) <-pa.terminate break outer @@ -253,8 +252,14 @@ outer: pa.onSourceSetNotReady() case req := <-pa.clientDescribe: + if _, ok := pa.clients[req.Client]; ok { + req.Res <- ClientDescribeRes{nil, fmt.Errorf("already subscribed")} + continue + } + // reply immediately req.Res <- ClientDescribeRes{pa, nil} + pa.onClientDescribe(req.Client) case req := <-pa.clientSetupPlay: @@ -266,9 +271,7 @@ outer: req.Res <- ClientSetupPlayRes{pa, nil} case req := <-pa.clientPlay: - if _, ok := pa.clients[req.client]; ok { - pa.onClientPlay(req.client) - } + pa.onClientPlay(req.client) close(req.res) case req := <-pa.clientAnnounce: @@ -280,22 +283,74 @@ outer: req.Res <- ClientAnnounceRes{pa, nil} case req := <-pa.clientRecord: - if _, ok := pa.clients[req.client]; ok { - pa.onClientRecord(req.client) - } + pa.onClientRecord(req.client) close(req.res) case req := <-pa.clientRemove: - if _, ok := pa.clients[req.client]; ok { - pa.onClientRemove(req.client) + if _, ok := pa.clients[req.client]; !ok { + close(req.res) + continue } + + if pa.clients[req.client] != clientStatePreRemove { + pa.onClientPreRemove(req.client) + } + + delete(pa.clients, req.client) + pa.clientsWg.Done() + close(req.res) case <-pa.terminate: + pa.exhaustChannels() break outer } } + if pa.onInitCmd != nil { + pa.Log("stopping on init command (closing)") + pa.onInitCmd.Close() + } + + if source, ok := pa.source.(*sourcertsp.Source); ok { + source.Close() + + } else if source, ok := pa.source.(*sourcertmp.Source); ok { + source.Close() + } + + if pa.onDemandCmd != nil { + pa.Log("stopping on demand command (closing)") + pa.onDemandCmd.Close() + } + + for c, state := range pa.clients { + if state != clientStatePreRemove { + switch state { + case clientStatePlay: + atomic.AddInt64(pa.stats.CountReaders, -1) + pa.readers.remove(c) + + case clientStateRecord: + atomic.AddInt64(pa.stats.CountPublishers, -1) + } + + pa.parent.OnPathClientClose(c) + } + } + pa.clientsWg.Wait() + + close(pa.sourceSetReady) + close(pa.sourceSetNotReady) + close(pa.clientDescribe) + close(pa.clientAnnounce) + close(pa.clientSetupPlay) + close(pa.clientPlay) + close(pa.clientRecord) + close(pa.clientRemove) +} + +func (pa *Path) exhaustChannels() { go func() { for { select { @@ -343,50 +398,27 @@ outer: if !ok { return } - close(req.res) - } - } - }() - if pa.onInitCmd != nil { - pa.Log("stopping on init command (closing)") - pa.onInitCmd.Close() - } - - if source, ok := pa.source.(*sourcertsp.Source); ok { - source.Close() - - } else if source, ok := pa.source.(*sourcertmp.Source); ok { - source.Close() - } + if _, ok := pa.clients[req.client]; !ok { + close(req.res) + continue + } - if pa.onDemandCmd != nil { - pa.Log("stopping on demand command (closing)") - pa.onDemandCmd.Close() - } + pa.clientsWg.Done() - for c, state := range pa.clients { - if state == clientStateWaitingDescribe { - delete(pa.clients, c) - c.OnPathDescribeData(nil, fmt.Errorf("publisher of path '%s' has timed out", pa.name)) - } else { - pa.onClientRemove(c) - pa.parent.OnPathClientClose(c) + close(req.res) + } } - } - - close(pa.sourceSetReady) - close(pa.sourceSetNotReady) - close(pa.clientDescribe) - close(pa.clientAnnounce) - close(pa.clientSetupPlay) - close(pa.clientPlay) - close(pa.clientRecord) - close(pa.clientRemove) + }() } func (pa *Path) hasClients() bool { - return len(pa.clients) > 0 + for _, state := range pa.clients { + if state != clientStatePreRemove { + return true + } + } + return false } func (pa *Path) hasClientsWaitingDescribe() bool { @@ -399,8 +431,8 @@ func (pa *Path) hasClientsWaitingDescribe() bool { } func (pa *Path) hasClientReadersOrWaitingDescribe() bool { - for c := range pa.clients { - if c != pa.source { + for c, state := range pa.clients { + if state != clientStatePreRemove && c != pa.source { return true } } @@ -412,8 +444,8 @@ func (pa *Path) onCheck() bool { if pa.hasClientsWaitingDescribe() && time.Since(pa.lastDescribeActivation) >= describeTimeout { for c, state := range pa.clients { - if state == clientStateWaitingDescribe { - delete(pa.clients, c) + if state != clientStatePreRemove && state == clientStateWaitingDescribe { + pa.clients[c] = clientStatePreRemove c.OnPathDescribeData(nil, fmt.Errorf("publisher of path '%s' has timed out", pa.name)) } } @@ -451,9 +483,10 @@ func (pa *Path) onCheck() bool { pa.onDemandCmd = nil } - // remove path if is regexp and has no clients + // remove path if is regexp, has no source, has no on-demand command and has no clients if pa.conf.Regexp != nil && pa.source == nil && + pa.onDemandCmd == nil && !pa.hasClients() { return false } @@ -467,7 +500,7 @@ func (pa *Path) onSourceSetReady() { // reply to all clients that are waiting for a description for c, state := range pa.clients { if state == clientStateWaitingDescribe { - delete(pa.clients, c) + pa.clients[c] = clientStatePreRemove c.OnPathDescribeData(pa.sourceSdp, nil) } } @@ -478,8 +511,8 @@ func (pa *Path) onSourceSetNotReady() { // close all clients that are reading or waiting to read for c, state := range pa.clients { - if state != clientStateWaitingDescribe && c != pa.source { - pa.onClientRemove(c) + if state != clientStatePreRemove && state != clientStateWaitingDescribe && c != pa.source { + pa.onClientPreRemove(c) pa.parent.OnPathClientClose(c) } } @@ -504,10 +537,14 @@ func (pa *Path) onClientDescribe(c *client.Client) { } pa.clients[c] = clientStateWaitingDescribe + pa.clientsWg.Add(1) // no on-demand: reply with 404 } else { - c.OnPathDescribeData(nil, fmt.Errorf("no one is publishing on path '%s'", pa.name)) + pa.clients[c] = clientStatePreRemove + pa.clientsWg.Add(1) + + c.OnPathDescribeData(nil, fmt.Errorf("no one is publishing to path '%s'", pa.name)) } // publisher was found but is not ready: put the client on hold @@ -532,38 +569,61 @@ func (pa *Path) onClientDescribe(c *client.Client) { } pa.clients[c] = clientStateWaitingDescribe + pa.clientsWg.Add(1) // publisher was found and is ready } else { + pa.clients[c] = clientStatePreRemove + pa.clientsWg.Add(1) + c.OnPathDescribeData(pa.sourceSdp, nil) } } func (pa *Path) onClientSetupPlay(c *client.Client, trackId int) error { if !pa.sourceReady { - return fmt.Errorf("no one is publishing on path '%s'", pa.name) + return fmt.Errorf("no one is publishing to path '%s'", pa.name) } if trackId >= pa.sourceTrackCount { return fmt.Errorf("track %d does not exist", trackId) } - pa.clients[c] = clientStatePrePlay + if _, ok := pa.clients[c]; !ok { + pa.clients[c] = clientStatePrePlay + pa.clientsWg.Add(1) + } + return nil } func (pa *Path) onClientPlay(c *client.Client) { + state, ok := pa.clients[c] + if !ok { + return + } + + if state != clientStatePrePlay { + return + } + atomic.AddInt64(pa.stats.CountReaders, 1) pa.clients[c] = clientStatePlay pa.readers.add(c) } func (pa *Path) onClientAnnounce(c *client.Client, tracks gortsplib.Tracks) error { + if _, ok := pa.clients[c]; ok { + return fmt.Errorf("already subscribed") + } + if pa.source != nil { - return fmt.Errorf("someone is already publishing on path '%s'", pa.name) + return fmt.Errorf("someone is already publishing to path '%s'", pa.name) } pa.clients[c] = clientStatePreRecord + pa.clientsWg.Add(1) + pa.source = c pa.sourceTrackCount = len(tracks) pa.sourceSdp = tracks.Write() @@ -571,14 +631,24 @@ func (pa *Path) onClientAnnounce(c *client.Client, tracks gortsplib.Tracks) erro } func (pa *Path) onClientRecord(c *client.Client) { + state, ok := pa.clients[c] + if !ok { + return + } + + if state != clientStatePreRecord { + return + } + atomic.AddInt64(pa.stats.CountPublishers, 1) pa.clients[c] = clientStateRecord + pa.onSourceSetReady() } -func (pa *Path) onClientRemove(c *client.Client) { +func (pa *Path) onClientPreRemove(c *client.Client) { state := pa.clients[c] - delete(pa.clients, c) + pa.clients[c] = clientStatePreRemove switch state { case clientStatePlay: @@ -595,8 +665,8 @@ func (pa *Path) onClientRemove(c *client.Client) { // close all clients that are reading or waiting to read for oc, state := range pa.clients { - if state != clientStateWaitingDescribe && oc != pa.source { - pa.onClientRemove(oc) + if state != clientStatePreRemove && state != clientStateWaitingDescribe && oc != pa.source { + pa.onClientPreRemove(oc) pa.parent.OnPathClientClose(oc) } } @@ -613,6 +683,14 @@ func (pa *Path) OnSourceNotReady() { pa.sourceSetNotReady <- struct{}{} } +func (pa *Path) ConfName() string { + return pa.confName +} + +func (pa *Path) Conf() *conf.PathConf { + return pa.conf +} + func (pa *Path) Name() string { return pa.name } @@ -621,10 +699,6 @@ func (pa *Path) SourceTrackCount() int { return pa.sourceTrackCount } -func (pa *Path) Conf() *conf.PathConf { - return pa.conf -} - func (pa *Path) OnPathManDescribe(req ClientDescribeReq) { pa.clientDescribe <- req } diff --git a/pathman/pathman.go b/pathman/pathman.go index 7e0ecb75..d4317f84 100644 --- a/pathman/pathman.go +++ b/pathman/pathman.go @@ -12,7 +12,6 @@ import ( "github.com/aler9/rtsp-simple-server/client" "github.com/aler9/rtsp-simple-server/conf" "github.com/aler9/rtsp-simple-server/path" - "github.com/aler9/rtsp-simple-server/serverudp" "github.com/aler9/rtsp-simple-server/stats" ) @@ -21,19 +20,18 @@ type Parent interface { } type PathManager struct { - stats *stats.Stats - serverUdpRtp *serverudp.Server - serverUdpRtcp *serverudp.Server - readTimeout time.Duration - writeTimeout time.Duration - authMethods []headers.AuthMethod - confPaths map[string]*conf.PathConf - parent Parent + readTimeout time.Duration + writeTimeout time.Duration + authMethods []headers.AuthMethod + pathConfs map[string]*conf.PathConf + stats *stats.Stats + parent Parent paths map[string]*path.Path wg sync.WaitGroup // in + confReload chan map[string]*conf.PathConf pathClose chan *path.Path clientDescribe chan path.ClientDescribeReq clientAnnounce chan path.ClientAnnounceReq @@ -45,25 +43,23 @@ type PathManager struct { done chan struct{} } -func New(stats *stats.Stats, - serverUdpRtp *serverudp.Server, - serverUdpRtcp *serverudp.Server, +func New( readTimeout time.Duration, writeTimeout time.Duration, authMethods []headers.AuthMethod, - confPaths map[string]*conf.PathConf, + pathConfs map[string]*conf.PathConf, + stats *stats.Stats, parent Parent) *PathManager { pm := &PathManager{ - stats: stats, - serverUdpRtp: serverUdpRtp, - serverUdpRtcp: serverUdpRtcp, readTimeout: readTimeout, writeTimeout: writeTimeout, authMethods: authMethods, - confPaths: confPaths, + pathConfs: pathConfs, + stats: stats, parent: parent, paths: make(map[string]*path.Path), + confReload: make(chan map[string]*conf.PathConf), pathClose: make(chan *path.Path), clientDescribe: make(chan path.ClientDescribeReq), clientAnnounce: make(chan path.ClientAnnounceReq), @@ -73,13 +69,7 @@ func New(stats *stats.Stats, done: make(chan struct{}), } - for name, pathConf := range confPaths { - if pathConf.Regexp == nil { - pa := path.New(&pm.wg, pm.stats, pm.serverUdpRtp, pm.serverUdpRtcp, - pm.readTimeout, pm.writeTimeout, name, pathConf, pm) - pm.paths[name] = pa - } - } + pm.createPaths() go pm.run() return pm @@ -104,12 +94,53 @@ func (pm *PathManager) run() { outer: for { select { + case pathConfs := <-pm.confReload: + // remove confs + for pathName := range pm.pathConfs { + if _, ok := pathConfs[pathName]; !ok { + delete(pm.pathConfs, pathName) + } + } + + // update confs + for pathName, oldConf := range pm.pathConfs { + if !oldConf.Equal(pathConfs[pathName]) { + pm.pathConfs[pathName] = pathConfs[pathName] + } + } + + // add confs + for pathName, pathConf := range pathConfs { + if _, ok := pm.pathConfs[pathName]; !ok { + pm.pathConfs[pathName] = pathConf + } + } + + // remove paths associated with a conf which doesn't exist anymore + // or has changed + for _, pa := range pm.paths { + if pathConf, ok := pm.pathConfs[pa.ConfName()]; !ok { + delete(pm.paths, pa.Name()) + pa.Close() + + } else if pathConf != pa.Conf() { + delete(pm.paths, pa.Name()) + pa.Close() + } + } + + // add paths + pm.createPaths() + case pa := <-pm.pathClose: + if _, ok := pm.paths[pa.Name()]; !ok { + continue + } delete(pm.paths, pa.Name()) pa.Close() case req := <-pm.clientDescribe: - pathConf, err := pm.findPathConf(req.PathName) + pathName, pathConf, err := pm.findPathConf(req.PathName) if err != nil { req.Res <- path.ClientDescribeRes{nil, err} continue @@ -124,15 +155,16 @@ outer: // create path if it doesn't exist if _, ok := pm.paths[req.PathName]; !ok { - pa := path.New(&pm.wg, pm.stats, pm.serverUdpRtp, pm.serverUdpRtcp, - pm.readTimeout, pm.writeTimeout, req.PathName, pathConf, pm) + pa := path.New( + pm.readTimeout, pm.writeTimeout, pathName, pathConf, req.PathName, + &pm.wg, pm.stats, pm) pm.paths[req.PathName] = pa } pm.paths[req.PathName].OnPathManDescribe(req) case req := <-pm.clientAnnounce: - pathConf, err := pm.findPathConf(req.PathName) + pathName, pathConf, err := pm.findPathConf(req.PathName) if err != nil { req.Res <- path.ClientAnnounceRes{nil, err} continue @@ -147,8 +179,9 @@ outer: // create path if it doesn't exist if _, ok := pm.paths[req.PathName]; !ok { - pa := path.New(&pm.wg, pm.stats, pm.serverUdpRtp, pm.serverUdpRtcp, - pm.readTimeout, pm.writeTimeout, req.PathName, pathConf, pm) + pa := path.New( + pm.readTimeout, pm.writeTimeout, pathName, pathConf, req.PathName, + &pm.wg, pm.stats, pm) pm.paths[req.PathName] = pa } @@ -156,11 +189,11 @@ outer: case req := <-pm.clientSetupPlay: if _, ok := pm.paths[req.PathName]; !ok { - req.Res <- path.ClientSetupPlayRes{nil, fmt.Errorf("no one is publishing on path '%s'", req.PathName)} + req.Res <- path.ClientSetupPlayRes{nil, fmt.Errorf("no one is publishing to path '%s'", req.PathName)} continue } - pathConf, err := pm.findPathConf(req.PathName) + _, pathConf, err := pm.findPathConf(req.PathName) if err != nil { req.Res <- path.ClientSetupPlayRes{nil, err} continue @@ -183,6 +216,11 @@ outer: go func() { for { select { + case _, ok := <-pm.confReload: + if !ok { + return + } + case _, ok := <-pm.pathClose: if !ok { return @@ -205,6 +243,7 @@ outer: } pm.wg.Wait() + close(pm.confReload) close(pm.clientClose) close(pm.pathClose) close(pm.clientDescribe) @@ -212,25 +251,40 @@ outer: close(pm.clientSetupPlay) } -func (pm *PathManager) findPathConf(name string) (*conf.PathConf, error) { +func (pm *PathManager) createPaths() { + for pathName, pathConf := range pm.pathConfs { + if pathConf.Regexp == nil { + pa := path.New( + pm.readTimeout, pm.writeTimeout, pathName, pathConf, pathName, + &pm.wg, pm.stats, pm) + pm.paths[pathName] = pa + } + } +} + +func (pm *PathManager) findPathConf(name string) (string, *conf.PathConf, error) { err := conf.CheckPathName(name) if err != nil { - return nil, fmt.Errorf("invalid path name: %s (%s)", err, name) + return "", nil, fmt.Errorf("invalid path name: %s (%s)", err, name) } // normal path - if pathConf, ok := pm.confPaths[name]; ok { - return pathConf, nil + if pathConf, ok := pm.pathConfs[name]; ok { + return name, pathConf, nil } // regular expression path - for _, pathConf := range pm.confPaths { + for pathName, pathConf := range pm.pathConfs { if pathConf.Regexp != nil && pathConf.Regexp.MatchString(name) { - return pathConf, nil + return pathName, pathConf, nil } } - return nil, fmt.Errorf("unable to find a valid configuration for path '%s'", name) + return "", nil, fmt.Errorf("unable to find a valid configuration for path '%s'", name) +} + +func (pm *PathManager) OnProgramConfReload(pathConfs map[string]*conf.PathConf) { + pm.confReload <- pathConfs } func (pm *PathManager) OnPathClose(pa *path.Path) { diff --git a/serverudp/server.go b/serverudp/server.go index 29fefb51..bc6bf0f7 100644 --- a/serverudp/server.go +++ b/serverudp/server.go @@ -122,12 +122,20 @@ func (s *Server) run() { break } - pub := s.getPublisher(addr.IP, addr.Port) - if pub == nil { - continue - } - - pub.publisher.OnUdpPublisherFrame(pub.trackId, s.streamType, buf[:n]) + func() { + s.publishersMutex.RLock() + defer s.publishersMutex.RUnlock() + + // find publisher data + var pubAddr publisherAddr + pubAddr.fill(addr.IP, addr.Port) + pubData, ok := s.publishers[pubAddr] + if !ok { + return + } + + pubData.publisher.OnUdpPublisherFrame(pubData.trackId, s.streamType, buf[:n]) + }() } close(s.write) @@ -164,17 +172,3 @@ func (s *Server) RemovePublisher(ip net.IP, port int, publisher Publisher) { delete(s.publishers, addr) } - -func (s *Server) getPublisher(ip net.IP, port int) *publisherData { - s.publishersMutex.RLock() - defer s.publishersMutex.RUnlock() - - var addr publisherAddr - addr.fill(ip, port) - - el, ok := s.publishers[addr] - if !ok { - return nil - } - return el -}