diff --git a/internal/rtmp/conn_test.go b/internal/rtmp/conn_test.go index 0c1275f3..56f499c7 100644 --- a/internal/rtmp/conn_test.go +++ b/internal/rtmp/conn_test.go @@ -63,30 +63,7 @@ func TestClientHandshake(t *testing.T) { defer conn.Close() bc := bytecounter.NewReadWriter(conn) - // C->S handshake C0 - err = handshake.C0S0{}.Read(bc) - require.NoError(t, err) - - // C->S handshake C1 - c1 := handshake.C1S1{} - err = c1.Read(bc, true) - require.NoError(t, err) - - // S->C handshake S0 - err = handshake.C0S0{}.Write(bc) - require.NoError(t, err) - - // S->C handshake S1 - s1 := handshake.C1S1{} - err = s1.Write(bc, false) - require.NoError(t, err) - - // S->C handshake S2 - err = handshake.C2S2{Digest: c1.Digest}.Write(bc) - require.NoError(t, err) - - // C->S handshake C2 - err = (&handshake.C2S2{Digest: s1.Digest}).Read(bc) + err = handshake.DoServer(bc) require.NoError(t, err) mrw := message.NewReadWriter(bc) @@ -345,30 +322,7 @@ func TestServerHandshake(t *testing.T) { defer conn.Close() bc := bytecounter.NewReadWriter(conn) - // C->S handshake C0 - err = handshake.C0S0{}.Write(bc) - require.NoError(t, err) - - // C->S handshake C1 - c1 := handshake.C1S1{} - err = c1.Write(bc, true) - require.NoError(t, err) - - // S->C handshake S0 - err = handshake.C0S0{}.Read(bc) - require.NoError(t, err) - - // S->C handshake S1 - s1 := handshake.C1S1{} - err = s1.Read(bc, false) - require.NoError(t, err) - - // S->C handshake S2 - err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc) - require.NoError(t, err) - - // C->S handshake C2 - err = handshake.C2S2{Digest: s1.Digest}.Write(bc) + err = handshake.DoClient(bc) require.NoError(t, err) mrw := message.NewReadWriter(bc) @@ -656,30 +610,7 @@ func TestReadTracks(t *testing.T) { defer conn.Close() bc := bytecounter.NewReadWriter(conn) - // C->S handshake C0 - err = handshake.C0S0{}.Write(bc) - require.NoError(t, err) - - // C->S handshake C1 - c1 := handshake.C1S1{} - err = c1.Write(bc, true) - require.NoError(t, err) - - // S->C handshake S0 - err = handshake.C0S0{}.Read(bc) - require.NoError(t, err) - - // S->C handshake S1 - s1 := handshake.C1S1{} - err = s1.Read(bc, false) - require.NoError(t, err) - - // S->C handshake S2 - err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc) - require.NoError(t, err) - - // C->S handshake C2 - err = handshake.C2S2{Digest: s1.Digest}.Write(bc) + err = handshake.DoClient(bc) require.NoError(t, err) mrw := message.NewReadWriter(bc) @@ -1061,30 +992,7 @@ func TestWriteTracks(t *testing.T) { defer conn.Close() bc := bytecounter.NewReadWriter(conn) - // C->S handshake C0 - err = handshake.C0S0{}.Write(bc) - require.NoError(t, err) - - // C-> handshake C1 - c1 := handshake.C1S1{} - err = c1.Write(bc, true) - require.NoError(t, err) - - // S->C handshake S0 - err = handshake.C0S0{}.Read(bc) - require.NoError(t, err) - - // S->C handshake S1 - s1 := handshake.C1S1{} - err = s1.Read(bc, false) - require.NoError(t, err) - - // S->C handshake S2 - err = (&handshake.C2S2{Digest: c1.Digest}).Read(bc) - require.NoError(t, err) - - // C->S handshake C2 - err = handshake.C2S2{Digest: s1.Digest}.Write(bc) + err = handshake.DoClient(bc) require.NoError(t, err) mrw := message.NewReadWriter(bc) diff --git a/internal/rtmp/handshake/handshake.go b/internal/rtmp/handshake/handshake.go new file mode 100644 index 00000000..94a1637e --- /dev/null +++ b/internal/rtmp/handshake/handshake.go @@ -0,0 +1,79 @@ +package handshake + +import ( + "io" +) + +// DoClient performs a client-side handshake. +func DoClient(rw io.ReadWriter) error { + err := C0S0{}.Write(rw) + if err != nil { + return err + } + + c1 := C1S1{} + err = c1.Write(rw, true) + if err != nil { + return err + } + + err = C0S0{}.Read(rw) + if err != nil { + return err + } + + s1 := C1S1{} + err = s1.Read(rw, false) + if err != nil { + return err + } + + err = (&C2S2{Digest: c1.Digest}).Read(rw) + if err != nil { + return err + } + + err = C2S2{Digest: s1.Digest}.Write(rw) + if err != nil { + return err + } + + return nil +} + +// DoServer performs a server-side handshake. +func DoServer(rw io.ReadWriter) error { + err := C0S0{}.Read(rw) + if err != nil { + return err + } + + c1 := C1S1{} + err = c1.Read(rw, true) + if err != nil { + return err + } + + err = C0S0{}.Write(rw) + if err != nil { + return err + } + + s1 := C1S1{} + err = s1.Write(rw, false) + if err != nil { + return err + } + + err = C2S2{Digest: c1.Digest}.Write(rw) + if err != nil { + return err + } + + err = (&C2S2{Digest: s1.Digest}).Read(rw) + if err != nil { + return err + } + + return nil +} diff --git a/internal/rtmp/handshake/handshake_test.go b/internal/rtmp/handshake/handshake_test.go new file mode 100644 index 00000000..113aa824 --- /dev/null +++ b/internal/rtmp/handshake/handshake_test.go @@ -0,0 +1,36 @@ +package handshake + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHandshake(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:9122") + require.NoError(t, err) + defer ln.Close() + + done := make(chan struct{}) + + go func() { + conn, err := ln.Accept() + require.NoError(t, err) + defer conn.Close() + + err = DoServer(conn) + require.NoError(t, err) + + close(done) + }() + + conn, err := net.Dial("tcp", "127.0.0.1:9122") + require.NoError(t, err) + defer conn.Close() + + err = DoClient(conn) + require.NoError(t, err) + + <-done +}