From 27ae0b981219721b542f6609aec17c770e8cef3f Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Mon, 22 Aug 2022 11:20:23 +0200 Subject: [PATCH] rtmp client: validate command ID of results --- internal/rtmp/conn.go | 14 +++++++------- internal/rtmp/conn_test.go | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/internal/rtmp/conn.go b/internal/rtmp/conn.go index e747e171..1b47a7c2 100644 --- a/internal/rtmp/conn.go +++ b/internal/rtmp/conn.go @@ -132,7 +132,7 @@ func (c *Conn) readCommand() (*message.MsgCommandAMF0, error) { } } -func (c *Conn) readCommandResult(commandName string, isValid func(*message.MsgCommandAMF0) bool) error { +func (c *Conn) readCommandResult(commandID int, commandName string, isValid func(*message.MsgCommandAMF0) bool) error { for { msg, err := c.mrw.Read() if err != nil { @@ -140,7 +140,7 @@ func (c *Conn) readCommandResult(commandName string, isValid func(*message.MsgCo } if cmd, ok := msg.(*message.MsgCommandAMF0); ok { - if cmd.Name == commandName { + if cmd.CommandID == commandID && cmd.Name == commandName { if !isValid(cmd) { return fmt.Errorf("server refused connect request") } @@ -203,7 +203,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return err } - err = c.readCommandResult("_result", resultIsOK1) + err = c.readCommandResult(1, "_result", resultIsOK1) if err != nil { return err } @@ -221,7 +221,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return err } - err = c.readCommandResult("_result", resultIsOK2) + err = c.readCommandResult(2, "_result", resultIsOK2) if err != nil { return err } @@ -247,7 +247,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return err } - return c.readCommandResult("onStatus", resultIsOK1) + return c.readCommandResult(3, "onStatus", resultIsOK1) } err = c.mrw.Write(&message.MsgCommandAMF0{ @@ -288,7 +288,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return err } - err = c.readCommandResult("_result", resultIsOK2) + err = c.readCommandResult(4, "_result", resultIsOK2) if err != nil { return err } @@ -308,7 +308,7 @@ func (c *Conn) InitializeClient(u *url.URL, isPublishing bool) error { return err } - return c.readCommandResult("onStatus", resultIsOK1) + return c.readCommandResult(5, "onStatus", resultIsOK1) } // InitializeServer performs the initialization of a server-side connection. diff --git a/internal/rtmp/conn_test.go b/internal/rtmp/conn_test.go index 27f02a28..f6823299 100644 --- a/internal/rtmp/conn_test.go +++ b/internal/rtmp/conn_test.go @@ -141,7 +141,7 @@ func TestInitializeClient(t *testing.T) { ChunkStreamID: 5, MessageStreamID: 0x1000000, Name: "onStatus", - CommandID: 4, + CommandID: 3, Arguments: []interface{}{ nil, flvio.AMFMap{