Browse Source

rtmp: make chunk writes atomic

pull/1003/head
aler9 3 years ago
parent
commit
ee2908081e
  1. 37
      internal/rtmp/chunk/chunk0.go
  2. 5
      internal/rtmp/chunk/chunk0_test.go
  3. 29
      internal/rtmp/chunk/chunk1.go
  4. 5
      internal/rtmp/chunk/chunk1_test.go
  5. 21
      internal/rtmp/chunk/chunk2.go
  6. 5
      internal/rtmp/chunk/chunk2_test.go
  7. 15
      internal/rtmp/chunk/chunk3.go
  8. 5
      internal/rtmp/chunk/chunk3_test.go
  9. 16
      internal/rtmp/rawmessage/reader_test.go
  10. 45
      internal/rtmp/rawmessage/writer.go

37
internal/rtmp/chunk/chunk0.go

@ -42,25 +42,20 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen int) error {
} }
// Write writes the chunk. // Write writes the chunk.
func (c Chunk0) Write(w io.Writer) error { func (c Chunk0) Write() ([]byte, error) {
header := make([]byte, 12) buf := make([]byte, 12+len(c.Body))
header[0] = c.ChunkStreamID buf[0] = c.ChunkStreamID
header[1] = byte(c.Timestamp >> 16) buf[1] = byte(c.Timestamp >> 16)
header[2] = byte(c.Timestamp >> 8) buf[2] = byte(c.Timestamp >> 8)
header[3] = byte(c.Timestamp) buf[3] = byte(c.Timestamp)
header[4] = byte(c.BodyLen >> 16) buf[4] = byte(c.BodyLen >> 16)
header[5] = byte(c.BodyLen >> 8) buf[5] = byte(c.BodyLen >> 8)
header[6] = byte(c.BodyLen) buf[6] = byte(c.BodyLen)
header[7] = byte(c.Type) buf[7] = byte(c.Type)
header[8] = byte(c.MessageStreamID >> 24) buf[8] = byte(c.MessageStreamID >> 24)
header[9] = byte(c.MessageStreamID >> 16) buf[9] = byte(c.MessageStreamID >> 16)
header[10] = byte(c.MessageStreamID >> 8) buf[10] = byte(c.MessageStreamID >> 8)
header[11] = byte(c.MessageStreamID) buf[11] = byte(c.MessageStreamID)
_, err := w.Write(header) copy(buf[12:], c.Body)
if err != nil { return buf, nil
return err
}
_, err = w.Write(c.Body)
return err
} }

5
internal/rtmp/chunk/chunk0_test.go

@ -29,8 +29,7 @@ func TestChunk0Read(t *testing.T) {
} }
func TestChunk0Write(t *testing.T) { func TestChunk0Write(t *testing.T) {
var buf bytes.Buffer buf, err := chunk0dec.Write()
err := chunk0dec.Write(&buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, chunk0enc, buf.Bytes()) require.Equal(t, chunk0enc, buf)
} }

29
internal/rtmp/chunk/chunk1.go

@ -42,21 +42,16 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen int) error {
} }
// Write writes the chunk. // Write writes the chunk.
func (c Chunk1) Write(w io.Writer) error { func (c Chunk1) Write() ([]byte, error) {
header := make([]byte, 8) buf := make([]byte, 8+len(c.Body))
header[0] = 1<<6 | c.ChunkStreamID buf[0] = 1<<6 | c.ChunkStreamID
header[1] = byte(c.TimestampDelta >> 16) buf[1] = byte(c.TimestampDelta >> 16)
header[2] = byte(c.TimestampDelta >> 8) buf[2] = byte(c.TimestampDelta >> 8)
header[3] = byte(c.TimestampDelta) buf[3] = byte(c.TimestampDelta)
header[4] = byte(c.BodyLen >> 16) buf[4] = byte(c.BodyLen >> 16)
header[5] = byte(c.BodyLen >> 8) buf[5] = byte(c.BodyLen >> 8)
header[6] = byte(c.BodyLen) buf[6] = byte(c.BodyLen)
header[7] = byte(c.Type) buf[7] = byte(c.Type)
_, err := w.Write(header) copy(buf[8:], c.Body)
if err != nil { return buf, nil
return err
}
_, err = w.Write(c.Body)
return err
} }

5
internal/rtmp/chunk/chunk1_test.go

@ -28,8 +28,7 @@ func TestChunk1Read(t *testing.T) {
} }
func TestChunk1Write(t *testing.T) { func TestChunk1Write(t *testing.T) {
var buf bytes.Buffer buf, err := chunk1dec.Write()
err := chunk1dec.Write(&buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, chunk1enc, buf.Bytes()) require.Equal(t, chunk1enc, buf)
} }

21
internal/rtmp/chunk/chunk2.go

@ -31,17 +31,12 @@ func (c *Chunk2) Read(r io.Reader, chunkBodyLen int) error {
} }
// Write writes the chunk. // Write writes the chunk.
func (c Chunk2) Write(w io.Writer) error { func (c Chunk2) Write() ([]byte, error) {
header := make([]byte, 4) buf := make([]byte, 4+len(c.Body))
header[0] = 2<<6 | c.ChunkStreamID buf[0] = 2<<6 | c.ChunkStreamID
header[1] = byte(c.TimestampDelta >> 16) buf[1] = byte(c.TimestampDelta >> 16)
header[2] = byte(c.TimestampDelta >> 8) buf[2] = byte(c.TimestampDelta >> 8)
header[3] = byte(c.TimestampDelta) buf[3] = byte(c.TimestampDelta)
_, err := w.Write(header) copy(buf[4:], c.Body)
if err != nil { return buf, nil
return err
}
_, err = w.Write(c.Body)
return err
} }

5
internal/rtmp/chunk/chunk2_test.go

@ -25,8 +25,7 @@ func TestChunk2Read(t *testing.T) {
} }
func TestChunk2Write(t *testing.T) { func TestChunk2Write(t *testing.T) {
var buf bytes.Buffer buf, err := chunk2dec.Write()
err := chunk2dec.Write(&buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, chunk2enc, buf.Bytes()) require.Equal(t, chunk2enc, buf)
} }

15
internal/rtmp/chunk/chunk3.go

@ -31,14 +31,9 @@ func (c *Chunk3) Read(r io.Reader, chunkBodyLen int) error {
} }
// Write writes the chunk. // Write writes the chunk.
func (c Chunk3) Write(w io.Writer) error { func (c Chunk3) Write() ([]byte, error) {
header := make([]byte, 1) buf := make([]byte, 1+len(c.Body))
header[0] = 3<<6 | c.ChunkStreamID buf[0] = 3<<6 | c.ChunkStreamID
_, err := w.Write(header) copy(buf[1:], c.Body)
if err != nil { return buf, nil
return err
}
_, err = w.Write(c.Body)
return err
} }

5
internal/rtmp/chunk/chunk3_test.go

@ -24,8 +24,7 @@ func TestChunk3Read(t *testing.T) {
} }
func TestChunk3Write(t *testing.T) { func TestChunk3Write(t *testing.T) {
var buf bytes.Buffer buf, err := chunk3dec.Write()
err := chunk3dec.Write(&buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, chunk3enc, buf.Bytes()) require.Equal(t, chunk3enc, buf)
} }

16
internal/rtmp/rawmessage/reader_test.go

@ -3,7 +3,6 @@ package rawmessage
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"io"
"testing" "testing"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -11,7 +10,7 @@ import (
) )
type writableChunk interface { type writableChunk interface {
Write(w io.Writer) error Write() ([]byte, error)
} }
type sequenceEntry struct { type sequenceEntry struct {
@ -25,8 +24,9 @@ func TestReader(t *testing.T) {
r := NewReader(bufio.NewReader(&buf)) r := NewReader(bufio.NewReader(&buf))
for _, entry := range seq { for _, entry := range seq {
err := entry.chunk.Write(&buf) buf2, err := entry.chunk.Write()
require.NoError(t, err) require.NoError(t, err)
buf.Write(buf2)
msg, err := r.Read() msg, err := r.Read()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, entry.msg, msg) require.Equal(t, entry.msg, msg)
@ -124,21 +124,23 @@ func TestReader(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
r := NewReader(bufio.NewReader(&buf)) r := NewReader(bufio.NewReader(&buf))
err := chunk.Chunk0{ buf2, err := chunk.Chunk0{
ChunkStreamID: 27, ChunkStreamID: 27,
Timestamp: 18576, Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth, Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123, MessageStreamID: 3123,
BodyLen: 192, BodyLen: 192,
Body: bytes.Repeat([]byte{0x03}, 128), Body: bytes.Repeat([]byte{0x03}, 128),
}.Write(&buf) }.Write()
require.NoError(t, err) require.NoError(t, err)
buf.Write(buf2)
err = chunk.Chunk3{ buf2, err = chunk.Chunk3{
ChunkStreamID: 27, ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x03}, 64), Body: bytes.Repeat([]byte{0x03}, 64),
}.Write(&buf) }.Write()
require.NoError(t, err) require.NoError(t, err)
buf.Write(buf2)
msg, err := r.Read() msg, err := r.Read()
require.NoError(t, err) require.NoError(t, err)

45
internal/rtmp/rawmessage/writer.go

@ -42,45 +42,65 @@ func (wc *writerChunkStream) write(msg *Message) error {
switch { switch {
case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID: case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID:
err := chunk.Chunk0{ buf, err := chunk.Chunk0{
ChunkStreamID: msg.ChunkStreamID, ChunkStreamID: msg.ChunkStreamID,
Timestamp: msg.Timestamp, Timestamp: msg.Timestamp,
Type: msg.Type, Type: msg.Type,
MessageStreamID: msg.MessageStreamID, MessageStreamID: msg.MessageStreamID,
BodyLen: uint32(bodyLen), BodyLen: uint32(bodyLen),
Body: msg.Body[pos : pos+chunkBodyLen], Body: msg.Body[pos : pos+chunkBodyLen],
}.Write(wc.mw.w) }.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
if err != nil { if err != nil {
return err return err
} }
case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen: case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen:
err := chunk.Chunk1{ buf, err := chunk.Chunk1{
ChunkStreamID: msg.ChunkStreamID, ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta, TimestampDelta: *timestampDelta,
Type: msg.Type, Type: msg.Type,
BodyLen: uint32(bodyLen), BodyLen: uint32(bodyLen),
Body: msg.Body[pos : pos+chunkBodyLen], Body: msg.Body[pos : pos+chunkBodyLen],
}.Write(wc.mw.w) }.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
if err != nil { if err != nil {
return err return err
} }
case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta: case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta:
err := chunk.Chunk2{ buf, err := chunk.Chunk2{
ChunkStreamID: msg.ChunkStreamID, ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta, TimestampDelta: *timestampDelta,
Body: msg.Body[pos : pos+chunkBodyLen], Body: msg.Body[pos : pos+chunkBodyLen],
}.Write(wc.mw.w) }.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
if err != nil { if err != nil {
return err return err
} }
default: default:
err := chunk.Chunk3{ buf, err := chunk.Chunk3{
ChunkStreamID: msg.ChunkStreamID, ChunkStreamID: msg.ChunkStreamID,
Body: msg.Body[pos : pos+chunkBodyLen], Body: msg.Body[pos : pos+chunkBodyLen],
}.Write(wc.mw.w) }.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
if err != nil { if err != nil {
return err return err
} }
@ -100,10 +120,15 @@ func (wc *writerChunkStream) write(msg *Message) error {
wc.lastTimestampDelta = &v5 wc.lastTimestampDelta = &v5
} }
} else { } else {
err := chunk.Chunk3{ buf, err := chunk.Chunk3{
ChunkStreamID: msg.ChunkStreamID, ChunkStreamID: msg.ChunkStreamID,
Body: msg.Body[pos : pos+chunkBodyLen], Body: msg.Body[pos : pos+chunkBodyLen],
}.Write(wc.mw.w) }.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
if err != nil { if err != nil {
return err return err
} }

Loading…
Cancel
Save