Browse Source

webrtc muxer: fix multiple race conditions

pull/1377/head
aler9 3 years ago
parent
commit
f3f55452e5
  1. 6
      internal/core/api.go
  2. 25
      internal/core/http_requestpool.go
  3. 6
      internal/core/metrics.go
  4. 4
      internal/core/pprof.go
  5. 19
      internal/core/webrtc_conn.go
  6. 50
      internal/core/webrtc_server.go
  7. 32
      internal/core/webrtc_server_test.go

6
internal/core/api.go

@ -3,6 +3,7 @@ package core @@ -3,6 +3,7 @@ package core
import (
"context"
"encoding/json"
"log"
"net"
"net/http"
"reflect"
@ -201,7 +202,10 @@ func newAPI( @@ -201,7 +202,10 @@ func newAPI(
group.POST("/v1/webrtcconns/kick/:id", a.onWebRTCConnsKick)
}
a.s = &http.Server{Handler: router}
a.s = &http.Server{
Handler: router,
ErrorLog: log.New(&nilWriter{}, "", 0),
}
go a.s.Serve(ln)

25
internal/core/http_requestpool.go

@ -0,0 +1,25 @@ @@ -0,0 +1,25 @@
package core
import (
"sync"
"github.com/gin-gonic/gin"
)
type httpRequestPool struct {
wg sync.WaitGroup
}
func newHTTPRequestPool() *httpRequestPool {
return &httpRequestPool{}
}
func (rp *httpRequestPool) mw(ctx *gin.Context) {
rp.wg.Add(1)
ctx.Next()
rp.wg.Done()
}
func (rp *httpRequestPool) close() {
rp.wg.Wait()
}

6
internal/core/metrics.go

@ -3,6 +3,7 @@ package core @@ -3,6 +3,7 @@ package core
import (
"context"
"io"
"log"
"net"
"net/http"
"strconv"
@ -53,7 +54,10 @@ func newMetrics( @@ -53,7 +54,10 @@ func newMetrics(
router.SetTrustedProxies(nil)
router.GET("/metrics", m.onMetrics)
m.server = &http.Server{Handler: router}
m.server = &http.Server{
Handler: router,
ErrorLog: log.New(&nilWriter{}, "", 0),
}
m.log(logger.Info, "listener opened on "+address)

4
internal/core/pprof.go

@ -2,6 +2,7 @@ package core @@ -2,6 +2,7 @@ package core
import (
"context"
"log"
"net"
"net/http"
@ -37,7 +38,8 @@ func newPPROF( @@ -37,7 +38,8 @@ func newPPROF(
}
pp.server = &http.Server{
Handler: http.DefaultServeMux,
Handler: http.DefaultServeMux,
ErrorLog: log.New(&nilWriter{}, "", 0),
}
pp.log(logger.Info, "listener opened on "+address)

19
internal/core/webrtc_conn.go

@ -106,6 +106,8 @@ type webRTCConn struct { @@ -106,6 +106,8 @@ type webRTCConn struct {
created time.Time
curPC *webrtc.PeerConnection
mutex sync.RWMutex
closed chan struct{}
}
func newWebRTCConn(
@ -138,6 +140,7 @@ func newWebRTCConn( @@ -138,6 +140,7 @@ func newWebRTCConn(
iceUDPMux: iceUDPMux,
iceTCPMux: iceTCPMux,
iceHostNAT1To1IPs: iceHostNAT1To1IPs,
closed: make(chan struct{}),
}
c.log(logger.Info, "opened")
@ -152,6 +155,10 @@ func (c *webRTCConn) close() { @@ -152,6 +155,10 @@ func (c *webRTCConn) close() {
c.ctxCancel()
}
func (c *webRTCConn) wait() {
<-c.closed
}
func (c *webRTCConn) remoteAddr() net.Addr {
return c.wsconn.RemoteAddr()
}
@ -250,6 +257,7 @@ func (c *webRTCConn) log(level logger.Level, format string, args ...interface{}) @@ -250,6 +257,7 @@ func (c *webRTCConn) log(level logger.Level, format string, args ...interface{})
}
func (c *webRTCConn) run() {
defer close(c.closed)
defer c.wg.Done()
innerCtx, innerCtxCancel := context.WithCancel(c.ctx)
@ -277,11 +285,6 @@ func (c *webRTCConn) run() { @@ -277,11 +285,6 @@ func (c *webRTCConn) run() {
}
func (c *webRTCConn) runInner(ctx context.Context) error {
go func() {
<-ctx.Done()
c.wsconn.Close()
}()
res := c.pathManager.readerAdd(pathReaderAddReq{
author: c,
pathName: c.pathName,
@ -348,6 +351,12 @@ func (c *webRTCConn) runInner(ctx context.Context) error { @@ -348,6 +351,12 @@ func (c *webRTCConn) runInner(ctx context.Context) error {
pcClosed := make(chan struct{})
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
select {
case <-pcClosed:
return
default:
}
c.log(logger.Debug, "peer connection state: "+state.String())
switch state {

50
internal/core/webrtc_server.go

@ -66,6 +66,7 @@ type webRTCServerAPIConnsKickReq struct { @@ -66,6 +66,7 @@ type webRTCServerAPIConnsKickReq struct {
type webRTCConnNewReq struct {
pathName string
wsconn *websocket.Conn
res chan *webRTCConn
}
type webRTCServerParent interface {
@ -84,7 +85,6 @@ type webRTCServer struct { @@ -84,7 +85,6 @@ type webRTCServer struct {
ctx context.Context
ctxCancel func()
wg sync.WaitGroup
ln net.Listener
udpMuxLn net.PacketConn
tcpMuxLn net.Listener
@ -99,6 +99,9 @@ type webRTCServer struct { @@ -99,6 +99,9 @@ type webRTCServer struct {
chConnClose chan *webRTCConn
chAPIConnsList chan webRTCServerAPIConnsListReq
chAPIConnsKick chan webRTCServerAPIConnsKickReq
// out
done chan struct{}
}
func newWebRTCServer(
@ -182,6 +185,7 @@ func newWebRTCServer( @@ -182,6 +185,7 @@ func newWebRTCServer(
chConnClose: make(chan *webRTCConn),
chAPIConnsList: make(chan webRTCServerAPIConnsListReq),
chAPIConnsKick: make(chan webRTCServerAPIConnsKickReq),
done: make(chan struct{}),
}
str := "listener opened on " + address + " (HTTP)"
@ -197,7 +201,6 @@ func newWebRTCServer( @@ -197,7 +201,6 @@ func newWebRTCServer(
s.metrics.webRTCServerSet(s)
}
s.wg.Add(1)
go s.run()
return s, nil
@ -211,14 +214,17 @@ func (s *webRTCServer) log(level logger.Level, format string, args ...interface{ @@ -211,14 +214,17 @@ func (s *webRTCServer) log(level logger.Level, format string, args ...interface{
func (s *webRTCServer) close() {
s.log(logger.Info, "listener is closing")
s.ctxCancel()
s.wg.Wait()
<-s.done
}
func (s *webRTCServer) run() {
defer s.wg.Done()
defer close(s.done)
rp := newHTTPRequestPool()
defer rp.close()
router := gin.New()
router.NoRoute(httpLoggerMiddleware(s), s.onRequest)
router.NoRoute(rp.mw, httpLoggerMiddleware(s), s.onRequest)
tmp := make([]string, len(s.trustedProxies))
for i, entry := range s.trustedProxies {
@ -238,6 +244,8 @@ func (s *webRTCServer) run() { @@ -238,6 +244,8 @@ func (s *webRTCServer) run() {
go hs.Serve(s.ln)
}
var wg sync.WaitGroup
outer:
for {
select {
@ -248,7 +256,7 @@ outer: @@ -248,7 +256,7 @@ outer:
req.pathName,
req.wsconn,
s.stunServers,
&s.wg,
&wg,
s.pathManager,
s,
s.iceHostNAT1To1IPs,
@ -256,6 +264,7 @@ outer: @@ -256,6 +264,7 @@ outer:
s.iceTCPMux,
)
s.conns[c] = struct{}{}
req.res <- c
case conn := <-s.chConnClose:
delete(s.conns, conn)
@ -306,6 +315,8 @@ outer: @@ -306,6 +315,8 @@ outer:
hs.Shutdown(context.Background())
s.ln.Close() // in case Shutdown() is called before Serve()
wg.Wait()
if s.udpMuxLn != nil {
s.udpMuxLn.Close()
}
@ -389,14 +400,29 @@ func (s *webRTCServer) onRequest(ctx *gin.Context) { @@ -389,14 +400,29 @@ func (s *webRTCServer) onRequest(ctx *gin.Context) {
if err != nil {
return
}
defer wsconn.Close()
select {
case s.connNew <- webRTCConnNewReq{
pathName: dir,
wsconn: wsconn,
}:
case <-s.ctx.Done():
c := s.newConn(dir, wsconn)
if c == nil {
return
}
c.wait()
}
}
func (s *webRTCServer) newConn(dir string, wsconn *websocket.Conn) *webRTCConn {
req := webRTCConnNewReq{
pathName: dir,
wsconn: wsconn,
res: make(chan *webRTCConn),
}
select {
case s.connNew <- req:
return <-req.res
case <-s.ctx.Done():
return nil
}
}

32
internal/core/webrtc_server_test.go

@ -15,16 +15,18 @@ import ( @@ -15,16 +15,18 @@ import (
)
type webRTCTestClient struct {
wc *websocket.Conn
pc *webrtc.PeerConnection
track chan *webrtc.TrackRemote
wc *websocket.Conn
pc *webrtc.PeerConnection
track chan *webrtc.TrackRemote
closed chan struct{}
}
func newWebRTCTestClient(addr string) (*webRTCTestClient, error) {
wc, _, err := websocket.DefaultDialer.Dial(addr, nil) //nolint:bodyclose
wc, res, err := websocket.DefaultDialer.Dial(addr, nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
_, msg, err := wc.ReadMessage()
if err != nil {
@ -55,13 +57,25 @@ func newWebRTCTestClient(addr string) (*webRTCTestClient, error) { @@ -55,13 +57,25 @@ func newWebRTCTestClient(addr string) (*webRTCTestClient, error) {
})
connected := make(chan struct{})
closed := make(chan struct{})
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
if state == webrtc.PeerConnectionStateConnected {
switch state {
case webrtc.PeerConnectionStateConnected:
close(connected)
case webrtc.PeerConnectionStateClosed:
select {
case <-closed:
return
default:
}
close(closed)
}
})
track := make(chan *webrtc.TrackRemote, 1)
pc.OnTrack(func(trak *webrtc.TrackRemote, recv *webrtc.RTPReceiver) {
track <- trak
})
@ -143,15 +157,17 @@ func newWebRTCTestClient(addr string) (*webRTCTestClient, error) { @@ -143,15 +157,17 @@ func newWebRTCTestClient(addr string) (*webRTCTestClient, error) {
<-connected
return &webRTCTestClient{
wc: wc,
pc: pc,
track: track,
wc: wc,
pc: pc,
track: track,
closed: closed,
}, nil
}
func (c *webRTCTestClient) close() {
c.pc.Close()
c.wc.Close()
<-c.closed
}
func TestWebRTCServer(t *testing.T) {

Loading…
Cancel
Save