Browse Source

rtmp: improve handshake

pull/1003/head
aler9 3 years ago
parent
commit
eb09c7c965
  1. 14
      internal/rtmp/conn_test.go
  2. 18
      internal/rtmp/handshake/c1s1.go
  3. 4
      internal/rtmp/handshake/c1s1_test.go
  4. 20
      internal/rtmp/handshake/c2s2.go
  5. 79
      internal/rtmp/handshake/c2s2_test.go

14
internal/rtmp/conn_test.go

@ -119,7 +119,8 @@ func TestReadTracks(t *testing.T) { @@ -119,7 +119,8 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// C->S handshake C1
err = handshake.C1S1{}.Write(conn, true)
c1 := handshake.C1S1{}
err = c1.Write(conn, true)
require.NoError(t, err)
// S->C handshake S0
@ -132,11 +133,11 @@ func TestReadTracks(t *testing.T) { @@ -132,11 +133,11 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err)
// S->C handshake S2
err = (&handshake.C2S2{}).Read(conn)
err = (&handshake.C2S2{Digest: c1.Digest}).Read(conn)
require.NoError(t, err)
// C->S handshake C2
err = handshake.C2S2{}.Write(conn, s1.Key)
err = handshake.C2S2{Digest: s1.Digest}.Write(conn)
require.NoError(t, err)
mw := message.NewWriter(conn)
@ -478,7 +479,8 @@ func TestWriteTracks(t *testing.T) { @@ -478,7 +479,8 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// C-> handshake C1
err = handshake.C1S1{}.Write(conn, true)
c1 := handshake.C1S1{}
err = c1.Write(conn, true)
require.NoError(t, err)
// S->C handshake S0
@ -491,11 +493,11 @@ func TestWriteTracks(t *testing.T) { @@ -491,11 +493,11 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err)
// S->C handshake S2
err = (&handshake.C2S2{}).Read(conn)
err = (&handshake.C2S2{Digest: c1.Digest}).Read(conn)
require.NoError(t, err)
// C->S handshake C2
err = handshake.C2S2{}.Write(conn, s1.Key)
err = handshake.C2S2{Digest: s1.Digest}.Write(conn)
require.NoError(t, err)
mw := message.NewWriter(conn)

18
internal/rtmp/handshake/c1s1.go

@ -60,23 +60,21 @@ func hsFindDigest(p []byte, key []byte, base int) int { @@ -60,23 +60,21 @@ func hsFindDigest(p []byte, key []byte, base int) int {
return gap
}
func hsParse1(p []byte, peerkey []byte, key []byte) (ok bool, digest []byte) {
func hsParse1(p []byte, peerkey []byte, key []byte) (bool, []byte) {
var pos int
if pos = hsFindDigest(p, peerkey, 772); pos == -1 {
if pos = hsFindDigest(p, peerkey, 8); pos == -1 {
return
return false, nil
}
}
ok = true
digest = hsMakeDigest(key, p[pos:pos+32], -1)
return
return true, hsMakeDigest(key, p[pos:pos+32], -1)
}
// C1S1 is a C1 or S1 packet.
type C1S1 struct {
Time uint32
Random []byte
Key []byte
Digest []byte
}
// Read reads a C1S1.
@ -97,20 +95,20 @@ func (c *C1S1) Read(r io.Reader, isC1 bool) error { @@ -97,20 +95,20 @@ func (c *C1S1) Read(r io.Reader, isC1 bool) error {
peerKey = hsServerPartialKey
key = hsClientFullKey
}
ok, key := hsParse1(buf, peerKey, key)
ok, digest := hsParse1(buf, peerKey, key)
if !ok {
return fmt.Errorf("unable to validate C1/S1 signature")
}
c.Time = binary.BigEndian.Uint32(buf)
c.Random = buf[8:]
c.Key = key
c.Digest = digest
return nil
}
// Write writes a C1S1.
func (c C1S1) Write(w io.Writer, isC1 bool) error {
func (c *C1S1) Write(w io.Writer, isC1 bool) error {
buf := make([]byte, 1536)
binary.BigEndian.PutUint32(buf, c.Time)
@ -132,6 +130,8 @@ func (c C1S1) Write(w io.Writer, isC1 bool) error { @@ -132,6 +130,8 @@ func (c C1S1) Write(w io.Writer, isC1 bool) error {
}
digest := hsMakeDigest(key, buf, gap)
copy(buf[gap:], digest)
pos := hsFindDigest(buf, hsClientPartialKey, 8)
c.Digest = hsMakeDigest(hsServerFullKey, buf[pos:pos+32], -1)
_, err := w.Write(buf)
return err

4
internal/rtmp/handshake/c1s1_test.go

@ -21,7 +21,7 @@ func TestC1S1Read(t *testing.T) { @@ -21,7 +21,7 @@ func TestC1S1Read(t *testing.T) {
},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 370)...,
),
Key: []byte{
Digest: []byte{
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
@ -52,7 +52,7 @@ func TestC1S1Write(t *testing.T) { @@ -52,7 +52,7 @@ func TestC1S1Write(t *testing.T) {
c1s1dec := C1S1{
Time: 435234723,
Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382),
Key: []byte{
Digest: []byte{
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,

20
internal/rtmp/handshake/c2s2.go

@ -1,8 +1,10 @@ @@ -1,8 +1,10 @@
package handshake
import (
"bytes"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
)
@ -11,6 +13,7 @@ type C2S2 struct { @@ -11,6 +13,7 @@ type C2S2 struct {
Time uint32
Time2 uint32
Random []byte
Digest []byte
}
// Read reads a C2S2.
@ -21,6 +24,13 @@ func (c *C2S2) Read(r io.Reader) error { @@ -21,6 +24,13 @@ func (c *C2S2) Read(r io.Reader) error {
return err
}
// validate signature
gap := len(buf) - 32
digest := hsMakeDigest(c.Digest, buf, gap)
if !bytes.Equal(buf[gap:gap+32], digest) {
return fmt.Errorf("unable to validate C2/S2 signature")
}
c.Time = binary.BigEndian.Uint32(buf)
c.Time2 = binary.BigEndian.Uint32(buf[4:])
c.Random = buf[8:]
@ -29,7 +39,7 @@ func (c *C2S2) Read(r io.Reader) error { @@ -29,7 +39,7 @@ func (c *C2S2) Read(r io.Reader) error {
}
// Write writes a C2S2.
func (c C2S2) Write(w io.Writer, key []byte) error {
func (c C2S2) Write(w io.Writer) error {
buf := make([]byte, 1536)
binary.BigEndian.PutUint32(buf, c.Time)
binary.BigEndian.PutUint32(buf[4:], c.Time2)
@ -41,9 +51,11 @@ func (c C2S2) Write(w io.Writer, key []byte) error { @@ -41,9 +51,11 @@ func (c C2S2) Write(w io.Writer, key []byte) error {
}
// signature
gap := len(buf) - 32
digest := hsMakeDigest(key, buf, gap)
copy(buf[gap:], digest)
if c.Digest != nil {
gap := len(buf) - 32
digest := hsMakeDigest(c.Digest, buf, gap)
copy(buf[gap:], digest)
}
_, err := w.Write(buf)
return err

79
internal/rtmp/handshake/c2s2_test.go

@ -0,0 +1,79 @@ @@ -0,0 +1,79 @@
package handshake
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
func TestC2S2Read(t *testing.T) {
c2s2dec := C2S2{
Time: 435234723,
Time2: 7893542,
Random: append(
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 372),
[]byte{
0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04,
0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2,
0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14,
0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe,
0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec,
}...),
Digest: []byte{
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
},
}
c2s2enc := append(append(
[]byte{
0x19, 0xf1, 0x27, 0xa3, 0x00, 0x78, 0x72, 0x26,
},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 374)...,
), []byte{
0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2,
0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14,
0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe,
0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec,
}...)
var c2s2 C2S2
c2s2.Digest = c2s2dec.Digest
err := c2s2.Read(bytes.NewReader(c2s2enc))
require.NoError(t, err)
require.Equal(t, c2s2dec, c2s2)
}
func TestC2S2Write(t *testing.T) {
c2s2dec := C2S2{
Time: 435234723,
Time2: 7893542,
Random: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 382),
Digest: []byte{
0x3f, 0xd0, 0xb1, 0xdf, 0xed, 0x6c, 0x9b, 0xc3,
0x73, 0x68, 0xe2, 0x47, 0x92, 0x59, 0x32, 0x9a,
0x3a, 0xc9, 0x1e, 0xeb, 0xfc, 0xad, 0x8e, 0x9d,
0x4e, 0xf4, 0x30, 0xb1, 0x9a, 0xc9, 0x15, 0x99,
},
}
c2s2enc := append(append(
[]byte{
0x19, 0xf1, 0x27, 0xa3, 0x00, 0x78, 0x72, 0x26,
},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 374)...,
), []byte{
0x96, 0x07, 0x2f, 0xe4, 0x04, 0xc5, 0x84, 0xa2,
0x21, 0x05, 0xcc, 0xb5, 0x7f, 0x93, 0x02, 0x14,
0xaf, 0xb0, 0x76, 0x75, 0xfd, 0x82, 0x29, 0xbe,
0xb9, 0x27, 0x9d, 0x4b, 0x0c, 0x81, 0x13, 0xec,
}...)
var buf bytes.Buffer
err := c2s2dec.Write(&buf)
require.NoError(t, err)
require.Equal(t, c2s2enc, buf.Bytes())
}
Loading…
Cancel
Save