From b3eaec50c1a1b13bbc3f5e5a30ade5d4d25c20df Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Thu, 18 Jan 2024 23:23:51 +0100 Subject: [PATCH] srt: support standard streamID syntax (#2469) (#2919) --- internal/servers/srt/conn.go | 65 ++++++----------- internal/servers/srt/streamid.go | 100 ++++++++++++++++++++++++++ internal/servers/srt/streamid_test.go | 61 ++++++++++++++++ 3 files changed, 183 insertions(+), 43 deletions(-) create mode 100644 internal/servers/srt/streamid.go create mode 100644 internal/servers/srt/streamid_test.go diff --git a/internal/servers/srt/conn.go b/internal/servers/srt/conn.go index 9c4dd9b3..1719bf83 100644 --- a/internal/servers/srt/conn.go +++ b/internal/servers/srt/conn.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net" - "strings" "sync" "time" @@ -151,50 +150,30 @@ func (c *conn) runInner() error { } func (c *conn) runInner2(req srtNewConnReq) (bool, error) { - parts := strings.Split(req.connReq.StreamId(), ":") - if (len(parts) < 2 || len(parts) > 5) || (parts[0] != "read" && parts[0] != "publish") { - return false, fmt.Errorf("invalid streamid '%s':"+ - " it must be 'action:pathname[:query]' or 'action:pathname:user:pass[:query]', "+ - "where action is either read or publish, pathname is the path name, user and pass are the credentials, "+ - "query is an optional token containing additional information", - req.connReq.StreamId()) - } - - pathName := parts[1] - user := "" - pass := "" - query := "" - - if len(parts) == 4 || len(parts) == 5 { - user, pass = parts[2], parts[3] - } - - if len(parts) == 3 { - query = parts[2] - } - - if len(parts) == 5 { - query = parts[4] + var streamID streamID + err := streamID.unmarshal(req.connReq.StreamId()) + if err != nil { + return false, fmt.Errorf("invalid stream ID '%s': %w", req.connReq.StreamId(), err) } - if parts[0] == "publish" { - return c.runPublish(req, pathName, user, pass, query) + if streamID.mode == streamIDModePublish { + return c.runPublish(req, &streamID) } - return c.runRead(req, pathName, user, pass, query) + return c.runRead(req, &streamID) } -func (c *conn) runPublish(req srtNewConnReq, pathName string, user string, pass string, query string) (bool, error) { +func (c *conn) runPublish(req srtNewConnReq, streamID *streamID) (bool, error) { res := c.pathManager.AddPublisher(defs.PathAddPublisherReq{ Author: c, AccessRequest: defs.PathAccessRequest{ - Name: pathName, + Name: streamID.path, IP: c.ip(), Publish: true, - User: user, - Pass: pass, + User: streamID.user, + Pass: streamID.pass, Proto: defs.AuthProtocolSRT, ID: &c.uuid, - Query: query, + Query: streamID.query, }, }) @@ -222,8 +201,8 @@ func (c *conn) runPublish(req srtNewConnReq, pathName string, user string, pass c.mutex.Lock() c.state = connStatePublish - c.pathName = pathName - c.query = query + c.pathName = streamID.path + c.query = streamID.query c.sconn = sconn c.mutex.Unlock() @@ -283,17 +262,17 @@ func (c *conn) runPublishReader(sconn srt.Conn, path defs.Path) error { } } -func (c *conn) runRead(req srtNewConnReq, pathName string, user string, pass string, query string) (bool, error) { +func (c *conn) runRead(req srtNewConnReq, streamID *streamID) (bool, error) { res := c.pathManager.AddReader(defs.PathAddReaderReq{ Author: c, AccessRequest: defs.PathAccessRequest{ - Name: pathName, + Name: streamID.path, IP: c.ip(), - User: user, - Pass: pass, + User: streamID.user, + Pass: streamID.pass, Proto: defs.AuthProtocolSRT, ID: &c.uuid, - Query: query, + Query: streamID.query, }, }) @@ -322,8 +301,8 @@ func (c *conn) runRead(req srtNewConnReq, pathName string, user string, pass str c.mutex.Lock() c.state = connStateRead - c.pathName = pathName - c.query = query + c.pathName = streamID.path + c.query = streamID.query c.sconn = sconn c.mutex.Unlock() @@ -347,7 +326,7 @@ func (c *conn) runRead(req srtNewConnReq, pathName string, user string, pass str Conf: res.Path.SafeConf(), ExternalCmdEnv: res.Path.ExternalCmdEnv(), Reader: c.APIReaderDescribe(), - Query: query, + Query: streamID.query, }) defer onUnreadHook() diff --git a/internal/servers/srt/streamid.go b/internal/servers/srt/streamid.go new file mode 100644 index 00000000..fcf535e8 --- /dev/null +++ b/internal/servers/srt/streamid.go @@ -0,0 +1,100 @@ +package srt + +import ( + "fmt" + "strings" +) + +type streamIDMode int + +const ( + streamIDModeRead streamIDMode = iota + streamIDModePublish +) + +type streamID struct { + mode streamIDMode + path string + query string + user string + pass string +} + +func (s *streamID) unmarshal(raw string) error { + // standard syntax + // https://github.com/Haivision/srt/blob/master/docs/features/access-control.md + if strings.HasPrefix(raw, "#!::") { + for _, kv := range strings.Split(raw[len("#!::"):], ",") { + kv2 := strings.SplitN(kv, "=", 2) + if len(kv2) != 2 { + return fmt.Errorf("invalid value") + } + + key, value := kv2[0], kv2[1] + + switch key { + case "u": + s.user = value + + case "r": + s.path = value + + case "h": + + case "s": + s.pass = value + + case "t": + + case "m": + switch value { + case "request": + s.mode = streamIDModeRead + + case "publish": + s.mode = streamIDModePublish + + default: + return fmt.Errorf("unsupported mode '%s'", value) + } + + default: + return fmt.Errorf("unsupported key '%s'", key) + } + } + } else { + parts := strings.Split(raw, ":") + if len(parts) < 2 || len(parts) > 5 { + return fmt.Errorf("stream ID must be 'action:pathname[:query]' or 'action:pathname:user:pass[:query]', " + + "where action is either read or publish, pathname is the path name, user and pass are the credentials, " + + "query is an optional token containing additional information") + } + + switch parts[0] { + case "read": + s.mode = streamIDModeRead + + case "publish": + s.mode = streamIDModePublish + + default: + return fmt.Errorf("stream ID must be 'action:pathname[:query]' or 'action:pathname:user:pass[:query]', " + + "where action is either read or publish, pathname is the path name, user and pass are the credentials, " + + "query is an optional token containing additional information") + } + + s.path = parts[1] + + if len(parts) == 4 || len(parts) == 5 { + s.user, s.pass = parts[2], parts[3] + } + + if len(parts) == 3 { + s.query = parts[2] + } else if len(parts) == 5 { + s.query = parts[4] + } + } + + return nil +} diff --git a/internal/servers/srt/streamid_test.go b/internal/servers/srt/streamid_test.go new file mode 100644 index 00000000..30c3a4c6 --- /dev/null +++ b/internal/servers/srt/streamid_test.go @@ -0,0 +1,61 @@ +package srt + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStreamIDUnmarshal(t *testing.T) { + for _, ca := range []struct { + name string + raw string + dec streamID + }{ + { + "mediamtx syntax 1", + "read:mypath", + streamID{ + mode: streamIDModeRead, + path: "mypath", + }, + }, + { + "mediamtx syntax 2", + "publish:mypath:myquery", + streamID{ + mode: streamIDModePublish, + path: "mypath", + query: "myquery", + }, + }, + { + "mediamtx syntax 3", + "read:mypath:myuser:mypass:myquery", + streamID{ + mode: streamIDModeRead, + path: "mypath", + user: "myuser", + pass: "mypass", + query: "myquery", + }, + }, + { + "standard syntax", + "#!::u=johnny,t=file,m=publish,r=results.csv,s=mypass,h=myhost.com", + streamID{ + mode: streamIDModePublish, + path: "results.csv", + user: "johnny", + pass: "mypass", + }, + }, + } { + t.Run(ca.name, func(t *testing.T) { + var streamID streamID + err := streamID.unmarshal(ca.raw) + require.NoError(t, err) + require.Equal(t, ca.dec, streamID) + }) + } +}