Ready-to-use SRT / WebRTC / RTSP / RTMP / LL-HLS media server and media proxy that allows to read, publish, proxy, record and playback video and audio streams.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

879 lines
19 KiB

package core
import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/aler9/gortsplib/v2/pkg/format"
"github.com/aler9/gortsplib/v2/pkg/formatdecenc/rtph264"
"github.com/aler9/gortsplib/v2/pkg/formatdecenc/rtpvp8"
"github.com/aler9/gortsplib/v2/pkg/formatdecenc/rtpvp9"
"github.com/aler9/gortsplib/v2/pkg/media"
"github.com/aler9/gortsplib/v2/pkg/ringbuffer"
"github.com/google/uuid"
"github.com/pion/ice/v2"
"github.com/pion/interceptor"
"github.com/pion/webrtc/v3"
"github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/formatprocessor"
"github.com/aler9/rtsp-simple-server/internal/logger"
"github.com/aler9/rtsp-simple-server/internal/websocket"
)
const (
webrtcHandshakeDeadline = 10 * time.Second
webrtcWsWriteDeadline = 2 * time.Second
webrtcPayloadMaxSize = 1188 // 1200 - 12 (RTP header)
)
// newPeerConnection creates a PeerConnection with the default codecs and
// interceptors. See RegisterDefaultCodecs and RegisterDefaultInterceptors.
//
// This function is a copy of webrtc/peerconnection.go
// unlike the original one, allows you to add additional custom options
func newPeerConnection(configuration webrtc.Configuration,
options ...func(*webrtc.API),
) (*webrtc.PeerConnection, error) {
m := &webrtc.MediaEngine{}
if err := m.RegisterDefaultCodecs(); err != nil {
return nil, err
}
i := &interceptor.Registry{}
if err := webrtc.RegisterDefaultInterceptors(m, i); err != nil {
return nil, err
}
options = append(options, webrtc.WithMediaEngine(m))
options = append(options, webrtc.WithInterceptorRegistry(i))
api := webrtc.NewAPI(options...)
return api.NewPeerConnection(configuration)
}
type webRTCTrack struct {
media *media.Media
format format.Format
webRTCTrack *webrtc.TrackLocalStaticRTP
cb func(formatprocessor.Unit, context.Context, chan error)
}
func gatherMedias(tracks []*webRTCTrack) media.Medias {
var ret media.Medias
for _, track := range tracks {
ret = append(ret, track.media)
}
return ret
}
type webRTCConnPathManager interface {
readerAdd(req pathReaderAddReq) pathReaderSetupPlayRes
}
type webRTCConnParent interface {
log(logger.Level, string, ...interface{})
connClose(*webRTCConn)
}
type webRTCConn struct {
readBufferCount int
pathName string
wsconn *websocket.ServerConn
iceServers []string
wg *sync.WaitGroup
pathManager webRTCConnPathManager
parent webRTCConnParent
iceUDPMux ice.UDPMux
iceTCPMux ice.TCPMux
iceHostNAT1To1IPs []string
ctx context.Context
ctxCancel func()
uuid uuid.UUID
created time.Time
curPC *webrtc.PeerConnection
mutex sync.RWMutex
closed chan struct{}
}
func newWebRTCConn(
parentCtx context.Context,
readBufferCount int,
pathName string,
wsconn *websocket.ServerConn,
iceServers []string,
wg *sync.WaitGroup,
pathManager webRTCConnPathManager,
parent webRTCConnParent,
iceHostNAT1To1IPs []string,
iceUDPMux ice.UDPMux,
iceTCPMux ice.TCPMux,
) *webRTCConn {
ctx, ctxCancel := context.WithCancel(parentCtx)
c := &webRTCConn{
readBufferCount: readBufferCount,
pathName: pathName,
wsconn: wsconn,
iceServers: iceServers,
wg: wg,
pathManager: pathManager,
parent: parent,
ctx: ctx,
ctxCancel: ctxCancel,
uuid: uuid.New(),
created: time.Now(),
iceUDPMux: iceUDPMux,
iceTCPMux: iceTCPMux,
iceHostNAT1To1IPs: iceHostNAT1To1IPs,
closed: make(chan struct{}),
}
c.log(logger.Info, "opened")
wg.Add(1)
go c.run()
return c
}
func (c *webRTCConn) close() {
c.ctxCancel()
}
func (c *webRTCConn) wait() {
<-c.closed
}
func (c *webRTCConn) remoteAddr() net.Addr {
return c.wsconn.RemoteAddr()
}
func (c *webRTCConn) peerConnectionEstablished() bool {
c.mutex.RLock()
defer c.mutex.RUnlock()
return c.curPC != nil
}
func (c *webRTCConn) localCandidate() string {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.curPC != nil {
var cid string
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.ICECandidatePairStats); ok && tstats.Nominated {
cid = tstats.LocalCandidateID
break
}
}
if cid != "" {
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.ICECandidateStats); ok && tstats.ID == cid {
return tstats.CandidateType.String() + "/" + tstats.Protocol + "/" +
tstats.IP + "/" + strconv.FormatInt(int64(tstats.Port), 10)
}
}
}
}
return ""
}
func (c *webRTCConn) remoteCandidate() string {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.curPC != nil {
var cid string
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.ICECandidatePairStats); ok && tstats.Nominated {
cid = tstats.RemoteCandidateID
break
}
}
if cid != "" {
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.ICECandidateStats); ok && tstats.ID == cid {
return tstats.CandidateType.String() + "/" + tstats.Protocol + "/" +
tstats.IP + "/" + strconv.FormatInt(int64(tstats.Port), 10)
}
}
}
}
return ""
}
func (c *webRTCConn) bytesReceived() uint64 {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.curPC != nil {
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.TransportStats); ok {
if tstats.ID == "iceTransport" {
return tstats.BytesReceived
}
}
}
}
return 0
}
func (c *webRTCConn) bytesSent() uint64 {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.curPC != nil {
for _, stats := range c.curPC.GetStats() {
if tstats, ok := stats.(webrtc.TransportStats); ok {
if tstats.ID == "iceTransport" {
return tstats.BytesSent
}
}
}
}
return 0
}
func (c *webRTCConn) log(level logger.Level, format string, args ...interface{}) {
c.parent.log(level, "[conn %v] "+format, append([]interface{}{c.wsconn.RemoteAddr()}, args...)...)
}
func (c *webRTCConn) run() {
defer close(c.closed)
defer c.wg.Done()
innerCtx, innerCtxCancel := context.WithCancel(c.ctx)
runErr := make(chan error)
go func() {
runErr <- c.runInner(innerCtx)
}()
var err error
select {
case err = <-runErr:
innerCtxCancel()
case <-c.ctx.Done():
innerCtxCancel()
<-runErr
err = errors.New("terminated")
}
c.ctxCancel()
c.parent.connClose(c)
c.log(logger.Info, "closed (%v)", err)
}
func (c *webRTCConn) runInner(ctx context.Context) error {
res := c.pathManager.readerAdd(pathReaderAddReq{
author: c,
pathName: c.pathName,
authenticate: func(
pathIPs []fmt.Stringer,
pathUser conf.Credential,
pathPass conf.Credential,
) error {
return nil
},
})
if res.err != nil {
return res.err
}
path := res.path
defer func() {
path.readerRemove(pathReaderRemoveReq{author: c})
}()
tracks, err := c.allocateTracks(res.stream.medias())
if err != nil {
return err
}
err = c.wsconn.WriteJSON(c.genICEServers())
if err != nil {
return err
}
offer, err := c.readOffer()
if err != nil {
return err
}
configuration := webrtc.Configuration{ICEServers: c.genICEServers()}
settingsEngine := webrtc.SettingEngine{}
if len(c.iceHostNAT1To1IPs) != 0 {
settingsEngine.SetNAT1To1IPs(c.iceHostNAT1To1IPs, webrtc.ICECandidateTypeHost)
}
if c.iceUDPMux != nil {
settingsEngine.SetICEUDPMux(c.iceUDPMux)
}
if c.iceTCPMux != nil {
settingsEngine.SetICETCPMux(c.iceTCPMux)
settingsEngine.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4})
}
pc, err := newPeerConnection(configuration, webrtc.WithSettingEngine(settingsEngine))
if err != nil {
return err
}
pcConnected := make(chan struct{})
pcDisconnected := make(chan struct{})
pcClosed := make(chan struct{})
var stateChangeMutex sync.Mutex
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
stateChangeMutex.Lock()
defer stateChangeMutex.Unlock()
select {
case <-pcClosed:
return
default:
}
c.log(logger.Debug, "peer connection state: "+state.String())
switch state {
case webrtc.PeerConnectionStateConnected:
close(pcConnected)
case webrtc.PeerConnectionStateDisconnected:
close(pcDisconnected)
case webrtc.PeerConnectionStateClosed:
close(pcClosed)
}
})
defer func() {
pc.Close()
<-pcClosed
}()
for _, track := range tracks {
rtpSender, err := pc.AddTrack(track.webRTCTrack)
if err != nil {
return err
}
// read incoming RTCP packets in order to make interceptors work
go func() {
buf := make([]byte, 1500)
for {
_, _, err := rtpSender.Read(buf)
if err != nil {
return
}
}
}()
}
localCandidate := make(chan *webrtc.ICECandidateInit)
pc.OnICECandidate(func(i *webrtc.ICECandidate) {
if i != nil {
v := i.ToJSON()
select {
case localCandidate <- &v:
case <-pcConnected:
case <-ctx.Done():
}
}
})
err = pc.SetRemoteDescription(*offer)
if err != nil {
return err
}
answer, err := pc.CreateAnswer(nil)
if err != nil {
return err
}
err = pc.SetLocalDescription(answer)
if err != nil {
return err
}
err = c.wsconn.WriteJSON(&answer)
if err != nil {
return err
}
wsReadError := make(chan error)
remoteCandidate := make(chan *webrtc.ICECandidateInit)
go func() {
for {
candidate, err := c.readCandidate()
if err != nil {
select {
case wsReadError <- err:
case <-ctx.Done():
}
return
}
select {
case remoteCandidate <- candidate:
case <-pcConnected:
case <-ctx.Done():
}
}
}()
t := time.NewTimer(webrtcHandshakeDeadline)
defer t.Stop()
outer:
for {
select {
case candidate := <-localCandidate:
c.log(logger.Debug, "local candidate: %+v", candidate.Candidate)
err := c.wsconn.WriteJSON(candidate)
if err != nil {
return err
}
case candidate := <-remoteCandidate:
c.log(logger.Debug, "remote candidate: %+v", candidate.Candidate)
err := pc.AddICECandidate(*candidate)
if err != nil {
return err
}
case err := <-wsReadError:
return err
case <-t.C:
return fmt.Errorf("deadline exceeded")
case <-pcConnected:
break outer
case <-ctx.Done():
return fmt.Errorf("terminated")
}
}
// Keep WebSocket connection open and use it to notify shutdowns.
// This is because pion/webrtc doesn't write yet a WebRTC shutdown
// message to clients (like a DTLS close alert or a RTCP BYE),
// therefore browsers do not properly detect shutdowns and do not
// attempt to restart the connection immediately.
c.mutex.Lock()
c.curPC = pc
c.mutex.Unlock()
c.log(logger.Info, "peer connection established, local candidate: %v, remote candidate: %v",
c.localCandidate(), c.remoteCandidate())
ringBuffer, _ := ringbuffer.New(uint64(c.readBufferCount))
defer ringBuffer.Close()
writeError := make(chan error)
for _, track := range tracks {
ctrack := track
res.stream.readerAdd(c, track.media, track.format, func(unit formatprocessor.Unit) {
ringBuffer.Push(func() {
ctrack.cb(unit, ctx, writeError)
})
})
}
defer res.stream.readerRemove(c)
c.log(logger.Info, "is reading from path '%s', %s",
path.name, sourceMediaInfo(gatherMedias(tracks)))
go func() {
for {
item, ok := ringBuffer.Pull()
if !ok {
return
}
item.(func())()
}
}()
select {
case <-pcDisconnected:
return fmt.Errorf("peer connection closed")
case err := <-wsReadError:
return fmt.Errorf("websocket error: %v", err)
case err := <-writeError:
return err
case <-ctx.Done():
return fmt.Errorf("terminated")
}
}
func (c *webRTCConn) allocateTracks(medias media.Medias) ([]*webRTCTrack, error) {
var ret []*webRTCTrack
var vp9Format *format.VP9
vp9Media := medias.FindFormat(&vp9Format)
if vp9Format != nil {
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeVP9,
ClockRate: uint32(vp9Format.ClockRate()),
},
"vp9",
"rtspss",
)
if err != nil {
return nil, err
}
encoder := &rtpvp9.Encoder{
PayloadType: 96,
PayloadMaxSize: webrtcPayloadMaxSize,
}
encoder.Init()
ret = append(ret, &webRTCTrack{
media: vp9Media,
format: vp9Format,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
tunit := unit.(*formatprocessor.UnitVP9)
if tunit.Frame == nil {
return
}
packets, err := encoder.Encode(tunit.Frame, tunit.PTS)
if err != nil {
return
}
for _, pkt := range packets {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
var vp8Format *format.VP8
if vp9Format == nil {
vp8Media := medias.FindFormat(&vp8Format)
if vp8Format != nil {
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeVP8,
ClockRate: uint32(vp8Format.ClockRate()),
},
"vp8",
"rtspss",
)
if err != nil {
return nil, err
}
encoder := &rtpvp8.Encoder{
PayloadType: 96,
PayloadMaxSize: webrtcPayloadMaxSize,
}
encoder.Init()
ret = append(ret, &webRTCTrack{
media: vp8Media,
format: vp8Format,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
tunit := unit.(*formatprocessor.UnitVP8)
if tunit.Frame == nil {
return
}
packets, err := encoder.Encode(tunit.Frame, tunit.PTS)
if err != nil {
return
}
for _, pkt := range packets {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
}
if vp9Format == nil && vp8Format == nil {
var h264Format *format.H264
h264Media := medias.FindFormat(&h264Format)
if h264Format != nil {
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeH264,
ClockRate: uint32(h264Format.ClockRate()),
},
"h264",
"rtspss",
)
if err != nil {
return nil, err
}
encoder := &rtph264.Encoder{
PayloadType: 96,
PayloadMaxSize: webrtcPayloadMaxSize,
}
encoder.Init()
var lastPTS time.Duration
firstNALUReceived := false
ret = append(ret, &webRTCTrack{
media: h264Media,
format: h264Format,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
tunit := unit.(*formatprocessor.UnitH264)
if tunit.AU == nil {
return
}
if !firstNALUReceived {
firstNALUReceived = true
lastPTS = tunit.PTS
} else {
if tunit.PTS < lastPTS {
select {
case writeError <- fmt.Errorf("WebRTC doesn't support H264 streams with B-frames"):
case <-ctx.Done():
}
return
}
lastPTS = tunit.PTS
}
packets, err := encoder.Encode(tunit.AU, tunit.PTS)
if err != nil {
return
}
for _, pkt := range packets {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
}
var opusFormat *format.Opus
opusMedia := medias.FindFormat(&opusFormat)
if opusFormat != nil {
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeOpus,
ClockRate: uint32(opusFormat.ClockRate()),
},
"opus",
"rtspss",
)
if err != nil {
return nil, err
}
ret = append(ret, &webRTCTrack{
media: opusMedia,
format: opusFormat,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
for _, pkt := range unit.GetRTPPackets() {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
var g722Format *format.G722
if opusFormat == nil {
g722Media := medias.FindFormat(&g722Format)
if g722Format != nil {
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeG722,
ClockRate: uint32(g722Format.ClockRate()),
},
"g722",
"rtspss",
)
if err != nil {
return nil, err
}
ret = append(ret, &webRTCTrack{
media: g722Media,
format: g722Format,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
for _, pkt := range unit.GetRTPPackets() {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
}
var g711Format *format.G711
if opusFormat == nil && g722Format == nil {
g711Media := medias.FindFormat(&g711Format)
if g711Format != nil {
var mtyp string
if g711Format.MULaw {
mtyp = webrtc.MimeTypePCMU
} else {
mtyp = webrtc.MimeTypePCMA
}
webRTCTrak, err := webrtc.NewTrackLocalStaticRTP(
webrtc.RTPCodecCapability{
MimeType: mtyp,
ClockRate: uint32(g711Format.ClockRate()),
},
"g711",
"rtspss",
)
if err != nil {
return nil, err
}
ret = append(ret, &webRTCTrack{
media: g711Media,
format: g711Format,
webRTCTrack: webRTCTrak,
cb: func(unit formatprocessor.Unit, ctx context.Context, writeError chan error) {
for _, pkt := range unit.GetRTPPackets() {
webRTCTrak.WriteRTP(pkt)
}
},
})
}
}
if ret == nil {
return nil, fmt.Errorf(
"the stream doesn't contain any supported codec (which are currently VP9, VP8, H264, Opus, G722, G711)")
}
return ret, nil
}
func (c *webRTCConn) genICEServers() []webrtc.ICEServer {
ret := make([]webrtc.ICEServer, len(c.iceServers))
for i, s := range c.iceServers {
parts := strings.Split(s, ":")
if len(parts) == 5 {
if parts[1] == "AUTH_SECRET" {
s := webrtc.ICEServer{
URLs: []string{parts[0] + ":" + parts[3] + ":" + parts[4]},
}
randomUser := func() string {
const charset = "abcdefghijklmnopqrstuvwxyz1234567890"
b := make([]byte, 20)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}()
expireDate := time.Now().Add(24 * 3600 * time.Second).Unix()
s.Username = strconv.FormatInt(expireDate, 10) + ":" + randomUser
h := hmac.New(sha1.New, []byte(parts[2]))
h.Write([]byte(s.Username))
s.Credential = base64.StdEncoding.EncodeToString(h.Sum(nil))
ret[i] = s
} else {
ret[i] = webrtc.ICEServer{
URLs: []string{parts[0] + ":" + parts[3] + ":" + parts[4]},
Username: parts[1],
Credential: parts[2],
}
}
} else {
ret[i] = webrtc.ICEServer{
URLs: []string{s},
}
}
}
return ret
}
func (c *webRTCConn) readOffer() (*webrtc.SessionDescription, error) {
var offer webrtc.SessionDescription
err := c.wsconn.ReadJSON(&offer)
if err != nil {
return nil, err
}
if offer.Type != webrtc.SDPTypeOffer {
return nil, fmt.Errorf("received SDP is not an offer")
}
return &offer, nil
}
func (c *webRTCConn) readCandidate() (*webrtc.ICECandidateInit, error) {
var candidate webrtc.ICECandidateInit
err := c.wsconn.ReadJSON(&candidate)
if err != nil {
return nil, err
}
return &candidate, err
}
// apiReaderDescribe implements reader.
func (c *webRTCConn) apiReaderDescribe() interface{} {
return struct {
Type string `json:"type"`
ID string `json:"id"`
}{"webRTCConn", c.uuid.String()}
}