Browse Source

rtmp: fix parsing error caused by extended timestamps (#2393) (#2556) (#2384) (#1550) (#2564) (#2808)

pull/2814/head
Alessandro Ros 1 year ago committed by GitHub
parent
commit
89560c19a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      internal/protocols/rtmp/chunk/chunk.go
  2. 74
      internal/protocols/rtmp/chunk/chunk0.go
  3. 35
      internal/protocols/rtmp/chunk/chunk0_test.go
  4. 62
      internal/protocols/rtmp/chunk/chunk1.go
  5. 34
      internal/protocols/rtmp/chunk/chunk1_test.go
  6. 46
      internal/protocols/rtmp/chunk/chunk2.go
  7. 31
      internal/protocols/rtmp/chunk/chunk2_test.go
  8. 35
      internal/protocols/rtmp/chunk/chunk3.go
  9. 30
      internal/protocols/rtmp/chunk/chunk3_test.go
  10. 158
      internal/protocols/rtmp/chunk/chunk_test.go
  11. 18
      internal/protocols/rtmp/rawmessage/reader.go
  12. 57
      internal/protocols/rtmp/rawmessage/reader_test.go
  13. 39
      internal/protocols/rtmp/rawmessage/writer.go
  14. 35
      internal/protocols/rtmp/rawmessage/writer_test.go

4
internal/protocols/rtmp/chunk/chunk.go

@ -7,6 +7,6 @@ import ( @@ -7,6 +7,6 @@ import (
// Chunk is a chunk.
type Chunk interface {
Read(io.Reader, uint32) error
Marshal() ([]byte, error)
Read(r io.Reader, bodyLen uint32, hasExtendedTimestamp bool) error
Marshal(hasExtendedTimestamp bool) ([]byte, error)
}

74
internal/protocols/rtmp/chunk/chunk0.go

@ -11,14 +11,14 @@ import ( @@ -11,14 +11,14 @@ import (
type Chunk0 struct {
ChunkStreamID byte
Timestamp uint32
BodyLen uint32
Type uint8
MessageStreamID uint32
BodyLen uint32
Body []byte
}
// Read reads the chunk.
func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
func (c *Chunk0) Read(r io.Reader, maxBodyLen uint32, _ bool) error {
header := make([]byte, 12)
_, err := io.ReadFull(r, header)
if err != nil {
@ -31,9 +31,18 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error { @@ -31,9 +31,18 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
c.Type = header[7]
c.MessageStreamID = uint32(header[8])<<24 | uint32(header[9])<<16 | uint32(header[10])<<8 | uint32(header[11])
if c.Timestamp >= 0xFFFFFF {
_, err := io.ReadFull(r, header[:4])
if err != nil {
return err
}
c.Timestamp = uint32(header[0])<<24 | uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
}
chunkBodyLen := c.BodyLen
if chunkBodyLen > chunkMaxBodyLen {
chunkBodyLen = chunkMaxBodyLen
if chunkBodyLen > maxBodyLen {
chunkBodyLen = maxBodyLen
}
c.Body = make([]byte, chunkBodyLen)
@ -41,21 +50,50 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error { @@ -41,21 +50,50 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
return err
}
func (c Chunk0) marshalSize() int {
n := 12 + len(c.Body)
if c.Timestamp >= 0xFFFFFF {
n += 4
}
return n
}
// Marshal writes the chunk.
func (c Chunk0) Marshal() ([]byte, error) {
buf := make([]byte, 12+len(c.Body))
func (c Chunk0) Marshal(_ bool) ([]byte, error) {
buf := make([]byte, c.marshalSize())
buf[0] = c.ChunkStreamID
buf[1] = byte(c.Timestamp >> 16)
buf[2] = byte(c.Timestamp >> 8)
buf[3] = byte(c.Timestamp)
buf[4] = byte(c.BodyLen >> 16)
buf[5] = byte(c.BodyLen >> 8)
buf[6] = byte(c.BodyLen)
buf[7] = c.Type
buf[8] = byte(c.MessageStreamID >> 24)
buf[9] = byte(c.MessageStreamID >> 16)
buf[10] = byte(c.MessageStreamID >> 8)
buf[11] = byte(c.MessageStreamID)
copy(buf[12:], c.Body)
if c.Timestamp >= 0xFFFFFF {
buf[1] = 0xFF
buf[2] = 0xFF
buf[3] = 0xFF
buf[4] = byte(c.BodyLen >> 16)
buf[5] = byte(c.BodyLen >> 8)
buf[6] = byte(c.BodyLen)
buf[7] = c.Type
buf[8] = byte(c.MessageStreamID >> 24)
buf[9] = byte(c.MessageStreamID >> 16)
buf[10] = byte(c.MessageStreamID >> 8)
buf[11] = byte(c.MessageStreamID)
buf[12] = byte(c.Timestamp >> 24)
buf[13] = byte(c.Timestamp >> 16)
buf[14] = byte(c.Timestamp >> 8)
buf[15] = byte(c.Timestamp)
copy(buf[16:], c.Body)
} else {
buf[1] = byte(c.Timestamp >> 16)
buf[2] = byte(c.Timestamp >> 8)
buf[3] = byte(c.Timestamp)
buf[4] = byte(c.BodyLen >> 16)
buf[5] = byte(c.BodyLen >> 8)
buf[6] = byte(c.BodyLen)
buf[7] = c.Type
buf[8] = byte(c.MessageStreamID >> 24)
buf[9] = byte(c.MessageStreamID >> 16)
buf[10] = byte(c.MessageStreamID >> 8)
buf[11] = byte(c.MessageStreamID)
copy(buf[12:], c.Body)
}
return buf, nil
}

35
internal/protocols/rtmp/chunk/chunk0_test.go

@ -1,35 +0,0 @@ @@ -1,35 +0,0 @@
package chunk
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
var chunk0enc = []byte{
0x19, 0xb1, 0xa1, 0x91, 0x0, 0x0, 0x14, 0x14,
0x3, 0x5d, 0x17, 0x3d, 0x1, 0x2, 0x3, 0x4,
}
var chunk0dec = Chunk0{
ChunkStreamID: 25,
Timestamp: 11641233,
Type: 20,
MessageStreamID: 56432445,
BodyLen: 20,
Body: []byte{0x01, 0x02, 0x03, 0x04},
}
func TestChunk0Read(t *testing.T) {
var chunk0 Chunk0
err := chunk0.Read(bytes.NewReader(chunk0enc), 4)
require.NoError(t, err)
require.Equal(t, chunk0dec, chunk0)
}
func TestChunk0Marshal(t *testing.T) {
buf, err := chunk0dec.Marshal()
require.NoError(t, err)
require.Equal(t, chunk0enc, buf)
}

62
internal/protocols/rtmp/chunk/chunk1.go

@ -13,13 +13,13 @@ import ( @@ -13,13 +13,13 @@ import (
type Chunk1 struct {
ChunkStreamID byte
TimestampDelta uint32
Type uint8
BodyLen uint32
Type uint8
Body []byte
}
// Read reads the chunk.
func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
func (c *Chunk1) Read(r io.Reader, maxBodyLen uint32, _ bool) error {
header := make([]byte, 8)
_, err := io.ReadFull(r, header)
if err != nil {
@ -31,9 +31,18 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error { @@ -31,9 +31,18 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
c.BodyLen = uint32(header[4])<<16 | uint32(header[5])<<8 | uint32(header[6])
c.Type = header[7]
if c.TimestampDelta >= 0xFFFFFF {
_, err = io.ReadFull(r, header[:4])
if err != nil {
return err
}
c.TimestampDelta = uint32(header[0])<<24 | uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
}
chunkBodyLen := (c.BodyLen)
if chunkBodyLen > chunkMaxBodyLen {
chunkBodyLen = chunkMaxBodyLen
if chunkBodyLen > maxBodyLen {
chunkBodyLen = maxBodyLen
}
c.Body = make([]byte, chunkBodyLen)
@ -41,17 +50,42 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error { @@ -41,17 +50,42 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
return err
}
func (c Chunk1) marshalSize() int {
n := 8 + len(c.Body)
if c.TimestampDelta >= 0xFFFFFF {
n += 4
}
return n
}
// Marshal writes the chunk.
func (c Chunk1) Marshal() ([]byte, error) {
buf := make([]byte, 8+len(c.Body))
func (c Chunk1) Marshal(_ bool) ([]byte, error) {
buf := make([]byte, c.marshalSize())
buf[0] = 1<<6 | c.ChunkStreamID
buf[1] = byte(c.TimestampDelta >> 16)
buf[2] = byte(c.TimestampDelta >> 8)
buf[3] = byte(c.TimestampDelta)
buf[4] = byte(c.BodyLen >> 16)
buf[5] = byte(c.BodyLen >> 8)
buf[6] = byte(c.BodyLen)
buf[7] = c.Type
copy(buf[8:], c.Body)
if c.TimestampDelta >= 0xFFFFFF {
buf[1] = 0xFF
buf[2] = 0xFF
buf[3] = 0xFF
buf[4] = byte(c.BodyLen >> 16)
buf[5] = byte(c.BodyLen >> 8)
buf[6] = byte(c.BodyLen)
buf[7] = c.Type
buf[8] = byte(c.TimestampDelta >> 24)
buf[9] = byte(c.TimestampDelta >> 16)
buf[10] = byte(c.TimestampDelta >> 8)
buf[11] = byte(c.TimestampDelta)
copy(buf[12:], c.Body)
} else {
buf[1] = byte(c.TimestampDelta >> 16)
buf[2] = byte(c.TimestampDelta >> 8)
buf[3] = byte(c.TimestampDelta)
buf[4] = byte(c.BodyLen >> 16)
buf[5] = byte(c.BodyLen >> 8)
buf[6] = byte(c.BodyLen)
buf[7] = c.Type
copy(buf[8:], c.Body)
}
return buf, nil
}

34
internal/protocols/rtmp/chunk/chunk1_test.go

@ -1,34 +0,0 @@ @@ -1,34 +0,0 @@
package chunk
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
var chunk1enc = []byte{
0x59, 0xb1, 0xa1, 0x91, 0x0, 0x0, 0x14, 0x14,
0x1, 0x2, 0x3, 0x4,
}
var chunk1dec = Chunk1{
ChunkStreamID: 25,
TimestampDelta: 11641233,
Type: 20,
BodyLen: 20,
Body: []byte{0x01, 0x02, 0x03, 0x04},
}
func TestChunk1Read(t *testing.T) {
var chunk1 Chunk1
err := chunk1.Read(bytes.NewReader(chunk1enc), 4)
require.NoError(t, err)
require.Equal(t, chunk1dec, chunk1)
}
func TestChunk1Marshal(t *testing.T) {
buf, err := chunk1dec.Marshal()
require.NoError(t, err)
require.Equal(t, chunk1enc, buf)
}

46
internal/protocols/rtmp/chunk/chunk2.go

@ -15,7 +15,7 @@ type Chunk2 struct { @@ -15,7 +15,7 @@ type Chunk2 struct {
}
// Read reads the chunk.
func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error {
func (c *Chunk2) Read(r io.Reader, bodyLen uint32, _ bool) error {
header := make([]byte, 4)
_, err := io.ReadFull(r, header)
if err != nil {
@ -25,18 +25,48 @@ func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error { @@ -25,18 +25,48 @@ func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error {
c.ChunkStreamID = header[0] & 0x3F
c.TimestampDelta = uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
c.Body = make([]byte, chunkBodyLen)
if c.TimestampDelta >= 0xFFFFFF {
_, err = io.ReadFull(r, header[:4])
if err != nil {
return err
}
c.TimestampDelta = uint32(header[0])<<24 | uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
}
c.Body = make([]byte, bodyLen)
_, err = io.ReadFull(r, c.Body)
return err
}
func (c Chunk2) marshalSize() int {
n := 4 + len(c.Body)
if c.TimestampDelta >= 0xFFFFFF {
n += 4
}
return n
}
// Marshal writes the chunk.
func (c Chunk2) Marshal() ([]byte, error) {
buf := make([]byte, 4+len(c.Body))
func (c Chunk2) Marshal(_ bool) ([]byte, error) {
buf := make([]byte, c.marshalSize())
buf[0] = 2<<6 | c.ChunkStreamID
buf[1] = byte(c.TimestampDelta >> 16)
buf[2] = byte(c.TimestampDelta >> 8)
buf[3] = byte(c.TimestampDelta)
copy(buf[4:], c.Body)
if c.TimestampDelta >= 0xFFFFFF {
buf[1] = 0xFF
buf[2] = 0xFF
buf[3] = 0xFF
buf[4] = byte(c.TimestampDelta >> 24)
buf[5] = byte(c.TimestampDelta >> 16)
buf[6] = byte(c.TimestampDelta >> 8)
buf[7] = byte(c.TimestampDelta)
copy(buf[8:], c.Body)
} else {
buf[1] = byte(c.TimestampDelta >> 16)
buf[2] = byte(c.TimestampDelta >> 8)
buf[3] = byte(c.TimestampDelta)
copy(buf[4:], c.Body)
}
return buf, nil
}

31
internal/protocols/rtmp/chunk/chunk2_test.go

@ -1,31 +0,0 @@ @@ -1,31 +0,0 @@
package chunk
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
var chunk2enc = []byte{
0x99, 0xb1, 0xa1, 0x91, 0x1, 0x2, 0x3, 0x4,
}
var chunk2dec = Chunk2{
ChunkStreamID: 25,
TimestampDelta: 11641233,
Body: []byte{0x01, 0x02, 0x03, 0x04},
}
func TestChunk2Read(t *testing.T) {
var chunk2 Chunk2
err := chunk2.Read(bytes.NewReader(chunk2enc), 4)
require.NoError(t, err)
require.Equal(t, chunk2dec, chunk2)
}
func TestChunk2Marshal(t *testing.T) {
buf, err := chunk2dec.Marshal()
require.NoError(t, err)
require.Equal(t, chunk2enc, buf)
}

35
internal/protocols/rtmp/chunk/chunk3.go

@ -16,24 +16,45 @@ type Chunk3 struct { @@ -16,24 +16,45 @@ type Chunk3 struct {
}
// Read reads the chunk.
func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error {
header := make([]byte, 1)
_, err := io.ReadFull(r, header)
func (c *Chunk3) Read(r io.Reader, bodyLen uint32, hasExtendedTimestamp bool) error {
header := make([]byte, 4)
_, err := io.ReadFull(r, header[:1])
if err != nil {
return err
}
c.ChunkStreamID = header[0] & 0x3F
c.Body = make([]byte, chunkBodyLen)
if hasExtendedTimestamp {
_, err := io.ReadFull(r, header[:4])
if err != nil {
return err
}
}
c.Body = make([]byte, bodyLen)
_, err = io.ReadFull(r, c.Body)
return err
}
func (c Chunk3) marshalSize(hasExtendedTimestamp bool) int {
n := 1 + len(c.Body)
if hasExtendedTimestamp {
n += 4
}
return n
}
// Marshal writes the chunk.
func (c Chunk3) Marshal() ([]byte, error) {
buf := make([]byte, 1+len(c.Body))
func (c Chunk3) Marshal(hasExtendedTimestamp bool) ([]byte, error) {
buf := make([]byte, c.marshalSize(hasExtendedTimestamp))
buf[0] = 3<<6 | c.ChunkStreamID
copy(buf[1:], c.Body)
if hasExtendedTimestamp {
copy(buf[5:], c.Body)
} else {
copy(buf[1:], c.Body)
}
return buf, nil
}

30
internal/protocols/rtmp/chunk/chunk3_test.go

@ -1,30 +0,0 @@ @@ -1,30 +0,0 @@
package chunk
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
var chunk3enc = []byte{
0xd9, 0x1, 0x2, 0x3, 0x4,
}
var chunk3dec = Chunk3{
ChunkStreamID: 25,
Body: []byte{0x01, 0x02, 0x03, 0x04},
}
func TestChunk3Read(t *testing.T) {
var chunk3 Chunk3
err := chunk3.Read(bytes.NewReader(chunk3enc), 4)
require.NoError(t, err)
require.Equal(t, chunk3dec, chunk3)
}
func TestChunk3Marshal(t *testing.T) {
buf, err := chunk3dec.Marshal()
require.NoError(t, err)
require.Equal(t, chunk3enc, buf)
}

158
internal/protocols/rtmp/chunk/chunk_test.go

@ -0,0 +1,158 @@ @@ -0,0 +1,158 @@
package chunk
import (
"bytes"
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
var cases = []struct {
name string
enc []byte
bodyLen uint32
hasExtendedTimestamp bool
dec Chunk
}{
{
"chunk0 standard",
[]byte{
0x19, 0xb1, 0xa1, 0x91, 0x0, 0x0, 0x14, 0x14,
0x3, 0x5d, 0x17, 0x3d, 0x1, 0x2, 0x3, 0x4,
},
4,
false,
&Chunk0{
ChunkStreamID: 25,
Timestamp: 11641233,
Type: 20,
MessageStreamID: 56432445,
BodyLen: 20,
Body: []byte{1, 2, 3, 4},
},
},
{
"chunk0 extended timestamp",
[]byte{
0x19, 0xff, 0xff, 0xff, 0x00, 0x00, 0x14, 0x0f,
0x00, 0x31, 0x84, 0xb2, 0xff, 0x34, 0x86, 0xa2,
0x05, 0x06, 0x07, 0x08,
},
4,
false,
&Chunk0{
ChunkStreamID: 25,
Timestamp: 0xFF3486a2,
Type: 15,
MessageStreamID: 3245234,
BodyLen: 20,
Body: []byte{5, 6, 7, 8},
},
},
{
"chunk1 standard",
[]byte{
0x59, 0xb1, 0xa1, 0x91, 0x0, 0x0, 0x14, 0x14,
0x1, 0x2, 0x3, 0x4,
},
4,
false,
&Chunk1{
ChunkStreamID: 25,
TimestampDelta: 11641233,
Type: 20,
BodyLen: 20,
Body: []byte{1, 2, 3, 4},
},
},
{
"chunk1 extended timestamp",
[]byte{
0x59, 0xff, 0xff, 0xff, 0x00, 0x00, 0x14, 0x14,
0xff, 0x88, 0x4b, 0x6c, 0x05, 0x06, 0x07, 0x08,
},
4,
false,
&Chunk1{
ChunkStreamID: 25,
TimestampDelta: 0xFF884B6C,
Type: 20,
BodyLen: 20,
Body: []byte{5, 6, 7, 8},
},
},
{
"chunk2 standard",
[]byte{
0x99, 0xb1, 0xa1, 0x91, 0x1, 0x2, 0x3, 0x4,
},
4,
false,
&Chunk2{
ChunkStreamID: 25,
TimestampDelta: 11641233,
Body: []byte{1, 2, 3, 4},
},
},
{
"chunk2 extended timestamp",
[]byte{
0x99, 0xff, 0xff, 0xff, 0xff, 0xaa, 0xbb, 0xcc,
0x05, 0x06, 0x07, 0x08,
},
4,
false,
&Chunk2{
ChunkStreamID: 25,
TimestampDelta: 0xFFAABBCC,
Body: []byte{5, 6, 7, 8},
},
},
{
"chunk3 standard",
[]byte{
0xd9, 0x1, 0x2, 0x3, 0x4,
},
4,
false,
&Chunk3{
ChunkStreamID: 25,
Body: []byte{1, 2, 3, 4},
},
},
{
"chunk3 extended timestamp",
[]byte{
0xd9, 0x00, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07,
0x08,
},
4,
true,
&Chunk3{
ChunkStreamID: 25,
Body: []byte{5, 6, 7, 8},
},
},
}
func TestChunkRead(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
chunk := reflect.New(reflect.TypeOf(ca.dec).Elem()).Interface().(Chunk)
err := chunk.Read(bytes.NewReader(ca.enc), ca.bodyLen, ca.hasExtendedTimestamp)
require.NoError(t, err)
require.Equal(t, ca.dec, chunk)
})
}
}
func TestChunkMarshal(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
buf, err := ca.dec.Marshal(ca.hasExtendedTimestamp)
require.NoError(t, err)
require.Equal(t, ca.enc, buf)
})
}
}

18
internal/protocols/rtmp/rawmessage/reader.go

@ -37,10 +37,11 @@ type readerChunkStream struct { @@ -37,10 +37,11 @@ type readerChunkStream struct {
curBodyRecv uint32
curTimestampDelta uint32
curTimestampDeltaAvailable bool
hasExtendedTimestamp bool
}
func (rc *readerChunkStream) readChunk(c chunk.Chunk, chunkBodySize uint32) error {
err := c.Read(rc.mr.br, chunkBodySize)
func (rc *readerChunkStream) readChunk(c chunk.Chunk, bodySize uint32, hasExtendedTimestamp bool) error {
err := c.Read(rc.mr.br, bodySize, hasExtendedTimestamp)
if err != nil {
return err
}
@ -70,7 +71,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -70,7 +71,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
return nil, fmt.Errorf("received type 0 chunk but expected type 3 chunk")
}
err := rc.readChunk(&rc.mr.c0, rc.mr.chunkSize)
err := rc.readChunk(&rc.mr.c0, rc.mr.chunkSize, false)
if err != nil {
return nil, err
}
@ -81,6 +82,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -81,6 +82,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
rc.curTimestampAvailable = true
rc.curTimestampDeltaAvailable = false
rc.curBodyLen = rc.mr.c0.BodyLen
rc.hasExtendedTimestamp = rc.mr.c0.Timestamp >= 0xFFFFFF
if rc.curBodyLen > maxBodySize {
return nil, fmt.Errorf("body size (%d) exceeds maximum (%d)", rc.curBodyLen, maxBodySize)
@ -109,7 +111,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -109,7 +111,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
return nil, fmt.Errorf("received type 1 chunk but expected type 3 chunk")
}
err := rc.readChunk(&rc.mr.c1, rc.mr.chunkSize)
err := rc.readChunk(&rc.mr.c1, rc.mr.chunkSize, false)
if err != nil {
return nil, err
}
@ -119,6 +121,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -119,6 +121,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
rc.curTimestampDelta = rc.mr.c1.TimestampDelta
rc.curTimestampDeltaAvailable = true
rc.curBodyLen = rc.mr.c1.BodyLen
rc.hasExtendedTimestamp = rc.mr.c1.TimestampDelta >= 0xFFFFFF
if rc.curBodyLen > maxBodySize {
return nil, fmt.Errorf("body size (%d) exceeds maximum (%d)", rc.curBodyLen, maxBodySize)
@ -152,7 +155,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -152,7 +155,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
chunkBodyLen = rc.mr.chunkSize
}
err := rc.readChunk(&rc.mr.c2, chunkBodyLen)
err := rc.readChunk(&rc.mr.c2, chunkBodyLen, false)
if err != nil {
return nil, err
}
@ -160,6 +163,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -160,6 +163,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
rc.curTimestamp += rc.mr.c2.TimestampDelta
rc.curTimestampDelta = rc.mr.c2.TimestampDelta
rc.curTimestampDeltaAvailable = true
rc.hasExtendedTimestamp = rc.mr.c2.TimestampDelta >= 0xFFFFFF
le := uint32(len(rc.mr.c2.Body))
@ -182,7 +186,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -182,7 +186,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
chunkBodyLen = rc.mr.chunkSize
}
err := rc.readChunk(&rc.mr.c3, chunkBodyLen)
err := rc.readChunk(&rc.mr.c3, chunkBodyLen, rc.hasExtendedTimestamp)
if err != nil {
return nil, err
}
@ -212,7 +216,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -212,7 +216,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
chunkBodyLen = rc.mr.chunkSize
}
err := rc.readChunk(&rc.mr.c3, chunkBodyLen)
err := rc.readChunk(&rc.mr.c3, chunkBodyLen, rc.hasExtendedTimestamp)
if err != nil {
return nil, err
}

57
internal/protocols/rtmp/rawmessage/reader_test.go

@ -12,10 +12,9 @@ import ( @@ -12,10 +12,9 @@ import (
)
var cases = []struct {
name string
messages []*Message
chunks []chunk.Chunk
chunkSizes []uint32
name string
messages []*Message
chunks []chunk.Chunk
}{
{
"(chunk0) + (chunk1)",
@ -52,10 +51,6 @@ var cases = []struct { @@ -52,10 +51,6 @@ var cases = []struct {
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
[]uint32{
128,
128,
},
},
{
"(chunk0) + (chunk2) + (chunk3)",
@ -101,11 +96,6 @@ var cases = []struct { @@ -101,11 +96,6 @@ var cases = []struct {
Body: bytes.Repeat([]byte{0x05}, 64),
},
},
[]uint32{
128,
64,
64,
},
},
{
"(chunk0 + chunk3) + (chunk1 + chunk3) + (chunk2 + chunk3) + (chunk3 + chunk3)",
@ -181,15 +171,31 @@ var cases = []struct { @@ -181,15 +171,31 @@ var cases = []struct {
Body: bytes.Repeat([]byte{0x06}, 64),
},
},
[]uint32{
128,
62,
128,
64,
128,
64,
128,
64,
},
{
"(chunk0 + chunk3 with extended timestamp)",
[]*Message{
{
ChunkStreamID: 27,
Timestamp: 0xFF123456 * time.Millisecond,
Type: 6,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{5}, 160),
},
},
[]chunk.Chunk{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 4279383126,
Type: 6,
MessageStreamID: 3123,
BodyLen: 160,
Body: bytes.Repeat([]byte{5}, 128),
},
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{5}, 32),
},
},
},
}
@ -203,10 +209,13 @@ func TestReader(t *testing.T) { @@ -203,10 +209,13 @@ func TestReader(t *testing.T) {
return nil
})
hasExtendedTimestamp := false
for _, cach := range ca.chunks {
buf2, err := cach.Marshal()
buf2, err := cach.Marshal(hasExtendedTimestamp)
require.NoError(t, err)
buf.Write(buf2)
hasExtendedTimestamp = chunkHasExtendedTimestamp(cach)
}
for _, camsg := range ca.messages {
@ -247,7 +256,7 @@ func TestReaderAcknowledge(t *testing.T) { @@ -247,7 +256,7 @@ func TestReaderAcknowledge(t *testing.T) {
MessageStreamID: 3123,
BodyLen: 200,
Body: bytes.Repeat([]byte{0x03}, 200),
}.Marshal()
}.Marshal(false)
require.NoError(t, err)
buf.Write(buf2)

39
internal/protocols/rtmp/rawmessage/writer.go

@ -11,15 +11,16 @@ import ( @@ -11,15 +11,16 @@ import (
)
type writerChunkStream struct {
mw *Writer
lastMessageStreamID *uint32
lastType *uint8
lastBodyLen *uint32
lastTimestamp *int64
lastTimestampDelta *int64
mw *Writer
lastMessageStreamID *uint32
lastType *uint8
lastBodyLen *uint32
lastTimestamp *int64
lastTimestampDelta *int64
hasExtendedTimestamp bool
}
func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error {
func (wc *writerChunkStream) writeChunk(c chunk.Chunk, hasExtendedTimestamp bool) error {
// check if we received an acknowledge
if wc.mw.checkAcknowledge && wc.mw.ackWindowSize != 0 {
diff := uint32(wc.mw.bcw.Count()) - wc.mw.ackValue
@ -29,7 +30,7 @@ func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error { @@ -29,7 +30,7 @@ func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error {
}
}
buf, err := c.Marshal()
buf, err := c.Marshal(hasExtendedTimestamp)
if err != nil {
return err
}
@ -72,45 +73,51 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error { @@ -72,45 +73,51 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
switch {
case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID:
ts := uint32(timestamp)
err := wc.writeChunk(&chunk.Chunk0{
ChunkStreamID: msg.ChunkStreamID,
Timestamp: uint32(timestamp),
Timestamp: ts,
Type: msg.Type,
MessageStreamID: msg.MessageStreamID,
BodyLen: (bodyLen),
Body: msg.Body[pos : pos+chunkBodyLen],
})
}, false)
if err != nil {
return err
}
wc.hasExtendedTimestamp = ts >= 0xFFFFFF
case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen:
ts := uint32(*timestampDelta)
err := wc.writeChunk(&chunk.Chunk1{
ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: uint32(*timestampDelta),
TimestampDelta: ts,
Type: msg.Type,
BodyLen: (bodyLen),
Body: msg.Body[pos : pos+chunkBodyLen],
})
}, false)
if err != nil {
return err
}
wc.hasExtendedTimestamp = ts >= 0xFFFFFF
case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta:
ts := uint32(*timestampDelta)
err := wc.writeChunk(&chunk.Chunk2{
ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: uint32(*timestampDelta),
TimestampDelta: ts,
Body: msg.Body[pos : pos+chunkBodyLen],
})
}, false)
if err != nil {
return err
}
wc.hasExtendedTimestamp = ts >= 0xFFFFFF
default:
err := wc.writeChunk(&chunk.Chunk3{
ChunkStreamID: msg.ChunkStreamID,
Body: msg.Body[pos : pos+chunkBodyLen],
})
}, wc.hasExtendedTimestamp)
if err != nil {
return err
}
@ -133,7 +140,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error { @@ -133,7 +140,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
err := wc.writeChunk(&chunk.Chunk3{
ChunkStreamID: msg.ChunkStreamID,
Body: msg.Body[pos : pos+chunkBodyLen],
})
}, wc.hasExtendedTimestamp)
if err != nil {
return err
}

35
internal/protocols/rtmp/rawmessage/writer_test.go

@ -11,6 +11,34 @@ import ( @@ -11,6 +11,34 @@ import (
"github.com/stretchr/testify/require"
)
func chunkBodySize(ch chunk.Chunk) uint32 {
switch ch := ch.(type) {
case *chunk.Chunk0:
return uint32(len(ch.Body))
case *chunk.Chunk1:
return uint32(len(ch.Body))
case *chunk.Chunk2:
return uint32(len(ch.Body))
case *chunk.Chunk3:
return uint32(len(ch.Body))
}
return 0
}
func chunkHasExtendedTimestamp(ch chunk.Chunk) bool {
switch ch := ch.(type) {
case *chunk.Chunk0:
return ch.Timestamp >= 0xFFFFFF
case *chunk.Chunk1:
return ch.TimestampDelta >= 0xFFFFFF
case *chunk.Chunk2:
return ch.TimestampDelta >= 0xFFFFFF
case *chunk.Chunk3:
return false
}
return false
}
func TestWriter(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
@ -23,11 +51,14 @@ func TestWriter(t *testing.T) { @@ -23,11 +51,14 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
}
for i, cach := range ca.chunks {
hasExtendedTimestamp := false
for _, cach := range ca.chunks {
ch := reflect.New(reflect.TypeOf(cach).Elem()).Interface().(chunk.Chunk)
err := ch.Read(&buf, ca.chunkSizes[i])
err := ch.Read(&buf, chunkBodySize(cach), hasExtendedTimestamp)
require.NoError(t, err)
require.Equal(t, cach, ch)
hasExtendedTimestamp = chunkHasExtendedTimestamp(cach)
}
})
}

Loading…
Cancel
Save