Browse Source

rtmp: use bufio reader during handshake

pull/1003/head
aler9 3 years ago
parent
commit
cd19332350
  1. 19
      internal/rtmp/conn_test.go
  2. 3
      internal/rtmp/handshake/c0s0.go
  3. 3
      internal/rtmp/handshake/c0s0_test.go
  4. 3
      internal/rtmp/handshake/c1s1.go
  5. 3
      internal/rtmp/handshake/c1s1_test.go
  6. 3
      internal/rtmp/handshake/c2s2.go
  7. 3
      internal/rtmp/handshake/c2s2_test.go
  8. 4
      internal/rtmp/message/reader.go
  9. 5
      internal/rtmp/rawmessage/reader.go

19
internal/rtmp/conn_test.go

@ -1,6 +1,7 @@
package rtmp package rtmp
import ( import (
"bufio"
"net" "net"
"net/url" "net/url"
"strings" "strings"
@ -113,6 +114,7 @@ func TestReadTracks(t *testing.T) {
conn, err := net.Dial("tcp", "127.0.0.1:9121") conn, err := net.Dial("tcp", "127.0.0.1:9121")
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn)
// C->S handshake C0 // C->S handshake C0
err = handshake.C0S0{}.Write(conn) err = handshake.C0S0{}.Write(conn)
@ -124,16 +126,16 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// S->C handshake S0 // S->C handshake S0
err = handshake.C0S0{}.Read(conn) err = handshake.C0S0{}.Read(br)
require.NoError(t, err) require.NoError(t, err)
// S->C handshake S1 // S->C handshake S1
s1 := handshake.C1S1{} s1 := handshake.C1S1{}
err = s1.Read(conn, false) err = s1.Read(br, false)
require.NoError(t, err) require.NoError(t, err)
// S->C handshake S2 // S->C handshake S2
err = (&handshake.C2S2{Digest: c1.Digest}).Read(conn) err = (&handshake.C2S2{Digest: c1.Digest}).Read(br)
require.NoError(t, err) require.NoError(t, err)
// C->S handshake C2 // C->S handshake C2
@ -141,7 +143,7 @@ func TestReadTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
mw := message.NewWriter(conn) mw := message.NewWriter(conn)
mr := message.NewReader(conn) mr := message.NewReader(br)
// C->S connect // C->S connect
err = mw.Write(&message.MsgCommandAMF0{ err = mw.Write(&message.MsgCommandAMF0{
@ -473,6 +475,7 @@ func TestWriteTracks(t *testing.T) {
conn, err := net.Dial("tcp", "127.0.0.1:9121") conn, err := net.Dial("tcp", "127.0.0.1:9121")
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
br := bufio.NewReader(conn)
// C->S handshake C0 // C->S handshake C0
err = handshake.C0S0{}.Write(conn) err = handshake.C0S0{}.Write(conn)
@ -484,16 +487,16 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// S->C handshake S0 // S->C handshake S0
err = handshake.C0S0{}.Read(conn) err = handshake.C0S0{}.Read(br)
require.NoError(t, err) require.NoError(t, err)
// S->C handshake S1 // S->C handshake S1
s1 := handshake.C1S1{} s1 := handshake.C1S1{}
err = s1.Read(conn, false) err = s1.Read(br, false)
require.NoError(t, err) require.NoError(t, err)
// S->C handshake S2 // S->C handshake S2
err = (&handshake.C2S2{Digest: c1.Digest}).Read(conn) err = (&handshake.C2S2{Digest: c1.Digest}).Read(br)
require.NoError(t, err) require.NoError(t, err)
// C->S handshake C2 // C->S handshake C2
@ -501,7 +504,7 @@ func TestWriteTracks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
mw := message.NewWriter(conn) mw := message.NewWriter(conn)
mr := message.NewReader(conn) mr := message.NewReader(br)
// C->S connect // C->S connect
err = mw.Write(&message.MsgCommandAMF0{ err = mw.Write(&message.MsgCommandAMF0{

3
internal/rtmp/handshake/c0s0.go

@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"bufio"
"fmt" "fmt"
"io" "io"
) )
@ -13,7 +14,7 @@ const (
type C0S0 struct{} type C0S0 struct{}
// Read reads a C0S0. // Read reads a C0S0.
func (C0S0) Read(r io.Reader) error { func (C0S0) Read(r *bufio.Reader) error {
buf := make([]byte, 1) buf := make([]byte, 1)
_, err := io.ReadFull(r, buf) _, err := io.ReadFull(r, buf)
if err != nil { if err != nil {

3
internal/rtmp/handshake/c0s0_test.go

@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"bufio"
"bytes" "bytes"
"testing" "testing"
@ -13,7 +14,7 @@ var c0s0dec = C0S0{}
func TestC0S0Read(t *testing.T) { func TestC0S0Read(t *testing.T) {
var c0s0 C0S0 var c0s0 C0S0
err := c0s0.Read(bytes.NewReader(c0s0enc)) err := c0s0.Read(bufio.NewReader(bytes.NewReader(c0s0enc)))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c0s0dec, c0s0) require.Equal(t, c0s0dec, c0s0)
} }

3
internal/rtmp/handshake/c1s1.go

@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"bufio"
"bytes" "bytes"
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
@ -78,7 +79,7 @@ type C1S1 struct {
} }
// Read reads a C1S1. // Read reads a C1S1.
func (c *C1S1) Read(r io.Reader, isC1 bool) error { func (c *C1S1) Read(r *bufio.Reader, isC1 bool) error {
buf := make([]byte, 1536) buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf) _, err := io.ReadFull(r, buf)
if err != nil { if err != nil {

3
internal/rtmp/handshake/c1s1_test.go

@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"bufio"
"bytes" "bytes"
"testing" "testing"
@ -43,7 +44,7 @@ func TestC1S1Read(t *testing.T) {
) )
var c1s1 C1S1 var c1s1 C1S1
err := c1s1.Read(bytes.NewReader(c1s1enc), true) err := c1s1.Read(bufio.NewReader(bytes.NewReader(c1s1enc)), true)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c1s1dec, c1s1) require.Equal(t, c1s1dec, c1s1)
} }

3
internal/rtmp/handshake/c2s2.go

@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"bufio"
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
@ -17,7 +18,7 @@ type C2S2 struct {
} }
// Read reads a C2S2. // Read reads a C2S2.
func (c *C2S2) Read(r io.Reader) error { func (c *C2S2) Read(r *bufio.Reader) error {
buf := make([]byte, 1536) buf := make([]byte, 1536)
_, err := io.ReadFull(r, buf) _, err := io.ReadFull(r, buf)
if err != nil { if err != nil {

3
internal/rtmp/handshake/c2s2_test.go

@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"bufio"
"bytes" "bytes"
"testing" "testing"
@ -42,7 +43,7 @@ func TestC2S2Read(t *testing.T) {
var c2s2 C2S2 var c2s2 C2S2
c2s2.Digest = c2s2dec.Digest c2s2.Digest = c2s2dec.Digest
err := c2s2.Read(bytes.NewReader(c2s2enc)) err := c2s2.Read(bufio.NewReader(bytes.NewReader(c2s2enc)))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c2s2dec, c2s2) require.Equal(t, c2s2dec, c2s2)
} }

4
internal/rtmp/message/reader.go

@ -1,9 +1,9 @@
package message package message
import ( import (
"bufio"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
"github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage" "github.com/aler9/rtsp-simple-server/internal/rtmp/rawmessage"
@ -75,7 +75,7 @@ type Reader struct {
} }
// NewReader allocates a Reader. // NewReader allocates a Reader.
func NewReader(r io.Reader) *Reader { func NewReader(r *bufio.Reader) *Reader {
return &Reader{ return &Reader{
r: rawmessage.NewReader(r), r: rawmessage.NewReader(r),
} }

5
internal/rtmp/rawmessage/reader.go

@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"errors" "errors"
"fmt" "fmt"
"io"
"github.com/aler9/rtsp-simple-server/internal/rtmp/chunk" "github.com/aler9/rtsp-simple-server/internal/rtmp/chunk"
) )
@ -169,9 +168,9 @@ type Reader struct {
} }
// NewReader allocates a Reader. // NewReader allocates a Reader.
func NewReader(r io.Reader) *Reader { func NewReader(r *bufio.Reader) *Reader {
return &Reader{ return &Reader{
r: bufio.NewReader(r), r: r,
chunkSize: 128, chunkSize: 128,
chunkStreams: make(map[byte]*readerChunkStream), chunkStreams: make(map[byte]*readerChunkStream),
} }

Loading…
Cancel
Save