Browse Source

add --read-ips and --publish-ips arguments; fix #12

pull/31/head v0.6.5
aler9 6 years ago
parent
commit
fbd9f74c8b
  1. 2
      README.md
  2. 48
      main.go
  3. 2
      main_test.go
  4. 44
      server-client.go

2
README.md

@ -91,8 +91,10 @@ Flags:
--write-timeout=5s timeout of write operations --write-timeout=5s timeout of write operations
--publish-user="" optional username required to publish --publish-user="" optional username required to publish
--publish-pass="" optional password required to publish --publish-pass="" optional password required to publish
--publish-ips="" comma-separated list of IPs or networks (x.x.x.x/24) that can publish
--read-user="" optional username required to read --read-user="" optional username required to read
--read-pass="" optional password required to read --read-pass="" optional password required to read
--read-ips="" comma-separated list of IPs or networks (x.x.x.x/24) that can read
--pre-script="" optional script to run on client connect --pre-script="" optional script to run on client connect
--post-script="" optional script to run on client disconnect --post-script="" optional script to run on client disconnect
``` ```

48
main.go

@ -3,6 +3,7 @@ package main
import ( import (
"fmt" "fmt"
"log" "log"
"net"
"os" "os"
"regexp" "regexp"
"strings" "strings"
@ -13,6 +14,30 @@ import (
var Version string = "v0.0.0" var Version string = "v0.0.0"
func parseIpCidrList(in string) ([]interface{}, error) {
if in == "" {
return nil, nil
}
var ret []interface{}
for _, t := range strings.Split(in, ",") {
_, ipnet, err := net.ParseCIDR(t)
if err == nil {
ret = append(ret, ipnet)
continue
}
ip := net.ParseIP(t)
if ip != nil {
ret = append(ret, ip)
continue
}
return nil, fmt.Errorf("unable to parse ip/network '%s'", t)
}
return ret, nil
}
type trackFlow int type trackFlow int
const ( const (
@ -49,8 +74,10 @@ type args struct {
writeTimeout time.Duration writeTimeout time.Duration
publishUser string publishUser string
publishPass string publishPass string
publishIps string
readUser string readUser string
readPass string readPass string
readIps string
preScript string preScript string
postScript string postScript string
} }
@ -58,6 +85,8 @@ type args struct {
type program struct { type program struct {
args args args args
protocols map[streamProtocol]struct{} protocols map[streamProtocol]struct{}
publishIps []interface{}
readIps []interface{}
tcpl *serverTcpListener tcpl *serverTcpListener
udplRtp *serverUdpListener udplRtp *serverUdpListener
udplRtcp *serverUdpListener udplRtcp *serverUdpListener
@ -76,8 +105,10 @@ func newProgram(sargs []string) (*program, error) {
argWriteTimeout := kingpin.Flag("write-timeout", "timeout of write operations").Default("5s").Duration() argWriteTimeout := kingpin.Flag("write-timeout", "timeout of write operations").Default("5s").Duration()
argPublishUser := kingpin.Flag("publish-user", "optional username required to publish").Default("").String() argPublishUser := kingpin.Flag("publish-user", "optional username required to publish").Default("").String()
argPublishPass := kingpin.Flag("publish-pass", "optional password required to publish").Default("").String() argPublishPass := kingpin.Flag("publish-pass", "optional password required to publish").Default("").String()
argPublishIps := kingpin.Flag("publish-ips", "comma-separated list of IPs or networks (x.x.x.x/24) that can publish").Default("").String()
argReadUser := kingpin.Flag("read-user", "optional username required to read").Default("").String() argReadUser := kingpin.Flag("read-user", "optional username required to read").Default("").String()
argReadPass := kingpin.Flag("read-pass", "optional password required to read").Default("").String() argReadPass := kingpin.Flag("read-pass", "optional password required to read").Default("").String()
argReadIps := kingpin.Flag("read-ips", "comma-separated list of IPs or networks (x.x.x.x/24) that can read").Default("").String()
argPreScript := kingpin.Flag("pre-script", "optional script to run on client connect").Default("").String() argPreScript := kingpin.Flag("pre-script", "optional script to run on client connect").Default("").String()
argPostScript := kingpin.Flag("post-script", "optional script to run on client disconnect").Default("").String() argPostScript := kingpin.Flag("post-script", "optional script to run on client disconnect").Default("").String()
@ -93,8 +124,10 @@ func newProgram(sargs []string) (*program, error) {
writeTimeout: *argWriteTimeout, writeTimeout: *argWriteTimeout,
publishUser: *argPublishUser, publishUser: *argPublishUser,
publishPass: *argPublishPass, publishPass: *argPublishPass,
publishIps: *argPublishIps,
readUser: *argReadUser, readUser: *argReadUser,
readPass: *argReadPass, readPass: *argReadPass,
readIps: *argReadIps,
preScript: *argPreScript, preScript: *argPreScript,
postScript: *argPostScript, postScript: *argPostScript,
} }
@ -120,12 +153,14 @@ func newProgram(sargs []string) (*program, error) {
if len(protocols) == 0 { if len(protocols) == 0 {
return nil, fmt.Errorf("no protocols provided") return nil, fmt.Errorf("no protocols provided")
} }
if (args.rtpPort % 2) != 0 { if (args.rtpPort % 2) != 0 {
return nil, fmt.Errorf("rtp port must be even") return nil, fmt.Errorf("rtp port must be even")
} }
if args.rtcpPort != (args.rtpPort + 1) { if args.rtcpPort != (args.rtpPort + 1) {
return nil, fmt.Errorf("rtcp and rtp ports must be consecutive") return nil, fmt.Errorf("rtcp and rtp ports must be consecutive")
} }
if args.publishUser != "" { if args.publishUser != "" {
if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(args.publishUser) { if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(args.publishUser) {
return nil, fmt.Errorf("publish username must be alphanumeric") return nil, fmt.Errorf("publish username must be alphanumeric")
@ -136,6 +171,11 @@ func newProgram(sargs []string) (*program, error) {
return nil, fmt.Errorf("publish password must be alphanumeric") return nil, fmt.Errorf("publish password must be alphanumeric")
} }
} }
publishIps, err := parseIpCidrList(args.publishIps)
if err != nil {
return nil, err
}
if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" { if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" {
return nil, fmt.Errorf("read username and password must be both filled") return nil, fmt.Errorf("read username and password must be both filled")
} }
@ -152,16 +192,20 @@ func newProgram(sargs []string) (*program, error) {
if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" { if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" {
return nil, fmt.Errorf("read username and password must be both filled") return nil, fmt.Errorf("read username and password must be both filled")
} }
readIps, err := parseIpCidrList(args.readIps)
if err != nil {
return nil, err
}
log.Printf("rtsp-simple-server %s", Version) log.Printf("rtsp-simple-server %s", Version)
p := &program{ p := &program{
args: args, args: args,
protocols: protocols, protocols: protocols,
publishIps: publishIps,
readIps: readIps,
} }
var err error
p.udplRtp, err = newServerUdpListener(p, args.rtpPort, _TRACK_FLOW_RTP) p.udplRtp, err = newServerUdpListener(p, args.rtpPort, _TRACK_FLOW_RTP)
if err != nil { if err != nil {
return nil, err return nil, err

2
main_test.go

@ -142,6 +142,7 @@ func TestPublishAuth(t *testing.T) {
p, err := newProgram([]string{ p, err := newProgram([]string{
"--publish-user=testuser", "--publish-user=testuser",
"--publish-pass=testpass", "--publish-pass=testpass",
"--publish-ips=172.17.0.0/16",
}) })
require.NoError(t, err) require.NoError(t, err)
defer p.close() defer p.close()
@ -185,6 +186,7 @@ func TestReadAuth(t *testing.T) {
p, err := newProgram([]string{ p, err := newProgram([]string{
"--read-user=testuser", "--read-user=testuser",
"--read-pass=testpass", "--read-pass=testpass",
"--read-ips=172.17.0.0/16",
}) })
require.NoError(t, err) require.NoError(t, err)
defer p.close() defer p.close()

44
server-client.go

@ -202,7 +202,36 @@ func (c *serverClient) writeResError(req *gortsplib.Request, code gortsplib.Stat
var errAuthCritical = errors.New("auth critical") var errAuthCritical = errors.New("auth critical")
var errAuthNotCritical = errors.New("auth not critical") var errAuthNotCritical = errors.New("auth not critical")
func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass string, auth **gortsplib.AuthServer) error { func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass string, auth **gortsplib.AuthServer, ips []interface{}) error {
err := func() error {
if ips == nil {
return nil
}
connIp := c.conn.NetConn().LocalAddr().(*net.TCPAddr).IP
for _, item := range ips {
switch titem := item.(type) {
case net.IP:
if titem.Equal(connIp) {
return nil
}
case *net.IPNet:
if titem.Contains(connIp) {
return nil
}
}
}
c.log("ERR: ip '%s' not allowed", connIp)
return errAuthCritical
}()
if err != nil {
return err
}
err = func() error {
if user == "" { if user == "" {
return nil return nil
} }
@ -216,7 +245,7 @@ func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass st
err := (*auth).ValidateHeader(req.Header["Authorization"], req.Method, req.Url) err := (*auth).ValidateHeader(req.Header["Authorization"], req.Method, req.Url)
if err != nil { if err != nil {
if !initialRequest { if !initialRequest {
c.log("ERR: Unauthorized: %s", err) c.log("ERR: unauthorized: %s", err)
} }
c.conn.WriteResponse(&gortsplib.Response{ c.conn.WriteResponse(&gortsplib.Response{
@ -233,6 +262,11 @@ func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass st
return errAuthNotCritical return errAuthNotCritical
} }
return nil
}()
if err != nil {
return err
}
return nil return nil
} }
@ -291,7 +325,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth) err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth, c.p.readIps)
if err != nil { if err != nil {
if err == errAuthCritical { if err == errAuthCritical {
return false return false
@ -333,7 +367,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false return false
} }
err := c.validateAuth(req, c.p.args.publishUser, c.p.args.publishPass, &c.publishAuth) err := c.validateAuth(req, c.p.args.publishUser, c.p.args.publishPass, &c.publishAuth, c.p.publishIps)
if err != nil { if err != nil {
if err == errAuthCritical { if err == errAuthCritical {
return false return false
@ -405,7 +439,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
switch c.state { switch c.state {
// play // play
case _CLIENT_STATE_STARTING, _CLIENT_STATE_PRE_PLAY: case _CLIENT_STATE_STARTING, _CLIENT_STATE_PRE_PLAY:
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth) err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth, c.p.readIps)
if err != nil { if err != nil {
if err == errAuthCritical { if err == errAuthCritical {
return false return false

Loading…
Cancel
Save