Browse Source

add --read-ips and --publish-ips arguments; fix #12

pull/31/head v0.6.5
aler9 6 years ago
parent
commit
fbd9f74c8b
  1. 2
      README.md
  2. 62
      main.go
  3. 2
      main_test.go
  4. 86
      server-client.go

2
README.md

@ -91,8 +91,10 @@ Flags:
--write-timeout=5s timeout of write operations --write-timeout=5s timeout of write operations
--publish-user="" optional username required to publish --publish-user="" optional username required to publish
--publish-pass="" optional password required to publish --publish-pass="" optional password required to publish
--publish-ips="" comma-separated list of IPs or networks (x.x.x.x/24) that can publish
--read-user="" optional username required to read --read-user="" optional username required to read
--read-pass="" optional password required to read --read-pass="" optional password required to read
--read-ips="" comma-separated list of IPs or networks (x.x.x.x/24) that can read
--pre-script="" optional script to run on client connect --pre-script="" optional script to run on client connect
--post-script="" optional script to run on client disconnect --post-script="" optional script to run on client disconnect
``` ```

62
main.go

@ -3,6 +3,7 @@ package main
import ( import (
"fmt" "fmt"
"log" "log"
"net"
"os" "os"
"regexp" "regexp"
"strings" "strings"
@ -13,6 +14,30 @@ import (
var Version string = "v0.0.0" var Version string = "v0.0.0"
func parseIpCidrList(in string) ([]interface{}, error) {
if in == "" {
return nil, nil
}
var ret []interface{}
for _, t := range strings.Split(in, ",") {
_, ipnet, err := net.ParseCIDR(t)
if err == nil {
ret = append(ret, ipnet)
continue
}
ip := net.ParseIP(t)
if ip != nil {
ret = append(ret, ip)
continue
}
return nil, fmt.Errorf("unable to parse ip/network '%s'", t)
}
return ret, nil
}
type trackFlow int type trackFlow int
const ( const (
@ -49,18 +74,22 @@ type args struct {
writeTimeout time.Duration writeTimeout time.Duration
publishUser string publishUser string
publishPass string publishPass string
publishIps string
readUser string readUser string
readPass string readPass string
readIps string
preScript string preScript string
postScript string postScript string
} }
type program struct { type program struct {
args args args args
protocols map[streamProtocol]struct{} protocols map[streamProtocol]struct{}
tcpl *serverTcpListener publishIps []interface{}
udplRtp *serverUdpListener readIps []interface{}
udplRtcp *serverUdpListener tcpl *serverTcpListener
udplRtp *serverUdpListener
udplRtcp *serverUdpListener
} }
func newProgram(sargs []string) (*program, error) { func newProgram(sargs []string) (*program, error) {
@ -76,8 +105,10 @@ func newProgram(sargs []string) (*program, error) {
argWriteTimeout := kingpin.Flag("write-timeout", "timeout of write operations").Default("5s").Duration() argWriteTimeout := kingpin.Flag("write-timeout", "timeout of write operations").Default("5s").Duration()
argPublishUser := kingpin.Flag("publish-user", "optional username required to publish").Default("").String() argPublishUser := kingpin.Flag("publish-user", "optional username required to publish").Default("").String()
argPublishPass := kingpin.Flag("publish-pass", "optional password required to publish").Default("").String() argPublishPass := kingpin.Flag("publish-pass", "optional password required to publish").Default("").String()
argPublishIps := kingpin.Flag("publish-ips", "comma-separated list of IPs or networks (x.x.x.x/24) that can publish").Default("").String()
argReadUser := kingpin.Flag("read-user", "optional username required to read").Default("").String() argReadUser := kingpin.Flag("read-user", "optional username required to read").Default("").String()
argReadPass := kingpin.Flag("read-pass", "optional password required to read").Default("").String() argReadPass := kingpin.Flag("read-pass", "optional password required to read").Default("").String()
argReadIps := kingpin.Flag("read-ips", "comma-separated list of IPs or networks (x.x.x.x/24) that can read").Default("").String()
argPreScript := kingpin.Flag("pre-script", "optional script to run on client connect").Default("").String() argPreScript := kingpin.Flag("pre-script", "optional script to run on client connect").Default("").String()
argPostScript := kingpin.Flag("post-script", "optional script to run on client disconnect").Default("").String() argPostScript := kingpin.Flag("post-script", "optional script to run on client disconnect").Default("").String()
@ -93,8 +124,10 @@ func newProgram(sargs []string) (*program, error) {
writeTimeout: *argWriteTimeout, writeTimeout: *argWriteTimeout,
publishUser: *argPublishUser, publishUser: *argPublishUser,
publishPass: *argPublishPass, publishPass: *argPublishPass,
publishIps: *argPublishIps,
readUser: *argReadUser, readUser: *argReadUser,
readPass: *argReadPass, readPass: *argReadPass,
readIps: *argReadIps,
preScript: *argPreScript, preScript: *argPreScript,
postScript: *argPostScript, postScript: *argPostScript,
} }
@ -120,12 +153,14 @@ func newProgram(sargs []string) (*program, error) {
if len(protocols) == 0 { if len(protocols) == 0 {
return nil, fmt.Errorf("no protocols provided") return nil, fmt.Errorf("no protocols provided")
} }
if (args.rtpPort % 2) != 0 { if (args.rtpPort % 2) != 0 {
return nil, fmt.Errorf("rtp port must be even") return nil, fmt.Errorf("rtp port must be even")
} }
if args.rtcpPort != (args.rtpPort + 1) { if args.rtcpPort != (args.rtpPort + 1) {
return nil, fmt.Errorf("rtcp and rtp ports must be consecutive") return nil, fmt.Errorf("rtcp and rtp ports must be consecutive")
} }
if args.publishUser != "" { if args.publishUser != "" {
if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(args.publishUser) { if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(args.publishUser) {
return nil, fmt.Errorf("publish username must be alphanumeric") return nil, fmt.Errorf("publish username must be alphanumeric")
@ -136,6 +171,11 @@ func newProgram(sargs []string) (*program, error) {
return nil, fmt.Errorf("publish password must be alphanumeric") return nil, fmt.Errorf("publish password must be alphanumeric")
} }
} }
publishIps, err := parseIpCidrList(args.publishIps)
if err != nil {
return nil, err
}
if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" { if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" {
return nil, fmt.Errorf("read username and password must be both filled") return nil, fmt.Errorf("read username and password must be both filled")
} }
@ -152,16 +192,20 @@ func newProgram(sargs []string) (*program, error) {
if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" { if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" {
return nil, fmt.Errorf("read username and password must be both filled") return nil, fmt.Errorf("read username and password must be both filled")
} }
readIps, err := parseIpCidrList(args.readIps)
if err != nil {
return nil, err
}
log.Printf("rtsp-simple-server %s", Version) log.Printf("rtsp-simple-server %s", Version)
p := &program{ p := &program{
args: args, args: args,
protocols: protocols, protocols: protocols,
publishIps: publishIps,
readIps: readIps,
} }
var err error
p.udplRtp, err = newServerUdpListener(p, args.rtpPort, _TRACK_FLOW_RTP) p.udplRtp, err = newServerUdpListener(p, args.rtpPort, _TRACK_FLOW_RTP)
if err != nil { if err != nil {
return nil, err return nil, err

2
main_test.go

@ -142,6 +142,7 @@ func TestPublishAuth(t *testing.T) {
p, err := newProgram([]string{ p, err := newProgram([]string{
"--publish-user=testuser", "--publish-user=testuser",
"--publish-pass=testpass", "--publish-pass=testpass",
"--publish-ips=172.17.0.0/16",
}) })
require.NoError(t, err) require.NoError(t, err)
defer p.close() defer p.close()
@ -185,6 +186,7 @@ func TestReadAuth(t *testing.T) {
p, err := newProgram([]string{ p, err := newProgram([]string{
"--read-user=testuser", "--read-user=testuser",
"--read-pass=testpass", "--read-pass=testpass",
"--read-ips=172.17.0.0/16",
}) })
require.NoError(t, err) require.NoError(t, err)
defer p.close() defer p.close()

86
server-client.go

@ -202,36 +202,70 @@ func (c *serverClient) writeResError(req *gortsplib.Request, code gortsplib.Stat
var errAuthCritical = errors.New("auth critical") var errAuthCritical = errors.New("auth critical")
var errAuthNotCritical = errors.New("auth not critical") var errAuthNotCritical = errors.New("auth not critical")
func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass string, auth **gortsplib.AuthServer) error { func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass string, auth **gortsplib.AuthServer, ips []interface{}) error {
if user == "" { err := func() error {
return nil if ips == nil {
} return nil
}
initialRequest := false connIp := c.conn.NetConn().LocalAddr().(*net.TCPAddr).IP
if *auth == nil {
initialRequest = true
*auth = gortsplib.NewAuthServer(user, pass, nil)
}
err := (*auth).ValidateHeader(req.Header["Authorization"], req.Method, req.Url) for _, item := range ips {
if err != nil { switch titem := item.(type) {
if !initialRequest { case net.IP:
c.log("ERR: Unauthorized: %s", err) if titem.Equal(connIp) {
return nil
}
case *net.IPNet:
if titem.Contains(connIp) {
return nil
}
}
} }
c.conn.WriteResponse(&gortsplib.Response{ c.log("ERR: ip '%s' not allowed", connIp)
StatusCode: gortsplib.StatusUnauthorized, return errAuthCritical
Header: gortsplib.Header{ }()
"CSeq": []string{req.Header["CSeq"][0]}, if err != nil {
"WWW-Authenticate": (*auth).GenerateHeader(), return err
}, }
})
err = func() error {
if user == "" {
return nil
}
if !initialRequest { initialRequest := false
return errAuthCritical if *auth == nil {
initialRequest = true
*auth = gortsplib.NewAuthServer(user, pass, nil)
} }
return errAuthNotCritical err := (*auth).ValidateHeader(req.Header["Authorization"], req.Method, req.Url)
if err != nil {
if !initialRequest {
c.log("ERR: unauthorized: %s", err)
}
c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusUnauthorized,
Header: gortsplib.Header{
"CSeq": []string{req.Header["CSeq"][0]},
"WWW-Authenticate": (*auth).GenerateHeader(),
},
})
if !initialRequest {
return errAuthCritical
}
return errAuthNotCritical
}
return nil
}()
if err != nil {
return err
} }
return nil return nil
@ -291,7 +325,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth) err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth, c.p.readIps)
if err != nil { if err != nil {
if err == errAuthCritical { if err == errAuthCritical {
return false return false
@ -333,7 +367,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
err := c.validateAuth(req, c.p.args.publishUser, c.p.args.publishPass, &c.publishAuth) err := c.validateAuth(req, c.p.args.publishUser, c.p.args.publishPass, &c.publishAuth, c.p.publishIps)
if err != nil { if err != nil {
if err == errAuthCritical { if err == errAuthCritical {
return false return false
@ -405,7 +439,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
switch c.state { switch c.state {
// play // play
case _CLIENT_STATE_STARTING, _CLIENT_STATE_PRE_PLAY: case _CLIENT_STATE_STARTING, _CLIENT_STATE_PRE_PLAY:
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth) err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth, c.p.readIps)
if err != nil { if err != nil {
if err == errAuthCritical { if err == errAuthCritical {
return false return false

Loading…
Cancel
Save