From bf1f45df324b3aa6df3f004abf534b7f1022fe3e Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sat, 9 Jul 2022 16:19:49 +0200 Subject: [PATCH] rtmp: add conn handshake tests --- internal/core/rtmp_server_test.go | 4 +- internal/core/rtmp_source.go | 2 +- internal/rtmp/conn.go | 10 +- internal/rtmp/conn_test.go | 369 +++++++++++++++++++++++++++++- 4 files changed, 379 insertions(+), 6 deletions(-) diff --git a/internal/core/rtmp_server_test.go b/internal/core/rtmp_server_test.go index da56e0fd..433c90fd 100644 --- a/internal/core/rtmp_server_test.go +++ b/internal/core/rtmp_server_test.go @@ -139,7 +139,7 @@ func TestRTMPServerAuth(t *testing.T) { require.NoError(t, err) defer conn.Close() - err = conn.ClientHandshake() + err = conn.ClientHandshake(true) require.NoError(t, err) _, _, err = conn.ReadTracks() @@ -223,7 +223,7 @@ func TestRTMPServerAuthFail(t *testing.T) { require.NoError(t, err) defer conn.Close() - err = conn.ClientHandshake() + err = conn.ClientHandshake(true) require.Equal(t, err, io.EOF) }) } diff --git a/internal/core/rtmp_source.go b/internal/core/rtmp_source.go index 2e2b4197..5f540586 100644 --- a/internal/core/rtmp_source.go +++ b/internal/core/rtmp_source.go @@ -117,7 +117,7 @@ func (s *rtmpSource) runInner() bool { readDone <- func() error { conn.SetReadDeadline(time.Now().Add(time.Duration(s.readTimeout))) conn.SetWriteDeadline(time.Now().Add(time.Duration(s.writeTimeout))) - err = conn.ClientHandshake() + err = conn.ClientHandshake(true) if err != nil { return err } diff --git a/internal/rtmp/conn.go b/internal/rtmp/conn.go index a130af90..235fa838 100644 --- a/internal/rtmp/conn.go +++ b/internal/rtmp/conn.go @@ -35,8 +35,14 @@ func (c *Conn) Close() error { } // ClientHandshake performs the handshake of a client-side connection. -func (c *Conn) ClientHandshake() error { - return c.rconn.Prepare(rtmp.StageGotPublishOrPlayCommand, rtmp.PrepareReading) +func (c *Conn) ClientHandshake(isPlaying bool) error { + var flag int + if isPlaying { + flag = rtmp.PrepareReading + } else { + flag = rtmp.PrepareWriting + } + return c.rconn.Prepare(rtmp.StageGotPublishOrPlayCommand, flag) } // ServerHandshake performs the handshake of a server-side connection. diff --git a/internal/rtmp/conn_test.go b/internal/rtmp/conn_test.go index 3813b953..a86ac5a5 100644 --- a/internal/rtmp/conn_test.go +++ b/internal/rtmp/conn_test.go @@ -1,6 +1,7 @@ package rtmp import ( + "context" "net" "net/url" "strings" @@ -48,6 +49,371 @@ func getTcURL(u string) string { return nu.String() + app } +func TestClientHandshake(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:9121") + require.NoError(t, err) + defer ln.Close() + + done := make(chan struct{}) + + go func() { + conn, err := ln.Accept() + require.NoError(t, err) + defer conn.Close() + bc := bytecounter.NewReadWriter(conn) + + // C->S handshake C0 + err = handshake.C0S0{}.Read(bc) + require.NoError(t, err) + + // C->S handshake C1 + c1 := handshake.C1S1{} + err = c1.Read(bc, true) + require.NoError(t, err) + + // S->C handshake S0 + err = handshake.C0S0{}.Write(bc) + require.NoError(t, err) + + // S->C handshake S1 + s1 := handshake.C1S1{} + err = s1.Write(bc, false) + require.NoError(t, err) + + // S->C handshake S2 + err = handshake.C2S2{Digest: c1.Digest}.Write(bc) + require.NoError(t, err) + + // C->S handshake C2 + err = (&handshake.C2S2{Digest: s1.Digest}).Read(bc) + require.NoError(t, err) + + mrw := message.NewReadWriter(bc) + + // C->S set window ack size + msg, err := mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgSetWindowAckSize{ + Value: 2500000, + }, msg) + + // C->S set peer bandwidth + msg, err = mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgSetPeerBandwidth{ + Value: 0x2625a0, + Type: 2, + }, msg) + + // C->S set chunk size + msg, err = mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgSetChunkSize{ + Value: 65536, + }, msg) + + // C->S connect + msg, err = mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgCommandAMF0{ + ChunkStreamID: 3, + Payload: []interface{}{ + "connect", + float64(1), + flvio.AMFMap{ + {K: "app", V: "stream"}, + {K: "flashVer", V: "LNX 9,0,124,2"}, + {K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")}, + {K: "fpad", V: false}, + {K: "capabilities", V: float64(15)}, + {K: "audioCodecs", V: float64(4071)}, + {K: "videoCodecs", V: float64(252)}, + {K: "videoFunction", V: float64(1)}, + }, + }, + }, msg) + + // S->C result + err = mrw.Write(&message.MsgCommandAMF0{ + ChunkStreamID: 3, + Payload: []interface{}{ + "_result", + float64(1), + flvio.AMFMap{ + {K: "fmsVer", V: "LNX 9,0,124,2"}, + {K: "capabilities", V: float64(31)}, + }, + flvio.AMFMap{ + {K: "level", V: "status"}, + {K: "code", V: "NetConnection.Connect.Success"}, + {K: "description", V: "Connection succeeded."}, + {K: "objectEncoding", V: float64(0)}, + }, + }, + }) + require.NoError(t, err) + + // C->S create stream + msg, err = mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgCommandAMF0{ + ChunkStreamID: 3, + Payload: []interface{}{ + "createStream", + float64(2), + nil, + }, + }, msg) + + // S->C result + err = mrw.Write(&message.MsgCommandAMF0{ + ChunkStreamID: 3, + Payload: []interface{}{ + "_result", + float64(2), + nil, + float64(1), + }, + }) + require.NoError(t, err) + + // C->S user control set buffer length + msg, err = mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgUserControlSetBufferLength{ + BufferLength: 0x64, + }, msg) + + // C->S play + msg, err = mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgCommandAMF0{ + ChunkStreamID: 4, + MessageStreamID: 16777216, + Payload: []interface{}{ + "play", + float64(0), + nil, + "", + }, + }, msg) + + // S->C onStatus + err = mrw.Write(&message.MsgCommandAMF0{ + ChunkStreamID: 5, + MessageStreamID: 16777216, + Payload: []interface{}{ + "onStatus", + float64(4), + nil, + flvio.AMFMap{ + {K: "level", V: "status"}, + {K: "code", V: "NetStream.Play.Reset"}, + {K: "description", V: "play reset"}, + }, + }, + }) + require.NoError(t, err) + + close(done) + }() + + conn, err := DialContext(context.Background(), "rtmp://127.0.0.1:9121/stream") + require.NoError(t, err) + defer conn.Close() + + err = conn.ClientHandshake(true) + require.NoError(t, err) + + <-done +} + +func TestServerHandshake(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:9121") + require.NoError(t, err) + defer ln.Close() + + done := make(chan struct{}) + + go func() { + nconn, err := ln.Accept() + require.NoError(t, err) + defer nconn.Close() + + conn := NewServerConn(nconn) + err = conn.ServerHandshake() + require.NoError(t, err) + + close(done) + }() + + conn, err := net.Dial("tcp", "127.0.0.1:9121") + require.NoError(t, err) + defer conn.Close() + bc := bytecounter.NewReadWriter(conn) + + // C->S handshake C0 + err = handshake.C0S0{}.Write(bc) + require.NoError(t, err) + + // C->S handshake C1 + c1 := handshake.C1S1{} + err = c1.Write(bc, true) + require.NoError(t, err) + + // S->C handshake S0 + err = handshake.C0S0{}.Read(bc) + require.NoError(t, err) + + // S->C handshake S1 + s1 := handshake.C1S1{} + err = s1.Read(bc, false) + require.NoError(t, err) + + // S->C handshake S2 + err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc) + require.NoError(t, err) + + // C->S handshake C2 + err = handshake.C2S2{Digest: s1.Digest}.Write(bc) + require.NoError(t, err) + + mrw := message.NewReadWriter(bc) + + // C->S connect + err = mrw.Write(&message.MsgCommandAMF0{ + ChunkStreamID: 3, + Payload: []interface{}{ + "connect", + 1, + flvio.AMFMap{ + {K: "app", V: "/stream"}, + {K: "flashVer", V: "LNX 9,0,124,2"}, + {K: "tcUrl", V: getTcURL("rtmp://127.0.0.1:9121/stream")}, + {K: "fpad", V: false}, + {K: "capabilities", V: 15}, + {K: "audioCodecs", V: 4071}, + {K: "videoCodecs", V: 252}, + {K: "videoFunction", V: 1}, + }, + }, + }) + require.NoError(t, err) + + // S->C window acknowledgement size + msg, err := mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgSetWindowAckSize{ + Value: 2500000, + }, msg) + + // S->C set peer bandwidth + msg, err = mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgSetPeerBandwidth{ + Value: 2500000, + Type: 2, + }, msg) + + // S->C set chunk size + msg, err = mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgSetChunkSize{ + Value: 65536, + }, msg) + + // S->C result + msg, err = mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgCommandAMF0{ + ChunkStreamID: 3, + Payload: []interface{}{ + "_result", + float64(1), + flvio.AMFMap{ + {K: "fmsVer", V: "LNX 9,0,124,2"}, + {K: "capabilities", V: float64(31)}, + }, + flvio.AMFMap{ + {K: "level", V: "status"}, + {K: "code", V: "NetConnection.Connect.Success"}, + {K: "description", V: "Connection succeeded."}, + {K: "objectEncoding", V: float64(0)}, + }, + }, + }, msg) + + // C->S set chunk size + err = mrw.Write(&message.MsgSetChunkSize{ + Value: 65536, + }) + require.NoError(t, err) + + // C->S releaseStream + err = mrw.Write(&message.MsgCommandAMF0{ + ChunkStreamID: 3, + Payload: []interface{}{ + "releaseStream", + float64(2), + nil, + "", + }, + }) + require.NoError(t, err) + + // C->S FCPublish + err = mrw.Write(&message.MsgCommandAMF0{ + ChunkStreamID: 3, + Payload: []interface{}{ + "FCPublish", + float64(3), + nil, + "", + }, + }) + require.NoError(t, err) + + // C->S createStream + err = mrw.Write(&message.MsgCommandAMF0{ + ChunkStreamID: 3, + Payload: []interface{}{ + "createStream", + float64(4), + nil, + }, + }) + require.NoError(t, err) + + // S->C result + msg, err = mrw.Read() + require.NoError(t, err) + require.Equal(t, &message.MsgCommandAMF0{ + ChunkStreamID: 3, + Payload: []interface{}{ + "_result", + float64(4), + nil, + float64(1), + }, + }, msg) + + // C->S publish + err = mrw.Write(&message.MsgCommandAMF0{ + ChunkStreamID: 8, + MessageStreamID: 1, + Payload: []interface{}{ + "publish", + float64(5), + nil, + "", + "live", + }, + }) + require.NoError(t, err) + + <-done +} + func TestReadTracks(t *testing.T) { sps := []byte{ 0x67, 0x64, 0x00, 0x0c, 0xac, 0x3b, 0x50, 0xb0, @@ -696,7 +1062,8 @@ func TestWriteTracks(t *testing.T) { // C->S play err = mrw.Write(&message.MsgCommandAMF0{ - ChunkStreamID: 8, + ChunkStreamID: 8, + MessageStreamID: 16777216, Payload: []interface{}{ "play", float64(4),