diff --git a/go.mod b/go.mod index 73b262b0..3b52193f 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.18 require ( code.cloudfoundry.org/bytefmt v0.0.0-20211005130812-5bb3c17173e5 github.com/abema/go-mp4 v0.8.0 - github.com/aler9/gortsplib v0.0.0-20221101102023-dbb6934a3c3e + github.com/aler9/gortsplib v0.0.0-20221102164639-d3c23a849c83 github.com/asticode/go-astits v1.10.1-0.20220319093903-4abe66a9b757 github.com/fsnotify/fsnotify v1.4.9 github.com/gin-gonic/gin v1.8.1 diff --git a/go.sum b/go.sum index f759da7c..52eb25d5 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= -github.com/aler9/gortsplib v0.0.0-20221101102023-dbb6934a3c3e h1:x+EHN8/YHjG6NQM59WG+fdPmozyIarDZgJZymNbDmFE= -github.com/aler9/gortsplib v0.0.0-20221101102023-dbb6934a3c3e/go.mod h1:BOWNZ/QBkY/eVcRqUzJbPFEsRJshwxaxBT01K260Jeo= +github.com/aler9/gortsplib v0.0.0-20221102164639-d3c23a849c83 h1:Qn/TL5+Nm4g+IgQ1DODtu6oCve0plBiJsprbnLG3yfQ= +github.com/aler9/gortsplib v0.0.0-20221102164639-d3c23a849c83/go.mod h1:BOWNZ/QBkY/eVcRqUzJbPFEsRJshwxaxBT01K260Jeo= github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82 h1:9WgSzBLo3a9ToSVV7sRTBYZ1GGOZUpq4+5H3SN0UZq4= github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82/go.mod h1:qsMrZCbeBf/mCLOeF16KDkPu4gktn/pOWyaq1aYQE7U= github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8= diff --git a/internal/core/data.go b/internal/core/data.go index 8653b5e9..cc29dd9d 100644 --- a/internal/core/data.go +++ b/internal/core/data.go @@ -7,20 +7,65 @@ import ( ) // data is the data unit routed across the server. -// it must contain one or more of the following: -// - a single RTP packet -// - a group of H264 NALUs (grouped by timestamp) -// - a single AAC AU -type data struct { - trackID int +type data interface { + getTrackID() int + getRTPPackets() []*rtp.Packet + getPTSEqualsDTS() bool +} + +type dataGeneric struct { + trackID int + rtpPackets []*rtp.Packet + ptsEqualsDTS bool +} - rtpPacket *rtp.Packet +func (d *dataGeneric) getTrackID() int { + return d.trackID +} - // timing +func (d *dataGeneric) getRTPPackets() []*rtp.Packet { + return d.rtpPackets +} + +func (d *dataGeneric) getPTSEqualsDTS() bool { + return d.ptsEqualsDTS +} + +type dataH264 struct { + trackID int + rtpPackets []*rtp.Packet ptsEqualsDTS bool pts time.Duration + nalus [][]byte +} - h264NALUs [][]byte +func (d *dataH264) getTrackID() int { + return d.trackID +} + +func (d *dataH264) getRTPPackets() []*rtp.Packet { + return d.rtpPackets +} + +func (d *dataH264) getPTSEqualsDTS() bool { + return d.ptsEqualsDTS +} + +type dataMPEG4Audio struct { + trackID int + rtpPackets []*rtp.Packet + pts time.Duration + aus [][]byte +} + +func (d *dataMPEG4Audio) getTrackID() int { + return d.trackID +} + +func (d *dataMPEG4Audio) getRTPPackets() []*rtp.Packet { + return d.rtpPackets +} - mpeg4AudioAU []byte +func (d *dataMPEG4Audio) getPTSEqualsDTS() bool { + return true } diff --git a/internal/core/hls_muxer.go b/internal/core/hls_muxer.go index 3775b5e6..4059a079 100644 --- a/internal/core/hls_muxer.go +++ b/internal/core/hls_muxer.go @@ -14,7 +14,6 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/mpeg4audio" "github.com/aler9/gortsplib/pkg/ringbuffer" - "github.com/aler9/gortsplib/pkg/rtpmpeg4audio" "github.com/gin-gonic/gin" "github.com/aler9/rtsp-simple-server/internal/conf" @@ -295,7 +294,6 @@ func (m *hlsMuxer) runInner(innerCtx context.Context, innerReady chan struct{}) videoTrackID := -1 var audioTrack *gortsplib.TrackMPEG4Audio audioTrackID := -1 - var aacDecoder *rtpmpeg4audio.Decoder for i, track := range res.stream.tracks() { switch tt := track.(type) { @@ -314,13 +312,6 @@ func (m *hlsMuxer) runInner(innerCtx context.Context, innerReady chan struct{}) audioTrack = tt audioTrackID = i - aacDecoder = &rtpmpeg4audio.Decoder{ - SampleRate: tt.Config.SampleRate, - SizeLength: tt.SizeLength, - IndexLength: tt.IndexLength, - IndexDeltaLength: tt.IndexDeltaLength, - } - aacDecoder.Init() } } @@ -362,53 +353,12 @@ func (m *hlsMuxer) runInner(innerCtx context.Context, innerReady chan struct{}) writerDone := make(chan error) go func() { - writerDone <- func() error { - var videoInitialPTS *time.Duration - - for { - item, ok := m.ringBuffer.Pull() - if !ok { - return fmt.Errorf("terminated") - } - data := item.(*data) - - if videoTrack != nil && data.trackID == videoTrackID { - if data.h264NALUs == nil { - continue - } - - if videoInitialPTS == nil { - v := data.pts - videoInitialPTS = &v - } - pts := data.pts - *videoInitialPTS - - err = m.muxer.WriteH264(time.Now(), pts, data.h264NALUs) - if err != nil { - return fmt.Errorf("muxer error: %v", err) - } - } else if audioTrack != nil && data.trackID == audioTrackID { - aus, pts, err := aacDecoder.Decode(data.rtpPacket) - if err != nil { - if err != rtpmpeg4audio.ErrMorePacketsNeeded { - m.log(logger.Warn, "unable to decode audio track: %v", err) - } - continue - } - - for i, au := range aus { - err = m.muxer.WriteAAC( - time.Now(), - pts+time.Duration(i)*mpeg4audio.SamplesPerAccessUnit* - time.Second/time.Duration(audioTrack.ClockRate()), - au) - if err != nil { - return fmt.Errorf("muxer error: %v", err) - } - } - } - } - }() + writerDone <- m.runWriter( + videoTrack, + videoTrackID, + audioTrack, + audioTrackID, + ) }() closeCheckTicker := time.NewTicker(closeCheckPeriod) @@ -435,6 +385,68 @@ func (m *hlsMuxer) runInner(innerCtx context.Context, innerReady chan struct{}) } } +func (m *hlsMuxer) runWriter( + videoTrack *gortsplib.TrackH264, + videoTrackID int, + audioTrack *gortsplib.TrackMPEG4Audio, + audioTrackID int, +) error { + videoStartPTSFilled := false + var videoStartPTS time.Duration + audioStartPTSFilled := false + var audioStartPTS time.Duration + + for { + item, ok := m.ringBuffer.Pull() + if !ok { + return fmt.Errorf("terminated") + } + data := item.(data) + + if videoTrack != nil && data.getTrackID() == videoTrackID { + tdata := data.(*dataH264) + + if tdata.nalus == nil { + continue + } + + if !videoStartPTSFilled { + videoStartPTSFilled = true + videoStartPTS = tdata.pts + } + pts := tdata.pts - videoStartPTS + + err := m.muxer.WriteH264(time.Now(), pts, tdata.nalus) + if err != nil { + return fmt.Errorf("muxer error: %v", err) + } + } else if audioTrack != nil && data.getTrackID() == audioTrackID { + tdata := data.(*dataMPEG4Audio) + + if tdata.aus == nil { + continue + } + + if !audioStartPTSFilled { + audioStartPTSFilled = true + audioStartPTS = tdata.pts + } + pts := tdata.pts - audioStartPTS + + for i, au := range tdata.aus { + err := m.muxer.WriteAAC( + time.Now(), + pts+time.Duration(i)*mpeg4audio.SamplesPerAccessUnit* + time.Second/time.Duration(audioTrack.ClockRate()), + au) + if err != nil { + return fmt.Errorf("muxer error: %v", err) + } + } + } + } +} + func (m *hlsMuxer) handleRequest(req *hlsMuxerRequest) func() *hls.MuxerFileResponse { atomic.StoreInt64(m.lastRequestTime, time.Now().UnixNano()) @@ -558,7 +570,7 @@ func (m *hlsMuxer) apiHLSMuxersList(req hlsServerAPIMuxersListSubReq) { } // onReaderData implements reader. -func (m *hlsMuxer) onReaderData(data *data) { +func (m *hlsMuxer) onReaderData(data data) { m.ringBuffer.Push(data) } diff --git a/internal/core/hls_source.go b/internal/core/hls_source.go index 20c6a3e3..f2c178c2 100644 --- a/internal/core/hls_source.go +++ b/internal/core/hls_source.go @@ -79,21 +79,26 @@ func (s *hlsSource) run(ctx context.Context) error { } onVideoData := func(pts time.Duration, nalus [][]byte) { - stream.writeData(&data{ + err := stream.writeData(&dataH264{ trackID: videoTrackID, ptsEqualsDTS: h264.IDRPresent(nalus), pts: pts, - h264NALUs: nalus, + nalus: nalus, }) + if err != nil { + s.Log(logger.Warn, "%v", err) + } } onAudioData := func(pts time.Duration, au []byte) { - stream.writeData(&data{ - trackID: audioTrackID, - ptsEqualsDTS: true, - pts: pts, - mpeg4AudioAU: au, + err := stream.writeData(&dataMPEG4Audio{ + trackID: audioTrackID, + pts: pts, + aus: [][]byte{au}, }) + if err != nil { + s.Log(logger.Warn, "%v", err) + } } c, err := hls.NewClient( diff --git a/internal/core/hls_source_test.go b/internal/core/hls_source_test.go index e979b5ce..b0a3735f 100644 --- a/internal/core/hls_source_test.go +++ b/internal/core/hls_source_test.go @@ -14,6 +14,7 @@ import ( "github.com/aler9/gortsplib/pkg/url" "github.com/asticode/go-astits" "github.com/gin-gonic/gin" + "github.com/pion/rtp" "github.com/stretchr/testify/require" ) @@ -136,11 +137,26 @@ func TestHLSSource(t *testing.T) { c := gortsplib.Client{ OnPacketRTP: func(ctx *gortsplib.ClientOnPacketRTPCtx) { - require.Equal(t, [][]byte{ - {0x07, 0x01, 0x02, 0x03}, - {0x08}, - {0x05}, - }, ctx.H264NALUs) + require.Equal(t, &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: ctx.Packet.SequenceNumber, + Timestamp: ctx.Packet.Timestamp, + SSRC: ctx.Packet.SSRC, + CSRC: []uint32{}, + }, + Payload: []byte{ + 0x18, + 0x00, 0x04, + 0x07, 0x01, 0x02, 0x03, // SPS + 0x00, 0x01, + 0x08, // PPS + 0x00, 0x01, + 0x05, // ODR + }, + }, ctx.Packet) close(frameRecv) }, } diff --git a/internal/core/reader.go b/internal/core/reader.go index 6a5dc2e5..ae12a68a 100644 --- a/internal/core/reader.go +++ b/internal/core/reader.go @@ -3,6 +3,6 @@ package core // reader is an entity that can read a stream. type reader interface { close() - onReaderData(*data) + onReaderData(data) apiReaderDescribe() interface{} } diff --git a/internal/core/rpicamera_source.go b/internal/core/rpicamera_source.go index 6926089c..d34b13f4 100644 --- a/internal/core/rpicamera_source.go +++ b/internal/core/rpicamera_source.go @@ -59,12 +59,15 @@ func (s *rpiCameraSource) run(ctx context.Context) error { stream = res.stream } - stream.writeData(&data{ + err := stream.writeData(&dataH264{ trackID: 0, ptsEqualsDTS: h264.IDRPresent(nalus), pts: dts, - h264NALUs: nalus, + nalus: nalus, }) + if err != nil { + s.Log(logger.Warn, "%v", err) + } } cam, err := rpicamera.New(s.params, onData) diff --git a/internal/core/rtmp_conn.go b/internal/core/rtmp_conn.go index 2be09848..8641a316 100644 --- a/internal/core/rtmp_conn.go +++ b/internal/core/rtmp_conn.go @@ -14,7 +14,6 @@ import ( "github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/mpeg4audio" "github.com/aler9/gortsplib/pkg/ringbuffer" - "github.com/aler9/gortsplib/pkg/rtpmpeg4audio" "github.com/notedit/rtmp/format/flv/flvio" "github.com/aler9/rtsp-simple-server/internal/conf" @@ -258,7 +257,6 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { videoTrackID := -1 var audioTrack *gortsplib.TrackMPEG4Audio audioTrackID := -1 - var aacDecoder *rtpmpeg4audio.Decoder for i, track := range res.stream.tracks() { switch tt := track.(type) { @@ -277,13 +275,6 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { audioTrack = tt audioTrackID = i - aacDecoder = &rtpmpeg4audio.Decoder{ - SampleRate: tt.Config.SampleRate, - SizeLength: tt.SizeLength, - IndexLength: tt.IndexLength, - IndexDeltaLength: tt.IndexDeltaLength, - } - aacDecoder.Init() } } @@ -336,7 +327,11 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { // disable read deadline c.nconn.SetReadDeadline(time.Time{}) - var videoInitialPTS *time.Duration + videoStartPTSFilled := false + var videoStartPTS time.Duration + audioStartPTSFilled := false + var audioStartPTS time.Duration + videoFirstIDRFound := false var videoStartDTS time.Duration var videoDTSExtractor *h264.DTSExtractor @@ -346,27 +341,25 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { if !ok { return fmt.Errorf("terminated") } - data := item.(*data) + data := item.(data) + + if videoTrack != nil && data.getTrackID() == videoTrackID { + tdata := data.(*dataH264) - if videoTrack != nil && data.trackID == videoTrackID { - if data.h264NALUs == nil { + if tdata.nalus == nil { continue } - // video is decoded in another routine, - // while audio is decoded in this routine: - // we have to sync their PTS. - if videoInitialPTS == nil { - v := data.pts - videoInitialPTS = &v + if !videoStartPTSFilled { + videoStartPTSFilled = true + videoStartPTS = tdata.pts } - - pts := data.pts - *videoInitialPTS + pts := tdata.pts - videoStartPTS idrPresent := false nonIDRPresent := false - for _, nalu := range data.h264NALUs { + for _, nalu := range tdata.nalus { typ := h264.NALUType(nalu[0] & 0x1F) switch typ { case h264.NALUTypeIDR: @@ -389,7 +382,7 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { videoDTSExtractor = h264.NewDTSExtractor() var err error - dts, err = videoDTSExtractor.Extract(data.h264NALUs, pts) + dts, err = videoDTSExtractor.Extract(tdata.nalus, pts) if err != nil { return err } @@ -403,7 +396,7 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { } var err error - dts, err = videoDTSExtractor.Extract(data.h264NALUs, pts) + dts, err = videoDTSExtractor.Extract(tdata.nalus, pts) if err != nil { return err } @@ -412,7 +405,7 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { pts -= videoStartDTS } - avcc, err := h264.AVCCMarshal(data.h264NALUs) + avcc, err := h264.AVCCMarshal(tdata.nalus) if err != nil { return err } @@ -430,25 +423,31 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error { if err != nil { return err } - } else if audioTrack != nil && data.trackID == audioTrackID { - aus, pts, err := aacDecoder.Decode(data.rtpPacket) - if err != nil { - if err != rtpmpeg4audio.ErrMorePacketsNeeded { - c.log(logger.Warn, "unable to decode audio track: %v", err) - } + } else if audioTrack != nil && data.getTrackID() == audioTrackID { + tdata := data.(*dataMPEG4Audio) + + if tdata.aus == nil { continue } - if videoTrack != nil && !videoFirstIDRFound { - continue + if !audioStartPTSFilled { + audioStartPTSFilled = true + audioStartPTS = tdata.pts } + pts := tdata.pts - audioStartPTS - pts -= videoStartDTS - if pts < 0 { - continue + if videoTrack != nil { + if !videoFirstIDRFound { + continue + } + + pts -= videoStartDTS + if pts < 0 { + continue + } } - for i, au := range aus { + for i, au := range tdata.aus { c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout))) err := c.conn.WriteMessage(&message.MsgAudio{ ChunkStreamID: message.MsgAudioChunkStreamID, @@ -559,12 +558,15 @@ func (c *rtmpConn) runPublish(ctx context.Context, u *url.URL) error { conf.PPS, } - rres.stream.writeData(&data{ + err := rres.stream.writeData(&dataH264{ trackID: videoTrackID, ptsEqualsDTS: false, pts: tmsg.DTS + tmsg.PTSDelta, - h264NALUs: nalus, + nalus: nalus, }) + if err != nil { + c.log(logger.Warn, "%v", err) + } } else if tmsg.H264Type == flvio.AVC_NALU { if videoTrack == nil { return fmt.Errorf("received an H264 packet, but track is not set up") @@ -595,12 +597,15 @@ func (c *rtmpConn) runPublish(ctx context.Context, u *url.URL) error { } } - rres.stream.writeData(&data{ + err = rres.stream.writeData(&dataH264{ trackID: videoTrackID, ptsEqualsDTS: h264.IDRPresent(validNALUs), pts: tmsg.DTS + tmsg.PTSDelta, - h264NALUs: validNALUs, + nalus: validNALUs, }) + if err != nil { + c.log(logger.Warn, "%v", err) + } } case *message.MsgAudio: @@ -609,12 +614,14 @@ func (c *rtmpConn) runPublish(ctx context.Context, u *url.URL) error { return fmt.Errorf("received an AAC packet, but track is not set up") } - rres.stream.writeData(&data{ - trackID: audioTrackID, - ptsEqualsDTS: true, - pts: tmsg.DTS, - mpeg4AudioAU: tmsg.Payload, + err := rres.stream.writeData(&dataMPEG4Audio{ + trackID: audioTrackID, + pts: tmsg.DTS, + aus: [][]byte{tmsg.Payload}, }) + if err != nil { + c.log(logger.Warn, "%v", err) + } } } } @@ -667,7 +674,7 @@ func (c *rtmpConn) authenticate( } // onReaderData implements reader. -func (c *rtmpConn) onReaderData(data *data) { +func (c *rtmpConn) onReaderData(data data) { c.ringBuffer.Push(data) } diff --git a/internal/core/rtmp_source.go b/internal/core/rtmp_source.go index 7bdcde62..396d1efb 100644 --- a/internal/core/rtmp_source.go +++ b/internal/core/rtmp_source.go @@ -170,12 +170,15 @@ func (s *rtmpSource) run(ctx context.Context) error { return fmt.Errorf("unable to decode AVCC: %v", err) } - res.stream.writeData(&data{ + err = res.stream.writeData(&dataH264{ trackID: videoTrackID, ptsEqualsDTS: h264.IDRPresent(nalus), pts: tmsg.DTS + tmsg.PTSDelta, - h264NALUs: nalus, + nalus: nalus, }) + if err != nil { + s.Log(logger.Warn, "%v", err) + } } case *message.MsgAudio: @@ -184,12 +187,14 @@ func (s *rtmpSource) run(ctx context.Context) error { return fmt.Errorf("received an AAC packet, but track is not set up") } - res.stream.writeData(&data{ - trackID: audioTrackID, - ptsEqualsDTS: true, - pts: tmsg.DTS, - mpeg4AudioAU: tmsg.Payload, + err := res.stream.writeData(&dataMPEG4Audio{ + trackID: audioTrackID, + pts: tmsg.DTS, + aus: [][]byte{tmsg.Payload}, }) + if err != nil { + s.Log(logger.Warn, "%v", err) + } } } } diff --git a/internal/core/rtsp_session.go b/internal/core/rtsp_session.go index 5e549f5a..b0593b27 100644 --- a/internal/core/rtsp_session.go +++ b/internal/core/rtsp_session.go @@ -9,6 +9,7 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/base" + "github.com/pion/rtp" "github.com/aler9/rtsp-simple-server/internal/conf" "github.com/aler9/rtsp-simple-server/internal/externalcmd" @@ -342,7 +343,7 @@ func (s *rtspSession) onPause(ctx *gortsplib.ServerHandlerOnPauseCtx) (*base.Res } // onReaderData implements reader. -func (s *rtspSession) onReaderData(data *data) { +func (s *rtspSession) onReaderData(data data) { // packets are routed to the session by gortsplib.ServerStream. } @@ -378,21 +379,33 @@ func (s *rtspSession) apiSourceDescribe() interface{} { // onPacketRTP is called by rtspServer. func (s *rtspSession) onPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) { - if ctx.H264NALUs != nil { - s.stream.writeData(&data{ + var err error + + switch s.announcedTracks[ctx.TrackID].(type) { + case *gortsplib.TrackH264: + err = s.stream.writeData(&dataH264{ trackID: ctx.TrackID, - rtpPacket: ctx.Packet, + rtpPackets: []*rtp.Packet{ctx.Packet}, ptsEqualsDTS: ctx.PTSEqualsDTS, - pts: ctx.H264PTS, - h264NALUs: ctx.H264NALUs, }) - } else { - s.stream.writeData(&data{ + + case *gortsplib.TrackMPEG4Audio: + err = s.stream.writeData(&dataMPEG4Audio{ + trackID: ctx.TrackID, + rtpPackets: []*rtp.Packet{ctx.Packet}, + }) + + default: + err = s.stream.writeData(&dataGeneric{ trackID: ctx.TrackID, - rtpPacket: ctx.Packet, + rtpPackets: []*rtp.Packet{ctx.Packet}, ptsEqualsDTS: ctx.PTSEqualsDTS, }) } + + if err != nil { + s.log(logger.Warn, "%v", err) + } } // onDecodeError is called by rtspServer. diff --git a/internal/core/rtsp_source.go b/internal/core/rtsp_source.go index ef8bd9c5..355f90c0 100644 --- a/internal/core/rtsp_source.go +++ b/internal/core/rtsp_source.go @@ -11,6 +11,7 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/base" + "github.com/pion/rtp" "github.com/aler9/gortsplib/pkg/url" "github.com/aler9/rtsp-simple-server/internal/conf" @@ -143,21 +144,33 @@ func (s *rtspSource) run(ctx context.Context) error { }() c.OnPacketRTP = func(ctx *gortsplib.ClientOnPacketRTPCtx) { - if ctx.H264NALUs != nil { - res.stream.writeData(&data{ + var err error + + switch tracks[ctx.TrackID].(type) { + case *gortsplib.TrackH264: + err = res.stream.writeData(&dataH264{ trackID: ctx.TrackID, - rtpPacket: ctx.Packet, + rtpPackets: []*rtp.Packet{ctx.Packet}, ptsEqualsDTS: ctx.PTSEqualsDTS, - pts: ctx.H264PTS, - h264NALUs: ctx.H264NALUs, }) - } else { - res.stream.writeData(&data{ + + case *gortsplib.TrackMPEG4Audio: + err = res.stream.writeData(&dataMPEG4Audio{ + trackID: ctx.TrackID, + rtpPackets: []*rtp.Packet{ctx.Packet}, + }) + + default: + err = res.stream.writeData(&dataGeneric{ trackID: ctx.TrackID, - rtpPacket: ctx.Packet, + rtpPackets: []*rtp.Packet{ctx.Packet}, ptsEqualsDTS: ctx.PTSEqualsDTS, }) } + + if err != nil { + s.Log(logger.Warn, "%v", err) + } } _, err = c.Play(nil) diff --git a/internal/core/rtsp_source_test.go b/internal/core/rtsp_source_test.go index b465c52e..73b8c372 100644 --- a/internal/core/rtsp_source_test.go +++ b/internal/core/rtsp_source_test.go @@ -1,14 +1,19 @@ package core import ( + "bytes" "crypto/tls" + "net" "os" + "strings" "testing" "time" "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/auth" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/conn" + "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/url" "github.com/pion/rtp" @@ -237,11 +242,9 @@ func TestRTSPSourceNoPassword(t *testing.T) { } func TestRTSPSourceDynamicH264Params(t *testing.T) { - track := &gortsplib.TrackH264{ + stream := gortsplib.NewServerStream(gortsplib.Tracks{&gortsplib.TrackH264{ PayloadType: 96, - } - - stream := gortsplib.NewServerStream(gortsplib.Tracks{track}) + }}) defer stream.Close() s := gortsplib.Server{ @@ -340,3 +343,345 @@ func TestRTSPSourceDynamicH264Params(t *testing.T) { require.Equal(t, []byte{8, 1}, h264Track.SafePPS()) }() } + +func TestRTSPSourceRemovePadding(t *testing.T) { + stream := gortsplib.NewServerStream(gortsplib.Tracks{&gortsplib.TrackH264{ + PayloadType: 96, + }}) + defer stream.Close() + + s := gortsplib.Server{ + Handler: &testServer{ + onDescribe: func(ctx *gortsplib.ServerHandlerOnDescribeCtx) (*base.Response, *gortsplib.ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onSetup: func(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, *gortsplib.ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onPlay: func(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + RTSPAddress: "127.0.0.1:8555", + } + err := s.Start() + require.NoError(t, err) + defer s.Wait() + defer s.Close() + + p, ok := newInstance("rtmpDisable: yes\n" + + "hlsDisable: yes\n" + + "paths:\n" + + " proxied:\n" + + " source: rtsp://127.0.0.1:8555/teststream\n") + require.Equal(t, true, ok) + defer p.Close() + + time.Sleep(1 * time.Second) + + packetRecv := make(chan struct{}) + + c := gortsplib.Client{ + OnPacketRTP: func(ctx *gortsplib.ClientOnPacketRTPCtx) { + require.Equal(t, &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 123, + Timestamp: 45343, + SSRC: 563423, + CSRC: []uint32{}, + }, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }, ctx.Packet) + close(packetRecv) + }, + } + + u, err := url.Parse("rtsp://127.0.0.1:8554/proxied") + require.NoError(t, err) + + err = c.Start(u.Scheme, u.Host) + require.NoError(t, err) + defer c.Close() + + tracks, baseURL, _, err := c.Describe(u) + require.NoError(t, err) + + err = c.SetupAndPlay(tracks, baseURL) + require.NoError(t, err) + + stream.WritePacketRTP(0, &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 123, + Timestamp: 45343, + SSRC: 563423, + Padding: true, + }, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + PaddingSize: 20, + }, true) + + <-packetRecv +} + +func TestRTSPSourceOversizedPackets(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:8555") + require.NoError(t, err) + defer l.Close() + + connected := make(chan struct{}) + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + nconn, err := l.Accept() + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + req, err := conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }) + require.NoError(t, err) + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) + + tracks := gortsplib.Tracks{&gortsplib.TrackH264{ + PayloadType: 96, + SPS: []byte{0x01, 0x02, 0x03, 0x04}, + PPS: []byte{0x01, 0x02, 0x03, 0x04}, + }} + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: tracks.Marshal(false), + }) + require.NoError(t, err) + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) + + var inTH headers.Transport + err = inTH.Unmarshal(req.Header["Transport"]) + require.NoError(t, err) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": headers.Transport{ + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Protocol: headers.TransportProtocolTCP, + InterleavedIDs: inTH.InterleavedIDs, + }.Marshal(), + }, + }) + require.NoError(t, err) + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Play, req.Method) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + + <-connected + + byts, _ := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 123, + Timestamp: 45343, + SSRC: 563423, + Padding: true, + }, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }.Marshal() + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 0, + Payload: byts, + }, make([]byte, 1024)) + require.NoError(t, err) + + byts, _ = rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 124, + Timestamp: 45343, + SSRC: 563423, + Padding: true, + }, + Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 2000/4), + }.Marshal() + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 0, + Payload: byts, + }, make([]byte, 2048)) + require.NoError(t, err) + + byts, _ = rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 125, + Timestamp: 45343, + SSRC: 563423, + Padding: true, + }, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }.Marshal() + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 0, + Payload: byts, + }, make([]byte, 1024)) + require.NoError(t, err) + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Teardown, req.Method) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + }() + + p, ok := newInstance("rtmpDisable: yes\n" + + "hlsDisable: yes\n" + + "paths:\n" + + " proxied:\n" + + " source: rtsp://127.0.0.1:8555/teststream\n" + + " sourceProtocol: tcp\n") + require.Equal(t, true, ok) + defer p.Close() + + time.Sleep(1 * time.Second) + + packetRecv := make(chan struct{}) + i := 0 + + c := gortsplib.Client{ + OnPacketRTP: func(ctx *gortsplib.ClientOnPacketRTPCtx) { + switch i { + case 0: + require.Equal(t, &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 123, + Timestamp: 45343, + SSRC: 563423, + CSRC: []uint32{}, + }, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }, ctx.Packet) + + case 1: + require.Equal(t, &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: false, + PayloadType: 96, + SequenceNumber: 124, + Timestamp: 45343, + SSRC: 563423, + CSRC: []uint32{}, + }, + Payload: append( + append([]byte{0x1c, 0x81, 0x02, 0x03, 0x04}, bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 363)...), + []byte{0x01, 0x02, 0x03}..., + ), + }, ctx.Packet) + + case 2: + require.Equal(t, &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 125, + Timestamp: 45343, + SSRC: 563423, + CSRC: []uint32{}, + }, + Payload: append( + []byte{0x1c, 0x41, 0x04}, + bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 135)..., + ), + }, ctx.Packet) + + case 3: + require.Equal(t, &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 126, + Timestamp: 45343, + SSRC: 563423, + CSRC: []uint32{}, + }, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }, ctx.Packet) + close(packetRecv) + } + i++ + }, + } + + u, err := url.Parse("rtsp://127.0.0.1:8554/proxied") + require.NoError(t, err) + + err = c.Start(u.Scheme, u.Host) + require.NoError(t, err) + defer c.Close() + + tracks, baseURL, _, err := c.Describe(u) + require.NoError(t, err) + + err = c.SetupAndPlay(tracks, baseURL) + require.NoError(t, err) + + close(connected) + <-packetRecv +} diff --git a/internal/core/stream.go b/internal/core/stream.go index ac5c4722..711c76be 100644 --- a/internal/core/stream.go +++ b/internal/core/stream.go @@ -35,7 +35,7 @@ func (m *streamNonRTSPReadersMap) remove(r reader) { delete(m.ma, r) } -func (m *streamNonRTSPReadersMap) writeData(data *data) { +func (m *streamNonRTSPReadersMap) writeData(data data) { m.mutex.RLock() defer m.mutex.RUnlock() @@ -44,6 +44,12 @@ func (m *streamNonRTSPReadersMap) writeData(data *data) { } } +func (m *streamNonRTSPReadersMap) hasReaders() bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + return len(m.ma) > 0 +} + type stream struct { nonRTSPReaders *streamNonRTSPReadersMap rtspStream *gortsplib.ServerStream @@ -60,7 +66,7 @@ func newStream(tracks gortsplib.Tracks, generateRTPPackets bool) (*stream, error for i, track := range s.rtspStream.Tracks() { var err error - s.streamTracks[i], err = newStreamTrack(track, generateRTPPackets, s.writeDataInner) + s.streamTracks[i], err = newStreamTrack(track, generateRTPPackets) if err != nil { return nil, err } @@ -90,14 +96,19 @@ func (s *stream) readerRemove(r reader) { } } -func (s *stream) writeData(data *data) { - s.streamTracks[data.trackID].writeData(data) -} +func (s *stream) writeData(data data) error { + err := s.streamTracks[data.getTrackID()].onData(data, s.nonRTSPReaders.hasReaders()) + if err != nil { + return err + } -func (s *stream) writeDataInner(data *data) { - // forward to RTSP readers - s.rtspStream.WritePacketRTP(data.trackID, data.rtpPacket, data.ptsEqualsDTS) + // forward RTP packets to RTSP readers + for _, pkt := range data.getRTPPackets() { + s.rtspStream.WritePacketRTP(data.getTrackID(), pkt, data.getPTSEqualsDTS()) + } - // forward to non-RTSP readers + // forward data to non-RTSP readers s.nonRTSPReaders.writeData(data) + + return nil } diff --git a/internal/core/streamtrack.go b/internal/core/streamtrack.go index 1a02ed58..de96f5b4 100644 --- a/internal/core/streamtrack.go +++ b/internal/core/streamtrack.go @@ -7,21 +7,21 @@ import ( ) type streamTrack interface { - writeData(*data) + onData(data, bool) error } -func newStreamTrack(track gortsplib.Track, generateRTPPackets bool, writeDataInner func(*data)) (streamTrack, error) { +func newStreamTrack(track gortsplib.Track, generateRTPPackets bool) (streamTrack, error) { switch ttrack := track.(type) { case *gortsplib.TrackH264: - return newStreamTrackH264(ttrack, generateRTPPackets, writeDataInner), nil + return newStreamTrackH264(ttrack, generateRTPPackets), nil case *gortsplib.TrackMPEG4Audio: - return newStreamTrackMPEG4Audio(ttrack, generateRTPPackets, writeDataInner), nil + return newStreamTrackMPEG4Audio(ttrack, generateRTPPackets), nil default: if generateRTPPackets { return nil, fmt.Errorf("we don't know how to generate RTP packets of track %+v", track) } - return newStreamTrackGeneric(track, writeDataInner), nil + return newStreamTrackGeneric(), nil } } diff --git a/internal/core/streamtrack_generic.go b/internal/core/streamtrack_generic.go index 5d202ad7..671c15b7 100644 --- a/internal/core/streamtrack_generic.go +++ b/internal/core/streamtrack_generic.go @@ -1,19 +1,33 @@ package core import ( - "github.com/aler9/gortsplib" + "fmt" ) -type streamTrackGeneric struct { - writeDataInner func(*data) +const ( + // 1500 (UDP MTU) - 20 (IP header) - 8 (UDP header) + maxPacketSize = 1472 +) + +type streamTrackGeneric struct{} + +func newStreamTrackGeneric() *streamTrackGeneric { + return &streamTrackGeneric{} } -func newStreamTrackGeneric(track gortsplib.Track, writeDataInner func(*data)) *streamTrackGeneric { - return &streamTrackGeneric{ - writeDataInner: writeDataInner, +func (t *streamTrackGeneric) onData(dat data, hasNonRTSPReaders bool) error { + tdata := dat.(*dataGeneric) + + pkt := tdata.rtpPackets[0] + + // remove padding + pkt.Header.Padding = false + pkt.PaddingSize = 0 + + if pkt.MarshalSize() > maxPacketSize { + return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", + pkt.MarshalSize(), maxPacketSize) } -} -func (t *streamTrackGeneric) writeData(data *data) { - t.writeDataInner(data) + return nil } diff --git a/internal/core/streamtrack_h264.go b/internal/core/streamtrack_h264.go index ee02166a..14d0bd49 100644 --- a/internal/core/streamtrack_h264.go +++ b/internal/core/streamtrack_h264.go @@ -6,34 +6,97 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/rtph264" + "github.com/pion/rtp" ) +func rtpH264ExtractSPSPPS(pkt *rtp.Packet) ([]byte, []byte) { + if len(pkt.Payload) == 0 { + return nil, nil + } + + typ := h264.NALUType(pkt.Payload[0] & 0x1F) + + switch typ { + case h264.NALUTypeSPS: + return pkt.Payload, nil + + case h264.NALUTypePPS: + return nil, pkt.Payload + + case 24: // STAP-A + payload := pkt.Payload[1:] + var sps []byte + var pps []byte + + for len(payload) > 0 { + if len(payload) < 2 { + break + } + + size := uint16(payload[0])<<8 | uint16(payload[1]) + payload = payload[2:] + + if size == 0 || int(size) > len(payload) { + break + } + + nalu := payload[:size] + payload = payload[size:] + + typ = h264.NALUType(nalu[0] & 0x1F) + + switch typ { + case h264.NALUTypeSPS: + sps = nalu + + case h264.NALUTypePPS: + pps = nalu + } + } + + return sps, pps + + default: + return nil, nil + } +} + type streamTrackH264 struct { - track *gortsplib.TrackH264 - writeDataInner func(*data) + track *gortsplib.TrackH264 - rtpEncoder *rtph264.Encoder + encoder *rtph264.Encoder + decoder *rtph264.Decoder } func newStreamTrackH264( track *gortsplib.TrackH264, generateRTPPackets bool, - writeDataInner func(*data), ) *streamTrackH264 { t := &streamTrackH264{ - track: track, - writeDataInner: writeDataInner, + track: track, } if generateRTPPackets { - t.rtpEncoder = &rtph264.Encoder{PayloadType: 96} - t.rtpEncoder.Init() + t.encoder = &rtph264.Encoder{PayloadType: 96} + t.encoder.Init() } return t } -func (t *streamTrackH264) updateTrackParameters(nalus [][]byte) { +func (t *streamTrackH264) updateTrackParametersFromRTPPacket(pkt *rtp.Packet) { + sps, pps := rtpH264ExtractSPSPPS(pkt) + + if sps != nil && !bytes.Equal(sps, t.track.SafeSPS()) { + t.track.SafeSetSPS(sps) + } + + if pps != nil && !bytes.Equal(pps, t.track.SafePPS()) { + t.track.SafeSetPPS(pps) + } +} + +func (t *streamTrackH264) updateTrackParametersFromNALUs(nalus [][]byte) { for _, nalu := range nalus { typ := h264.NALUType(nalu[0] & 0x1F) @@ -105,40 +168,69 @@ func (t *streamTrackH264) remuxNALUs(nalus [][]byte) [][]byte { return filteredNALUs } -func (t *streamTrackH264) generateRTPPackets(dat *data) { - pkts, err := t.rtpEncoder.Encode(dat.h264NALUs, dat.pts) +func (t *streamTrackH264) generateRTPPackets(tdata *dataH264) error { + pkts, err := t.encoder.Encode(tdata.nalus, tdata.pts) if err != nil { - return + return err } - lastPkt := len(pkts) - 1 - for i, pkt := range pkts { - if i != lastPkt { - t.writeDataInner(&data{ - trackID: dat.trackID, - rtpPacket: pkt, - }) - } else { - t.writeDataInner(&data{ - trackID: dat.trackID, - rtpPacket: pkt, - ptsEqualsDTS: dat.ptsEqualsDTS, - pts: dat.pts, - h264NALUs: dat.h264NALUs, - }) - } - } + tdata.rtpPackets = pkts + return nil } -func (t *streamTrackH264) writeData(dat *data) { - if dat.h264NALUs != nil { - t.updateTrackParameters(dat.h264NALUs) - dat.h264NALUs = t.remuxNALUs(dat.h264NALUs) - } +func (t *streamTrackH264) onData(dat data, hasNonRTSPReaders bool) error { + tdata := dat.(*dataH264) + + if tdata.rtpPackets != nil { + pkt := tdata.rtpPackets[0] + t.updateTrackParametersFromRTPPacket(pkt) + + if t.encoder == nil { + // remove padding + pkt.Header.Padding = false + pkt.PaddingSize = 0 + + // we need to re-encode since RTP packets exceed maximum size + if pkt.MarshalSize() > maxPacketSize { + v1 := pkt.SSRC + v2 := pkt.SequenceNumber + v3 := pkt.Timestamp + t.encoder = &rtph264.Encoder{ + PayloadType: pkt.PayloadType, + SSRC: &v1, + InitialSequenceNumber: &v2, + InitialTimestamp: &v3, + } + t.encoder.Init() + } + } + + // decode from RTP + if hasNonRTSPReaders || t.encoder != nil { + if t.decoder == nil { + t.decoder = &rtph264.Decoder{} + t.decoder.Init() + } + + nalus, pts, err := t.decoder.Decode(pkt) + if err != nil { + return err + } - if dat.rtpPacket != nil { - t.writeDataInner(dat) - } else if dat.h264NALUs != nil { - t.generateRTPPackets(dat) + tdata.nalus = nalus + tdata.pts = pts + + tdata.nalus = t.remuxNALUs(tdata.nalus) + } + + // route packet as is + if t.encoder == nil { + return nil + } + } else { + t.updateTrackParametersFromNALUs(tdata.nalus) + tdata.nalus = t.remuxNALUs(tdata.nalus) } + + return t.generateRTPPackets(tdata) } diff --git a/internal/core/streamtrack_mpeg4audio.go b/internal/core/streamtrack_mpeg4audio.go index d6f69c20..774d96cb 100644 --- a/internal/core/streamtrack_mpeg4audio.go +++ b/internal/core/streamtrack_mpeg4audio.go @@ -1,58 +1,89 @@ package core import ( + "fmt" + "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/rtpmpeg4audio" ) type streamTrackMPEG4Audio struct { - writeDataInner func(*data) - - rtpEncoder *rtpmpeg4audio.Encoder + track *gortsplib.TrackMPEG4Audio + encoder *rtpmpeg4audio.Encoder + decoder *rtpmpeg4audio.Decoder } func newStreamTrackMPEG4Audio( track *gortsplib.TrackMPEG4Audio, generateRTPPackets bool, - writeDataInner func(*data), ) *streamTrackMPEG4Audio { t := &streamTrackMPEG4Audio{ - writeDataInner: writeDataInner, + track: track, } if generateRTPPackets { - t.rtpEncoder = &rtpmpeg4audio.Encoder{ + t.encoder = &rtpmpeg4audio.Encoder{ PayloadType: 96, SampleRate: track.ClockRate(), SizeLength: 13, IndexLength: 3, IndexDeltaLength: 3, } - t.rtpEncoder.Init() + t.encoder.Init() } return t } -func (t *streamTrackMPEG4Audio) generateRTPPackets(dat *data) { - pkts, err := t.rtpEncoder.Encode([][]byte{dat.mpeg4AudioAU}, dat.pts) +func (t *streamTrackMPEG4Audio) generateRTPPackets(tdata *dataMPEG4Audio) error { + pkts, err := t.encoder.Encode(tdata.aus, tdata.pts) if err != nil { - return + return err } - for _, pkt := range pkts { - t.writeDataInner(&data{ - trackID: dat.trackID, - rtpPacket: pkt, - ptsEqualsDTS: true, - }) - } + tdata.rtpPackets = pkts + return nil } -func (t *streamTrackMPEG4Audio) writeData(dat *data) { - if dat.rtpPacket != nil { - t.writeDataInner(dat) - } else { - t.generateRTPPackets(dat) +func (t *streamTrackMPEG4Audio) onData(dat data, hasNonRTSPReaders bool) error { + tdata := dat.(*dataMPEG4Audio) + + if tdata.rtpPackets != nil { + pkt := tdata.rtpPackets[0] + + // remove padding + pkt.Header.Padding = false + pkt.PaddingSize = 0 + + if pkt.MarshalSize() > maxPacketSize { + return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", + pkt.MarshalSize(), maxPacketSize) + } + + // decode from RTP + if hasNonRTSPReaders { + if t.decoder == nil { + t.decoder = &rtpmpeg4audio.Decoder{ + SampleRate: t.track.Config.SampleRate, + SizeLength: t.track.SizeLength, + IndexLength: t.track.IndexLength, + IndexDeltaLength: t.track.IndexDeltaLength, + } + t.decoder.Init() + } + + aus, pts, err := t.decoder.Decode(pkt) + if err != nil { + return err + } + + tdata.aus = aus + tdata.pts = pts + } + + // route packet as is + return nil } + + return t.generateRTPPackets(tdata) } diff --git a/internal/hls/muxer_variant_fmp4_segmenter.go b/internal/hls/muxer_variant_fmp4_segmenter.go index 0fe85d88..312fc7a5 100644 --- a/internal/hls/muxer_variant_fmp4_segmenter.go +++ b/internal/hls/muxer_variant_fmp4_segmenter.go @@ -213,7 +213,7 @@ func (m *muxerVariantFMP4Segmenter) writeH264Entry( } // put samples into a queue in order to - // - allow to compute sample duration + // - compute sample duration // - check if next sample is IDR sample, m.nextVideoSample = m.nextVideoSample, sample if sample == nil { @@ -290,6 +290,9 @@ func (m *muxerVariantFMP4Segmenter) writeAAC(now time.Time, dts time.Duration, a } dts -= m.startDTS + if dts < 0 { + return nil + } } sample := &augmentedAudioSample{ @@ -299,8 +302,7 @@ func (m *muxerVariantFMP4Segmenter) writeAAC(now time.Time, dts time.Duration, a dts: dts, } - // put samples into a queue in order to - // allow to compute the sample duration + // put samples into a queue in order to compute the sample duration sample, m.nextAudioSample = m.nextAudioSample, sample if sample == nil { return nil