diff --git a/main.go b/main.go index 47103468..37714b88 100644 --- a/main.go +++ b/main.go @@ -167,7 +167,7 @@ func (p *program) forwardTrack(path string, id int, flow trackFlow, frame []byte if c.path == path && c.state == _CLIENT_STATE_PLAY { if c.streamProtocol == _STREAM_PROTOCOL_UDP { if flow == _TRACK_FLOW_RTP { - p.rtpl.chanWrite <- &udpWrite{ + p.rtpl.write <- &udpWrite{ addr: &net.UDPAddr{ IP: c.ip(), Zone: c.zone(), @@ -176,7 +176,7 @@ func (p *program) forwardTrack(path string, id int, flow trackFlow, frame []byte buf: frame, } } else { - p.rtcpl.chanWrite <- &udpWrite{ + p.rtcpl.write <- &udpWrite{ addr: &net.UDPAddr{ IP: c.ip(), Zone: c.zone(), @@ -187,7 +187,7 @@ func (p *program) forwardTrack(path string, id int, flow trackFlow, frame []byte } } else { - c.chanWrite <- &gortsplib.InterleavedFrame{ + c.write <- &gortsplib.InterleavedFrame{ Channel: trackToInterleavedChannel(id, flow), Content: frame, } diff --git a/server-client.go b/server-client.go index cbd1addd..df24645a 100644 --- a/server-client.go +++ b/server-client.go @@ -113,7 +113,7 @@ type serverClient struct { streamSdpParsed *sdp.Message // filled only if publisher streamProtocol streamProtocol streamTracks []*track - chanWrite chan *gortsplib.InterleavedFrame + write chan *gortsplib.InterleavedFrame } func newServerClient(p *program, nconn net.Conn) *serverClient { @@ -124,8 +124,8 @@ func newServerClient(p *program, nconn net.Conn) *serverClient { ReadTimeout: _READ_TIMEOUT, WriteTimeout: _WRITE_TIMEOUT, }), - state: _CLIENT_STATE_STARTING, - chanWrite: make(chan *gortsplib.InterleavedFrame), + state: _CLIENT_STATE_STARTING, + write: make(chan *gortsplib.InterleavedFrame), } c.p.mutex.Lock() @@ -143,7 +143,7 @@ func (c *serverClient) close() error { delete(c.p.clients, c) c.conn.NetConn().Close() - close(c.chanWrite) + close(c.write) if c.path != "" { if pub, ok := c.p.publishers[c.path]; ok && pub == c { @@ -755,7 +755,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { if c.streamProtocol == _STREAM_PROTOCOL_TCP { // write RTP frames sequentially go func() { - for frame := range c.chanWrite { + for frame := range c.write { c.conn.WriteInterleavedFrame(frame) } }() diff --git a/server-udpl.go b/server-udpl.go index de79d872..1f1d470d 100644 --- a/server-udpl.go +++ b/server-udpl.go @@ -12,10 +12,11 @@ type udpWrite struct { } type serverUdpListener struct { - p *program - nconn *net.UDPConn - flow trackFlow - chanWrite chan *udpWrite + p *program + nconn *net.UDPConn + flow trackFlow + write chan *udpWrite + done chan struct{} } func newServerUdpListener(p *program, port int, flow trackFlow) (*serverUdpListener, error) { @@ -27,10 +28,11 @@ func newServerUdpListener(p *program, port int, flow trackFlow) (*serverUdpListe } l := &serverUdpListener{ - p: p, - nconn: nconn, - flow: flow, - chanWrite: make(chan *udpWrite), + p: p, + nconn: nconn, + flow: flow, + write: make(chan *udpWrite), + done: make(chan struct{}), } l.log("opened on :%d", port) @@ -49,56 +51,61 @@ func (l *serverUdpListener) log(format string, args ...interface{}) { func (l *serverUdpListener) run() { go func() { - for { - // create a buffer for each read. - // this is necessary since the buffer is propagated with channels - // so it must be unique. - buf := make([]byte, 2048) // UDP MTU is 1400 - n, addr, err := l.nconn.ReadFromUDP(buf) - if err != nil { - l.log("ERR: %s", err) - break - } + for w := range l.write { + l.nconn.SetWriteDeadline(time.Now().Add(_WRITE_TIMEOUT)) + l.nconn.WriteTo(w.buf, w.addr) + } + }() + + for { + // create a buffer for each read. + // this is necessary since the buffer is propagated with channels + // so it must be unique. + buf := make([]byte, 2048) // UDP MTU is 1400 + n, addr, err := l.nconn.ReadFromUDP(buf) + if err != nil { + break + } - func() { - l.p.mutex.RLock() - defer l.p.mutex.RUnlock() + func() { + l.p.mutex.RLock() + defer l.p.mutex.RUnlock() - // find path and track id from ip and port - path, trackId := func() (string, int) { - for _, pub := range l.p.publishers { - for i, t := range pub.streamTracks { - if !pub.ip().Equal(addr.IP) { - continue - } + // find path and track id from ip and port + path, trackId := func() (string, int) { + for _, pub := range l.p.publishers { + for i, t := range pub.streamTracks { + if !pub.ip().Equal(addr.IP) { + continue + } - if l.flow == _TRACK_FLOW_RTP { - if t.rtpPort == addr.Port { - return pub.path, i - } - } else { - if t.rtcpPort == addr.Port { - return pub.path, i - } + if l.flow == _TRACK_FLOW_RTP { + if t.rtpPort == addr.Port { + return pub.path, i + } + } else { + if t.rtcpPort == addr.Port { + return pub.path, i } } } - return "", -1 - }() - if path == "" { - return } - - l.p.forwardTrack(path, trackId, l.flow, buf[:n]) + return "", -1 }() - } - }() + if path == "" { + return + } - go func() { - for { - w := <-l.chanWrite - l.nconn.SetWriteDeadline(time.Now().Add(_WRITE_TIMEOUT)) - l.nconn.WriteTo(w.buf, w.addr) - } - }() + l.p.forwardTrack(path, trackId, l.flow, buf[:n]) + }() + } + + close(l.write) + + l.done <- struct{}{} +} + +func (l *serverUdpListener) close() { + l.nconn.Close() + <-l.done }