Browse Source

rtmp: add MessageReader

pull/956/head
aler9 3 years ago
parent
commit
959b78586a
  1. 28
      internal/rtmp/base/chunk1.go
  2. 21
      internal/rtmp/base/chunk2.go
  3. 20
      internal/rtmp/base/chunk3.go
  4. 212
      internal/rtmp/base/messagereader.go
  5. 10
      internal/rtmp/base/messagewriter.go
  6. 161
      internal/rtmp/conn_test.go

28
internal/rtmp/base/chunk1.go

@ -1,6 +1,7 @@ @@ -1,6 +1,7 @@
package base
import (
"fmt"
"io"
)
@ -18,6 +19,33 @@ type Chunk1 struct { @@ -18,6 +19,33 @@ type Chunk1 struct {
Body []byte
}
// Read reads the chunk.
func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen int) error {
header := make([]byte, 8)
_, err := r.Read(header)
if err != nil {
return err
}
if header[0]>>6 != 1 {
return fmt.Errorf("wrong chunk header type")
}
c.ChunkStreamID = header[0] & 0x3F
c.TimestampDelta = uint32(header[3])<<16 | uint32(header[2])<<8 | uint32(header[1])
c.BodyLen = uint32(header[4])<<16 | uint32(header[5])<<8 | uint32(header[6])
c.Type = MessageType(header[7])
chunkBodyLen := int(c.BodyLen)
if chunkBodyLen > chunkMaxBodyLen {
chunkBodyLen = chunkMaxBodyLen
}
c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body)
return err
}
// Write writes the chunk.
func (c Chunk1) Write(w io.Writer) error {
header := make([]byte, 8)

21
internal/rtmp/base/chunk2.go

@ -1,6 +1,7 @@ @@ -1,6 +1,7 @@
package base
import (
"fmt"
"io"
)
@ -14,6 +15,26 @@ type Chunk2 struct { @@ -14,6 +15,26 @@ type Chunk2 struct {
Body []byte
}
// Read reads the chunk.
func (c *Chunk2) Read(r io.Reader, chunkBodyLen int) error {
header := make([]byte, 4)
_, err := r.Read(header)
if err != nil {
return err
}
if header[0]>>6 != 2 {
return fmt.Errorf("wrong chunk header type")
}
c.ChunkStreamID = header[0] & 0x3F
c.TimestampDelta = uint32(header[3])<<16 | uint32(header[2])<<8 | uint32(header[1])
c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body)
return err
}
// Write writes the chunk.
func (c Chunk2) Write(w io.Writer) error {
header := make([]byte, 4)

20
internal/rtmp/base/chunk3.go

@ -1,6 +1,7 @@ @@ -1,6 +1,7 @@
package base
import (
"fmt"
"io"
)
@ -15,6 +16,25 @@ type Chunk3 struct { @@ -15,6 +16,25 @@ type Chunk3 struct {
Body []byte
}
// Read reads the chunk.
func (c *Chunk3) Read(r io.Reader, chunkBodyLen int) error {
header := make([]byte, 1)
_, err := r.Read(header)
if err != nil {
return err
}
if header[0]>>6 != 2 {
return fmt.Errorf("wrong chunk header type")
}
c.ChunkStreamID = header[0] & 0x3F
c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body)
return err
}
// Write writes the chunk.
func (c Chunk3) Write(w io.Writer) error {
header := make([]byte, 1)

212
internal/rtmp/base/messagereader.go

@ -0,0 +1,212 @@ @@ -0,0 +1,212 @@
package base
import (
"bufio"
"errors"
"fmt"
)
var errMoreChunksNeeded = errors.New("more chunks are needed")
type messageReaderChunkStream struct {
mr *MessageReader
curTimestamp *uint32
curType *MessageType
curMessageStreamID *uint32
curBodyLen *uint32
curBody *[]byte
}
func (rc *messageReaderChunkStream) read(typ byte) (*Message, error) {
switch typ {
case 0:
if rc.curBody != nil {
return nil, fmt.Errorf("received type 0 chunk but expected type 3 chunk")
}
var c0 Chunk0
err := c0.Read(rc.mr.r, rc.mr.chunkSize)
if err != nil {
return nil, err
}
v1 := c0.MessageStreamID
rc.curMessageStreamID = &v1
v2 := c0.Type
rc.curType = &v2
v3 := c0.Timestamp
rc.curTimestamp = &v3
v4 := c0.BodyLen
rc.curBodyLen = &v4
if c0.BodyLen != uint32(len(c0.Body)) {
rc.curBody = &c0.Body
return nil, errMoreChunksNeeded
}
return &Message{
Timestamp: c0.Timestamp,
Type: c0.Type,
MessageStreamID: c0.MessageStreamID,
Body: c0.Body,
}, nil
case 1:
if rc.curTimestamp == nil {
return nil, fmt.Errorf("received type 1 chunk without previous chunk")
}
if rc.curBody != nil {
return nil, fmt.Errorf("received type 1 chunk but expected type 3 chunk")
}
var c1 Chunk1
err := c1.Read(rc.mr.r, rc.mr.chunkSize)
if err != nil {
return nil, err
}
v2 := c1.Type
rc.curType = &v2
v3 := *rc.curTimestamp + c1.TimestampDelta
rc.curTimestamp = &v3
v4 := c1.BodyLen
rc.curBodyLen = &v4
if c1.BodyLen != uint32(len(c1.Body)) {
rc.curBody = &c1.Body
return nil, errMoreChunksNeeded
}
return &Message{
Timestamp: *rc.curTimestamp + c1.TimestampDelta,
Type: c1.Type,
MessageStreamID: *rc.curMessageStreamID,
Body: c1.Body,
}, nil
case 2:
if rc.curTimestamp == nil {
return nil, fmt.Errorf("received type 2 chunk without previous chunk")
}
if rc.curBody != nil {
return nil, fmt.Errorf("received type 2 chunk but expected type 3 chunk")
}
chunkBodyLen := int(*rc.curBodyLen)
if chunkBodyLen > rc.mr.chunkSize {
chunkBodyLen = rc.mr.chunkSize
}
var c2 Chunk2
err := c2.Read(rc.mr.r, chunkBodyLen)
if err != nil {
return nil, err
}
v3 := *rc.curTimestamp + c2.TimestampDelta
rc.curTimestamp = &v3
if chunkBodyLen != len(c2.Body) {
rc.curBody = &c2.Body
return nil, errMoreChunksNeeded
}
return &Message{
Timestamp: *rc.curTimestamp + c2.TimestampDelta,
Type: *rc.curType,
MessageStreamID: *rc.curMessageStreamID,
Body: c2.Body,
}, nil
default: // 3
if rc.curTimestamp == nil {
return nil, fmt.Errorf("received type 3 chunk without previous chunk")
}
if rc.curBody == nil {
return nil, fmt.Errorf("unsupported")
}
chunkBodyLen := int(*rc.curBodyLen)
if chunkBodyLen > rc.mr.chunkSize {
chunkBodyLen = rc.mr.chunkSize
}
var c3 Chunk3
err := c3.Read(rc.mr.r, chunkBodyLen)
if err != nil {
return nil, err
}
*rc.curBody = append(*rc.curBody, c3.Body...)
if *rc.curBodyLen != uint32(len(*rc.curBody)) {
return nil, errMoreChunksNeeded
}
body := *rc.curBody
rc.curBody = nil
return &Message{
Timestamp: *rc.curTimestamp,
Type: *rc.curType,
MessageStreamID: *rc.curMessageStreamID,
Body: body,
}, nil
}
}
// MessageReader is a message reader.
type MessageReader struct {
r *bufio.Reader
chunkSize int
chunkStreams map[byte]*messageReaderChunkStream
}
// NewMessageReader allocates a MessageReader.
func NewMessageReader(r *bufio.Reader) *MessageReader {
return &MessageReader{
r: r,
chunkSize: 128,
chunkStreams: make(map[byte]*messageReaderChunkStream),
}
}
// SetChunkSize sets the maximum chunk size.
func (mr *MessageReader) SetChunkSize(v int) {
mr.chunkSize = v
}
func (mr *MessageReader) Read() (*Message, error) {
for {
byt, err := mr.r.ReadByte()
if err != nil {
return nil, err
}
typ := byt >> 6
chunkStreamID := byt & 0x3F
rc, ok := mr.chunkStreams[chunkStreamID]
if !ok {
rc = &messageReaderChunkStream{mr: mr}
mr.chunkStreams[chunkStreamID] = rc
}
mr.r.UnreadByte()
msg, err := rc.read(typ)
if err != nil {
if err == errMoreChunksNeeded {
continue
}
return nil, err
}
msg.ChunkStreamID = chunkStreamID
return msg, err
}
}

10
internal/rtmp/base/messagewriter.go

@ -122,7 +122,7 @@ type MessageWriter struct { @@ -122,7 +122,7 @@ type MessageWriter struct {
chunkStreams map[byte]*messageWriterChunkStream
}
// NewMessageWriter instantiates a MessageWriter.
// NewMessageWriter allocates a MessageWriter.
func NewMessageWriter(w io.Writer) *MessageWriter {
return &MessageWriter{
w: w,
@ -138,11 +138,11 @@ func (mw *MessageWriter) SetChunkSize(v int) { @@ -138,11 +138,11 @@ func (mw *MessageWriter) SetChunkSize(v int) {
// Write writes a Message.
func (mw *MessageWriter) Write(msg *Message) error {
cs, ok := mw.chunkStreams[msg.ChunkStreamID]
wc, ok := mw.chunkStreams[msg.ChunkStreamID]
if !ok {
cs = &messageWriterChunkStream{mw: mw}
mw.chunkStreams[msg.ChunkStreamID] = cs
wc = &messageWriterChunkStream{mw: mw}
mw.chunkStreams[msg.ChunkStreamID] = wc
}
return cs.write(msg)
return wc.write(msg)
}

161
internal/rtmp/conn_test.go

@ -1,6 +1,7 @@ @@ -1,6 +1,7 @@
package rtmp
import (
"bufio"
"net"
"net/url"
"strings"
@ -135,6 +136,7 @@ func TestReadTracks(t *testing.T) { @@ -135,6 +136,7 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
mw := base.NewMessageWriter(conn)
mr := base.NewMessageReader(bufio.NewReader(conn))
// C->S connect
byts := flvio.FillAMF0ValsMalloc([]interface{}{
@ -159,42 +161,40 @@ func TestReadTracks(t *testing.T) { @@ -159,42 +161,40 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C window acknowledgement size
var c0 base.Chunk0
err = c0.Read(conn, 128)
msg, err := mr.Read()
require.NoError(t, err)
require.Equal(t, base.Chunk0{
require.Equal(t, &base.Message{
ChunkStreamID: base.ControlChunkStreamID,
Type: base.MessageTypeSetWindowAckSize,
BodyLen: 4,
Body: []byte{0x00, 38, 37, 160},
}, c0)
}, msg)
// S->C set peer bandwidth
err = c0.Read(conn, 128)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, base.Chunk0{
require.Equal(t, &base.Message{
ChunkStreamID: base.ControlChunkStreamID,
Type: base.MessageTypeSetPeerBandwidth,
BodyLen: 5,
Body: []byte{0x00, 0x26, 0x25, 0xa0, 0x02},
}, c0)
}, msg)
// S->C set chunk size
err = c0.Read(conn, 128)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, base.Chunk0{
require.Equal(t, &base.Message{
ChunkStreamID: base.ControlChunkStreamID,
Type: base.MessageTypeSetChunkSize,
BodyLen: 4,
Body: []byte{0x00, 0x01, 0x00, 0x00},
}, c0)
}, msg)
mr.SetChunkSize(65536)
// S->C result
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(3), c0.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, c0.Type)
arr, err := flvio.ParseAMFVals(c0.Body, false)
require.Equal(t, uint8(3), msg.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, msg.Type)
arr, err := flvio.ParseAMFVals(msg.Body, false)
require.NoError(t, err)
require.Equal(t, []interface{}{
"_result",
@ -260,11 +260,11 @@ func TestReadTracks(t *testing.T) { @@ -260,11 +260,11 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C result
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(3), c0.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, c0.Type)
arr, err = flvio.ParseAMFVals(c0.Body, false)
require.Equal(t, uint8(3), msg.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, msg.Type)
arr, err = flvio.ParseAMFVals(msg.Body, false)
require.NoError(t, err)
require.Equal(t, []interface{}{
"_result",
@ -289,11 +289,11 @@ func TestReadTracks(t *testing.T) { @@ -289,11 +289,11 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C onStatus
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(5), c0.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, c0.Type)
arr, err = flvio.ParseAMFVals(c0.Body, false)
require.Equal(t, uint8(5), msg.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, msg.Type)
arr, err = flvio.ParseAMFVals(msg.Body, false)
require.NoError(t, err)
require.Equal(t, []interface{}{
"onStatus",
@ -514,6 +514,7 @@ func TestWriteTracks(t *testing.T) { @@ -514,6 +514,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
mw := base.NewMessageWriter(conn)
mr := base.NewMessageReader(bufio.NewReader(conn))
// C->S connect
byts := flvio.FillAMF0ValsMalloc([]interface{}{
@ -538,42 +539,40 @@ func TestWriteTracks(t *testing.T) { @@ -538,42 +539,40 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C window acknowledgement size
var c0 base.Chunk0
err = c0.Read(conn, 128)
msg, err := mr.Read()
require.NoError(t, err)
require.Equal(t, base.Chunk0{
require.Equal(t, &base.Message{
ChunkStreamID: base.ControlChunkStreamID,
Type: base.MessageTypeSetWindowAckSize,
BodyLen: 4,
Body: []byte{0x00, 38, 37, 160},
}, c0)
}, msg)
// S->C set peer bandwidth
err = c0.Read(conn, 128)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, base.Chunk0{
require.Equal(t, &base.Message{
ChunkStreamID: base.ControlChunkStreamID,
Type: base.MessageTypeSetPeerBandwidth,
BodyLen: 5,
Body: []byte{0x00, 0x26, 0x25, 0xa0, 0x02},
}, c0)
}, msg)
// S->C set chunk size
err = c0.Read(conn, 128)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, base.Chunk0{
require.Equal(t, &base.Message{
ChunkStreamID: base.ControlChunkStreamID,
Type: base.MessageTypeSetChunkSize,
BodyLen: 4,
Body: []byte{0x00, 0x01, 0x00, 0x00},
}, c0)
}, msg)
mr.SetChunkSize(65536)
// S->C result
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(3), c0.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, c0.Type)
arr, err := flvio.ParseAMFVals(c0.Body, false)
require.Equal(t, uint8(3), msg.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, msg.Type)
arr, err := flvio.ParseAMFVals(msg.Body, false)
require.NoError(t, err)
require.Equal(t, []interface{}{
"_result",
@ -621,11 +620,11 @@ func TestWriteTracks(t *testing.T) { @@ -621,11 +620,11 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C result
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(3), c0.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, c0.Type)
arr, err = flvio.ParseAMFVals(c0.Body, false)
require.Equal(t, uint8(3), msg.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, msg.Type)
arr, err = flvio.ParseAMFVals(msg.Body, false)
require.NoError(t, err)
require.Equal(t, []interface{}{
"_result",
@ -663,31 +662,29 @@ func TestWriteTracks(t *testing.T) { @@ -663,31 +662,29 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C event "stream is recorded"
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, base.Chunk0{
require.Equal(t, &base.Message{
ChunkStreamID: base.ControlChunkStreamID,
Type: base.MessageTypeUserControl,
BodyLen: 6,
Body: []byte{0x00, 0x04, 0x00, 0x00, 0x00, 0x01},
}, c0)
}, msg)
// S->C event "stream begin 1"
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, base.Chunk0{
require.Equal(t, &base.Message{
ChunkStreamID: base.ControlChunkStreamID,
Type: base.MessageTypeUserControl,
BodyLen: 6,
Body: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
}, c0)
}, msg)
// S->C onStatus
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(5), c0.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, c0.Type)
arr, err = flvio.ParseAMFVals(c0.Body, false)
require.Equal(t, uint8(5), msg.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, msg.Type)
arr, err = flvio.ParseAMFVals(msg.Body, false)
require.NoError(t, err)
require.Equal(t, []interface{}{
"onStatus",
@ -701,11 +698,11 @@ func TestWriteTracks(t *testing.T) { @@ -701,11 +698,11 @@ func TestWriteTracks(t *testing.T) {
}, arr)
// S->C onStatus
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(5), c0.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, c0.Type)
arr, err = flvio.ParseAMFVals(c0.Body, false)
require.Equal(t, uint8(5), msg.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, msg.Type)
arr, err = flvio.ParseAMFVals(msg.Body, false)
require.NoError(t, err)
require.Equal(t, []interface{}{
"onStatus",
@ -719,11 +716,11 @@ func TestWriteTracks(t *testing.T) { @@ -719,11 +716,11 @@ func TestWriteTracks(t *testing.T) {
}, arr)
// S->C onStatus
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(5), c0.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, c0.Type)
arr, err = flvio.ParseAMFVals(c0.Body, false)
require.Equal(t, uint8(5), msg.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, msg.Type)
arr, err = flvio.ParseAMFVals(msg.Body, false)
require.NoError(t, err)
require.Equal(t, []interface{}{
"onStatus",
@ -737,11 +734,11 @@ func TestWriteTracks(t *testing.T) { @@ -737,11 +734,11 @@ func TestWriteTracks(t *testing.T) {
}, arr)
// S->C onStatus
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(5), c0.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, c0.Type)
arr, err = flvio.ParseAMFVals(c0.Body, false)
require.Equal(t, uint8(5), msg.ChunkStreamID)
require.Equal(t, base.MessageTypeCommandAMF0, msg.Type)
arr, err = flvio.ParseAMFVals(msg.Body, false)
require.NoError(t, err)
require.Equal(t, []interface{}{
"onStatus",
@ -755,11 +752,11 @@ func TestWriteTracks(t *testing.T) { @@ -755,11 +752,11 @@ func TestWriteTracks(t *testing.T) {
}, arr)
// S->C onMetadata
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(4), c0.ChunkStreamID)
require.Equal(t, base.MessageType(0x12), c0.Type)
arr, err = flvio.ParseAMFVals(c0.Body, false)
require.Equal(t, uint8(4), msg.ChunkStreamID)
require.Equal(t, base.MessageType(0x12), msg.Type)
arr, err = flvio.ParseAMFVals(msg.Body, false)
require.NoError(t, err)
require.Equal(t, []interface{}{
"onMetaData",
@ -772,10 +769,10 @@ func TestWriteTracks(t *testing.T) { @@ -772,10 +769,10 @@ func TestWriteTracks(t *testing.T) {
}, arr)
// S->C H264 decoder config
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(6), c0.ChunkStreamID)
require.Equal(t, base.MessageType(0x09), c0.Type)
require.Equal(t, uint8(6), msg.ChunkStreamID)
require.Equal(t, base.MessageType(0x09), msg.Type)
require.Equal(t, []byte{
0x17, 0x0, 0x0, 0x0, 0x0, 0x1, 0x64, 0x0,
0xc, 0xff, 0xe1, 0x0, 0x15, 0x67, 0x64, 0x0,
@ -783,12 +780,12 @@ func TestWriteTracks(t *testing.T) { @@ -783,12 +780,12 @@ func TestWriteTracks(t *testing.T) {
0x0, 0x3, 0x0, 0x2, 0x0, 0x0, 0x3, 0x0,
0x3d, 0x8, 0x1, 0x0, 0x4, 0x68, 0xee, 0x3c,
0x80,
}, c0.Body)
}, msg.Body)
// S->C AAC decoder config
err = c0.Read(conn, 65536)
msg, err = mr.Read()
require.NoError(t, err)
require.Equal(t, uint8(4), c0.ChunkStreamID)
require.Equal(t, base.MessageType(0x08), c0.Type)
require.Equal(t, []byte{0xae, 0x0, 0x12, 0x10}, c0.Body)
require.Equal(t, uint8(4), msg.ChunkStreamID)
require.Equal(t, base.MessageType(0x08), msg.Type)
require.Equal(t, []byte{0xae, 0x0, 0x12, 0x10}, msg.Body)
}

Loading…
Cancel
Save