Browse Source

rtmp: rewrite implementation of rtmp connection (#1047)

* rtmp: improve MsgCommandAMF0

* rtmp: fix MsgSetPeerBandwidth

* rtmp: add message tests

* rtmp: replace implementation with new one

* rtmp: rename handshake functions

* rtmp: avoid calling useless function

* rtmp: use time.Duration for PTSDelta

* rtmp: fix decoding chunks with relevant size

* rtmp: rewrite implementation of rtmp connection

* rtmp: fix tests

* rtmp: improve error message

* rtmp: replace h264 config implementation

* link against github.com/notedit/rtmp

* normalize MessageStreamID

* rtmp: make acknowledge optional

* rtmp: fix decoding of chunk2 + chunk3

* avoid using encoding/binary
pull/1060/head
Alessandro Ros 3 years ago committed by GitHub
parent
commit
9e6abc6e9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      go.mod
  2. 8
      go.sum
  3. 264
      internal/core/rtmp_conn.go
  4. 3
      internal/core/rtmp_server.go
  5. 16
      internal/core/rtmp_server_test.go
  6. 101
      internal/core/rtmp_source.go
  7. 3
      internal/core/rtsp_server.go
  8. 4
      internal/rtmp/chunk/chunk0.go
  9. 4
      internal/rtmp/chunk/chunk1.go
  10. 4
      internal/rtmp/chunk/chunk2.go
  11. 4
      internal/rtmp/chunk/chunk3.go
  12. 883
      internal/rtmp/conn.go
  13. 408
      internal/rtmp/conn_test.go
  14. 89
      internal/rtmp/h264conf/h264conf.go
  15. 29
      internal/rtmp/h264conf/h264conf_test.go
  16. 18
      internal/rtmp/handshake/c1s1.go
  17. 2
      internal/rtmp/handshake/c1s1_test.go
  18. 29
      internal/rtmp/handshake/c2s2.go
  19. 2
      internal/rtmp/handshake/c2s2_test.go
  20. 12
      internal/rtmp/handshake/handshake.go
  21. 4
      internal/rtmp/handshake/handshake_test.go
  22. 16
      internal/rtmp/message/msg_acknowledge.go
  23. 3
      internal/rtmp/message/msg_audio.go
  24. 30
      internal/rtmp/message/msg_command_amf0.go
  25. 16
      internal/rtmp/message/msg_setchunksize.go
  26. 20
      internal/rtmp/message/msg_setpeerbandwidth.go
  27. 16
      internal/rtmp/message/msg_setwindowacksize.go
  28. 18
      internal/rtmp/message/msg_usercontrol_pingrequest.go
  29. 18
      internal/rtmp/message/msg_usercontrol_pingresponse.go
  30. 25
      internal/rtmp/message/msg_usercontrol_setbufferlength.go
  31. 18
      internal/rtmp/message/msg_usercontrol_streambegin.go
  32. 18
      internal/rtmp/message/msg_usercontrol_streamdry.go
  33. 18
      internal/rtmp/message/msg_usercontrol_streameof.go
  34. 18
      internal/rtmp/message/msg_usercontrol_streamisrecorded.go
  35. 19
      internal/rtmp/message/msg_video.go
  36. 5
      internal/rtmp/message/reader.go
  37. 227
      internal/rtmp/message/reader_test.go
  38. 4
      internal/rtmp/message/readwriter.go
  39. 4
      internal/rtmp/message/writer.go
  40. 22
      internal/rtmp/message/writer_test.go
  41. 4
      internal/rtmp/rawmessage/message.go
  42. 15
      internal/rtmp/rawmessage/reader.go
  43. 280
      internal/rtmp/rawmessage/reader_test.go
  44. 42
      internal/rtmp/rawmessage/writer.go
  45. 304
      internal/rtmp/rawmessage/writer_test.go

6
go.mod

@ -5,14 +5,14 @@ go 1.17 @@ -5,14 +5,14 @@ go 1.17
require (
code.cloudfoundry.org/bytefmt v0.0.0-20211005130812-5bb3c17173e5
github.com/abema/go-mp4 v0.7.2
github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f
github.com/aler9/gortsplib v0.0.0-20220717125404-c6972424d6b8
github.com/asticode/go-astits v1.10.1-0.20220319093903-4abe66a9b757
github.com/fsnotify/fsnotify v1.4.9
github.com/gin-gonic/gin v1.8.1
github.com/gookit/color v1.4.2
github.com/grafov/m3u8 v0.11.1
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51
github.com/notedit/rtmp v0.0.0
github.com/notedit/rtmp v0.0.2
github.com/orcaman/writerseeker v0.0.0
github.com/pion/rtp v1.7.13
github.com/stretchr/testify v1.7.1
@ -51,6 +51,4 @@ require ( @@ -51,6 +51,4 @@ require (
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)
replace github.com/notedit/rtmp => github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927
replace github.com/orcaman/writerseeker => github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82

8
go.sum

@ -6,10 +6,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo @@ -6,10 +6,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafo
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f h1:EC+MOSv3e8ZEvtdHoL1++HahNoiVIkvu2Ygjrx6LyOg=
github.com/aler9/gortsplib v0.0.0-20220709151311-234e4f4f8d6f/go.mod h1:WI3nMhY2mM6nfoeW9uyk7TyG5Qr6YnYxmFoCply0sbo=
github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927 h1:95mXJ5fUCYpBRdSOnLAQAdJHHKxxxJrVCiaqDi965YQ=
github.com/aler9/rtmp v0.0.0-20210403095203-3be4a5535927/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc=
github.com/aler9/gortsplib v0.0.0-20220717125404-c6972424d6b8 h1:GdQOJFYbcrw8bXGClhroHTBIEJAb/jPCIV33Q966rms=
github.com/aler9/gortsplib v0.0.0-20220717125404-c6972424d6b8/go.mod h1:WI3nMhY2mM6nfoeW9uyk7TyG5Qr6YnYxmFoCply0sbo=
github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82 h1:9WgSzBLo3a9ToSVV7sRTBYZ1GGOZUpq4+5H3SN0UZq4=
github.com/aler9/writerseeker v0.0.0-20220601075008-6f0e685b9c82/go.mod h1:qsMrZCbeBf/mCLOeF16KDkPu4gktn/pOWyaq1aYQE7U=
github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8=
@ -83,6 +81,8 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH @@ -83,6 +81,8 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/notedit/rtmp v0.0.2 h1:5+to4yezKATiJgnrcETu9LbV5G/QsWkOV9Ts2M/p33w=
github.com/notedit/rtmp v0.0.2/go.mod h1:vzuE21rowz+lT1NGsWbreIvYulgBpCGnQyeTyFblUHc=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=

264
internal/core/rtmp_conn.go

@ -16,13 +16,14 @@ import ( @@ -16,13 +16,14 @@ import (
"github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtpaac"
"github.com/aler9/gortsplib/pkg/rtph264"
"github.com/notedit/rtmp/av"
nh264 "github.com/notedit/rtmp/codec/h264"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/externalcmd"
"github.com/aler9/rtsp-simple-server/internal/logger"
"github.com/aler9/rtsp-simple-server/internal/rtmp"
"github.com/aler9/rtsp-simple-server/internal/rtmp/h264conf"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
)
const (
@ -107,7 +108,7 @@ func newRTMPConn( @@ -107,7 +108,7 @@ func newRTMPConn(
runOnConnect: runOnConnect,
runOnConnectRestart: runOnConnectRestart,
wg: wg,
conn: rtmp.NewServerConn(nconn),
conn: rtmp.NewConn(nconn),
nconn: nconn,
externalCmdPool: externalCmdPool,
pathManager: pathManager,
@ -211,19 +212,19 @@ func (c *rtmpConn) runInner(ctx context.Context) error { @@ -211,19 +212,19 @@ func (c *rtmpConn) runInner(ctx context.Context) error {
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err := c.conn.ServerHandshake()
u, isReading, err := c.conn.InitializeServer()
if err != nil {
return err
}
if c.conn.IsPublishing() {
return c.runPublish(ctx)
if isReading {
return c.runRead(ctx, u)
}
return c.runRead(ctx)
return c.runPublish(ctx, u)
}
func (c *rtmpConn) runRead(ctx context.Context) error {
pathName, query, rawQuery := pathNameAndQuery(c.conn.URL())
func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error {
pathName, query, rawQuery := pathNameAndQuery(u)
res := c.pathManager.onReaderSetupPlay(pathReaderSetupPlayReq{
author: c,
@ -410,22 +411,17 @@ func (c *rtmpConn) runRead(ctx context.Context) error { @@ -410,22 +411,17 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
sps := videoTrack.SafeSPS()
pps := videoTrack.SafePPS()
codec := nh264.Codec{
SPS: map[int][]byte{
0: sps,
},
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
b = b[:n]
err = c.conn.WritePacket(av.Packet{
Type: av.H264DecoderConfig,
Data: b,
buf, _ := h264conf.Conf{
SPS: sps,
PPS: pps,
}.Marshal()
err = c.conn.WriteMessage(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: buf,
})
if err != nil {
return err
@ -438,11 +434,14 @@ func (c *rtmpConn) runRead(ctx context.Context) error { @@ -438,11 +434,14 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
}
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err = c.conn.WritePacket(av.Packet{
Type: av.H264,
Data: avcc,
Time: dts,
CTime: pts - dts,
err = c.conn.WriteMessage(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: idrPresent,
H264Type: flvio.AVC_NALU,
Payload: avcc,
DTS: dts,
PTSDelta: pts - dts,
})
if err != nil {
return err
@ -467,10 +466,15 @@ func (c *rtmpConn) runRead(ctx context.Context) error { @@ -467,10 +466,15 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
for i, au := range aus {
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err := c.conn.WritePacket(av.Packet{
Type: av.AAC,
Data: au,
Time: pts + time.Duration(i)*aac.SamplesPerAccessUnit*time.Second/time.Duration(audioTrack.ClockRate()),
err := c.conn.WriteMessage(&message.MsgAudio{
ChunkStreamID: 4,
MessageStreamID: 1,
Rate: flvio.SOUND_44Khz,
Depth: flvio.SOUND_16BIT,
Channels: flvio.SOUND_STEREO,
AACType: flvio.AAC_RAW,
Payload: au,
DTS: pts + time.Duration(i)*aac.SamplesPerAccessUnit*time.Second/time.Duration(audioTrack.ClockRate()),
})
if err != nil {
return err
@ -480,7 +484,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error { @@ -480,7 +484,7 @@ func (c *rtmpConn) runRead(ctx context.Context) error {
}
}
func (c *rtmpConn) runPublish(ctx context.Context) error {
func (c *rtmpConn) runPublish(ctx context.Context, u *url.URL) error {
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
videoTrack, audioTrack, err := c.conn.ReadTracks()
if err != nil {
@ -513,7 +517,7 @@ func (c *rtmpConn) runPublish(ctx context.Context) error { @@ -513,7 +517,7 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
tracks = append(tracks, audioTrack)
}
pathName, query, rawQuery := pathNameAndQuery(c.conn.URL())
pathName, query, rawQuery := pathNameAndQuery(u)
res := c.pathManager.onPublisherAnnounce(pathPublisherAnnounceReq{
author: c,
@ -559,121 +563,125 @@ func (c *rtmpConn) runPublish(ctx context.Context) error { @@ -559,121 +563,125 @@ func (c *rtmpConn) runPublish(ctx context.Context) error {
for {
c.nconn.SetReadDeadline(time.Now().Add(time.Duration(c.readTimeout)))
pkt, err := c.conn.ReadPacket()
msg, err := c.conn.ReadMessage()
if err != nil {
return err
}
switch pkt.Type {
case av.H264DecoderConfig:
codec, err := nh264.FromDecoderConfig(pkt.Data)
if err != nil {
return err
}
switch tmsg := msg.(type) {
case *message.MsgVideo:
if tmsg.H264Type == flvio.AVC_SEQHDR {
var conf h264conf.Conf
err = conf.Unmarshal(tmsg.Payload)
if err != nil {
return fmt.Errorf("unable to parse H264 config: %v", err)
}
pts := pkt.Time + pkt.CTime
nalus := [][]byte{
codec.SPS[0],
codec.PPS[0],
}
pts := tmsg.DTS + tmsg.PTSDelta
nalus := [][]byte{
conf.SPS,
conf.PPS,
}
pkts, err := h264Encoder.Encode(nalus, pts)
if err != nil {
return fmt.Errorf("error while encoding H264: %v", err)
}
pkts, err := h264Encoder.Encode(nalus, pts)
if err != nil {
return fmt.Errorf("error while encoding H264: %v", err)
}
lastPkt := len(pkts) - 1
for i, pkt := range pkts {
if i != lastPkt {
rres.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: false,
})
} else {
rres.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: false,
h264NALUs: nalus,
h264PTS: pts,
})
lastPkt := len(pkts) - 1
for i, pkt := range pkts {
if i != lastPkt {
rres.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: false,
})
} else {
rres.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: false,
h264NALUs: nalus,
h264PTS: pts,
})
}
}
} else if tmsg.H264Type == flvio.AVC_NALU {
if videoTrack == nil {
return fmt.Errorf("received an H264 packet, but track is not set up")
}
}
case av.H264:
if videoTrack == nil {
return fmt.Errorf("received an H264 packet, but track is not set up")
}
nalus, err := h264.AVCCUnmarshal(tmsg.Payload)
if err != nil {
return fmt.Errorf("unable to decode AVCC: %v", err)
}
nalus, err := h264.AVCCUnmarshal(pkt.Data)
if err != nil {
return err
}
// skip invalid NALUs sent by DJI
n := 0
for _, nalu := range nalus {
if len(nalu) != 0 {
n++
}
}
if n == 0 {
continue
}
// skip invalid NALUs sent by DJI
n := 0
for _, nalu := range nalus {
if len(nalu) != 0 {
n++
validNALUs := make([][]byte, n)
pos := 0
for _, nalu := range nalus {
if len(nalu) != 0 {
validNALUs[pos] = nalu
pos++
}
}
}
if n == 0 {
continue
}
validNALUs := make([][]byte, n)
pos := 0
for _, nalu := range nalus {
if len(nalu) != 0 {
validNALUs[pos] = nalu
pos++
pts := tmsg.DTS + tmsg.PTSDelta
pkts, err := h264Encoder.Encode(validNALUs, pts)
if err != nil {
return fmt.Errorf("error while encoding H264: %v", err)
}
lastPkt := len(pkts) - 1
for i, pkt := range pkts {
if i != lastPkt {
rres.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: false,
})
} else {
rres.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: h264.IDRPresent(validNALUs),
h264NALUs: validNALUs,
h264PTS: pts,
})
}
}
}
pts := pkt.Time + pkt.CTime
case *message.MsgAudio:
if tmsg.AACType == flvio.AAC_RAW {
if audioTrack == nil {
return fmt.Errorf("received an AAC packet, but track is not set up")
}
pkts, err := h264Encoder.Encode(validNALUs, pts)
if err != nil {
return fmt.Errorf("error while encoding H264: %v", err)
}
pkts, err := aacEncoder.Encode([][]byte{tmsg.Payload}, tmsg.DTS)
if err != nil {
return fmt.Errorf("error while encoding AAC: %v", err)
}
lastPkt := len(pkts) - 1
for i, pkt := range pkts {
if i != lastPkt {
for _, pkt := range pkts {
rres.stream.writeData(&data{
trackID: videoTrackID,
trackID: audioTrackID,
rtp: pkt,
ptsEqualsDTS: false,
})
} else {
rres.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: h264.IDRPresent(validNALUs),
h264NALUs: validNALUs,
h264PTS: pts,
ptsEqualsDTS: true,
})
}
}
case av.AAC:
if audioTrack == nil {
return fmt.Errorf("received an AAC packet, but track is not set up")
}
pkts, err := aacEncoder.Encode([][]byte{pkt.Data}, pkt.Time+pkt.CTime)
if err != nil {
return fmt.Errorf("error while encoding AAC: %v", err)
}
for _, pkt := range pkts {
rres.stream.writeData(&data{
trackID: audioTrackID,
rtp: pkt,
ptsEqualsDTS: true,
})
}
}
}
}

3
internal/core/rtmp_server.go

@ -3,7 +3,6 @@ package core @@ -3,7 +3,6 @@ package core
import (
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"strconv"
@ -259,7 +258,7 @@ func (s *rtmpServer) newConnID() (string, error) { @@ -259,7 +258,7 @@ func (s *rtmpServer) newConnID() (string, error) {
return "", err
}
u := binary.LittleEndian.Uint32(b)
u := uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
u %= 899999999
u += 100000000

16
internal/core/rtmp_server_test.go

@ -141,9 +141,9 @@ func TestRTMPServerAuth(t *testing.T) { @@ -141,9 +141,9 @@ func TestRTMPServerAuth(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewClientConn(nconn, u)
conn := rtmp.NewConn(nconn)
err = conn.ClientHandshake(true)
err = conn.InitializeClient(u, true)
require.NoError(t, err)
_, _, err = conn.ReadTracks()
@ -229,9 +229,17 @@ func TestRTMPServerAuthFail(t *testing.T) { @@ -229,9 +229,17 @@ func TestRTMPServerAuthFail(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := rtmp.NewClientConn(nconn, u)
conn := rtmp.NewConn(nconn)
err = conn.ClientHandshake(true)
err = conn.InitializeClient(u, true)
require.NoError(t, err)
for i := 0; i < 3; i++ {
_, err := conn.ReadMessage()
require.NoError(t, err)
}
_, err = conn.ReadMessage()
require.Equal(t, err, io.EOF)
})
}

101
internal/core/rtmp_source.go

@ -12,11 +12,12 @@ import ( @@ -12,11 +12,12 @@ import (
"github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/rtpaac"
"github.com/aler9/gortsplib/pkg/rtph264"
"github.com/notedit/rtmp/av"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/logger"
"github.com/aler9/rtsp-simple-server/internal/rtmp"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
)
const (
@ -126,14 +127,14 @@ func (s *rtmpSource) runInner() bool { @@ -126,14 +127,14 @@ func (s *rtmpSource) runInner() bool {
return err
}
conn := rtmp.NewClientConn(nconn, u)
conn := rtmp.NewConn(nconn)
readDone := make(chan error)
go func() {
readDone <- func() error {
nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
nconn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout)))
err = conn.ClientHandshake(true)
err = conn.InitializeClient(u, true)
if err != nil {
return err
}
@ -187,64 +188,68 @@ func (s *rtmpSource) runInner() bool { @@ -187,64 +188,68 @@ func (s *rtmpSource) runInner() bool {
for {
nconn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout)))
pkt, err := conn.ReadPacket()
msg, err := conn.ReadMessage()
if err != nil {
return err
}
switch pkt.Type {
case av.H264:
if videoTrack == nil {
return fmt.Errorf("received an H264 packet, but track is not set up")
}
switch tmsg := msg.(type) {
case *message.MsgVideo:
if tmsg.H264Type == flvio.AVC_NALU {
if videoTrack == nil {
return fmt.Errorf("received an H264 packet, but track is not set up")
}
nalus, err := h264.AVCCUnmarshal(pkt.Data)
if err != nil {
return err
}
nalus, err := h264.AVCCUnmarshal(tmsg.Payload)
if err != nil {
return fmt.Errorf("unable to decode AVCC: %v", err)
}
pts := pkt.Time + pkt.CTime
pts := tmsg.DTS + tmsg.PTSDelta
pkts, err := h264Encoder.Encode(nalus, pts)
if err != nil {
return fmt.Errorf("error while encoding H264: %v", err)
}
pkts, err := h264Encoder.Encode(nalus, pts)
if err != nil {
return fmt.Errorf("error while encoding H264: %v", err)
}
lastPkt := len(pkts) - 1
for i, pkt := range pkts {
if i != lastPkt {
res.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: false,
})
} else {
res.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: h264.IDRPresent(nalus),
h264NALUs: nalus,
h264PTS: pts,
})
lastPkt := len(pkts) - 1
for i, pkt := range pkts {
if i != lastPkt {
res.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: false,
})
} else {
res.stream.writeData(&data{
trackID: videoTrackID,
rtp: pkt,
ptsEqualsDTS: h264.IDRPresent(nalus),
h264NALUs: nalus,
h264PTS: pts,
})
}
}
}
case av.AAC:
if audioTrack == nil {
return fmt.Errorf("received an AAC packet, but track is not set up")
}
case *message.MsgAudio:
if tmsg.AACType == flvio.AAC_RAW {
if audioTrack == nil {
return fmt.Errorf("received an AAC packet, but track is not set up")
}
pkts, err := aacEncoder.Encode([][]byte{pkt.Data}, pkt.Time+pkt.CTime)
if err != nil {
return fmt.Errorf("error while encoding AAC: %v", err)
}
pkts, err := aacEncoder.Encode([][]byte{tmsg.Payload}, tmsg.DTS)
if err != nil {
return fmt.Errorf("error while encoding AAC: %v", err)
}
for _, pkt := range pkts {
res.stream.writeData(&data{
trackID: audioTrackID,
rtp: pkt,
ptsEqualsDTS: true,
})
for _, pkt := range pkts {
res.stream.writeData(&data{
trackID: audioTrackID,
rtp: pkt,
ptsEqualsDTS: true,
})
}
}
}
}

3
internal/core/rtsp_server.go

@ -4,7 +4,6 @@ import ( @@ -4,7 +4,6 @@ import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/binary"
"fmt"
"strconv"
"strings"
@ -235,7 +234,7 @@ func (s *rtspServer) newSessionID() (string, error) { @@ -235,7 +234,7 @@ func (s *rtspServer) newSessionID() (string, error) {
return "", err
}
u := binary.LittleEndian.Uint32(b)
u := uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
u %= 899999999
u += 100000000

4
internal/rtmp/chunk/chunk0.go

@ -20,7 +20,7 @@ type Chunk0 struct { @@ -20,7 +20,7 @@ type Chunk0 struct {
// Read reads the chunk.
func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
header := make([]byte, 12)
_, err := r.Read(header)
_, err := io.ReadFull(r, header)
if err != nil {
return err
}
@ -37,7 +37,7 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error { @@ -37,7 +37,7 @@ func (c *Chunk0) Read(r io.Reader, chunkMaxBodyLen uint32) error {
}
c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body)
_, err = io.ReadFull(r, c.Body)
return err
}

4
internal/rtmp/chunk/chunk1.go

@ -21,7 +21,7 @@ type Chunk1 struct { @@ -21,7 +21,7 @@ type Chunk1 struct {
// Read reads the chunk.
func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
header := make([]byte, 8)
_, err := r.Read(header)
_, err := io.ReadFull(r, header)
if err != nil {
return err
}
@ -37,7 +37,7 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error { @@ -37,7 +37,7 @@ func (c *Chunk1) Read(r io.Reader, chunkMaxBodyLen uint32) error {
}
c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body)
_, err = io.ReadFull(r, c.Body)
return err
}

4
internal/rtmp/chunk/chunk2.go

@ -17,7 +17,7 @@ type Chunk2 struct { @@ -17,7 +17,7 @@ type Chunk2 struct {
// Read reads the chunk.
func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error {
header := make([]byte, 4)
_, err := r.Read(header)
_, err := io.ReadFull(r, header)
if err != nil {
return err
}
@ -26,7 +26,7 @@ func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error { @@ -26,7 +26,7 @@ func (c *Chunk2) Read(r io.Reader, chunkBodyLen uint32) error {
c.TimestampDelta = uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body)
_, err = io.ReadFull(r, c.Body)
return err
}

4
internal/rtmp/chunk/chunk3.go

@ -18,7 +18,7 @@ type Chunk3 struct { @@ -18,7 +18,7 @@ type Chunk3 struct {
// Read reads the chunk.
func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error {
header := make([]byte, 1)
_, err := r.Read(header)
_, err := io.ReadFull(r, header)
if err != nil {
return err
}
@ -26,7 +26,7 @@ func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error { @@ -26,7 +26,7 @@ func (c *Chunk3) Read(r io.Reader, chunkBodyLen uint32) error {
c.ChunkStreamID = header[0] & 0x3F
c.Body = make([]byte, chunkBodyLen)
_, err = r.Read(c.Body)
_, err = io.ReadFull(r, c.Body)
return err
}

883
internal/rtmp/conn.go

File diff suppressed because it is too large Load Diff

408
internal/rtmp/conn_test.go

@ -3,52 +3,20 @@ package rtmp @@ -3,52 +3,20 @@ package rtmp
import (
"net"
"net/url"
"strings"
"testing"
"github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/aac"
nh264 "github.com/notedit/rtmp/codec/h264"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/h264conf"
"github.com/aler9/rtsp-simple-server/internal/rtmp/handshake"
"github.com/aler9/rtsp-simple-server/internal/rtmp/message"
)
func splitPath(u *url.URL) (app, stream string) {
nu := *u
nu.ForceQuery = false
pathsegs := strings.Split(nu.RequestURI(), "/")
if len(pathsegs) == 2 {
app = pathsegs[1]
}
if len(pathsegs) == 3 {
app = pathsegs[1]
stream = pathsegs[2]
}
if len(pathsegs) > 3 {
app = strings.Join(pathsegs[1:3], "/")
stream = strings.Join(pathsegs[3:], "/")
}
return
}
func getTcURL(u string) string {
ur, err := url.Parse(u)
if err != nil {
panic(err)
}
app, _ := splitPath(ur)
nu := *ur
nu.RawQuery = ""
nu.Path = "/"
return nu.String() + app
}
func TestClientHandshake(t *testing.T) {
func TestInitializeClient(t *testing.T) {
for _, ca := range []string{"read", "publish"} {
t.Run(ca, func(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:9121")
@ -63,10 +31,10 @@ func TestClientHandshake(t *testing.T) { @@ -63,10 +31,10 @@ func TestClientHandshake(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
err = handshake.DoServer(bc)
err = handshake.DoServer(bc, true)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)
mrw := message.NewReadWriter(bc, true)
// C->S set window ack size
msg, err := mrw.Read()
@ -79,7 +47,7 @@ func TestClientHandshake(t *testing.T) { @@ -79,7 +47,7 @@ func TestClientHandshake(t *testing.T) {
msg, err = mrw.Read()
require.NoError(t, err)
require.Equal(t, &message.MsgSetPeerBandwidth{
Value: 0x2625a0,
Value: 2500000,
Type: 2,
}, msg)
@ -95,13 +63,13 @@ func TestClientHandshake(t *testing.T) { @@ -95,13 +63,13 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
float64(1),
Name: "connect",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "app", V: "stream"},
{K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")},
{K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false},
{K: "capabilities", V: float64(15)},
{K: "audioCodecs", V: float64(4071)},
@ -114,9 +82,9 @@ func TestClientHandshake(t *testing.T) { @@ -114,9 +82,9 @@ func TestClientHandshake(t *testing.T) {
// S->C result
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(1),
Name: "_result",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)},
@ -137,9 +105,9 @@ func TestClientHandshake(t *testing.T) { @@ -137,9 +105,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(2),
Name: "createStream",
CommandID: 2,
Arguments: []interface{}{
nil,
},
}, msg)
@ -147,9 +115,9 @@ func TestClientHandshake(t *testing.T) { @@ -147,9 +115,9 @@ func TestClientHandshake(t *testing.T) {
// S->C result
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(2),
Name: "_result",
CommandID: 2,
Arguments: []interface{}{
nil,
float64(1),
},
@ -168,10 +136,10 @@ func TestClientHandshake(t *testing.T) { @@ -168,10 +136,10 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 4,
MessageStreamID: 16777216,
Payload: []interface{}{
"play",
float64(0),
MessageStreamID: 0x1000000,
Name: "play",
CommandID: 0,
Arguments: []interface{}{
nil,
"",
},
@ -180,10 +148,10 @@ func TestClientHandshake(t *testing.T) { @@ -180,10 +148,10 @@ func TestClientHandshake(t *testing.T) {
// S->C onStatus
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(4),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 4,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@ -199,9 +167,9 @@ func TestClientHandshake(t *testing.T) { @@ -199,9 +167,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"releaseStream",
float64(2),
Name: "releaseStream",
CommandID: 2,
Arguments: []interface{}{
nil,
"",
},
@ -212,9 +180,9 @@ func TestClientHandshake(t *testing.T) { @@ -212,9 +180,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"FCPublish",
float64(3),
Name: "FCPublish",
CommandID: 3,
Arguments: []interface{}{
nil,
"",
},
@ -225,9 +193,9 @@ func TestClientHandshake(t *testing.T) { @@ -225,9 +193,9 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(4),
Name: "createStream",
CommandID: 4,
Arguments: []interface{}{
nil,
},
}, msg)
@ -235,9 +203,9 @@ func TestClientHandshake(t *testing.T) { @@ -235,9 +203,9 @@ func TestClientHandshake(t *testing.T) {
// S->C result
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(4),
Name: "_result",
CommandID: 4,
Arguments: []interface{}{
nil,
float64(1),
},
@ -249,10 +217,10 @@ func TestClientHandshake(t *testing.T) { @@ -249,10 +217,10 @@ func TestClientHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 4,
MessageStreamID: 16777216,
Payload: []interface{}{
"publish",
float64(5),
MessageStreamID: 0x1000000,
Name: "publish",
CommandID: 5,
Arguments: []interface{}{
nil,
"",
"stream",
@ -262,10 +230,10 @@ func TestClientHandshake(t *testing.T) { @@ -262,10 +230,10 @@ func TestClientHandshake(t *testing.T) {
// S->C onStatus
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(5),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 5,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@ -286,9 +254,9 @@ func TestClientHandshake(t *testing.T) { @@ -286,9 +254,9 @@ func TestClientHandshake(t *testing.T) {
nconn, err := net.Dial("tcp", u.Host)
require.NoError(t, err)
defer nconn.Close()
conn := NewClientConn(nconn, u)
conn := NewConn(nconn)
err = conn.ClientHandshake(ca == "read")
err = conn.InitializeClient(u, ca == "read")
require.NoError(t, err)
<-done
@ -296,7 +264,7 @@ func TestClientHandshake(t *testing.T) { @@ -296,7 +264,7 @@ func TestClientHandshake(t *testing.T) {
}
}
func TestServerHandshake(t *testing.T) {
func TestInitializeServer(t *testing.T) {
for _, ca := range []string{"read", "publish"} {
t.Run(ca, func(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:9121")
@ -310,9 +278,15 @@ func TestServerHandshake(t *testing.T) { @@ -310,9 +278,15 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err)
defer nconn.Close()
conn := NewServerConn(nconn)
err = conn.ServerHandshake()
conn := NewConn(nconn)
u, isReading, err := conn.InitializeServer()
require.NoError(t, err)
require.Equal(t, &url.URL{
Scheme: "rtmp",
Host: "127.0.0.1:9121",
Path: "//stream/",
}, u)
require.Equal(t, ca == "read", isReading)
close(done)
}()
@ -322,21 +296,21 @@ func TestServerHandshake(t *testing.T) { @@ -322,21 +296,21 @@ func TestServerHandshake(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
err = handshake.DoClient(bc)
err = handshake.DoClient(bc, true)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)
mrw := message.NewReadWriter(bc, true)
// C->S connect
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
1,
Name: "connect",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "app", V: "/stream"},
{K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")},
{K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false},
{K: "capabilities", V: 15},
{K: "audioCodecs", V: 4071},
@ -374,9 +348,9 @@ func TestServerHandshake(t *testing.T) { @@ -374,9 +348,9 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(1),
Name: "_result",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)},
@ -400,9 +374,9 @@ func TestServerHandshake(t *testing.T) { @@ -400,9 +374,9 @@ func TestServerHandshake(t *testing.T) {
// C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(2),
Name: "createStream",
CommandID: 2,
Arguments: []interface{}{
nil,
},
})
@ -413,9 +387,9 @@ func TestServerHandshake(t *testing.T) { @@ -413,9 +387,9 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(2),
Name: "_result",
CommandID: 2,
Arguments: []interface{}{
nil,
float64(1),
},
@ -430,10 +404,10 @@ func TestServerHandshake(t *testing.T) { @@ -430,10 +404,10 @@ func TestServerHandshake(t *testing.T) {
// C->S play
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 4,
MessageStreamID: 16777216,
Payload: []interface{}{
"play",
float64(0),
MessageStreamID: 0x1000000,
Name: "play",
CommandID: 0,
Arguments: []interface{}{
nil,
"",
},
@ -443,9 +417,9 @@ func TestServerHandshake(t *testing.T) { @@ -443,9 +417,9 @@ func TestServerHandshake(t *testing.T) {
// C->S releaseStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"releaseStream",
float64(2),
Name: "releaseStream",
CommandID: 2,
Arguments: []interface{}{
nil,
"",
},
@ -455,9 +429,9 @@ func TestServerHandshake(t *testing.T) { @@ -455,9 +429,9 @@ func TestServerHandshake(t *testing.T) {
// C->S FCPublish
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"FCPublish",
float64(3),
Name: "FCPublish",
CommandID: 3,
Arguments: []interface{}{
nil,
"",
},
@ -467,9 +441,9 @@ func TestServerHandshake(t *testing.T) { @@ -467,9 +441,9 @@ func TestServerHandshake(t *testing.T) {
// C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(4),
Name: "createStream",
CommandID: 4,
Arguments: []interface{}{
nil,
},
})
@ -480,9 +454,9 @@ func TestServerHandshake(t *testing.T) { @@ -480,9 +454,9 @@ func TestServerHandshake(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(4),
Name: "_result",
CommandID: 4,
Arguments: []interface{}{
nil,
float64(1),
},
@ -491,10 +465,10 @@ func TestServerHandshake(t *testing.T) { @@ -491,10 +465,10 @@ func TestServerHandshake(t *testing.T) {
// C->S publish
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 4,
MessageStreamID: 16777216,
Payload: []interface{}{
"publish",
float64(5),
MessageStreamID: 0x1000000,
Name: "publish",
CommandID: 5,
Arguments: []interface{}{
nil,
"",
"stream",
@ -536,8 +510,8 @@ func TestReadTracks(t *testing.T) { @@ -536,8 +510,8 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
rconn := NewServerConn(conn)
err = rconn.ServerHandshake()
rconn := NewConn(conn)
_, _, err = rconn.InitializeServer()
require.NoError(t, err)
videoTrack, audioTrack, err := rconn.ReadTracks()
@ -610,21 +584,21 @@ func TestReadTracks(t *testing.T) { @@ -610,21 +584,21 @@ func TestReadTracks(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
err = handshake.DoClient(bc)
err = handshake.DoClient(bc, true)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)
mrw := message.NewReadWriter(bc, true)
// C->S connect
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
1,
Name: "connect",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "app", V: "/stream"},
{K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")},
{K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false},
{K: "capabilities", V: 15},
{K: "audioCodecs", V: 4071},
@ -662,9 +636,9 @@ func TestReadTracks(t *testing.T) { @@ -662,9 +636,9 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(1),
Name: "_result",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)},
@ -687,9 +661,9 @@ func TestReadTracks(t *testing.T) { @@ -687,9 +661,9 @@ func TestReadTracks(t *testing.T) {
// C->S releaseStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"releaseStream",
float64(2),
Name: "releaseStream",
CommandID: 2,
Arguments: []interface{}{
nil,
"",
},
@ -699,9 +673,9 @@ func TestReadTracks(t *testing.T) { @@ -699,9 +673,9 @@ func TestReadTracks(t *testing.T) {
// C->S FCPublish
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"FCPublish",
float64(3),
Name: "FCPublish",
CommandID: 3,
Arguments: []interface{}{
nil,
"",
},
@ -711,9 +685,9 @@ func TestReadTracks(t *testing.T) { @@ -711,9 +685,9 @@ func TestReadTracks(t *testing.T) {
// C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(4),
Name: "createStream",
CommandID: 4,
Arguments: []interface{}{
nil,
},
})
@ -724,9 +698,9 @@ func TestReadTracks(t *testing.T) { @@ -724,9 +698,9 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(4),
Name: "_result",
CommandID: 4,
Arguments: []interface{}{
nil,
float64(1),
},
@ -736,9 +710,9 @@ func TestReadTracks(t *testing.T) { @@ -736,9 +710,9 @@ func TestReadTracks(t *testing.T) {
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
MessageStreamID: 1,
Payload: []interface{}{
"publish",
float64(5),
Name: "publish",
CommandID: 5,
Arguments: []interface{}{
nil,
"",
"live",
@ -751,10 +725,10 @@ func TestReadTracks(t *testing.T) { @@ -751,10 +725,10 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(5),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 5,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@ -796,23 +770,16 @@ func TestReadTracks(t *testing.T) { @@ -796,23 +770,16 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// C->S H264 decoder config
codec := nh264.Codec{
SPS: map[int][]byte{
0: sps,
},
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
buf, _ := h264conf.Conf{
SPS: sps,
PPS: pps,
}.Marshal()
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: b[:n],
Payload: buf,
})
require.NoError(t, err)
@ -861,23 +828,16 @@ func TestReadTracks(t *testing.T) { @@ -861,23 +828,16 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// C->S H264 decoder config
codec := nh264.Codec{
SPS: map[int][]byte{
0: sps,
},
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
buf, _ := h264conf.Conf{
SPS: sps,
PPS: pps,
}.Marshal()
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: b[:n],
Payload: buf,
})
require.NoError(t, err)
@ -901,23 +861,16 @@ func TestReadTracks(t *testing.T) { @@ -901,23 +861,16 @@ func TestReadTracks(t *testing.T) {
case "missing metadata":
// C->S H264 decoder config
codec := nh264.Codec{
SPS: map[int][]byte{
0: sps,
},
PPS: map[int][]byte{
0: pps,
},
}
b := make([]byte, 128)
var n int
codec.ToConfig(b, &n)
buf, _ := h264conf.Conf{
SPS: sps,
PPS: pps,
}.Marshal()
err = mrw.Write(&message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 1,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: b[:n],
Payload: buf,
})
require.NoError(t, err)
@ -955,8 +908,8 @@ func TestWriteTracks(t *testing.T) { @@ -955,8 +908,8 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
rconn := NewServerConn(conn)
err = rconn.ServerHandshake()
rconn := NewConn(conn)
_, _, err = rconn.InitializeServer()
require.NoError(t, err)
videoTrack := &gortsplib.TrackH264{
@ -992,21 +945,21 @@ func TestWriteTracks(t *testing.T) { @@ -992,21 +945,21 @@ func TestWriteTracks(t *testing.T) {
defer conn.Close()
bc := bytecounter.NewReadWriter(conn)
err = handshake.DoClient(bc)
err = handshake.DoClient(bc, true)
require.NoError(t, err)
mrw := message.NewReadWriter(bc)
mrw := message.NewReadWriter(bc, true)
// C->S connect
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"connect",
1,
Name: "connect",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "app", V: "/stream"},
{K: "flashVer", V: "LNX 9,0,124,2"},
{K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")},
{K: "tcUrl", V: "rtmp://127.0.0.1:9121/stream"},
{K: "fpad", V: false},
{K: "capabilities", V: 15},
{K: "audioCodecs", V: 4071},
@ -1044,9 +997,9 @@ func TestWriteTracks(t *testing.T) { @@ -1044,9 +997,9 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(1),
Name: "_result",
CommandID: 1,
Arguments: []interface{}{
flvio.AMFMap{
{K: "fmsVer", V: "LNX 9,0,124,2"},
{K: "capabilities", V: float64(31)},
@ -1075,9 +1028,9 @@ func TestWriteTracks(t *testing.T) { @@ -1075,9 +1028,9 @@ func TestWriteTracks(t *testing.T) {
// C->S createStream
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"createStream",
float64(2),
Name: "createStream",
CommandID: 2,
Arguments: []interface{}{
nil,
},
})
@ -1088,9 +1041,9 @@ func TestWriteTracks(t *testing.T) { @@ -1088,9 +1041,9 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 3,
Payload: []interface{}{
"_result",
float64(2),
Name: "_result",
CommandID: 2,
Arguments: []interface{}{
nil,
float64(1),
},
@ -1099,9 +1052,9 @@ func TestWriteTracks(t *testing.T) { @@ -1099,9 +1052,9 @@ func TestWriteTracks(t *testing.T) {
// C->S getStreamLength
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
Payload: []interface{}{
"getStreamLength",
float64(3),
Name: "getStreamLength",
CommandID: 3,
Arguments: []interface{}{
nil,
"",
},
@ -1111,10 +1064,10 @@ func TestWriteTracks(t *testing.T) { @@ -1111,10 +1064,10 @@ func TestWriteTracks(t *testing.T) {
// C->S play
err = mrw.Write(&message.MsgCommandAMF0{
ChunkStreamID: 8,
MessageStreamID: 16777216,
Payload: []interface{}{
"play",
float64(4),
MessageStreamID: 0x1000000,
Name: "play",
CommandID: 4,
Arguments: []interface{}{
nil,
"",
float64(-2000),
@ -1141,10 +1094,10 @@ func TestWriteTracks(t *testing.T) { @@ -1141,10 +1094,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(4),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 4,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@ -1159,10 +1112,10 @@ func TestWriteTracks(t *testing.T) { @@ -1159,10 +1112,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(4),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 4,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@ -1177,10 +1130,10 @@ func TestWriteTracks(t *testing.T) { @@ -1177,10 +1130,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(4),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 4,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@ -1195,10 +1148,10 @@ func TestWriteTracks(t *testing.T) { @@ -1195,10 +1148,10 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgCommandAMF0{
ChunkStreamID: 5,
MessageStreamID: 16777216,
Payload: []interface{}{
"onStatus",
float64(4),
MessageStreamID: 0x1000000,
Name: "onStatus",
CommandID: 4,
Arguments: []interface{}{
nil,
flvio.AMFMap{
{K: "level", V: "status"},
@ -1213,8 +1166,9 @@ func TestWriteTracks(t *testing.T) { @@ -1213,8 +1166,9 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgDataAMF0{
ChunkStreamID: 4,
MessageStreamID: 16777216,
MessageStreamID: 0x1000000,
Payload: []interface{}{
"@setDataFrame",
"onMetaData",
flvio.AMFMap{
{K: "videodatarate", V: float64(0)},
@ -1230,7 +1184,7 @@ func TestWriteTracks(t *testing.T) { @@ -1230,7 +1184,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgVideo{
ChunkStreamID: 6,
MessageStreamID: 16777216,
MessageStreamID: 0x1000000,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
Payload: []byte{
@ -1248,7 +1202,7 @@ func TestWriteTracks(t *testing.T) { @@ -1248,7 +1202,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &message.MsgAudio{
ChunkStreamID: 4,
MessageStreamID: 16777216,
MessageStreamID: 0x1000000,
Rate: flvio.SOUND_44Khz,
Depth: flvio.SOUND_16BIT,
Channels: flvio.SOUND_STEREO,

89
internal/rtmp/h264conf/h264conf.go

@ -0,0 +1,89 @@ @@ -0,0 +1,89 @@
package h264conf
import (
"fmt"
)
// Conf is a RTMP H264 configuration.
type Conf struct {
SPS []byte
PPS []byte
}
// Unmarshal decodes a Conf from bytes.
func (c *Conf) Unmarshal(buf []byte) error {
if len(buf) < 8 {
return fmt.Errorf("invalid size 1")
}
pos := 5
spsCount := buf[pos] & 0x1F
pos++
if spsCount != 1 {
return fmt.Errorf("sps count != 1 is unsupported")
}
spsLen := int(uint16(buf[pos])<<8 | uint16(buf[pos+1]))
pos += 2
if (len(buf) - pos) < spsLen {
return fmt.Errorf("invalid size 2")
}
c.SPS = buf[pos : pos+spsLen]
pos += spsLen
if (len(buf) - pos) < 3 {
return fmt.Errorf("invalid size 3")
}
ppsCount := buf[pos]
pos++
if ppsCount != 1 {
return fmt.Errorf("pps count != 1 is unsupported")
}
ppsLen := int(uint16(buf[pos])<<8 | uint16(buf[pos+1]))
pos += 2
if (len(buf) - pos) < ppsLen {
return fmt.Errorf("invalid size")
}
c.PPS = buf[pos : pos+ppsLen]
return nil
}
// Marshal encodes a Conf into bytes.
func (c Conf) Marshal() ([]byte, error) {
spsLen := len(c.SPS)
ppsLen := len(c.PPS)
buf := make([]byte, 11+spsLen+ppsLen)
buf[0] = 1
buf[1] = c.SPS[1]
buf[2] = c.SPS[2]
buf[3] = c.SPS[3]
buf[4] = 3 | 0xFC
buf[5] = 1 | 0xE0
pos := 6
buf[pos] = byte(spsLen >> 8)
buf[pos+1] = byte(spsLen)
pos += 2
copy(buf[pos:], c.SPS)
pos += spsLen
buf[pos] = 1
pos++
buf[pos] = byte(ppsLen >> 8)
buf[pos+1] = byte(ppsLen)
pos += 2
copy(buf[pos:], c.PPS)
return buf, nil
}

29
internal/rtmp/h264conf/h264conf_test.go

@ -0,0 +1,29 @@ @@ -0,0 +1,29 @@
package h264conf
import (
"testing"
"github.com/stretchr/testify/require"
)
var decoded = Conf{
SPS: []byte{0x45, 0x32, 0xA3, 0x08},
PPS: []byte{0x45, 0x34},
}
var encoded = []byte{
0x1, 0x32, 0xa3, 0x8, 0xff, 0xe1, 0x0, 0x4, 0x45, 0x32, 0xa3, 0x8, 0x1, 0x0, 0x2, 0x45, 0x34,
}
func TestUnmarshal(t *testing.T) {
var dec Conf
err := dec.Unmarshal(encoded)
require.NoError(t, err)
require.Equal(t, decoded, dec)
}
func TestMarshal(t *testing.T) {
enc, err := decoded.Marshal()
require.NoError(t, err)
require.Equal(t, encoded, enc)
}

18
internal/rtmp/handshake/c1s1.go

@ -5,7 +5,6 @@ import ( @@ -5,7 +5,6 @@ import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"fmt"
"io"
)
@ -78,14 +77,13 @@ type C1S1 struct { @@ -78,14 +77,13 @@ type C1S1 struct {
}
// Read reads a C1S1.
func (c *C1S1) Read(r io.Reader, isC1 bool) error {
func (c *C1S1) Read(r io.Reader, isC1 bool, validateSignature bool) error {
buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf)
if err != nil {
return err
}
// validate signature
var peerKey []byte
var key []byte
if isC1 {
@ -97,12 +95,15 @@ func (c *C1S1) Read(r io.Reader, isC1 bool) error { @@ -97,12 +95,15 @@ func (c *C1S1) Read(r io.Reader, isC1 bool) error {
}
ok, digest := hsParse1(buf, peerKey, key)
if !ok {
return fmt.Errorf("unable to validate C1/S1 signature")
if validateSignature {
return fmt.Errorf("unable to validate C1/S1 signature")
}
} else {
c.Digest = digest
}
c.Time = binary.BigEndian.Uint32(buf)
c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3])
c.Random = buf[8:]
c.Digest = digest
return nil
}
@ -111,7 +112,10 @@ func (c *C1S1) Read(r io.Reader, isC1 bool) error { @@ -111,7 +112,10 @@ func (c *C1S1) Read(r io.Reader, isC1 bool) error {
func (c *C1S1) Write(w io.Writer, isC1 bool) error {
buf := make([]byte, 1536)
binary.BigEndian.PutUint32(buf, c.Time)
buf[0] = byte(c.Time >> 24)
buf[1] = byte(c.Time >> 16)
buf[2] = byte(c.Time >> 8)
buf[3] = byte(c.Time)
copy(buf[4:], []byte{0, 0, 0, 0})
if c.Random == nil {

2
internal/rtmp/handshake/c1s1_test.go

@ -89,7 +89,7 @@ func TestC1S1Read(t *testing.T) { @@ -89,7 +89,7 @@ func TestC1S1Read(t *testing.T) {
},
} {
var c1s1 C1S1
err := c1s1.Read((bytes.NewReader(ca.enc)), ca.isC1)
err := c1s1.Read((bytes.NewReader(ca.enc)), ca.isC1, true)
require.NoError(t, err)
require.Equal(t, ca.dec, c1s1)
}

29
internal/rtmp/handshake/c2s2.go

@ -3,7 +3,6 @@ package handshake @@ -3,7 +3,6 @@ package handshake
import (
"bytes"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
)
@ -17,22 +16,23 @@ type C2S2 struct { @@ -17,22 +16,23 @@ type C2S2 struct {
}
// Read reads a C2S2.
func (c *C2S2) Read(r io.Reader) error {
func (c *C2S2) Read(r io.Reader, validateSignature bool) error {
buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf)
if err != nil {
return err
}
// validate signature
gap := len(buf) - 32
digest := hsMakeDigest(c.Digest, buf, gap)
if !bytes.Equal(buf[gap:gap+32], digest) {
return fmt.Errorf("unable to validate C2/S2 signature")
if validateSignature {
gap := len(buf) - 32
digest := hsMakeDigest(c.Digest, buf, gap)
if !bytes.Equal(buf[gap:gap+32], digest) {
return fmt.Errorf("unable to validate C2/S2 signature")
}
}
c.Time = binary.BigEndian.Uint32(buf)
c.Time2 = binary.BigEndian.Uint32(buf[4:])
c.Time = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3])
c.Time2 = uint32(buf[4])<<24 | uint32(buf[5])<<16 | uint32(buf[6])<<8 | uint32(buf[7])
c.Random = buf[8:]
return nil
@ -41,8 +41,15 @@ func (c *C2S2) Read(r io.Reader) error { @@ -41,8 +41,15 @@ func (c *C2S2) Read(r io.Reader) error {
// Write writes a C2S2.
func (c C2S2) Write(w io.Writer) error {
buf := make([]byte, 1536)
binary.BigEndian.PutUint32(buf, c.Time)
binary.BigEndian.PutUint32(buf[4:], c.Time2)
buf[0] = byte(c.Time >> 24)
buf[1] = byte(c.Time >> 16)
buf[2] = byte(c.Time >> 8)
buf[3] = byte(c.Time)
buf[4] = byte(c.Time2 >> 24)
buf[5] = byte(c.Time2 >> 16)
buf[6] = byte(c.Time2 >> 8)
buf[7] = byte(c.Time2)
if c.Random == nil {
rand.Read(buf[8:])

2
internal/rtmp/handshake/c2s2_test.go

@ -42,7 +42,7 @@ func TestC2S2Read(t *testing.T) { @@ -42,7 +42,7 @@ func TestC2S2Read(t *testing.T) {
var c2s2 C2S2
c2s2.Digest = c2s2dec.Digest
err := c2s2.Read((bytes.NewReader(c2s2enc)))
err := c2s2.Read((bytes.NewReader(c2s2enc)), true)
require.NoError(t, err)
require.Equal(t, c2s2dec, c2s2)
}

12
internal/rtmp/handshake/handshake.go

@ -5,7 +5,7 @@ import ( @@ -5,7 +5,7 @@ import (
)
// DoClient performs a client-side handshake.
func DoClient(rw io.ReadWriter) error {
func DoClient(rw io.ReadWriter, validateSignature bool) error {
err := C0S0{}.Write(rw)
if err != nil {
return err
@ -23,12 +23,12 @@ func DoClient(rw io.ReadWriter) error { @@ -23,12 +23,12 @@ func DoClient(rw io.ReadWriter) error {
}
s1 := C1S1{}
err = s1.Read(rw, false)
err = s1.Read(rw, false, validateSignature)
if err != nil {
return err
}
err = (&C2S2{Digest: c1.Digest}).Read(rw)
err = (&C2S2{Digest: c1.Digest}).Read(rw, validateSignature)
if err != nil {
return err
}
@ -42,14 +42,14 @@ func DoClient(rw io.ReadWriter) error { @@ -42,14 +42,14 @@ func DoClient(rw io.ReadWriter) error {
}
// DoServer performs a server-side handshake.
func DoServer(rw io.ReadWriter) error {
func DoServer(rw io.ReadWriter, validateSignature bool) error {
err := C0S0{}.Read(rw)
if err != nil {
return err
}
c1 := C1S1{}
err = c1.Read(rw, true)
err = c1.Read(rw, true, validateSignature)
if err != nil {
return err
}
@ -70,7 +70,7 @@ func DoServer(rw io.ReadWriter) error { @@ -70,7 +70,7 @@ func DoServer(rw io.ReadWriter) error {
return err
}
err = (&C2S2{Digest: s1.Digest}).Read(rw)
err = (&C2S2{Digest: s1.Digest}).Read(rw, validateSignature)
if err != nil {
return err
}

4
internal/rtmp/handshake/handshake_test.go

@ -19,7 +19,7 @@ func TestHandshake(t *testing.T) { @@ -19,7 +19,7 @@ func TestHandshake(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
err = DoServer(conn)
err = DoServer(conn, true)
require.NoError(t, err)
close(done)
@ -29,7 +29,7 @@ func TestHandshake(t *testing.T) { @@ -29,7 +29,7 @@ func TestHandshake(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
err = DoClient(conn)
err = DoClient(conn, true)
require.NoError(t, err)
<-done

16
internal/rtmp/message/msg_acknowledge.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -23,18 +22,23 @@ func (m *MsgAcknowledge) Unmarshal(raw *rawmessage.Message) error { @@ -23,18 +22,23 @@ func (m *MsgAcknowledge) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size")
}
m.Value = binary.BigEndian.Uint32(raw.Body)
m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
return nil
}
// Marshal implements Message.
func (m *MsgAcknowledge) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 4)
binary.BigEndian.PutUint32(body, m.Value)
buf := make([]byte, 4)
buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeAcknowledge,
Body: body,
Body: buf,
}, nil
}

3
internal/rtmp/message/msg_audio.go

@ -2,6 +2,7 @@ package message @@ -2,6 +2,7 @@ package message
import (
"fmt"
"time"
"github.com/notedit/rtmp/format/flv/flvio"
@ -12,7 +13,7 @@ import ( @@ -12,7 +13,7 @@ import (
// MsgAudio is an audio message.
type MsgAudio struct {
ChunkStreamID byte
DTS uint32
DTS time.Duration
MessageStreamID uint32
Rate uint8
Depth uint8

30
internal/rtmp/message/msg_command_amf0.go

@ -1,6 +1,8 @@ @@ -1,6 +1,8 @@
package message
import (
"fmt"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -11,7 +13,9 @@ import ( @@ -11,7 +13,9 @@ import (
type MsgCommandAMF0 struct {
ChunkStreamID byte
MessageStreamID uint32
Payload []interface{}
Name string
CommandID int
Arguments []interface{}
}
// Unmarshal implements Message.
@ -23,7 +27,24 @@ func (m *MsgCommandAMF0) Unmarshal(raw *rawmessage.Message) error { @@ -23,7 +27,24 @@ func (m *MsgCommandAMF0) Unmarshal(raw *rawmessage.Message) error {
if err != nil {
return err
}
m.Payload = payload
if len(payload) < 3 {
return fmt.Errorf("invalid command payload")
}
var ok bool
m.Name, ok = payload[0].(string)
if !ok {
return fmt.Errorf("invalid command payload")
}
tmp, ok := payload[1].(float64)
if !ok {
return fmt.Errorf("invalid command payload")
}
m.CommandID = int(tmp)
m.Arguments = payload[2:]
return nil
}
@ -34,6 +55,9 @@ func (m MsgCommandAMF0) Marshal() (*rawmessage.Message, error) { @@ -34,6 +55,9 @@ func (m MsgCommandAMF0) Marshal() (*rawmessage.Message, error) {
ChunkStreamID: m.ChunkStreamID,
Type: chunk.MessageTypeCommandAMF0,
MessageStreamID: m.MessageStreamID,
Body: flvio.FillAMF0ValsMalloc(m.Payload),
Body: flvio.FillAMF0ValsMalloc(append([]interface{}{
m.Name,
float64(m.CommandID),
}, m.Arguments...)),
}, nil
}

16
internal/rtmp/message/msg_setchunksize.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -23,18 +22,23 @@ func (m *MsgSetChunkSize) Unmarshal(raw *rawmessage.Message) error { @@ -23,18 +22,23 @@ func (m *MsgSetChunkSize) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size")
}
m.Value = binary.BigEndian.Uint32(raw.Body)
m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
return nil
}
// Marshal implements Message.
func (m *MsgSetChunkSize) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 4)
binary.BigEndian.PutUint32(body, m.Value)
buf := make([]byte, 4)
buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeSetChunkSize,
Body: body,
Body: buf,
}, nil
}

20
internal/rtmp/message/msg_setpeerbandwidth.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -24,20 +23,25 @@ func (m *MsgSetPeerBandwidth) Unmarshal(raw *rawmessage.Message) error { @@ -24,20 +23,25 @@ func (m *MsgSetPeerBandwidth) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size")
}
m.Value = binary.BigEndian.Uint32(raw.Body)
m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
m.Type = raw.Body[4]
return nil
}
// Marshal implements Message.
func (m *MsgSetPeerBandwidth) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 5)
binary.BigEndian.PutUint32(body, m.Value)
body[4] = m.Type
buf := make([]byte, 5)
buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
buf[4] = m.Type
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeSetChunkSize,
Body: body,
Type: chunk.MessageTypeSetPeerBandwidth,
Body: buf,
}, nil
}

16
internal/rtmp/message/msg_setwindowacksize.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -23,18 +22,23 @@ func (m *MsgSetWindowAckSize) Unmarshal(raw *rawmessage.Message) error { @@ -23,18 +22,23 @@ func (m *MsgSetWindowAckSize) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("unexpected body size")
}
m.Value = binary.BigEndian.Uint32(raw.Body)
m.Value = uint32(raw.Body[0])<<24 | uint32(raw.Body[1])<<16 | uint32(raw.Body[2])<<8 | uint32(raw.Body[3])
return nil
}
// Marshal implements Message.
func (m *MsgSetWindowAckSize) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 4)
binary.BigEndian.PutUint32(body, m.Value)
buf := make([]byte, 4)
buf[0] = byte(m.Value >> 24)
buf[1] = byte(m.Value >> 16)
buf[2] = byte(m.Value >> 8)
buf[3] = byte(m.Value)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeSetWindowAckSize,
Body: body,
Body: buf,
}, nil
}

18
internal/rtmp/message/msg_usercontrol_pingrequest.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -23,20 +22,25 @@ func (m *MsgUserControlPingRequest) Unmarshal(raw *rawmessage.Message) error { @@ -23,20 +22,25 @@ func (m *MsgUserControlPingRequest) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size")
}
m.ServerTime = binary.BigEndian.Uint32(raw.Body[2:])
m.ServerTime = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlPingRequest) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypePingRequest)
binary.BigEndian.PutUint32(body[2:], m.ServerTime)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypePingRequest >> 8)
buf[1] = byte(UserControlTypePingRequest)
buf[2] = byte(m.ServerTime >> 24)
buf[3] = byte(m.ServerTime >> 16)
buf[4] = byte(m.ServerTime >> 8)
buf[5] = byte(m.ServerTime)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

18
internal/rtmp/message/msg_usercontrol_pingresponse.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -23,20 +22,25 @@ func (m *MsgUserControlPingResponse) Unmarshal(raw *rawmessage.Message) error { @@ -23,20 +22,25 @@ func (m *MsgUserControlPingResponse) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size")
}
m.ServerTime = binary.BigEndian.Uint32(raw.Body[2:])
m.ServerTime = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlPingResponse) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypePingResponse)
binary.BigEndian.PutUint32(body[2:], m.ServerTime)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypePingResponse >> 8)
buf[1] = byte(UserControlTypePingResponse)
buf[2] = byte(m.ServerTime >> 24)
buf[3] = byte(m.ServerTime >> 16)
buf[4] = byte(m.ServerTime >> 8)
buf[5] = byte(m.ServerTime)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

25
internal/rtmp/message/msg_usercontrol_setbufferlength.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -24,22 +23,30 @@ func (m *MsgUserControlSetBufferLength) Unmarshal(raw *rawmessage.Message) error @@ -24,22 +23,30 @@ func (m *MsgUserControlSetBufferLength) Unmarshal(raw *rawmessage.Message) error
return fmt.Errorf("invalid body size")
}
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:])
m.BufferLength = binary.BigEndian.Uint32(raw.Body[6:])
m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
m.BufferLength = uint32(raw.Body[6])<<24 | uint32(raw.Body[7])<<16 | uint32(raw.Body[8])<<8 | uint32(raw.Body[9])
return nil
}
// Marshal implements Message.
func (m MsgUserControlSetBufferLength) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 10)
binary.BigEndian.PutUint16(body, UserControlTypeSetBufferLength)
binary.BigEndian.PutUint32(body[2:], m.StreamID)
binary.BigEndian.PutUint32(body[6:], m.BufferLength)
buf := make([]byte, 10)
buf[0] = byte(UserControlTypeSetBufferLength >> 8)
buf[1] = byte(UserControlTypeSetBufferLength)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
buf[6] = byte(m.BufferLength >> 24)
buf[7] = byte(m.BufferLength >> 16)
buf[8] = byte(m.BufferLength >> 8)
buf[9] = byte(m.BufferLength)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

18
internal/rtmp/message/msg_usercontrol_streambegin.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -23,20 +22,25 @@ func (m *MsgUserControlStreamBegin) Unmarshal(raw *rawmessage.Message) error { @@ -23,20 +22,25 @@ func (m *MsgUserControlStreamBegin) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size")
}
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:])
m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlStreamBegin) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamBegin)
binary.BigEndian.PutUint32(body[2:], m.StreamID)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypeStreamBegin >> 8)
buf[1] = byte(UserControlTypeStreamBegin)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

18
internal/rtmp/message/msg_usercontrol_streamdry.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -23,20 +22,25 @@ func (m *MsgUserControlStreamDry) Unmarshal(raw *rawmessage.Message) error { @@ -23,20 +22,25 @@ func (m *MsgUserControlStreamDry) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size")
}
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:])
m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlStreamDry) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamDry)
binary.BigEndian.PutUint32(body[2:], m.StreamID)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypeStreamDry >> 8)
buf[1] = byte(UserControlTypeStreamDry)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

18
internal/rtmp/message/msg_usercontrol_streameof.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -23,20 +22,25 @@ func (m *MsgUserControlStreamEOF) Unmarshal(raw *rawmessage.Message) error { @@ -23,20 +22,25 @@ func (m *MsgUserControlStreamEOF) Unmarshal(raw *rawmessage.Message) error {
return fmt.Errorf("invalid body size")
}
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:])
m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlStreamEOF) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamEOF)
binary.BigEndian.PutUint32(body[2:], m.StreamID)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypeStreamEOF >> 8)
buf[1] = byte(UserControlTypeStreamEOF)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

18
internal/rtmp/message/msg_usercontrol_streamisrecorded.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
package message //nolint:dupl
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -23,20 +22,25 @@ func (m *MsgUserControlStreamIsRecorded) Unmarshal(raw *rawmessage.Message) erro @@ -23,20 +22,25 @@ func (m *MsgUserControlStreamIsRecorded) Unmarshal(raw *rawmessage.Message) erro
return fmt.Errorf("invalid body size")
}
m.StreamID = binary.BigEndian.Uint32(raw.Body[2:])
m.StreamID = uint32(raw.Body[2])<<24 | uint32(raw.Body[3])<<16 | uint32(raw.Body[4])<<8 | uint32(raw.Body[5])
return nil
}
// Marshal implements Message.
func (m MsgUserControlStreamIsRecorded) Marshal() (*rawmessage.Message, error) {
body := make([]byte, 6)
binary.BigEndian.PutUint16(body, UserControlTypeStreamIsRecorded)
binary.BigEndian.PutUint32(body[2:], m.StreamID)
buf := make([]byte, 6)
buf[0] = byte(UserControlTypeStreamIsRecorded >> 8)
buf[1] = byte(UserControlTypeStreamIsRecorded)
buf[2] = byte(m.StreamID >> 24)
buf[3] = byte(m.StreamID >> 16)
buf[4] = byte(m.StreamID >> 8)
buf[5] = byte(m.StreamID)
return &rawmessage.Message{
ChunkStreamID: ControlChunkStreamID,
Type: chunk.MessageTypeUserControl,
Body: body,
Body: buf,
}, nil
}

19
internal/rtmp/message/msg_video.go

@ -2,6 +2,7 @@ package message @@ -2,6 +2,7 @@ package message
import (
"fmt"
"time"
"github.com/notedit/rtmp/format/flv/flvio"
@ -12,11 +13,11 @@ import ( @@ -12,11 +13,11 @@ import (
// MsgVideo is a video message.
type MsgVideo struct {
ChunkStreamID byte
DTS uint32
DTS time.Duration
MessageStreamID uint32
IsKeyFrame bool
H264Type uint8
PTSDelta uint32
PTSDelta time.Duration
Payload []byte
}
@ -38,7 +39,10 @@ func (m *MsgVideo) Unmarshal(raw *rawmessage.Message) error { @@ -38,7 +39,10 @@ func (m *MsgVideo) Unmarshal(raw *rawmessage.Message) error {
}
m.H264Type = raw.Body[1]
m.PTSDelta = uint32(raw.Body[2])<<16 | uint32(raw.Body[3])<<8 | uint32(raw.Body[4])
tmp := uint32(raw.Body[2])<<16 | uint32(raw.Body[3])<<8 | uint32(raw.Body[4])
m.PTSDelta = time.Duration(tmp) * time.Millisecond
m.Payload = raw.Body[5:]
return nil
@ -55,9 +59,12 @@ func (m MsgVideo) Marshal() (*rawmessage.Message, error) { @@ -55,9 +59,12 @@ func (m MsgVideo) Marshal() (*rawmessage.Message, error) {
}
body[0] |= flvio.VIDEO_H264
body[1] = m.H264Type
body[2] = uint8(m.PTSDelta >> 16)
body[3] = uint8(m.PTSDelta >> 8)
body[4] = uint8(m.PTSDelta)
tmp := uint32(m.PTSDelta / time.Millisecond)
body[2] = uint8(tmp >> 16)
body[3] = uint8(tmp >> 8)
body[4] = uint8(tmp)
copy(body[5:], m.Payload)
return &rawmessage.Message{

5
internal/rtmp/message/reader.go

@ -1,7 +1,6 @@ @@ -1,7 +1,6 @@
package message
import (
"encoding/binary"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
@ -28,7 +27,7 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) { @@ -28,7 +27,7 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) {
return nil, fmt.Errorf("invalid body size")
}
subType := binary.BigEndian.Uint16(raw.Body)
subType := uint16(raw.Body[0])<<8 | uint16(raw.Body[1])
switch subType {
case UserControlTypeStreamBegin:
return &MsgUserControlStreamBegin{}, nil
@ -68,7 +67,7 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) { @@ -68,7 +67,7 @@ func allocateMessage(raw *rawmessage.Message) (Message, error) {
return &MsgVideo{}, nil
default:
return nil, fmt.Errorf("unhandled message")
return nil, fmt.Errorf("unhandled message type (%v)", raw.Type)
}
}

227
internal/rtmp/message/reader_test.go

@ -0,0 +1,227 @@ @@ -0,0 +1,227 @@
package message
import (
"bytes"
"testing"
"time"
"github.com/notedit/rtmp/format/flv/flvio"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
)
var readWriterCases = []struct {
name string
dec Message
enc []byte
}{
{
"acknowledge",
&MsgAcknowledge{
Value: 45953968,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x3,
0x0, 0x0, 0x0, 0x0, 0x2, 0xbd, 0x33, 0xb0,
},
},
{
"audio",
&MsgAudio{
ChunkStreamID: 7,
DTS: 6013806 * time.Millisecond,
MessageStreamID: 4534543,
Rate: flvio.SOUND_44Khz,
Depth: flvio.SOUND_16BIT,
Channels: flvio.SOUND_STEREO,
AACType: flvio.AAC_RAW,
Payload: []byte{0x5A, 0xC0, 0x77, 0x40},
},
[]byte{
0x7, 0x5b, 0xc3, 0x6e, 0x0, 0x0, 0x6, 0x8,
0x0, 0x45, 0x31, 0xf, 0xaf, 0x1, 0x5a, 0xc0,
0x77, 0x40,
},
},
{
"command amf0",
&MsgCommandAMF0{
ChunkStreamID: 3,
MessageStreamID: 345243,
Name: "i8yythrergre",
CommandID: 56456,
Arguments: []interface{}{
flvio.AMFMap{
{K: "k1", V: "v1"},
{K: "k2", V: "v2"},
},
nil,
},
},
[]byte{
0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2f, 0x14,
0x0, 0x5, 0x44, 0x9b, 0x2, 0x0, 0xc, 0x69,
0x38, 0x79, 0x79, 0x74, 0x68, 0x72, 0x65, 0x72,
0x67, 0x72, 0x65, 0x0, 0x40, 0xeb, 0x91, 0x0,
0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x2, 0x6b,
0x31, 0x2, 0x0, 0x2, 0x76, 0x31, 0x0, 0x2,
0x6b, 0x32, 0x2, 0x0, 0x2, 0x76, 0x32, 0x0,
0x0, 0x9, 0x5,
},
},
{
"data amf0",
&MsgDataAMF0{
ChunkStreamID: 3,
MessageStreamID: 345243,
Payload: []interface{}{
float64(234),
"string",
nil,
},
},
[]byte{
0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x13, 0x12,
0x0, 0x5, 0x44, 0x9b, 0x0, 0x40, 0x6d, 0x40,
0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x6,
0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x05,
},
},
{
"set chunk size",
&MsgSetChunkSize{
Value: 10000,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x1,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x27, 0x10,
},
},
{
"set peer bandwidth",
&MsgSetChunkSize{
Value: 10000,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x1,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x27, 0x10,
},
},
{
"set window ack size",
&MsgSetChunkSize{
Value: 10000,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x1,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x27, 0x10,
},
},
{
"user control ping request",
&MsgUserControlPingRequest{
ServerTime: 569834435,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x21, 0xf6,
0xfb, 0xc3,
},
},
{
"user control ping response",
&MsgUserControlPingResponse{
ServerTime: 569834435,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x21, 0xf6,
0xfb, 0xc3,
},
},
{
"user control set buffer length",
&MsgUserControlSetBufferLength{
StreamID: 35534,
BufferLength: 235345,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0,
0x8a, 0xce, 0x0, 0x3, 0x97, 0x51,
},
},
{
"user control stream begin",
&MsgUserControlStreamBegin{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"user control stream dry",
&MsgUserControlStreamDry{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"user control stream eof",
&MsgUserControlStreamEOF{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"user control stream is recorded",
&MsgUserControlStreamIsRecorded{
StreamID: 35534,
},
[]byte{
0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x6, 0x4,
0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0,
0x8a, 0xce,
},
},
{
"video",
&MsgVideo{
ChunkStreamID: 6,
DTS: 2543534 * time.Millisecond,
MessageStreamID: 0x1000000,
IsKeyFrame: true,
H264Type: flvio.AVC_SEQHDR,
PTSDelta: 10 * time.Millisecond,
Payload: []byte{0x01, 0x02, 0x03},
},
[]byte{
0x6, 0x26, 0xcf, 0xae, 0x0, 0x0, 0x8, 0x9,
0x1, 0x0, 0x0, 0x0, 0x17, 0x0, 0x0, 0x0,
0xa, 0x1, 0x2, 0x3,
},
},
}
func TestReader(t *testing.T) {
for _, ca := range readWriterCases {
t.Run(ca.name, func(t *testing.T) {
r := NewReader(bytecounter.NewReader(bytes.NewReader(ca.enc)), nil)
dec, err := r.Read()
require.NoError(t, err)
require.Equal(t, ca.dec, dec)
})
}
}

4
internal/rtmp/message/readwriter.go

@ -11,8 +11,8 @@ type ReadWriter struct { @@ -11,8 +11,8 @@ type ReadWriter struct {
}
// NewReadWriter allocates a ReadWriter.
func NewReadWriter(bc *bytecounter.ReadWriter) *ReadWriter {
w := NewWriter(bc.Writer)
func NewReadWriter(bc *bytecounter.ReadWriter, checkAcknowledge bool) *ReadWriter {
w := NewWriter(bc.Writer, checkAcknowledge)
r := NewReader(bc.Reader, func(count uint32) error {
return w.Write(&MsgAcknowledge{

4
internal/rtmp/message/writer.go

@ -11,9 +11,9 @@ type Writer struct { @@ -11,9 +11,9 @@ type Writer struct {
}
// NewWriter allocates a Writer.
func NewWriter(w *bytecounter.Writer) *Writer {
func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer {
return &Writer{
w: rawmessage.NewWriter(w),
w: rawmessage.NewWriter(w, checkAcknowledge),
}
}

22
internal/rtmp/message/writer_test.go

@ -0,0 +1,22 @@ @@ -0,0 +1,22 @@
package message
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
)
func TestWriter(t *testing.T) {
for _, ca := range readWriterCases {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
r := NewWriter(bytecounter.NewWriter(&buf), true)
err := r.Write(ca.dec)
require.NoError(t, err)
require.Equal(t, ca.enc, buf.Bytes())
})
}
}

4
internal/rtmp/rawmessage/message.go

@ -1,13 +1,15 @@ @@ -1,13 +1,15 @@
package rawmessage
import (
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
)
// Message is a raw message.
type Message struct {
ChunkStreamID byte
Timestamp uint32
Timestamp time.Duration
Type chunk.MessageType
MessageStreamID uint32
Body []byte

15
internal/rtmp/rawmessage/reader.go

@ -3,6 +3,7 @@ package rawmessage @@ -3,6 +3,7 @@ package rawmessage
import (
"errors"
"fmt"
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -73,7 +74,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -73,7 +74,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
}
return &Message{
Timestamp: c0.Timestamp,
Timestamp: time.Duration(c0.Timestamp) * time.Millisecond,
Type: c0.Type,
MessageStreamID: c0.MessageStreamID,
Body: c0.Body,
@ -109,7 +110,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -109,7 +110,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
}
return &Message{
Timestamp: *rc.curTimestamp,
Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: c1.Type,
MessageStreamID: *rc.curMessageStreamID,
Body: c1.Body,
@ -124,7 +125,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -124,7 +125,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
return nil, fmt.Errorf("received type 2 chunk but expected type 3 chunk")
}
chunkBodyLen := (*rc.curBodyLen)
chunkBodyLen := *rc.curBodyLen
if chunkBodyLen > rc.mr.chunkSize {
chunkBodyLen = rc.mr.chunkSize
}
@ -140,13 +141,13 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -140,13 +141,13 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
v2 := c2.TimestampDelta
rc.curTimestampDelta = &v2
if chunkBodyLen != uint32(len(c2.Body)) {
if *rc.curBodyLen != uint32(len(c2.Body)) {
rc.curBody = &c2.Body
return nil, errMoreChunksNeeded
}
return &Message{
Timestamp: *rc.curTimestamp,
Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: *rc.curType,
MessageStreamID: *rc.curMessageStreamID,
Body: c2.Body,
@ -179,7 +180,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -179,7 +180,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
rc.curBody = nil
return &Message{
Timestamp: *rc.curTimestamp,
Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: *rc.curType,
MessageStreamID: *rc.curMessageStreamID,
Body: body,
@ -201,7 +202,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) { @@ -201,7 +202,7 @@ func (rc *readerChunkStream) readMessage(typ byte) (*Message, error) {
rc.curTimestamp = &v1
return &Message{
Timestamp: *rc.curTimestamp,
Timestamp: time.Duration(*rc.curTimestamp) * time.Millisecond,
Type: *rc.curType,
MessageStreamID: *rc.curMessageStreamID,
Body: c3.Body,

280
internal/rtmp/rawmessage/reader_test.go

@ -3,6 +3,7 @@ package rawmessage @@ -3,6 +3,7 @@ package rawmessage
import (
"bytes"
"testing"
"time"
"github.com/stretchr/testify/require"
@ -10,151 +11,174 @@ import ( @@ -10,151 +11,174 @@ import (
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
)
type sequenceEntry struct {
chunk chunk.Chunk
msg *Message
}
func TestReader(t *testing.T) {
testSequence := func(t *testing.T, seq []sequenceEntry) {
var buf bytes.Buffer
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
for _, entry := range seq {
buf2, err := entry.chunk.Marshal()
require.NoError(t, err)
buf.Write(buf2)
msg, err := r.Read()
require.NoError(t, err)
require.Equal(t, entry.msg, msg)
}
type sequenceEntry struct {
chunk chunk.Chunk
msg *Message
}
t.Run("chunk0 + chunk1", func(t *testing.T) {
testSequence(t, []sequenceEntry{
{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x02}, 64),
for _, ca := range []struct {
name string
sequence []sequenceEntry
}{
{
"chunk0 + chunk1",
[]sequenceEntry{
{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x02}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x02}, 64),
},
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x02}, 64),
{
&chunk.Chunk1{
ChunkStreamID: 27,
TimestampDelta: 15,
Type: chunk.MessageTypeSetPeerBandwidth,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
},
},
},
{
&chunk.Chunk1{
ChunkStreamID: 27,
TimestampDelta: 15,
Type: chunk.MessageTypeSetPeerBandwidth,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
},
{
"chunk0 + chunk2 + chunk3",
[]sequenceEntry{
{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x02}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x02}, 64),
},
},
},
})
})
t.Run("chunk0 + chunk2 + chunk3", func(t *testing.T) {
testSequence(t, []sequenceEntry{
{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x02}, 64),
{
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
},
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x02}, 64),
{
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x04}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: (18576 + 15 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
},
{
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x03}, 64),
},
{
"chunk0 + chunk3 + chunk2 + chunk3",
[]sequenceEntry{
{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 192,
Body: bytes.Repeat([]byte{0x03}, 128),
},
nil,
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
{
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 192),
},
},
},
{
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x04}, 64),
{
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x04}, 128),
},
nil,
},
&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15 + 15,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
{
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x04}, 64),
},
&Message{
ChunkStreamID: 27,
Timestamp: 18591 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 192),
},
},
},
})
})
t.Run("chunk0 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
buf2, err := chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 192,
Body: bytes.Repeat([]byte{0x03}, 128),
}.Marshal()
require.NoError(t, err)
buf.Write(buf2)
},
} {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
bcr := bytecounter.NewReader(&buf)
r := NewReader(bcr, func(count uint32) error {
return nil
})
buf2, err = chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x03}, 64),
}.Marshal()
require.NoError(t, err)
buf.Write(buf2)
for _, entry := range ca.sequence {
buf2, err := entry.chunk.Marshal()
require.NoError(t, err)
buf.Write(buf2)
msg, err := r.Read()
require.NoError(t, err)
require.Equal(t, &Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 192),
}, msg)
})
if entry.msg != nil {
msg, err := r.Read()
require.NoError(t, err)
require.Equal(t, entry.msg, msg)
}
}
})
}
}
func TestReaderAcknowledge(t *testing.T) {

42
internal/rtmp/rawmessage/writer.go

@ -2,6 +2,7 @@ package rawmessage @@ -2,6 +2,7 @@ package rawmessage
import (
"fmt"
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -12,14 +13,14 @@ type writerChunkStream struct { @@ -12,14 +13,14 @@ type writerChunkStream struct {
lastMessageStreamID *uint32
lastType *chunk.MessageType
lastBodyLen *uint32
lastTimestamp *uint32
lastTimestampDelta *uint32
lastTimestamp *time.Duration
lastTimestampDelta *time.Duration
}
func (wc *writerChunkStream) writeChunk(c chunk.Chunk) error {
// check if we received an acknowledge
if wc.mw.ackWindowSize != 0 {
diff := wc.mw.w.Count() - (wc.mw.ackValue)
if wc.mw.checkAcknowledge && wc.mw.ackWindowSize != 0 {
diff := wc.mw.w.Count() - wc.mw.ackValue
if diff > (wc.mw.ackWindowSize * 3 / 2) {
return fmt.Errorf("no acknowledge received within window")
@ -44,14 +45,13 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error { @@ -44,14 +45,13 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
pos := uint32(0)
firstChunk := true
var timestampDelta *uint32
var timestampDelta *time.Duration
if wc.lastTimestamp != nil {
diff := int64(msg.Timestamp) - int64(*wc.lastTimestamp)
diff := msg.Timestamp - *wc.lastTimestamp
// use delta only if it is positive
if diff >= 0 {
v := uint32(diff)
timestampDelta = &v
timestampDelta = &diff
}
}
@ -68,7 +68,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error { @@ -68,7 +68,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
case wc.lastMessageStreamID == nil || timestampDelta == nil || *wc.lastMessageStreamID != msg.MessageStreamID:
err := wc.writeChunk(&chunk.Chunk0{
ChunkStreamID: msg.ChunkStreamID,
Timestamp: msg.Timestamp,
Timestamp: uint32(msg.Timestamp / time.Millisecond),
Type: msg.Type,
MessageStreamID: msg.MessageStreamID,
BodyLen: (bodyLen),
@ -81,7 +81,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error { @@ -81,7 +81,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
case *wc.lastType != msg.Type || *wc.lastBodyLen != bodyLen:
err := wc.writeChunk(&chunk.Chunk1{
ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta,
TimestampDelta: uint32(*timestampDelta / time.Millisecond),
Type: msg.Type,
BodyLen: (bodyLen),
Body: msg.Body[pos : pos+chunkBodyLen],
@ -93,7 +93,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error { @@ -93,7 +93,7 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
case wc.lastTimestampDelta == nil || *wc.lastTimestampDelta != *timestampDelta:
err := wc.writeChunk(&chunk.Chunk2{
ChunkStreamID: msg.ChunkStreamID,
TimestampDelta: *timestampDelta,
TimestampDelta: uint32(*timestampDelta / time.Millisecond),
Body: msg.Body[pos : pos+chunkBodyLen],
})
if err != nil {
@ -143,19 +143,21 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error { @@ -143,19 +143,21 @@ func (wc *writerChunkStream) writeMessage(msg *Message) error {
// Writer is a raw message writer.
type Writer struct {
w *bytecounter.Writer
chunkSize uint32
ackWindowSize uint32
ackValue uint32
chunkStreams map[byte]*writerChunkStream
w *bytecounter.Writer
checkAcknowledge bool
chunkSize uint32
ackWindowSize uint32
ackValue uint32
chunkStreams map[byte]*writerChunkStream
}
// NewWriter allocates a Writer.
func NewWriter(w *bytecounter.Writer) *Writer {
func NewWriter(w *bytecounter.Writer, checkAcknowledge bool) *Writer {
return &Writer{
w: w,
chunkSize: 128,
chunkStreams: make(map[byte]*writerChunkStream),
w: w,
checkAcknowledge: checkAcknowledge,
chunkSize: 128,
chunkStreams: make(map[byte]*writerChunkStream),
}
}

304
internal/rtmp/rawmessage/writer_test.go

@ -2,7 +2,9 @@ package rawmessage @@ -2,7 +2,9 @@ package rawmessage
import (
"bytes"
"reflect"
"testing"
"time"
"github.com/aler9/rtsp-simple-server/internal/rtmp/bytecounter"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
@ -10,146 +12,168 @@ import ( @@ -10,146 +12,168 @@ import (
)
func TestWriter(t *testing.T) {
t.Run("chunk0 + chunk1", func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
})
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
}, c0)
err = w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15,
Type: chunk.MessageTypeSetWindowAckSize,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
})
require.NoError(t, err)
var c1 chunk.Chunk1
err = c1.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk1{
ChunkStreamID: 27,
TimestampDelta: 15,
Type: chunk.MessageTypeSetWindowAckSize,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x04}, 64),
}, c1)
})
t.Run("chunk0 + chunk2 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf))
err := w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
})
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
}, c0)
err = w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
})
require.NoError(t, err)
var c2 chunk.Chunk2
err = c2.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x04}, 64),
}, c2)
err = w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576 + 15 + 15,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x05}, 64),
})
require.NoError(t, err)
var c3 chunk.Chunk3
err = c3.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x05}, 64),
}, c3)
})
for _, ca := range []struct {
name string
messages []*Message
chunks []chunk.Chunk
chunkSizes []uint32
}{
{
"chunk0 + chunk1",
[]*Message{
{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
},
{
ChunkStreamID: 27,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetWindowAckSize,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
[]chunk.Chunk{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&chunk.Chunk1{
ChunkStreamID: 27,
TimestampDelta: 15,
Type: chunk.MessageTypeSetWindowAckSize,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
[]uint32{
128,
128,
},
},
{
"chunk0 + chunk2 + chunk3",
[]*Message{
{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 64),
},
{
ChunkStreamID: 27,
Timestamp: (18576 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 64),
},
{
ChunkStreamID: 27,
Timestamp: (18576 + 15 + 15) * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x05}, 64),
},
},
[]chunk.Chunk{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 64,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x04}, 64),
},
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x05}, 64),
},
},
[]uint32{
128,
64,
64,
},
},
{
"chunk0 + chunk3 + chunk2 + chunk3",
[]*Message{
{
ChunkStreamID: 27,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 192),
},
{
ChunkStreamID: 27,
Timestamp: 18591 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x04}, 192),
},
},
[]chunk.Chunk{
&chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 192,
Body: bytes.Repeat([]byte{0x03}, 128),
},
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x03}, 64),
},
&chunk.Chunk2{
ChunkStreamID: 27,
TimestampDelta: 15,
Body: bytes.Repeat([]byte{0x04}, 128),
},
&chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x04}, 64),
},
},
[]uint32{
128,
64,
128,
64,
},
},
} {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf), true)
t.Run("chunk0 + chunk3", func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(bytecounter.NewWriter(&buf))
for _, msg := range ca.messages {
err := w.Write(msg)
require.NoError(t, err)
}
err := w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 192),
for i, cach := range ca.chunks {
ch := reflect.New(reflect.TypeOf(cach).Elem()).Interface().(chunk.Chunk)
err := ch.Read(&buf, ca.chunkSizes[i])
require.NoError(t, err)
require.Equal(t, cach, ch)
}
})
require.NoError(t, err)
var c0 chunk.Chunk0
err = c0.Read(&buf, 128)
require.NoError(t, err)
require.Equal(t, chunk.Chunk0{
ChunkStreamID: 27,
Timestamp: 18576,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
BodyLen: 192,
Body: bytes.Repeat([]byte{0x03}, 128),
}, c0)
var c3 chunk.Chunk3
err = c3.Read(&buf, 64)
require.NoError(t, err)
require.Equal(t, chunk.Chunk3{
ChunkStreamID: 27,
Body: bytes.Repeat([]byte{0x03}, 64),
}, c3)
})
}
}
func TestWriterAcknowledge(t *testing.T) {
@ -157,7 +181,7 @@ func TestWriterAcknowledge(t *testing.T) { @@ -157,7 +181,7 @@ func TestWriterAcknowledge(t *testing.T) {
t.Run(ca, func(t *testing.T) {
var buf bytes.Buffer
bcw := bytecounter.NewWriter(&buf)
w := NewWriter(bcw)
w := NewWriter(bcw, true)
if ca == "overflow" {
bcw.SetCount(4294967096)
@ -169,7 +193,7 @@ func TestWriterAcknowledge(t *testing.T) { @@ -169,7 +193,7 @@ func TestWriterAcknowledge(t *testing.T) {
err := w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 200),
@ -178,7 +202,7 @@ func TestWriterAcknowledge(t *testing.T) { @@ -178,7 +202,7 @@ func TestWriterAcknowledge(t *testing.T) {
err = w.Write(&Message{
ChunkStreamID: 27,
Timestamp: 18576,
Timestamp: 18576 * time.Millisecond,
Type: chunk.MessageTypeSetPeerBandwidth,
MessageStreamID: 3123,
Body: bytes.Repeat([]byte{0x03}, 200),

Loading…
Cancel
Save