Browse Source

improve tests

pull/1405/head
aler9 2 years ago
parent
commit
97c1e68c0b
  1. 1
      internal/rtmp/handshake/c1s1.go
  2. 27
      internal/rtmp/handshake/handshake.go
  3. 54
      internal/rtmp/handshake/handshake_test.go

1
internal/rtmp/handshake/c1s1.go

@ -120,6 +120,7 @@ func (c *C1S1) Write(w io.Writer, isC1 bool) error { @@ -120,6 +120,7 @@ func (c *C1S1) Write(w io.Writer, isC1 bool) error {
if c.Random == nil {
rand.Read(buf[8:])
c.Random = buf[8:]
} else {
copy(buf[8:], c.Random)
}

27
internal/rtmp/handshake/handshake.go

@ -7,7 +7,8 @@ import ( @@ -7,7 +7,8 @@ import (
// DoClient performs a client-side handshake.
func DoClient(rw io.ReadWriter, validateSignature bool) error {
err := C0S0{}.Write(rw)
c0 := C0S0{}
err := c0.Write(rw)
if err != nil {
return err
}
@ -18,7 +19,8 @@ func DoClient(rw io.ReadWriter, validateSignature bool) error { @@ -18,7 +19,8 @@ func DoClient(rw io.ReadWriter, validateSignature bool) error {
return err
}
err = C0S0{}.Read(rw)
s0 := C0S0{}
err = s0.Read(rw)
if err != nil {
return err
}
@ -29,18 +31,20 @@ func DoClient(rw io.ReadWriter, validateSignature bool) error { @@ -29,18 +31,20 @@ func DoClient(rw io.ReadWriter, validateSignature bool) error {
return err
}
err = (&C2S2{
s2 := C2S2{
Digest: c1.Digest,
}).Read(rw, validateSignature)
}
err = s2.Read(rw, validateSignature)
if err != nil {
return err
}
err = C2S2{
c2 := C2S2{
Time: s1.Time,
Random: s1.Random,
Digest: s1.Digest,
}.Write(rw)
}
err = c2.Write(rw)
if err != nil {
return err
}
@ -61,7 +65,8 @@ func DoServer(rw io.ReadWriter, validateSignature bool) error { @@ -61,7 +65,8 @@ func DoServer(rw io.ReadWriter, validateSignature bool) error {
return err
}
err = C0S0{}.Write(rw)
s0 := C0S0{}
err = s0.Write(rw)
if err != nil {
return err
}
@ -72,16 +77,18 @@ func DoServer(rw io.ReadWriter, validateSignature bool) error { @@ -72,16 +77,18 @@ func DoServer(rw io.ReadWriter, validateSignature bool) error {
return err
}
err = C2S2{
s2 := C2S2{
Time: c1.Time,
Random: c1.Random,
Digest: c1.Digest,
}.Write(rw)
}
err = s2.Write(rw)
if err != nil {
return err
}
err = (&C2S2{Digest: s1.Digest}).Read(rw, validateSignature)
c2 := C2S2{Digest: s1.Digest}
err = c2.Read(rw, validateSignature)
if err != nil {
return err
}

54
internal/rtmp/handshake/handshake_test.go

@ -1,6 +1,7 @@ @@ -1,6 +1,7 @@
package handshake
import (
"math/rand"
"net"
"testing"
@ -34,3 +35,56 @@ func TestHandshake(t *testing.T) { @@ -34,3 +35,56 @@ func TestHandshake(t *testing.T) {
<-done
}
// when C1 signature is invalid, S2 must be equal to C1.
func TestHandshakeFallback(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:9122")
require.NoError(t, err)
defer ln.Close()
done := make(chan struct{})
go func() {
conn, err := ln.Accept()
require.NoError(t, err)
defer conn.Close()
err = DoServer(conn, false)
require.NoError(t, err)
close(done)
}()
conn, err := net.Dial("tcp", "127.0.0.1:9122")
require.NoError(t, err)
defer conn.Close()
err = C0S0{}.Write(conn)
require.NoError(t, err)
c1 := make([]byte, 1536)
rand.Read(c1[8:])
_, err = conn.Write(c1)
require.NoError(t, err)
err = C0S0{}.Read(conn)
require.NoError(t, err)
s1 := C1S1{}
err = s1.Read(conn, false, false)
require.NoError(t, err)
s2 := C2S2{}
err = s2.Read(conn, false)
require.NoError(t, err)
require.Equal(t, c1[8:], s2.Random)
err = C2S2{
Time: s1.Time,
Random: s1.Random,
Digest: s1.Digest,
}.Write(conn)
require.NoError(t, err)
<-done
}

Loading…
Cancel
Save