Browse Source

webrtc muxer: fix multiple race conditions

pull/1377/head
aler9 3 years ago
parent
commit
f3f55452e5
  1. 6
      internal/core/api.go
  2. 25
      internal/core/http_requestpool.go
  3. 6
      internal/core/metrics.go
  4. 4
      internal/core/pprof.go
  5. 19
      internal/core/webrtc_conn.go
  6. 50
      internal/core/webrtc_server.go
  7. 32
      internal/core/webrtc_server_test.go

6
internal/core/api.go

@ -3,6 +3,7 @@ package core
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"log"
"net" "net"
"net/http" "net/http"
"reflect" "reflect"
@ -201,7 +202,10 @@ func newAPI(
group.POST("/v1/webrtcconns/kick/:id", a.onWebRTCConnsKick) group.POST("/v1/webrtcconns/kick/:id", a.onWebRTCConnsKick)
} }
a.s = &http.Server{Handler: router} a.s = &http.Server{
Handler: router,
ErrorLog: log.New(&nilWriter{}, "", 0),
}
go a.s.Serve(ln) go a.s.Serve(ln)

25
internal/core/http_requestpool.go

@ -0,0 +1,25 @@
package core
import (
"sync"
"github.com/gin-gonic/gin"
)
type httpRequestPool struct {
wg sync.WaitGroup
}
func newHTTPRequestPool() *httpRequestPool {
return &httpRequestPool{}
}
func (rp *httpRequestPool) mw(ctx *gin.Context) {
rp.wg.Add(1)
ctx.Next()
rp.wg.Done()
}
func (rp *httpRequestPool) close() {
rp.wg.Wait()
}

6
internal/core/metrics.go

@ -3,6 +3,7 @@ package core
import ( import (
"context" "context"
"io" "io"
"log"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
@ -53,7 +54,10 @@ func newMetrics(
router.SetTrustedProxies(nil) router.SetTrustedProxies(nil)
router.GET("/metrics", m.onMetrics) router.GET("/metrics", m.onMetrics)
m.server = &http.Server{Handler: router} m.server = &http.Server{
Handler: router,
ErrorLog: log.New(&nilWriter{}, "", 0),
}
m.log(logger.Info, "listener opened on "+address) m.log(logger.Info, "listener opened on "+address)

4
internal/core/pprof.go

@ -2,6 +2,7 @@ package core
import ( import (
"context" "context"
"log"
"net" "net"
"net/http" "net/http"
@ -37,7 +38,8 @@ func newPPROF(
} }
pp.server = &http.Server{ pp.server = &http.Server{
Handler: http.DefaultServeMux, Handler: http.DefaultServeMux,
ErrorLog: log.New(&nilWriter{}, "", 0),
} }
pp.log(logger.Info, "listener opened on "+address) pp.log(logger.Info, "listener opened on "+address)

19
internal/core/webrtc_conn.go

@ -106,6 +106,8 @@ type webRTCConn struct {
created time.Time created time.Time
curPC *webrtc.PeerConnection curPC *webrtc.PeerConnection
mutex sync.RWMutex mutex sync.RWMutex
closed chan struct{}
} }
func newWebRTCConn( func newWebRTCConn(
@ -138,6 +140,7 @@ func newWebRTCConn(
iceUDPMux: iceUDPMux, iceUDPMux: iceUDPMux,
iceTCPMux: iceTCPMux, iceTCPMux: iceTCPMux,
iceHostNAT1To1IPs: iceHostNAT1To1IPs, iceHostNAT1To1IPs: iceHostNAT1To1IPs,
closed: make(chan struct{}),
} }
c.log(logger.Info, "opened") c.log(logger.Info, "opened")
@ -152,6 +155,10 @@ func (c *webRTCConn) close() {
c.ctxCancel() c.ctxCancel()
} }
func (c *webRTCConn) wait() {
<-c.closed
}
func (c *webRTCConn) remoteAddr() net.Addr { func (c *webRTCConn) remoteAddr() net.Addr {
return c.wsconn.RemoteAddr() return c.wsconn.RemoteAddr()
} }
@ -250,6 +257,7 @@ func (c *webRTCConn) log(level logger.Level, format string, args ...interface{})
} }
func (c *webRTCConn) run() { func (c *webRTCConn) run() {
defer close(c.closed)
defer c.wg.Done() defer c.wg.Done()
innerCtx, innerCtxCancel := context.WithCancel(c.ctx) innerCtx, innerCtxCancel := context.WithCancel(c.ctx)
@ -277,11 +285,6 @@ func (c *webRTCConn) run() {
} }
func (c *webRTCConn) runInner(ctx context.Context) error { func (c *webRTCConn) runInner(ctx context.Context) error {
go func() {
<-ctx.Done()
c.wsconn.Close()
}()
res := c.pathManager.readerAdd(pathReaderAddReq{ res := c.pathManager.readerAdd(pathReaderAddReq{
author: c, author: c,
pathName: c.pathName, pathName: c.pathName,
@ -348,6 +351,12 @@ func (c *webRTCConn) runInner(ctx context.Context) error {
pcClosed := make(chan struct{}) pcClosed := make(chan struct{})
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
select {
case <-pcClosed:
return
default:
}
c.log(logger.Debug, "peer connection state: "+state.String()) c.log(logger.Debug, "peer connection state: "+state.String())
switch state { switch state {

50
internal/core/webrtc_server.go

@ -66,6 +66,7 @@ type webRTCServerAPIConnsKickReq struct {
type webRTCConnNewReq struct { type webRTCConnNewReq struct {
pathName string pathName string
wsconn *websocket.Conn wsconn *websocket.Conn
res chan *webRTCConn
} }
type webRTCServerParent interface { type webRTCServerParent interface {
@ -84,7 +85,6 @@ type webRTCServer struct {
ctx context.Context ctx context.Context
ctxCancel func() ctxCancel func()
wg sync.WaitGroup
ln net.Listener ln net.Listener
udpMuxLn net.PacketConn udpMuxLn net.PacketConn
tcpMuxLn net.Listener tcpMuxLn net.Listener
@ -99,6 +99,9 @@ type webRTCServer struct {
chConnClose chan *webRTCConn chConnClose chan *webRTCConn
chAPIConnsList chan webRTCServerAPIConnsListReq chAPIConnsList chan webRTCServerAPIConnsListReq
chAPIConnsKick chan webRTCServerAPIConnsKickReq chAPIConnsKick chan webRTCServerAPIConnsKickReq
// out
done chan struct{}
} }
func newWebRTCServer( func newWebRTCServer(
@ -182,6 +185,7 @@ func newWebRTCServer(
chConnClose: make(chan *webRTCConn), chConnClose: make(chan *webRTCConn),
chAPIConnsList: make(chan webRTCServerAPIConnsListReq), chAPIConnsList: make(chan webRTCServerAPIConnsListReq),
chAPIConnsKick: make(chan webRTCServerAPIConnsKickReq), chAPIConnsKick: make(chan webRTCServerAPIConnsKickReq),
done: make(chan struct{}),
} }
str := "listener opened on " + address + " (HTTP)" str := "listener opened on " + address + " (HTTP)"
@ -197,7 +201,6 @@ func newWebRTCServer(
s.metrics.webRTCServerSet(s) s.metrics.webRTCServerSet(s)
} }
s.wg.Add(1)
go s.run() go s.run()
return s, nil return s, nil
@ -211,14 +214,17 @@ func (s *webRTCServer) log(level logger.Level, format string, args ...interface{
func (s *webRTCServer) close() { func (s *webRTCServer) close() {
s.log(logger.Info, "listener is closing") s.log(logger.Info, "listener is closing")
s.ctxCancel() s.ctxCancel()
s.wg.Wait() <-s.done
} }
func (s *webRTCServer) run() { func (s *webRTCServer) run() {
defer s.wg.Done() defer close(s.done)
rp := newHTTPRequestPool()
defer rp.close()
router := gin.New() router := gin.New()
router.NoRoute(httpLoggerMiddleware(s), s.onRequest) router.NoRoute(rp.mw, httpLoggerMiddleware(s), s.onRequest)
tmp := make([]string, len(s.trustedProxies)) tmp := make([]string, len(s.trustedProxies))
for i, entry := range s.trustedProxies { for i, entry := range s.trustedProxies {
@ -238,6 +244,8 @@ func (s *webRTCServer) run() {
go hs.Serve(s.ln) go hs.Serve(s.ln)
} }
var wg sync.WaitGroup
outer: outer:
for { for {
select { select {
@ -248,7 +256,7 @@ outer:
req.pathName, req.pathName,
req.wsconn, req.wsconn,
s.stunServers, s.stunServers,
&s.wg, &wg,
s.pathManager, s.pathManager,
s, s,
s.iceHostNAT1To1IPs, s.iceHostNAT1To1IPs,
@ -256,6 +264,7 @@ outer:
s.iceTCPMux, s.iceTCPMux,
) )
s.conns[c] = struct{}{} s.conns[c] = struct{}{}
req.res <- c
case conn := <-s.chConnClose: case conn := <-s.chConnClose:
delete(s.conns, conn) delete(s.conns, conn)
@ -306,6 +315,8 @@ outer:
hs.Shutdown(context.Background()) hs.Shutdown(context.Background())
s.ln.Close() // in case Shutdown() is called before Serve() s.ln.Close() // in case Shutdown() is called before Serve()
wg.Wait()
if s.udpMuxLn != nil { if s.udpMuxLn != nil {
s.udpMuxLn.Close() s.udpMuxLn.Close()
} }
@ -389,14 +400,29 @@ func (s *webRTCServer) onRequest(ctx *gin.Context) {
if err != nil { if err != nil {
return return
} }
defer wsconn.Close()
select { c := s.newConn(dir, wsconn)
case s.connNew <- webRTCConnNewReq{ if c == nil {
pathName: dir, return
wsconn: wsconn,
}:
case <-s.ctx.Done():
} }
c.wait()
}
}
func (s *webRTCServer) newConn(dir string, wsconn *websocket.Conn) *webRTCConn {
req := webRTCConnNewReq{
pathName: dir,
wsconn: wsconn,
res: make(chan *webRTCConn),
}
select {
case s.connNew <- req:
return <-req.res
case <-s.ctx.Done():
return nil
} }
} }

32
internal/core/webrtc_server_test.go

@ -15,16 +15,18 @@ import (
) )
type webRTCTestClient struct { type webRTCTestClient struct {
wc *websocket.Conn wc *websocket.Conn
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
track chan *webrtc.TrackRemote track chan *webrtc.TrackRemote
closed chan struct{}
} }
func newWebRTCTestClient(addr string) (*webRTCTestClient, error) { func newWebRTCTestClient(addr string) (*webRTCTestClient, error) {
wc, _, err := websocket.DefaultDialer.Dial(addr, nil) //nolint:bodyclose wc, res, err := websocket.DefaultDialer.Dial(addr, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close()
_, msg, err := wc.ReadMessage() _, msg, err := wc.ReadMessage()
if err != nil { if err != nil {
@ -55,13 +57,25 @@ func newWebRTCTestClient(addr string) (*webRTCTestClient, error) {
}) })
connected := make(chan struct{}) connected := make(chan struct{})
closed := make(chan struct{})
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
if state == webrtc.PeerConnectionStateConnected { switch state {
case webrtc.PeerConnectionStateConnected:
close(connected) close(connected)
case webrtc.PeerConnectionStateClosed:
select {
case <-closed:
return
default:
}
close(closed)
} }
}) })
track := make(chan *webrtc.TrackRemote, 1) track := make(chan *webrtc.TrackRemote, 1)
pc.OnTrack(func(trak *webrtc.TrackRemote, recv *webrtc.RTPReceiver) { pc.OnTrack(func(trak *webrtc.TrackRemote, recv *webrtc.RTPReceiver) {
track <- trak track <- trak
}) })
@ -143,15 +157,17 @@ func newWebRTCTestClient(addr string) (*webRTCTestClient, error) {
<-connected <-connected
return &webRTCTestClient{ return &webRTCTestClient{
wc: wc, wc: wc,
pc: pc, pc: pc,
track: track, track: track,
closed: closed,
}, nil }, nil
} }
func (c *webRTCTestClient) close() { func (c *webRTCTestClient) close() {
c.pc.Close() c.pc.Close()
c.wc.Close() c.wc.Close()
<-c.closed
} }
func TestWebRTCServer(t *testing.T) { func TestWebRTCServer(t *testing.T) {

Loading…
Cancel
Save