diff --git a/internal/rtmp/chunk/chunk0.go b/internal/rtmp/chunk/chunk0.go index 88612dd4..f51802f7 100644 --- a/internal/rtmp/chunk/chunk0.go +++ b/internal/rtmp/chunk/chunk0.go @@ -42,25 +42,20 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen int) error { } // Write writes the chunk. -func (c Chunk0) Write(w io.Writer) error { - header := make([]byte, 12) - header[0] = c.ChunkStreamID - header[1] = byte(c.Timestamp >> 16) - header[2] = byte(c.Timestamp >> 8) - header[3] = byte(c.Timestamp) - header[4] = byte(c.BodyLen >> 16) - header[5] = byte(c.BodyLen >> 8) - header[6] = byte(c.BodyLen) - header[7] = byte(c.Type) - header[8] = byte(c.MessageStreamID >> 24) - header[9] = byte(c.MessageStreamID >> 16) - header[10] = byte(c.MessageStreamID >> 8) - header[11] = byte(c.MessageStreamID) - _, err := w.Write(header) - if err != nil { - return err - } - - _, err = w.Write(c.Body) - return err +func (c Chunk0) Write() ([]byte, error) { + buf := make([]byte, 12+len(c.Body)) + buf[0] = c.ChunkStreamID + buf[1] = byte(c.Timestamp >> 16) + buf[2] = byte(c.Timestamp >> 8) + buf[3] = byte(c.Timestamp) + buf[4] = byte(c.BodyLen >> 16) + buf[5] = byte(c.BodyLen >> 8) + buf[6] = byte(c.BodyLen) + buf[7] = byte(c.Type) + buf[8] = byte(c.MessageStreamID >> 24) + buf[9] = byte(c.MessageStreamID >> 16) + buf[10] = byte(c.MessageStreamID >> 8) + buf[11] = byte(c.MessageStreamID) + copy(buf[12:], c.Body) + return buf, nil } diff --git a/internal/rtmp/chunk/chunk0_test.go b/internal/rtmp/chunk/chunk0_test.go index 2021a2ae..ed5a3643 100644 --- a/internal/rtmp/chunk/chunk0_test.go +++ b/internal/rtmp/chunk/chunk0_test.go @@ -29,8 +29,7 @@ func TestChunk0Read(t *testing.T) { } func TestChunk0Write(t *testing.T) { - var buf bytes.Buffer - err := chunk0dec.Write(&buf) + buf, err := chunk0dec.Write() require.NoError(t, err) - require.Equal(t, chunk0enc, buf.Bytes()) + require.Equal(t, chunk0enc, buf) } diff --git a/internal/rtmp/chunk/chunk1.go b/internal/rtmp/chunk/chunk1.go index b594c92e..d316e336 100644 --- a/internal/rtmp/chunk/chunk1.go +++ b/internal/rtmp/chunk/chunk1.go @@ -42,21 +42,16 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen int) error { } // Write writes the chunk. -func (c Chunk1) Write(w io.Writer) error { - header := make([]byte, 8) - header[0] = 1<<6 | c.ChunkStreamID - header[1] = byte(c.TimestampDelta >> 16) - header[2] = byte(c.TimestampDelta >> 8) - header[3] = byte(c.TimestampDelta) - header[4] = byte(c.BodyLen >> 16) - header[5] = byte(c.BodyLen >> 8) - header[6] = byte(c.BodyLen) - header[7] = byte(c.Type) - _, err := w.Write(header) - if err != nil { - return err - } - - _, err = w.Write(c.Body) - return err +func (c Chunk1) Write() ([]byte, error) { + buf := make([]byte, 8+len(c.Body)) + buf[0] = 1<<6 | c.ChunkStreamID + buf[1] = byte(c.TimestampDelta >> 16) + buf[2] = byte(c.TimestampDelta >> 8) + buf[3] = byte(c.TimestampDelta) + buf[4] = byte(c.BodyLen >> 16) + buf[5] = byte(c.BodyLen >> 8) + buf[6] = byte(c.BodyLen) + buf[7] = byte(c.Type) + copy(buf[8:], c.Body) + return buf, nil } diff --git a/internal/rtmp/chunk/chunk1_test.go b/internal/rtmp/chunk/chunk1_test.go index 903a33de..5339740c 100644 --- a/internal/rtmp/chunk/chunk1_test.go +++ b/internal/rtmp/chunk/chunk1_test.go @@ -28,8 +28,7 @@ func TestChunk1Read(t *testing.T) { } func TestChunk1Write(t *testing.T) { - var buf bytes.Buffer - err := chunk1dec.Write(&buf) + buf, err := chunk1dec.Write() require.NoError(t, err) - require.Equal(t, chunk1enc, buf.Bytes()) + require.Equal(t, chunk1enc, buf) } diff --git a/internal/rtmp/chunk/chunk2.go b/internal/rtmp/chunk/chunk2.go index 5d552c4f..56ca9e13 100644 --- a/internal/rtmp/chunk/chunk2.go +++ b/internal/rtmp/chunk/chunk2.go @@ -31,17 +31,12 @@ func (c *Chunk2) Read(r io.Reader, chunkBodyLen int) error { } // Write writes the chunk. -func (c Chunk2) Write(w io.Writer) error { - header := make([]byte, 4) - header[0] = 2<<6 | c.ChunkStreamID - header[1] = byte(c.TimestampDelta >> 16) - header[2] = byte(c.TimestampDelta >> 8) - header[3] = byte(c.TimestampDelta) - _, err := w.Write(header) - if err != nil { - return err - } - - _, err = w.Write(c.Body) - return err +func (c Chunk2) Write() ([]byte, error) { + buf := make([]byte, 4+len(c.Body)) + buf[0] = 2<<6 | c.ChunkStreamID + buf[1] = byte(c.TimestampDelta >> 16) + buf[2] = byte(c.TimestampDelta >> 8) + buf[3] = byte(c.TimestampDelta) + copy(buf[4:], c.Body) + return buf, nil } diff --git a/internal/rtmp/chunk/chunk2_test.go b/internal/rtmp/chunk/chunk2_test.go index 99bfac68..b7b9a14d 100644 --- a/internal/rtmp/chunk/chunk2_test.go +++ b/internal/rtmp/chunk/chunk2_test.go @@ -25,8 +25,7 @@ func TestChunk2Read(t *testing.T) { } func TestChunk2Write(t *testing.T) { - var buf bytes.Buffer - err := chunk2dec.Write(&buf) + buf, err := chunk2dec.Write() require.NoError(t, err) - require.Equal(t, chunk2enc, buf.Bytes()) + require.Equal(t, chunk2enc, buf) } diff --git a/internal/rtmp/chunk/chunk3.go b/internal/rtmp/chunk/chunk3.go index e8008abd..1146c7a2 100644 --- a/internal/rtmp/chunk/chunk3.go +++ b/internal/rtmp/chunk/chunk3.go @@ -31,14 +31,9 @@ func (c *Chunk3) Read(r io.Reader, chunkBodyLen int) error { } // Write writes the chunk. -func (c Chunk3) Write(w io.Writer) error { - header := make([]byte, 1) - header[0] = 3<<6 | c.ChunkStreamID - _, err := w.Write(header) - if err != nil { - return err - } - - _, err = w.Write(c.Body) - return err +func (c Chunk3) Write() ([]byte, error) { + buf := make([]byte, 1+len(c.Body)) + buf[0] = 3<<6 | c.ChunkStreamID + copy(buf[1:], c.Body) + return buf, nil } diff --git a/internal/rtmp/chunk/chunk3_test.go b/internal/rtmp/chunk/chunk3_test.go index f6a1e874..4ed25956 100644 --- a/internal/rtmp/chunk/chunk3_test.go +++ b/internal/rtmp/chunk/chunk3_test.go @@ -24,8 +24,7 @@ func TestChunk3Read(t *testing.T) { } func TestChunk3Write(t *testing.T) { - var buf bytes.Buffer - err := chunk3dec.Write(&buf) + buf, err := chunk3dec.Write() require.NoError(t, err) - require.Equal(t, chunk3enc, buf.Bytes()) + require.Equal(t, chunk3enc, buf) } diff --git a/internal/rtmp/rawmessage/reader_test.go b/internal/rtmp/rawmessage/reader_test.go index 4b7cddd0..b78201e4 100644 --- a/internal/rtmp/rawmessage/reader_test.go +++ b/internal/rtmp/rawmessage/reader_test.go @@ -3,7 +3,6 @@ package rawmessage import ( "bufio" "bytes" - "io" "testing" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" @@ -11,7 +10,7 @@ import ( ) type writableChunk interface { - Write(w io.Writer) error + Write() ([]byte, error) } type sequenceEntry struct { @@ -25,8 +24,9 @@ func TestReader(t *testing.T) { r := NewReader(bufio.NewReader(&buf)) for _, entry := range seq { - err := entry.chunk.Write(&buf) + buf2, err := entry.chunk.Write() require.NoError(t, err) + buf.Write(buf2) msg, err := r.Read() require.NoError(t, err) require.Equal(t, entry.msg, msg) @@ -124,21 +124,23 @@ func TestReader(t *testing.T) { var buf bytes.Buffer r := NewReader(bufio.NewReader(&buf)) - err := chunk.Chunk0{ + buf2, err := chunk.Chunk0{ ChunkStreamID: 27, Timestamp: 18576, Type: chunk.MessageTypeSetPeerBandwidth, MessageStreamID: 3123, BodyLen: 192, Body: bytes.Repeat([]byte{0x03}, 128), - }.Write(&buf) + }.Write() require.NoError(t, err) + buf.Write(buf2) - err = chunk.Chunk3{ + buf2, err = chunk.Chunk3{ ChunkStreamID: 27, Body: bytes.Repeat([]byte{0x03}, 64), - }.Write(&buf) + }.Write() require.NoError(t, err) + buf.Write(buf2) msg, err := r.Read() require.NoError(t, err) diff --git a/internal/rtmp/rawmessage/writer.go b/internal/rtmp/rawmessage/writer.go index 20b069c4..b49696aa 100644 --- a/internal/rtmp/rawmessage/writer.go +++ b/internal/rtmp/rawmessage/writer.go @@ -42,45 +42,65 @@ func (wc *writerChunkStream) write(msg *Message) error { switch { case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID: - err := chunk.Chunk0{ + buf, err := chunk.Chunk0{ ChunkStreamID: msg.ChunkStreamID, Timestamp: msg.Timestamp, Type: msg.Type, MessageStreamID: msg.MessageStreamID, BodyLen: uint32(bodyLen), Body: msg.Body[pos : pos+chunkBodyLen], - }.Write(wc.mw.w) + }.Write() + if err != nil { + return err + } + + _, err = wc.mw.w.Write(buf) if err != nil { return err } case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen: - err := chunk.Chunk1{ + buf, err := chunk.Chunk1{ ChunkStreamID: msg.ChunkStreamID, TimestampDelta: *timestampDelta, Type: msg.Type, BodyLen: uint32(bodyLen), Body: msg.Body[pos : pos+chunkBodyLen], - }.Write(wc.mw.w) + }.Write() + if err != nil { + return err + } + + _, err = wc.mw.w.Write(buf) if err != nil { return err } case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta: - err := chunk.Chunk2{ + buf, err := chunk.Chunk2{ ChunkStreamID: msg.ChunkStreamID, TimestampDelta: *timestampDelta, Body: msg.Body[pos : pos+chunkBodyLen], - }.Write(wc.mw.w) + }.Write() + if err != nil { + return err + } + + _, err = wc.mw.w.Write(buf) if err != nil { return err } default: - err := chunk.Chunk3{ + buf, err := chunk.Chunk3{ ChunkStreamID: msg.ChunkStreamID, Body: msg.Body[pos : pos+chunkBodyLen], - }.Write(wc.mw.w) + }.Write() + if err != nil { + return err + } + + _, err = wc.mw.w.Write(buf) if err != nil { return err } @@ -100,10 +120,15 @@ func (wc *writerChunkStream) write(msg *Message) error { wc.lastTimestampDelta = &v5 } } else { - err := chunk.Chunk3{ + buf, err := chunk.Chunk3{ ChunkStreamID: msg.ChunkStreamID, Body: msg.Body[pos : pos+chunkBodyLen], - }.Write(wc.mw.w) + }.Write() + if err != nil { + return err + } + + _, err = wc.mw.w.Write(buf) if err != nil { return err }