Browse Source

rtmp: implement acknowledge mechanism

pull/1003/head
aler9 3 years ago
parent
commit
2601ca5661
  1. 37
      internal/rtmp/bytecounter/reader.go
  2. 19
      internal/rtmp/bytecounter/readwriter.go
  3. 30
      internal/rtmp/bytecounter/writer.go
  4. 11
      internal/rtmp/chunk/chunk.go
  5. 4
      internal/rtmp/chunk/chunk0.go
  6. 4
      internal/rtmp/chunk/chunk1.go
  7. 2
      internal/rtmp/chunk/chunk2.go
  8. 2
      internal/rtmp/chunk/chunk3.go
  9. 108
      internal/rtmp/conn_test.go
  10. 3
      internal/rtmp/handshake/c0s0.go
  11. 3
      internal/rtmp/handshake/c0s0_test.go
  12. 3
      internal/rtmp/handshake/c1s1.go
  13. 3
      internal/rtmp/handshake/c1s1_test.go
  14. 3
      internal/rtmp/handshake/c2s2.go
  15. 3
      internal/rtmp/handshake/c2s2_test.go
  16. 40
      internal/rtmp/message/msg_acknowledge.go
  17. 24
      internal/rtmp/message/reader.go
  18. 46
      internal/rtmp/message/readwriter.go
  19. 28
      internal/rtmp/message/writer.go
  20. 69
      internal/rtmp/rawmessage/reader.go
  21. 48
      internal/rtmp/rawmessage/reader_test.go
  22. 108
      internal/rtmp/rawmessage/writer.go
  23. 52
      internal/rtmp/rawmessage/writer_test.go

37
internal/rtmp/bytecounter/reader.go

@ -0,0 +1,37 @@ @@ -0,0 +1,37 @@
package bytecounter
import (
"bufio"
"io"
)
type readerInner struct {
r io.Reader
count uint32
}
func (r *readerInner) Read(p []byte) (int, error) {
n, err := r.r.Read(p)
r.count += uint32(n)
return n, err
}
// Reader allows to count read bytes.
type Reader struct {
ri *readerInner
*bufio.Reader
}
// NewReader allocates a Reader.
func NewReader(r io.Reader) *Reader {
ri := &readerInner{r: r}
return &Reader{
ri: ri,
Reader: bufio.NewReader(ri),
}
}
// Count returns read bytes.
func (r Reader) Count() uint32 {
return r.ri.count
}

19
internal/rtmp/bytecounter/readwriter.go

@ -0,0 +1,19 @@ @@ -0,0 +1,19 @@
package bytecounter
import (
"io"
)
// ReadWriter allows to count read and written bytes.
type ReadWriter struct {
*Reader
*Writer
}
// NewReadWriter allocates a ReadWriter.
func NewReadWriter(rw io.ReadWriter) *ReadWriter {
return &ReadWriter{
Reader: NewReader(rw),
Writer: NewWriter(rw),
}
}

30
internal/rtmp/bytecounter/writer.go

@ -0,0 +1,30 @@ @@ -0,0 +1,30 @@
package bytecounter
import (
"io"
)
// Writer allows to count written bytes.
type Writer struct {
w io.Writer
count uint32
}
// NewWriter allocates a Writer.
func NewWriter(w io.Writer) *Writer {
return &Writer{
w: w,
}
}
// Write implements io.Writer.
func (w *Writer) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.count += uint32(n)
return n, err
}
// Count returns written bytes.
func (w Writer) Count() uint32 {
return w.count
}

11
internal/rtmp/chunk/chunk.go

@ -0,0 +1,11 @@ @@ -0,0 +1,11 @@
package chunk
import (
"io"
)
// Chunk is a chunk.
type Chunk interface {
Read(io.Reader, uint32) error
Write() ([]byte, error)
}

4
internal/rtmp/chunk/chunk0.go

@ -18,7 +18,7 @@ type Chunk0 struct { @@ -18,7 +18,7 @@ type Chunk0 struct {
}
// Read reads the chunk.
func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen int) error {
func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
header := make([]byte, 12)
_, err := r.Read(header)
if err != nil {
@ -31,7 +31,7 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen int) error { @@ -31,7 +31,7 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen int) error {
c.Type = MessageType(header[7])
c.MessageStreamID = uint32(header[8])<<24 | uint32(header[9])<<16 | uint32(header[10])<<8 | uint32(header[11])
chunkBodyLen := int(c.BodyLen)
chunkBodyLen := c.BodyLen
if chunkBodyLen > chunkMaxBodyLen {
chunkBodyLen = chunkMaxBodyLen
}

4
internal/rtmp/chunk/chunk1.go

@ -19,7 +19,7 @@ type Chunk1 struct { @@ -19,7 +19,7 @@ type Chunk1 struct {
}
// Read reads the chunk.
func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen int) error {
func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
header := make([]byte, 8)
_, err := r.Read(header)
if err != nil {
@ -31,7 +31,7 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen int) error { @@ -31,7 +31,7 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen int) error {
c.BodyLen = uint32(header[4])<<16 | uint32(header[5])<<8 | uint32(header[6])
c.Type = MessageType(header[7])
chunkBodyLen := int(c.BodyLen)
chunkBodyLen := (c.BodyLen)
if chunkBodyLen > chunkMaxBodyLen {
chunkBodyLen = chunkMaxBodyLen
}

2
internal/rtmp/chunk/chunk2.go

@ -15,7 +15,7 @@ type Chunk2 struct { @@ -15,7 +15,7 @@ type Chunk2 struct {
}
// Read reads the chunk.
func (c *Chunk2) Read(r io.Reader, chunkBodyLen int) error {
func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error {
header := make([]byte, 4)
_, err := r.Read(header)
if err != nil {

2
internal/rtmp/chunk/chunk3.go

@ -16,7 +16,7 @@ type Chunk3 struct { @@ -16,7 +16,7 @@ type Chunk3 struct {
}
// Read reads the chunk.
func (c *Chunk3) Read(r io.Reader, chunkBodyLen int) error {
func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error {
header := make([]byte, 1)
_, err := r.Read(header)
if err != nil {

108
internal/rtmp/conn_test.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package rtmp
import (
"bufio"
"net"
"net/url"
"strings"
@ -13,6 +12,7 @@ import ( @@ -13,6 +12,7 @@ import (
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/handshake"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
)
@ -114,7 +114,7 @@ func TestReadTracks(t *testing.T) { @@ -114,7 +114,7 @@ func TestReadTracks(t *testing.T) {
conn, err := net.Dial("tcp", "127.0.0.1:9121")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
bc := bytecounter.NewReadWriter(conn)
// C->S handshake C0
err = handshake.C0S0{}.Write(conn)
@ -126,27 +126,26 @@ func TestReadTracks(t *testing.T) { @@ -126,27 +126,26 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C handshake S0
err = handshake.C0S0{}.Read(br)
err = handshake.C0S0{}.Read(bc)
require.NoError(t, err)
// S->C handshake S1
s1 := handshake.C1S1{}
err = s1.Read(br, false)
err = s1.Read(bc, false)
require.NoError(t, err)
// S->C handshake S2
err = (&handshake.C2S2{Digest: c1.Digest}).Read(br)
err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc)
require.NoError(t, err)
// C->S handshake C2
err = handshake.C2S2{Digest: s1.Digest}.Write(conn)
require.NoError(t, err)
mw := message.NewWriter(conn)
mr := message.NewReader(br)
mrw := message.NewReadWriter(bc)
// C->S connect
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
@ -166,14 +165,14 @@ func TestReadTracks(t *testing.T) { @@ -166,14 +165,14 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C window acknowledgement size
msg, err := mr.Read()
msg, err := mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetWindowAckSize{
Value: 2500000,
}, msg)
// S->C set peer bandwidth
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetPeerBandwidth{
Value: 2500000,
@ -181,16 +180,14 @@ func TestReadTracks(t *testing.T) { @@ -181,16 +180,14 @@ func TestReadTracks(t *testing.T) {
}, msg)
// S->C set chunk size
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetChunkSize{
Value: 65536,
}, msg)
mr.SetChunkSize(65536)
// S->C result
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
@ -211,15 +208,13 @@ func TestReadTracks(t *testing.T) { @@ -211,15 +208,13 @@ func TestReadTracks(t *testing.T) {
}, msg)
// C->S set chunk size
err = mw.Write(&message.MsgSetChunkSize{
err = mrw.Write(&message.MsgSetChunkSize{
Value: 65536,
})
require.NoError(t, err)
mw.SetChunkSize(65536)
// C->S releaseStream
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"releaseStream",
@ -231,7 +226,7 @@ func TestReadTracks(t *testing.T) { @@ -231,7 +226,7 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// C->S FCPublish
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"FCPublish",
@ -243,7 +238,7 @@ func TestReadTracks(t *testing.T) { @@ -243,7 +238,7 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// C->S createStream
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
@ -254,7 +249,7 @@ func TestReadTracks(t *testing.T) { @@ -254,7 +249,7 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C result
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
@ -267,7 +262,7 @@ func TestReadTracks(t *testing.T) { @@ -267,7 +262,7 @@ func TestReadTracks(t *testing.T) {
}, msg)
// C->S publish
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
MessageStreamID: 1,
Payload: []interface{}{
@ -281,7 +276,7 @@ func TestReadTracks(t *testing.T) { @@ -281,7 +276,7 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C onStatus
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
@ -301,7 +296,7 @@ func TestReadTracks(t *testing.T) { @@ -301,7 +296,7 @@ func TestReadTracks(t *testing.T) {
switch ca {
case "standard":
// C->S metadata
err = mw.Write(&message.MsgDataAMF0{
err = mrw.Write(&message.MsgDataAMF0{
ChunkStreamID: 4,
MessageStreamID: 1,
Payload: []interface{}{
@ -341,7 +336,7 @@ func TestReadTracks(t *testing.T) { @@ -341,7 +336,7 @@ func TestReadTracks(t *testing.T) {
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
err = mw.Write(&message.MsgVideo{
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
@ -357,7 +352,7 @@ func TestReadTracks(t *testing.T) { @@ -357,7 +352,7 @@ func TestReadTracks(t *testing.T) {
ChannelCount: 2,
}.Encode()
require.NoError(t, err)
err = mw.Write(&message.MsgAudio{
err = mrw.Write(&message.MsgAudio{
ChunkStreamID: 4,
MessageStreamID: 1,
Rate: flvio.SOUND_44Khz,
@ -370,7 +365,7 @@ func TestReadTracks(t *testing.T) { @@ -370,7 +365,7 @@ func TestReadTracks(t *testing.T) {
case "metadata without codec id":
// C->S metadata
err = mw.Write(&message.MsgDataAMF0{
err = mrw.Write(&message.MsgDataAMF0{
ChunkStreamID: 4,
MessageStreamID: 1,
Payload: []interface{}{
@ -406,7 +401,7 @@ func TestReadTracks(t *testing.T) { @@ -406,7 +401,7 @@ func TestReadTracks(t *testing.T) {
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
err = mw.Write(&message.MsgVideo{
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
@ -428,7 +423,7 @@ func TestReadTracks(t *testing.T) { @@ -428,7 +423,7 @@ func TestReadTracks(t *testing.T) {
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
err = mw.Write(&message.MsgVideo{
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
@ -479,7 +474,7 @@ func TestWriteTracks(t *testing.T) { @@ -479,7 +474,7 @@ func TestWriteTracks(t *testing.T) {
conn, err := net.Dial("tcp", "127.0.0.1:9121")
require.NoError(t, err)
defer conn.Close()
br := bufio.NewReader(conn)
bc := bytecounter.NewReadWriter(conn)
// C->S handshake C0
err = handshake.C0S0{}.Write(conn)
@ -491,27 +486,26 @@ func TestWriteTracks(t *testing.T) { @@ -491,27 +486,26 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C handshake S0
err = handshake.C0S0{}.Read(br)
err = handshake.C0S0{}.Read(bc)
require.NoError(t, err)
// S->C handshake S1
s1 := handshake.C1S1{}
err = s1.Read(br, false)
err = s1.Read(bc, false)
require.NoError(t, err)
// S->C handshake S2
err = (&handshake.C2S2{Digest: c1.Digest}).Read(br)
err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc)
require.NoError(t, err)
// C->S handshake C2
err = handshake.C2S2{Digest: s1.Digest}.Write(conn)
require.NoError(t, err)
mw := message.NewWriter(conn)
mr := message.NewReader(br)
mrw := message.NewReadWriter(bc)
// C->S connect
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
@ -531,14 +525,14 @@ func TestWriteTracks(t *testing.T) { @@ -531,14 +525,14 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C window acknowledgement size
msg, err := mr.Read()
msg, err := mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetWindowAckSize{
Value: 2500000,
}, msg)
// S->C set peer bandwidth
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetPeerBandwidth{
Value: 2500000,
@ -546,16 +540,14 @@ func TestWriteTracks(t *testing.T) { @@ -546,16 +540,14 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C set chunk size
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetChunkSize{
Value: 65536,
}, msg)
mr.SetChunkSize(65536)
// S->C result
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
@ -576,21 +568,19 @@ func TestWriteTracks(t *testing.T) { @@ -576,21 +568,19 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// C->S window acknowledgement size
err = mw.Write(&message.MsgSetWindowAckSize{
err = mrw.Write(&message.MsgSetWindowAckSize{
Value: 2500000,
})
require.NoError(t, err)
// C->S set chunk size
err = mw.Write(&message.MsgSetChunkSize{
err = mrw.Write(&message.MsgSetChunkSize{
Value: 65536,
})
require.NoError(t, err)
mw.SetChunkSize(65536)
// C->S createStream
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
@ -601,7 +591,7 @@ func TestWriteTracks(t *testing.T) { @@ -601,7 +591,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C result
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
@ -614,7 +604,7 @@ func TestWriteTracks(t *testing.T) { @@ -614,7 +604,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// C->S getStreamLength
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
Payload: []interface{}{
"getStreamLength",
@ -626,7 +616,7 @@ func TestWriteTracks(t *testing.T) { @@ -626,7 +616,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// C->S play
err = mw.Write(&message.MsgCommandAMF0{
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
Payload: []interface{}{
"play",
@ -639,21 +629,21 @@ func TestWriteTracks(t *testing.T) { @@ -639,21 +629,21 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C event "stream is recorded"
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgUserControlStreamIsRecorded{
StreamID: 1,
}, msg)
// S->C event "stream begin 1"
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgUserControlStreamBegin{
StreamID: 1,
}, msg)
// S->C onStatus
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
@ -671,7 +661,7 @@ func TestWriteTracks(t *testing.T) { @@ -671,7 +661,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C onStatus
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
@ -689,7 +679,7 @@ func TestWriteTracks(t *testing.T) { @@ -689,7 +679,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C onStatus
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
@ -707,7 +697,7 @@ func TestWriteTracks(t *testing.T) { @@ -707,7 +697,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C onStatus
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
@ -725,7 +715,7 @@ func TestWriteTracks(t *testing.T) { @@ -725,7 +715,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C onMetadata
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgDataAMF0{
ChunkStreamID: 4,
@ -742,7 +732,7 @@ func TestWriteTracks(t *testing.T) { @@ -742,7 +732,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C H264 decoder config
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgVideo{
ChunkStreamID: 6,
@ -760,7 +750,7 @@ func TestWriteTracks(t *testing.T) { @@ -760,7 +750,7 @@ func TestWriteTracks(t *testing.T) {
}, msg)
// S->C AAC decoder config
msg, err = mr.Read()
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgAudio{
ChunkStreamID: 4,

3
internal/rtmp/handshake/c0s0.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"fmt"
"io"
)
@ -14,7 +13,7 @@ const ( @@ -14,7 +13,7 @@ const (
type C0S0 struct{}
// Read reads a C0S0.
func (C0S0) Read(r *bufio.Reader) error {
func (C0S0) Read(r io.Reader) error {
buf := make([]byte, 1)
_, err := io.ReadFull(r, buf)
if err != nil {

3
internal/rtmp/handshake/c0s0_test.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"bytes"
"testing"
@ -14,7 +13,7 @@ var c0s0dec = C0S0{} @@ -14,7 +13,7 @@ var c0s0dec = C0S0{}
func TestC0S0Read(t *testing.T) {
var c0s0 C0S0
err := c0s0.Read(bufio.NewReader(bytes.NewReader(c0s0enc)))
err := c0s0.Read((bytes.NewReader(c0s0enc)))
require.NoError(t, err)
require.Equal(t, c0s0dec, c0s0)
}

3
internal/rtmp/handshake/c1s1.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"bytes"
"crypto/hmac"
"crypto/rand"
@ -79,7 +78,7 @@ type C1S1 struct { @@ -79,7 +78,7 @@ type C1S1 struct {
}
// Read reads a C1S1.
func (c *C1S1) Read(r *bufio.Reader, isC1 bool) error {
func (c *C1S1) Read(r io.Reader, isC1 bool) error {
buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf)
if err != nil {

3
internal/rtmp/handshake/c1s1_test.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"bytes"
"testing"
@ -44,7 +43,7 @@ func TestC1S1Read(t *testing.T) { @@ -44,7 +43,7 @@ func TestC1S1Read(t *testing.T) {
)
var c1s1 C1S1
err := c1s1.Read(bufio.NewReader(bytes.NewReader(c1s1enc)), true)
err := c1s1.Read((bytes.NewReader(c1s1enc)), true)
require.NoError(t, err)
require.Equal(t, c1s1dec, c1s1)
}

3
internal/rtmp/handshake/c2s2.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"bytes"
"crypto/rand"
"encoding/binary"
@ -18,7 +17,7 @@ type C2S2 struct { @@ -18,7 +17,7 @@ type C2S2 struct {
}
// Read reads a C2S2.
func (c *C2S2) Read(r *bufio.Reader) error {
func (c *C2S2) Read(r io.Reader) error {
buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf)
if err != nil {

3
internal/rtmp/handshake/c2s2_test.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"bytes"
"testing"
@ -43,7 +42,7 @@ func TestC2S2Read(t *testing.T) { @@ -43,7 +42,7 @@ func TestC2S2Read(t *testing.T) {
var c2s2 C2S2
c2s2.Digest = c2s2dec.Digest
err := c2s2.Read(bufio.NewReader(bytes.NewReader(c2s2enc)))
err := c2s2.Read((bytes.NewReader(c2s2enc)))
require.NoError(t, err)
require.Equal(t, c2s2dec, c2s2)
}

40
internal/rtmp/message/msg_acknowledge.go

@ -0,0 +1,40 @@ @@ -0,0 +1,40 @@
package message
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
"github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage"
)
// MsgAcknowledge is an acknowledgement message.
type MsgAcknowledge struct {
Value uint32
}
// Unmarshal implements Message.
func (m *MsgAcknowledge) Unmarshal(raw *rawmessage.Message) error {
if raw.ChunkStreamID != ControlChunkStreamID {
return fmt.Errorf("unexpected chunk stream ID")
}
if len(raw.Body) != 4 {
return fmt.Errorf("unexpected body size")
}
m.Value = binary.BigEndian.Uint32(raw.Body)
return nil
}
// Marshal implements Message.
func (m *MsgAcknowledge) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 4)
binary.BigEndian.PutUint32(body, m.Value)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeAcknowledge,
Body: body,
}, nil
}

24
internal/rtmp/message/reader.go

@ -1,10 +1,10 @@ @@ -1,10 +1,10 @@
package message
import (
"bufio"
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
"github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage"
)
@ -14,6 +14,9 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) { @@ -14,6 +14,9 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) {
case chunk.MessageTypeSetChunkSize:
return &MsgSetChunkSize{}, nil
case chunk.MessageTypeAcknowledge:
return &MsgAcknowledge{}, nil
case chunk.MessageTypeSetWindowAckSize:
return &MsgSetWindowAckSize{}, nil
@ -75,18 +78,13 @@ type Reader struct { @@ -75,18 +78,13 @@ type Reader struct {
}
// NewReader allocates a Reader.
func NewReader(r *bufio.Reader) *Reader {
func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader {
return &Reader{
r: rawmessage.NewReader(r),
r: rawmessage.NewReader(r, onAckNeeded),
}
}
// SetChunkSize sets the maximum chunk size.
func (r *Reader) SetChunkSize(v int) {
r.r.SetChunkSize(v)
}
// Read reads a essage.
// Read reads a Message.
func (r *Reader) Read() (Message, error) {
raw, err := r.r.Read()
if err != nil {
@ -103,5 +101,13 @@ func (r *Reader) Read() (Message, error) { @@ -103,5 +101,13 @@ func (r *Reader) Read() (Message, error) {
return nil, err
}
switch tmsg := msg.(type) {
case *MsgSetChunkSize:
r.r.SetChunkSize(tmsg.Value)
case *MsgSetWindowAckSize:
r.r.SetWindowAckSize(tmsg.Value)
}
return msg, nil
}

46
internal/rtmp/message/readwriter.go

@ -0,0 +1,46 @@ @@ -0,0 +1,46 @@
package message
import (
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
)
// ReadWriter is a message reader/writer.
type ReadWriter struct {
r *Reader
w *Writer
}
// NewReadWriter allocates a ReadWriter.
func NewReadWriter(bc *bytecounter.ReadWriter) *ReadWriter {
w := NewWriter(bc.Writer)
r := NewReader(bc.Reader, func(count uint32) error {
return w.Write(&MsgAcknowledge{
Value: (count),
})
})
return &ReadWriter{
r: r,
w: w,
}
}
// Read reads a message.
func (rw *ReadWriter) Read() (Message, error) {
msg, err := rw.r.Read()
if err != nil {
return nil, err
}
if tmsg, ok := msg.(*MsgAcknowledge); ok {
rw.w.SetAcknowledgeValue(tmsg.Value)
}
return msg, nil
}
// Write writes a message.
func (rw *ReadWriter) Write(msg Message) error {
return rw.w.Write(msg)
}

28
internal/rtmp/message/writer.go

@ -1,8 +1,7 @@ @@ -1,8 +1,7 @@
package message
import (
"io"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage"
)
@ -12,23 +11,36 @@ type Writer struct { @@ -12,23 +11,36 @@ type Writer struct {
}
// NewWriter allocates a Writer.
func NewWriter(w io.Writer) *Writer {
func NewWriter(w *bytecounter.Writer) *Writer {
return &Writer{
w: rawmessage.NewWriter(w),
}
}
// SetChunkSize sets the maximum chunk size.
func (mw *Writer) SetChunkSize(v int) {
mw.w.SetChunkSize(v)
// SetAcknowledgeValue sets the value of the last received acknowledge.
func (w *Writer) SetAcknowledgeValue(v uint32) {
w.w.SetAcknowledgeValue(v)
}
// Write writes a message.
func (mw *Writer) Write(msg Message) error {
func (w *Writer) Write(msg Message) error {
raw, err := msg.Marshal()
if err != nil {
return err
}
return mw.w.Write(raw)
err = w.w.Write(raw)
if err != nil {
return err
}
switch tmsg := msg.(type) {
case *MsgSetChunkSize:
w.w.SetChunkSize(tmsg.Value)
case *MsgSetWindowAckSize:
w.w.SetWindowAckSize(tmsg.Value)
}
return nil
}

69
internal/rtmp/rawmessage/reader.go

@ -1,10 +1,10 @@ @@ -1,10 +1,10 @@
package rawmessage
import (
"bufio"
"errors"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
)
@ -20,7 +20,32 @@ type readerChunkStream struct { @@ -20,7 +20,32 @@ type readerChunkStream struct {
curTimestampDelta *uint32
}
func (rc *readerChunkStream) read(typ byte) (*Message, error) {
func (rc *readerChunkStream) readChunk(c chunk.Chunk, chunkBodySize uint32) error {
err := c.Read(rc.mr.r, chunkBodySize)
if err != nil {
return err
}
// check if an ack is needed
if rc.mr.ackWindowSize != 0 {
count := rc.mr.r.Count()
diff := count - rc.mr.lastAckCount
// TODO: handle overflow
if diff > (rc.mr.ackWindowSize) {
err := rc.mr.onAckNeeded(count)
if err != nil {
return err
}
rc.mr.lastAckCount += (rc.mr.ackWindowSize)
}
}
return nil
}
func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
switch typ {
case 0:
if rc.curBody != nil {
@ -28,7 +53,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { @@ -28,7 +53,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
}
var c0 chunk.Chunk0
err := c0.Read(rc.mr.r, rc.mr.chunkSize)
err := rc.readChunk(&c0, rc.mr.chunkSize)
if err != nil {
return nil, err
}
@ -65,7 +90,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { @@ -65,7 +90,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
}
var c1 chunk.Chunk1
err := c1.Read(rc.mr.r, rc.mr.chunkSize)
err := rc.readChunk(&c1, rc.mr.chunkSize)
if err != nil {
return nil, err
}
@ -100,13 +125,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { @@ -100,13 +125,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
return nil, fmt.Errorf("received type 2 chunk but expected type 3 chunk")
}
chunkBodyLen := int(*rc.curBodyLen)
chunkBodyLen := (*rc.curBodyLen)
if chunkBodyLen > rc.mr.chunkSize {
chunkBodyLen = rc.mr.chunkSize
}
var c2 chunk.Chunk2
err := c2.Read(rc.mr.r, chunkBodyLen)
err := rc.readChunk(&c2, chunkBodyLen)
if err != nil {
return nil, err
}
@ -116,7 +141,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { @@ -116,7 +141,7 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
v2 := c2.TimestampDelta
rc.curTimestampDelta = &v2
if chunkBodyLen != len(c2.Body) {
if chunkBodyLen != uint32(len(c2.Body)) {
rc.curBody = &c2.Body
return nil, errMoreChunksNeeded
}
@ -134,13 +159,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { @@ -134,13 +159,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
}
if rc.curBody != nil {
chunkBodyLen := int(*rc.curBodyLen) - len(*rc.curBody)
chunkBodyLen := (*rc.curBodyLen) - uint32(len(*rc.curBody))
if chunkBodyLen > rc.mr.chunkSize {
chunkBodyLen = rc.mr.chunkSize
}
var c3 chunk.Chunk3
err := c3.Read(rc.mr.r, chunkBodyLen)
err := rc.readChunk(&c3, chunkBodyLen)
if err != nil {
return nil, err
}
@ -162,13 +187,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { @@ -162,13 +187,13 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
}, nil
}
chunkBodyLen := int(*rc.curBodyLen)
chunkBodyLen := (*rc.curBodyLen)
if chunkBodyLen > rc.mr.chunkSize {
chunkBodyLen = rc.mr.chunkSize
}
var c3 chunk.Chunk3
err := c3.Read(rc.mr.r, chunkBodyLen)
err := rc.readChunk(&c3, chunkBodyLen)
if err != nil {
return nil, err
}
@ -187,25 +212,35 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) { @@ -187,25 +212,35 @@ func (rc *readerChunkStream) read(typ byte) (*Message, error) {
// Reader is a raw message reader.
type Reader struct {
r *bufio.Reader
chunkSize int
chunkStreams map[byte]*readerChunkStream
r *bytecounter.Reader
onAckNeeded func(uint32) error
chunkSize uint32
ackWindowSize uint32
lastAckCount uint32
chunkStreams map[byte]*readerChunkStream
}
// NewReader allocates a Reader.
func NewReader(r *bufio.Reader) *Reader {
func NewReader(r *bytecounter.Reader, onAckNeeded func(uint32) error) *Reader {
return &Reader{
r: r,
onAckNeeded: onAckNeeded,
chunkSize: 128,
chunkStreams: make(map[byte]*readerChunkStream),
}
}
// SetChunkSize sets the maximum chunk size.
func (r *Reader) SetChunkSize(v int) {
func (r *Reader) SetChunkSize(v uint32) {
r.chunkSize = v
}
// SetWindowAckSize sets the window acknowledgement size.
func (r *Reader) SetWindowAckSize(v uint32) {
r.ackWindowSize = v
}
// Read reads a Message.
func (r *Reader) Read() (*Message, error) {
for {
@ -225,7 +260,7 @@ func (r *Reader) Read() (*Message, error) { @@ -225,7 +260,7 @@ func (r *Reader) Read() (*Message, error) {
r.r.UnreadByte()
msg, err := rc.read(typ)
msg, err := rc.readMessage(typ)
if err != nil {
if err == errMoreChunksNeeded {
continue

48
internal/rtmp/rawmessage/reader_test.go

@ -1,12 +1,13 @@ @@ -1,12 +1,13 @@
package rawmessage
import (
"bufio"
"bytes"
"testing"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
)
type writableChunk interface {
@ -21,7 +22,10 @@ type sequenceEntry struct { @@ -21,7 +22,10 @@ type sequenceEntry struct {
func TestReader(t *testing.T) {
testSequence := func(t *testing.T, seq []sequenceEntry) {
var buf bytes.Buffer
r := NewReader(bufio.NewReader(&buf))
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
for _, entry := range seq {
buf2, err := entry.chunk.Write()
@ -122,7 +126,10 @@ func TestReader(t *testing.T) { @@ -122,7 +126,10 @@ func TestReader(t *testing.T) {
t.Run("chunk0 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
r := NewReader(bufio.NewReader(&buf))
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
buf2, err := chunk.Chunk0{
ChunkStreamID: 27,
@ -153,3 +160,36 @@ func TestReader(t *testing.T) { @@ -153,3 +160,36 @@ func TestReader(t *testing.T) {
}, msg)
})
}
func TestReaderAcknowledge(t *testing.T) {
onAckCalled := make(chan struct{})
var buf bytes.Buffer
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
close(onAckCalled)
return nil
})
r.SetWindowAckSize(100)
for i := 0; i < 2; i++ {
buf2, err := chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
}.Write()
require.NoError(t, err)
buf.Write(buf2)
}
for i := 0; i < 2; i++ {
_, err := r.Read()
require.NoError(t, err)
}
<-onAckCalled
}

108
internal/rtmp/rawmessage/writer.go

@ -1,8 +1,9 @@ @@ -1,8 +1,9 @@
package rawmessage
import (
"io"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
)
@ -10,14 +11,38 @@ type writerChunkStream struct { @@ -10,14 +11,38 @@ type writerChunkStream struct {
mw *Writer
lastMessageStreamID *uint32
lastType *chunk.MessageType
lastBodyLen *int
lastBodyLen *uint32
lastTimestamp *uint32
lastTimestampDelta *uint32
}
func (wc *writerChunkStream) write(msg *Message) error {
bodyLen := len(msg.Body)
pos := 0
func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error {
buf, err := c.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
if err != nil {
return err
}
// check if we received an acknowledge
if wc.mw.ackWindowSize != 0 {
diff := wc.mw.w.Count() - (wc.mw.ackValue)
// TODO: handle overflow
if diff > (wc.mw.ackWindowSize * 3 / 2) {
return fmt.Errorf("no acknowledge received within window")
}
}
return nil
}
func (wc *writerChunkStream) writeMessage(msg *Message) error {
bodyLen := uint32(len(msg.Body))
pos := uint32(0)
firstChunk := true
var timestampDelta *uint32
@ -42,65 +67,45 @@ func (wc *writerChunkStream) write(msg *Message) error { @@ -42,65 +67,45 @@ func (wc *writerChunkStream) write(msg *Message) error {
switch {
case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID:
buf, err := chunk.Chunk0{
err := wc.writeChunk(&chunk.Chunk0{
ChunkStreamID: msg.ChunkStreamID,
Timestamp: msg.Timestamp,
Type: msg.Type,
MessageStreamID: msg.MessageStreamID,
BodyLen: uint32(bodyLen),
BodyLen: (bodyLen),
Body: msg.Body[pos : pos+chunkBodyLen],
}.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
})
if err != nil {
return err
}
case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen:
buf, err := chunk.Chunk1{
err := wc.writeChunk(&chunk.Chunk1{
ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta,
Type: msg.Type,
BodyLen: uint32(bodyLen),
BodyLen: (bodyLen),
Body: msg.Body[pos : pos+chunkBodyLen],
}.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
})
if err != nil {
return err
}
case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta:
buf, err := chunk.Chunk2{
err := wc.writeChunk(&chunk.Chunk2{
ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta,
Body: msg.Body[pos : pos+chunkBodyLen],
}.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
})
if err != nil {
return err
}
default:
buf, err := chunk.Chunk3{
err := wc.writeChunk(&chunk.Chunk3{
ChunkStreamID: msg.ChunkStreamID,
Body: msg.Body[pos : pos+chunkBodyLen],
}.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
})
if err != nil {
return err
}
@ -120,15 +125,10 @@ func (wc *writerChunkStream) write(msg *Message) error { @@ -120,15 +125,10 @@ func (wc *writerChunkStream) write(msg *Message) error {
wc.lastTimestampDelta = &v5
}
} else {
buf, err := chunk.Chunk3{
err := wc.writeChunk(&chunk.Chunk3{
ChunkStreamID: msg.ChunkStreamID,
Body: msg.Body[pos : pos+chunkBodyLen],
}.Write()
if err != nil {
return err
}
_, err = wc.mw.w.Write(buf)
})
if err != nil {
return err
}
@ -144,13 +144,15 @@ func (wc *writerChunkStream) write(msg *Message) error { @@ -144,13 +144,15 @@ func (wc *writerChunkStream) write(msg *Message) error {
// Writer is a raw message writer.
type Writer struct {
w io.Writer
chunkSize int
chunkStreams map[byte]*writerChunkStream
w *bytecounter.Writer
chunkSize uint32
ackWindowSize uint32
ackValue uint32
chunkStreams map[byte]*writerChunkStream
}
// NewWriter allocates a Writer.
func NewWriter(w io.Writer) *Writer {
func NewWriter(w *bytecounter.Writer) *Writer {
return &Writer{
w: w,
chunkSize: 128,
@ -159,10 +161,20 @@ func NewWriter(w io.Writer) *Writer { @@ -159,10 +161,20 @@ func NewWriter(w io.Writer) *Writer {
}
// SetChunkSize sets the maximum chunk size.
func (w *Writer) SetChunkSize(v int) {
func (w *Writer) SetChunkSize(v uint32) {
w.chunkSize = v
}
// SetWindowAckSize sets the window acknowledgement size.
func (w *Writer) SetWindowAckSize(v uint32) {
w.ackWindowSize = v
}
// SetAcknowledgeValue sets the acknowledge sequence number.
func (w *Writer) SetAcknowledgeValue(v uint32) {
w.ackValue = v
}
// Write writes a Message.
func (w *Writer) Write(msg *Message) error {
wc, ok := w.chunkStreams[msg.ChunkStreamID]
@ -171,5 +183,5 @@ func (w *Writer) Write(msg *Message) error { @@ -171,5 +183,5 @@ func (w *Writer) Write(msg *Message) error {
w.chunkStreams[msg.ChunkStreamID] = wc
}
return wc.write(msg)
return wc.writeMessage(msg)
}

52
internal/rtmp/rawmessage/writer_test.go

@ -1,10 +1,10 @@ @@ -1,10 +1,10 @@
package rawmessage
import (
"bufio"
"bytes"
"testing"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
"github.com/stretchr/testify/require"
)
@ -12,8 +12,7 @@ import ( @@ -12,8 +12,7 @@ import (
func TestWriter(t *testing.T) {
t.Run("chunk0 + chunk1", func(t *testing.T) {
var buf bytes.Buffer
br := bufio.NewReader(&buf)
w := NewWriter(&buf)
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
ChunkStreamID: 27,
@ -25,7 +24,7 @@ func TestWriter(t *testing.T) { @@ -25,7 +24,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(br, 128)
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
@ -46,7 +45,7 @@ func TestWriter(t *testing.T) { @@ -46,7 +45,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c1 chunk.Chunk1
err = c1.Read(br, 128)
err = c1.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk1{
ChunkStreamID: 27,
@ -59,8 +58,7 @@ func TestWriter(t *testing.T) { @@ -59,8 +58,7 @@ func TestWriter(t *testing.T) {
t.Run("chunk0 + chunk2 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
br := bufio.NewReader(&buf)
w := NewWriter(&buf)
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
ChunkStreamID: 27,
@ -72,7 +70,7 @@ func TestWriter(t *testing.T) { @@ -72,7 +70,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(br, 128)
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
@ -93,7 +91,7 @@ func TestWriter(t *testing.T) { @@ -93,7 +91,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c2 chunk.Chunk2
err = c2.Read(br, 64)
err = c2.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk2{
ChunkStreamID: 27,
@ -111,7 +109,7 @@ func TestWriter(t *testing.T) { @@ -111,7 +109,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c3 chunk.Chunk3
err = c3.Read(br, 64)
err = c3.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk3{
ChunkStreamID: 27,
@ -121,8 +119,7 @@ func TestWriter(t *testing.T) { @@ -121,8 +119,7 @@ func TestWriter(t *testing.T) {
t.Run("chunk0 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
br := bufio.NewReader(&buf)
w := NewWriter(&buf)
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
ChunkStreamID: 27,
@ -134,7 +131,7 @@ func TestWriter(t *testing.T) { @@ -134,7 +131,7 @@ func TestWriter(t *testing.T) {
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(br, 128)
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
@ -146,7 +143,7 @@ func TestWriter(t *testing.T) { @@ -146,7 +143,7 @@ func TestWriter(t *testing.T) {
}, c0)
var c3 chunk.Chunk3
err = c3.Read(br, 64)
err = c3.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk3{
ChunkStreamID: 27,
@ -154,3 +151,30 @@ func TestWriter(t *testing.T) { @@ -154,3 +151,30 @@ func TestWriter(t *testing.T) {
}, c3)
})
}
func TestWriterAcknowledge(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf))
w.SetWindowAckSize(100)
for i := 0; i < 2; i++ {
err := w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
})
require.NoError(t, err)
}
err := w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
})
require.EqualError(t, err, "no acknowledge received within window")
}

Loading…
Cancel
Save