Browse Source

RTMP client: add and use DTS instead of clock

pull/346/head
aler9 4 years ago
parent
commit
cbda813e1a
  1. 43
      internal/clientrtmp/client.go
  2. 4
      internal/clientrtsp/client.go
  3. 6
      internal/h264/annexb_test.go
  4. 58
      internal/h264/dtsestimator.go
  5. 17
      internal/h264/dtsestimator_test.go
  6. 20
      internal/h264/nalutype.go

43
internal/clientrtmp/client.go

@ -30,6 +30,11 @@ import ( @@ -30,6 +30,11 @@ import (
const (
pauseAfterAuthError = 2 * time.Second
// an offset is needed to
// - avoid negative PTS values
// - avoid PTS < DTS during startup
ptsOffset = 2 * time.Second
)
func ipEqualOrInRange(ip net.IP, ips []interface{}) bool {
@ -284,10 +289,8 @@ func (c *Client) runRead() { @@ -284,10 +289,8 @@ func (c *Client) runRead() {
writerDone := make(chan error)
go func() {
writerDone <- func() error {
videoInitialized := false
var videoBuf [][]byte
var videoStartDTS time.Time
var videoLastDTS time.Duration
videoDTSEst := h264.NewDTSEstimator()
for {
data, ok := c.ringBuffer.Pull()
@ -297,7 +300,7 @@ func (c *Client) runRead() { @@ -297,7 +300,7 @@ func (c *Client) runRead() {
pair := data.(trackIDPayloadPair)
if videoTrack != nil && pair.trackID == videoTrack.ID {
nalus, _, err := h264Decoder.Decode(pair.buf)
nalus, pts, err := h264Decoder.Decode(pair.buf)
if err != nil {
if err != rtph264.ErrMorePacketsNeeded {
c.log(logger.Warn, "unable to decode video track: %v", err)
@ -305,11 +308,6 @@ func (c *Client) runRead() { @@ -305,11 +308,6 @@ func (c *Client) runRead() {
continue
}
if !videoInitialized {
videoInitialized = true
videoStartDTS = time.Now()
}
for _, nalu := range nalus {
// remove SPS, PPS and AUD, not needed by RTMP
typ := h264.NALUType(nalu[0] & 0x1F)
@ -325,20 +323,23 @@ func (c *Client) runRead() { @@ -325,20 +323,23 @@ func (c *Client) runRead() {
// send them together.
marker := (pair.buf[1] >> 7 & 0x1) > 0
if marker {
dts := time.Since(videoStartDTS)
// avoid duplicate DTS
// (RTMP has a resolution of 1ms)
if int64(dts/time.Millisecond) <= (int64(videoLastDTS / time.Millisecond)) {
dts = videoLastDTS + time.Millisecond
data, err := h264.EncodeAVCC(videoBuf)
if err != nil {
return err
}
videoLastDTS = dts
dts := videoDTSEst.Feed(pts + ptsOffset)
c.conn.NetConn().SetWriteDeadline(time.Now().Add(c.writeTimeout))
err := c.conn.WriteH264(videoBuf, dts)
err = c.conn.WritePacket(av.Packet{
Type: av.H264,
Data: data,
Time: dts,
CTime: pts + ptsOffset - dts,
})
if err != nil {
return err
}
videoBuf = nil
}
@ -352,8 +353,14 @@ func (c *Client) runRead() { @@ -352,8 +353,14 @@ func (c *Client) runRead() {
}
for i, au := range aus {
auPTS := pts + ptsOffset + time.Duration(i)*1000*time.Second/time.Duration(audioClockRate)
c.conn.NetConn().SetWriteDeadline(time.Now().Add(c.writeTimeout))
err := c.conn.WriteAAC(au, pts+time.Duration(i)*1000*time.Second/time.Duration(audioClockRate))
err := c.conn.WritePacket(av.Packet{
Type: av.AAC,
Data: au,
Time: auPTS,
})
if err != nil {
return err
}

4
internal/clientrtsp/client.go

@ -77,11 +77,11 @@ type Client struct { @@ -77,11 +77,11 @@ type Client struct {
authValidator *auth.Validator
authFailures int
// read only
// read
setuppedTracks map[int]*gortsplib.Track
onReadCmd *externalcmd.Cmd
// publish only
// publish
sp *streamproc.StreamProc
onPublishCmd *externalcmd.Cmd

6
internal/h264/annexb_test.go

@ -6,7 +6,7 @@ import ( @@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/require"
)
var annexBCases = []struct {
var casesAnnexB = []struct {
name string
encin []byte
encout []byte
@ -66,7 +66,7 @@ var annexBCases = []struct { @@ -66,7 +66,7 @@ var annexBCases = []struct {
}
func TestAnnexBDecode(t *testing.T) {
for _, ca := range annexBCases {
for _, ca := range casesAnnexB {
t.Run(ca.name, func(t *testing.T) {
dec, err := DecodeAnnexB(ca.encin)
require.NoError(t, err)
@ -76,7 +76,7 @@ func TestAnnexBDecode(t *testing.T) { @@ -76,7 +76,7 @@ func TestAnnexBDecode(t *testing.T) {
}
func TestAnnexBEncode(t *testing.T) {
for _, ca := range annexBCases {
for _, ca := range casesAnnexB {
t.Run(ca.name, func(t *testing.T) {
enc, err := EncodeAnnexB(ca.dec)
require.NoError(t, err)

58
internal/h264/dtsestimator.go

@ -0,0 +1,58 @@ @@ -0,0 +1,58 @@
package h264
import (
"time"
)
// DTSEstimator is a DTS estimator.
type DTSEstimator struct {
initializing int
prevDTS time.Duration
prevPTS time.Duration
prevPrevPTS time.Duration
}
// NewDTSEstimator allocates a DTSEstimator.
func NewDTSEstimator() *DTSEstimator {
return &DTSEstimator{
initializing: 2,
}
}
// Feed provides PTS to the estimator, and returns the estimated DTS.
func (d *DTSEstimator) Feed(pts time.Duration) time.Duration {
if d.initializing > 0 {
d.initializing--
dts := d.prevDTS + time.Millisecond
d.prevPrevPTS = d.prevPTS
d.prevPTS = pts
d.prevDTS = dts
return dts
}
dts := func() time.Duration {
// P or I frame
if pts > d.prevPTS {
// previous frame was B
// use the DTS of the previous frame
if d.prevPTS < d.prevPrevPTS {
return d.prevPTS
}
// previous frame was P or I
// use two frames ago plus a small quantity
// to avoid non-monotonous DTS with B-frames
return d.prevPrevPTS + time.Millisecond
}
// B Frame
// do not increase
return d.prevDTS + time.Millisecond
}()
d.prevPrevPTS = d.prevPTS
d.prevPTS = pts
d.prevDTS = dts
return dts
}

17
internal/h264/dtsestimator_test.go

@ -0,0 +1,17 @@ @@ -0,0 +1,17 @@
package h264
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestDTSEstimator(t *testing.T) {
est := NewDTSEstimator()
est.Feed(2 * time.Second)
est.Feed(2*time.Second - 200*time.Millisecond)
est.Feed(2*time.Second - 400*time.Millisecond)
dts := est.Feed(2*time.Second + 200*time.Millisecond)
require.Equal(t, 2*time.Second-400*time.Millisecond, dts)
}

20
internal/h264/nalutype.go

@ -32,12 +32,6 @@ const ( @@ -32,12 +32,6 @@ const (
NALUTypeSliceExtensionDepth NALUType = 21
NALUTypeReserved22 NALUType = 22
NALUTypeReserved23 NALUType = 23
NALUTypeSTAPA NALUType = 24
NALUTypeSTAPB NALUType = 25
NALUTypeMTAP16 NALUType = 26
NALUTypeMTAP24 NALUType = 27
NALUTypeFUA NALUType = 28
NALUTypeFUB NALUType = 29
)
// String implements fmt.Stringer.
@ -54,7 +48,7 @@ func (nt NALUType) String() string { @@ -54,7 +48,7 @@ func (nt NALUType) String() string {
case NALUTypeIDR:
return "IDR"
case NALUTypeSEI:
return "Sei"
return "SEI"
case NALUTypeSPS:
return "SPS"
case NALUTypePPS:
@ -89,18 +83,6 @@ func (nt NALUType) String() string { @@ -89,18 +83,6 @@ func (nt NALUType) String() string {
return "Reserved22"
case NALUTypeReserved23:
return "Reserved23"
case NALUTypeSTAPA:
return "STAPA"
case NALUTypeSTAPB:
return "STAPB"
case NALUTypeMTAP16:
return "MTAP16"
case NALUTypeMTAP24:
return "MTAP24"
case NALUTypeFUA:
return "FUA"
case NALUTypeFUB:
return "FUB"
}
return fmt.Sprintf("unknown (%d)", nt)
}

Loading…
Cancel
Save