From cbda813e1a577c10b49ccef6b6e3009bc639191b Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 11 Apr 2021 18:44:02 +0200 Subject: [PATCH] RTMP client: add and use DTS instead of clock --- internal/clientrtmp/client.go | 43 ++++++++++++---------- internal/clientrtsp/client.go | 4 +-- internal/h264/annexb_test.go | 6 ++-- internal/h264/dtsestimator.go | 58 ++++++++++++++++++++++++++++++ internal/h264/dtsestimator_test.go | 17 +++++++++ internal/h264/nalutype.go | 20 +---------- 6 files changed, 106 insertions(+), 42 deletions(-) create mode 100644 internal/h264/dtsestimator.go create mode 100644 internal/h264/dtsestimator_test.go diff --git a/internal/clientrtmp/client.go b/internal/clientrtmp/client.go index 1ba40c28..86c49189 100644 --- a/internal/clientrtmp/client.go +++ b/internal/clientrtmp/client.go @@ -30,6 +30,11 @@ import ( const ( pauseAfterAuthError = 2 * time.Second + + // an offset is needed to + // - avoid negative PTS values + // - avoid PTS < DTS during startup + ptsOffset = 2 * time.Second ) func ipEqualOrInRange(ip net.IP, ips []interface{}) bool { @@ -284,10 +289,8 @@ func (c *Client) runRead() { writerDone := make(chan error) go func() { writerDone <- func() error { - videoInitialized := false var videoBuf [][]byte - var videoStartDTS time.Time - var videoLastDTS time.Duration + videoDTSEst := h264.NewDTSEstimator() for { data, ok := c.ringBuffer.Pull() @@ -297,7 +300,7 @@ func (c *Client) runRead() { pair := data.(trackIDPayloadPair) if videoTrack != nil && pair.trackID == videoTrack.ID { - nalus, _, err := h264Decoder.Decode(pair.buf) + nalus, pts, err := h264Decoder.Decode(pair.buf) if err != nil { if err != rtph264.ErrMorePacketsNeeded { c.log(logger.Warn, "unable to decode video track: %v", err) @@ -305,11 +308,6 @@ func (c *Client) runRead() { continue } - if !videoInitialized { - videoInitialized = true - videoStartDTS = time.Now() - } - for _, nalu := range nalus { // remove SPS, PPS and AUD, not needed by RTMP typ := h264.NALUType(nalu[0] & 0x1F) @@ -325,20 +323,23 @@ func (c *Client) runRead() { // send them together. marker := (pair.buf[1] >> 7 & 0x1) > 0 if marker { - dts := time.Since(videoStartDTS) - - // avoid duplicate DTS - // (RTMP has a resolution of 1ms) - if int64(dts/time.Millisecond) <= (int64(videoLastDTS / time.Millisecond)) { - dts = videoLastDTS + time.Millisecond + data, err := h264.EncodeAVCC(videoBuf) + if err != nil { + return err } - videoLastDTS = dts + dts := videoDTSEst.Feed(pts + ptsOffset) c.conn.NetConn().SetWriteDeadline(time.Now().Add(c.writeTimeout)) - err := c.conn.WriteH264(videoBuf, dts) + err = c.conn.WritePacket(av.Packet{ + Type: av.H264, + Data: data, + Time: dts, + CTime: pts + ptsOffset - dts, + }) if err != nil { return err } + videoBuf = nil } @@ -352,8 +353,14 @@ func (c *Client) runRead() { } for i, au := range aus { + auPTS := pts + ptsOffset + time.Duration(i)*1000*time.Second/time.Duration(audioClockRate) + c.conn.NetConn().SetWriteDeadline(time.Now().Add(c.writeTimeout)) - err := c.conn.WriteAAC(au, pts+time.Duration(i)*1000*time.Second/time.Duration(audioClockRate)) + err := c.conn.WritePacket(av.Packet{ + Type: av.AAC, + Data: au, + Time: auPTS, + }) if err != nil { return err } diff --git a/internal/clientrtsp/client.go b/internal/clientrtsp/client.go index 80d1609f..843a79c6 100644 --- a/internal/clientrtsp/client.go +++ b/internal/clientrtsp/client.go @@ -77,11 +77,11 @@ type Client struct { authValidator *auth.Validator authFailures int - // read only + // read setuppedTracks map[int]*gortsplib.Track onReadCmd *externalcmd.Cmd - // publish only + // publish sp *streamproc.StreamProc onPublishCmd *externalcmd.Cmd diff --git a/internal/h264/annexb_test.go b/internal/h264/annexb_test.go index 86d87ffc..10d416b7 100644 --- a/internal/h264/annexb_test.go +++ b/internal/h264/annexb_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" ) -var annexBCases = []struct { +var casesAnnexB = []struct { name string encin []byte encout []byte @@ -66,7 +66,7 @@ var annexBCases = []struct { } func TestAnnexBDecode(t *testing.T) { - for _, ca := range annexBCases { + for _, ca := range casesAnnexB { t.Run(ca.name, func(t *testing.T) { dec, err := DecodeAnnexB(ca.encin) require.NoError(t, err) @@ -76,7 +76,7 @@ func TestAnnexBDecode(t *testing.T) { } func TestAnnexBEncode(t *testing.T) { - for _, ca := range annexBCases { + for _, ca := range casesAnnexB { t.Run(ca.name, func(t *testing.T) { enc, err := EncodeAnnexB(ca.dec) require.NoError(t, err) diff --git a/internal/h264/dtsestimator.go b/internal/h264/dtsestimator.go new file mode 100644 index 00000000..9da4bcb3 --- /dev/null +++ b/internal/h264/dtsestimator.go @@ -0,0 +1,58 @@ +package h264 + +import ( + "time" +) + +// DTSEstimator is a DTS estimator. +type DTSEstimator struct { + initializing int + prevDTS time.Duration + prevPTS time.Duration + prevPrevPTS time.Duration +} + +// NewDTSEstimator allocates a DTSEstimator. +func NewDTSEstimator() *DTSEstimator { + return &DTSEstimator{ + initializing: 2, + } +} + +// Feed provides PTS to the estimator, and returns the estimated DTS. +func (d *DTSEstimator) Feed(pts time.Duration) time.Duration { + if d.initializing > 0 { + d.initializing-- + dts := d.prevDTS + time.Millisecond + d.prevPrevPTS = d.prevPTS + d.prevPTS = pts + d.prevDTS = dts + return dts + } + + dts := func() time.Duration { + // P or I frame + if pts > d.prevPTS { + // previous frame was B + // use the DTS of the previous frame + if d.prevPTS < d.prevPrevPTS { + return d.prevPTS + } + + // previous frame was P or I + // use two frames ago plus a small quantity + // to avoid non-monotonous DTS with B-frames + return d.prevPrevPTS + time.Millisecond + } + + // B Frame + // do not increase + return d.prevDTS + time.Millisecond + }() + + d.prevPrevPTS = d.prevPTS + d.prevPTS = pts + d.prevDTS = dts + + return dts +} diff --git a/internal/h264/dtsestimator_test.go b/internal/h264/dtsestimator_test.go new file mode 100644 index 00000000..a5c8733d --- /dev/null +++ b/internal/h264/dtsestimator_test.go @@ -0,0 +1,17 @@ +package h264 + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestDTSEstimator(t *testing.T) { + est := NewDTSEstimator() + est.Feed(2 * time.Second) + est.Feed(2*time.Second - 200*time.Millisecond) + est.Feed(2*time.Second - 400*time.Millisecond) + dts := est.Feed(2*time.Second + 200*time.Millisecond) + require.Equal(t, 2*time.Second-400*time.Millisecond, dts) +} diff --git a/internal/h264/nalutype.go b/internal/h264/nalutype.go index 4430dbc9..6fb6ba06 100644 --- a/internal/h264/nalutype.go +++ b/internal/h264/nalutype.go @@ -32,12 +32,6 @@ const ( NALUTypeSliceExtensionDepth NALUType = 21 NALUTypeReserved22 NALUType = 22 NALUTypeReserved23 NALUType = 23 - NALUTypeSTAPA NALUType = 24 - NALUTypeSTAPB NALUType = 25 - NALUTypeMTAP16 NALUType = 26 - NALUTypeMTAP24 NALUType = 27 - NALUTypeFUA NALUType = 28 - NALUTypeFUB NALUType = 29 ) // String implements fmt.Stringer. @@ -54,7 +48,7 @@ func (nt NALUType) String() string { case NALUTypeIDR: return "IDR" case NALUTypeSEI: - return "Sei" + return "SEI" case NALUTypeSPS: return "SPS" case NALUTypePPS: @@ -89,18 +83,6 @@ func (nt NALUType) String() string { return "Reserved22" case NALUTypeReserved23: return "Reserved23" - case NALUTypeSTAPA: - return "STAPA" - case NALUTypeSTAPB: - return "STAPB" - case NALUTypeMTAP16: - return "MTAP16" - case NALUTypeMTAP24: - return "MTAP24" - case NALUTypeFUA: - return "FUA" - case NALUTypeFUB: - return "FUB" } return fmt.Sprintf("unknown (%d)", nt) }