Browse Source

cleanup code

pull/2/head
aler9 6 years ago
parent
commit
77918272df
  1. 2
      README.md
  2. 403
      client.go

2
README.md

@ -13,7 +13,7 @@ Features:
* Publish multiple streams at once, each in a separate path, that can be read by multiple users * Publish multiple streams at once, each in a separate path, that can be read by multiple users
* Each stream can have multiple video and audio tracks * Each stream can have multiple video and audio tracks
* Supports the RTP/RTCP streaming protocol * Supports the RTP/RTCP streaming protocol
* Optional authentication schema for publishers * Optional publisher authentication
* Compatible with Linux and Windows, does not require any dependency or interpreter, it's a single executable * Compatible with Linux and Windows, does not require any dependency or interpreter, it's a single executable

403
client.go

@ -1,7 +1,6 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -15,13 +14,6 @@ import (
"gortc.io/sdp" "gortc.io/sdp"
) )
var (
errTeardown = errors.New("teardown")
errPlay = errors.New("play")
errRecord = errors.New("record")
errWrongKey = errors.New("wrong key")
)
func interleavedChannelToTrack(channel int) (int, trackFlow) { func interleavedChannelToTrack(channel int) (int, trackFlow) {
if (channel % 2) == 0 { if (channel % 2) == 0 {
return (channel / 2), _TRACK_FLOW_RTP return (channel / 2), _TRACK_FLOW_RTP
@ -158,152 +150,49 @@ func (c *client) run() {
return return
} }
c.log(req.Method) ok := c.handleRequest(req)
if !ok {
res, err := c.handleRequest(req)
switch err {
// normal response
case nil:
err = c.rconn.WriteResponse(res)
if err != nil {
c.log("ERR: %s", err)
return
}
// TEARDOWN: close connection silently
case errTeardown:
return return
}
}
}
// PLAY: first write response, then set state func (c *client) writeRes(res *rtsp.Response) {
// otherwise, in case of TCP connections, RTP packets could be written c.rconn.WriteResponse(res)
// before the response }
// then switch to RTP if TCP
case errPlay:
err = c.rconn.WriteResponse(res)
if err != nil {
c.log("ERR: %s", err)
return
}
c.log("is receiving on path %s, %d %s via %s", c.path, len(c.streamTracks), func() string {
if len(c.streamTracks) == 1 {
return "track"
}
return "tracks"
}(), c.streamProtocol)
c.p.mutex.Lock()
c.state = "PLAY"
c.p.mutex.Unlock()
// when protocol is TCP, the RTSP connection becomes a RTP connection
// receive RTP feedback, do not parse it, wait until connection closes
if c.streamProtocol == _STREAM_PROTOCOL_TCP {
buf := make([]byte, 2048)
for {
_, err := c.rconn.Read(buf)
if err != nil {
if err != io.EOF {
c.log("ERR: %s", err)
}
return
}
}
}
// RECORD: switch to RTP if TCP
case errRecord:
err = c.rconn.WriteResponse(res)
if err != nil {
c.log("ERR: %s", err)
return
}
c.p.mutex.Lock()
c.state = "RECORD"
c.p.mutex.Unlock()
c.log("is publishing on path %s, %d %s via %s", c.path, len(c.streamTracks), func() string {
if len(c.streamTracks) == 1 {
return "track"
}
return "tracks"
}(), c.streamProtocol)
// when protocol is TCP, the RTSP connection becomes a RTP connection
// receive RTP data and parse it
if c.streamProtocol == _STREAM_PROTOCOL_TCP {
buf := make([]byte, 2048)
for {
channel, n, err := c.rconn.ReadInterleavedFrame(buf)
if err != nil {
if _, ok := err.(*net.OpError); ok {
} else if err == io.EOF {
} else {
c.log("ERR: %s", err)
}
return
}
trackId, trackFlow := interleavedChannelToTrack(channel)
if trackId >= len(c.streamTracks) {
c.log("ERR: invalid track id '%d'", trackId)
return
}
c.p.mutex.RLock()
c.p.forwardTrack(c.path, trackId, trackFlow, buf[:n])
c.p.mutex.RUnlock()
}
}
// wrong key: reply with 401 and exit
case errWrongKey:
c.log("ERR: %s", err)
c.rconn.WriteResponse(&rtsp.Response{
StatusCode: 401,
Status: "Unauthorized",
Headers: map[string]string{
"CSeq": req.Headers["CSeq"],
},
})
return
// generic error: reply with code 400 and exit func (c *client) writeResError(req *rtsp.Request, err error) {
default: c.log("ERR: %s", err)
c.log("ERR: %s", err)
if cseq, ok := req.Headers["CSeq"]; ok { if cseq, ok := req.Headers["CSeq"]; ok {
c.rconn.WriteResponse(&rtsp.Response{ c.rconn.WriteResponse(&rtsp.Response{
StatusCode: 400, StatusCode: 400,
Status: "Bad Request", Status: "Bad Request",
Headers: map[string]string{ Headers: map[string]string{
"CSeq": cseq, "CSeq": cseq,
}, },
}) })
} else { } else {
c.rconn.WriteResponse(&rtsp.Response{ c.rconn.WriteResponse(&rtsp.Response{
StatusCode: 400, StatusCode: 400,
Status: "Bad Request", Status: "Bad Request",
}) })
}
return
}
} }
} }
func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) { func (c *client) handleRequest(req *rtsp.Request) bool {
c.log(req.Method)
cseq, ok := req.Headers["CSeq"] cseq, ok := req.Headers["CSeq"]
if !ok { if !ok {
return nil, fmt.Errorf("cseq missing") c.writeResError(req, fmt.Errorf("cseq missing"))
return false
} }
ur, err := url.Parse(req.Url) ur, err := url.Parse(req.Url)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse path '%s'", req.Url) c.writeResError(req, fmt.Errorf("unable to parse path '%s'", req.Url))
return false
} }
path := func() string { path := func() string {
@ -327,7 +216,7 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
// do not check state, since OPTIONS can be requested // do not check state, since OPTIONS can be requested
// in any state // in any state
return &rtsp.Response{ c.writeRes(&rtsp.Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Headers: map[string]string{
@ -342,11 +231,13 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
"TEARDOWN", "TEARDOWN",
}, ", "), }, ", "),
}, },
}, nil })
return true
case "DESCRIBE": case "DESCRIBE":
if c.state != "STARTING" { if c.state != "STARTING" {
return nil, fmt.Errorf("client is in state '%s'", c.state) c.writeResError(req, fmt.Errorf("client is in state '%s'", c.state))
return false
} }
sdp, err := func() ([]byte, error) { sdp, err := func() ([]byte, error) {
@ -361,10 +252,11 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
return pub.streamSdpText, nil return pub.streamSdpText, nil
}() }()
if err != nil { if err != nil {
return nil, err c.writeResError(req, err)
return false
} }
return &rtsp.Response{ c.writeRes(&rtsp.Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Headers: map[string]string{
@ -373,20 +265,24 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
"Content-Type": "application/sdp", "Content-Type": "application/sdp",
}, },
Content: sdp, Content: sdp,
}, nil })
return true
case "ANNOUNCE": case "ANNOUNCE":
if c.state != "STARTING" { if c.state != "STARTING" {
return nil, fmt.Errorf("client is in state '%s'", c.state) c.writeResError(req, fmt.Errorf("client is in state '%s'", c.state))
return false
} }
ct, ok := req.Headers["Content-Type"] ct, ok := req.Headers["Content-Type"]
if !ok { if !ok {
return nil, fmt.Errorf("Content-Type header missing") c.writeResError(req, fmt.Errorf("Content-Type header missing"))
return false
} }
if ct != "application/sdp" { if ct != "application/sdp" {
return nil, fmt.Errorf("unsupported Content-Type '%s'", ct) c.writeResError(req, fmt.Errorf("unsupported Content-Type '%s'", ct))
return false
} }
sdpParsed, err := func() (*sdp.Message, error) { sdpParsed, err := func() (*sdp.Message, error) {
@ -405,22 +301,33 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
return m, nil return m, nil
}() }()
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid SDP: %s", err) c.writeResError(req, fmt.Errorf("invalid SDP: %s", err))
return false
} }
if c.p.publishKey != "" { if c.p.publishKey != "" {
q, err := url.ParseQuery(ur.RawQuery) q, err := url.ParseQuery(ur.RawQuery)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse query") c.writeResError(req, fmt.Errorf("unable to parse query"))
return false
} }
key, ok := q["key"] key, ok := q["key"]
if !ok || len(key) == 0 { if !ok || len(key) == 0 {
return nil, fmt.Errorf("key missing") c.writeResError(req, fmt.Errorf("key missing"))
return false
} }
if key[0] != c.p.publishKey { if key[0] != c.p.publishKey {
return nil, errWrongKey // reply with 401 and exit
c.writeRes(&rtsp.Response{
StatusCode: 401,
Status: "Unauthorized",
Headers: map[string]string{
"CSeq": req.Headers["CSeq"],
},
})
return false
} }
} }
@ -441,27 +348,31 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
return nil return nil
}() }()
if err != nil { if err != nil {
return nil, err c.writeResError(req, err)
return false
} }
return &rtsp.Response{ c.writeRes(&rtsp.Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Headers: map[string]string{
"CSeq": cseq, "CSeq": cseq,
}, },
}, nil })
return true
case "SETUP": case "SETUP":
transportstr, ok := req.Headers["Transport"] transportstr, ok := req.Headers["Transport"]
if !ok { if !ok {
return nil, fmt.Errorf("transport header missing") c.writeResError(req, fmt.Errorf("transport header missing"))
return false
} }
th := newTransportHeader(transportstr) th := newTransportHeader(transportstr)
if _, ok := th["unicast"]; !ok { if _, ok := th["unicast"]; !ok {
return nil, fmt.Errorf("transport header does not contain unicast") c.writeResError(req, fmt.Errorf("transport header does not contain unicast"))
return false
} }
switch c.state { switch c.state {
@ -471,11 +382,13 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
if _, ok := th["RTP/AVP"]; ok { if _, ok := th["RTP/AVP"]; ok {
rtpPort, rtcpPort := th.getClientPorts() rtpPort, rtcpPort := th.getClientPorts()
if rtpPort == 0 || rtcpPort == 0 { if rtpPort == 0 || rtcpPort == 0 {
return nil, fmt.Errorf("transport header does not have valid client ports (%s)", transportstr) c.writeResError(req, fmt.Errorf("transport header does not have valid client ports (%s)", transportstr))
return false
} }
if c.path != "" && path != c.path { if c.path != "" && path != c.path {
return nil, fmt.Errorf("path has changed") c.writeResError(req, fmt.Errorf("path has changed"))
return false
} }
err = func() error { err = func() error {
@ -506,10 +419,11 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
return nil return nil
}() }()
if err != nil { if err != nil {
return nil, err c.writeResError(req, err)
return false
} }
return &rtsp.Response{ c.writeRes(&rtsp.Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Headers: map[string]string{
@ -523,12 +437,14 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
}, ";"), }, ";"),
"Session": "12345678", "Session": "12345678",
}, },
}, nil })
return true
// play via TCP // play via TCP
} else if _, ok := th["RTP/AVP/TCP"]; ok { } else if _, ok := th["RTP/AVP/TCP"]; ok {
if c.path != "" && path != c.path { if c.path != "" && path != c.path {
return nil, fmt.Errorf("path has changed") c.writeResError(req, fmt.Errorf("path has changed"))
return false
} }
err = func() error { err = func() error {
@ -559,12 +475,13 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
return nil return nil
}() }()
if err != nil { if err != nil {
return nil, err c.writeResError(req, err)
return false
} }
interleaved := fmt.Sprintf("%d-%d", ((len(c.streamTracks) - 1) * 2), ((len(c.streamTracks)-1)*2)+1) interleaved := fmt.Sprintf("%d-%d", ((len(c.streamTracks) - 1) * 2), ((len(c.streamTracks)-1)*2)+1)
return &rtsp.Response{ c.writeRes(&rtsp.Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Headers: map[string]string{
@ -576,27 +493,32 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
}, ";"), }, ";"),
"Session": "12345678", "Session": "12345678",
}, },
}, nil })
return true
} else { } else {
return nil, fmt.Errorf("transport header does not contain a valid protocol (RTP/AVP or RTP/AVP/TCP) (%s)", transportstr) c.writeResError(req, fmt.Errorf("transport header does not contain a valid protocol (RTP/AVP or RTP/AVP/TCP) (%s)", transportstr))
return false
} }
// record // record
case "ANNOUNCE", "PRE_RECORD": case "ANNOUNCE", "PRE_RECORD":
if _, ok := th["mode=record"]; !ok { if _, ok := th["mode=record"]; !ok {
return nil, fmt.Errorf("transport header does not contain mode=record") c.writeResError(req, fmt.Errorf("transport header does not contain mode=record"))
return false
} }
if path != c.path { if path != c.path {
return nil, fmt.Errorf("path has changed") c.writeResError(req, fmt.Errorf("path has changed"))
return false
} }
// record via UDP // record via UDP
if _, ok := th["RTP/AVP/UDP"]; ok { if _, ok := th["RTP/AVP/UDP"]; ok {
rtpPort, rtcpPort := th.getClientPorts() rtpPort, rtcpPort := th.getClientPorts()
if rtpPort == 0 || rtcpPort == 0 { if rtpPort == 0 || rtcpPort == 0 {
return nil, fmt.Errorf("transport header does not have valid client ports (%s)", transportstr) c.writeResError(req, fmt.Errorf("transport header does not have valid client ports (%s)", transportstr))
return false
} }
err = func() error { err = func() error {
@ -621,10 +543,11 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
return nil return nil
}() }()
if err != nil { if err != nil {
return nil, err c.writeResError(req, err)
return false
} }
return &rtsp.Response{ c.writeRes(&rtsp.Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Headers: map[string]string{
@ -638,7 +561,8 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
}, ";"), }, ";"),
"Session": "12345678", "Session": "12345678",
}, },
}, nil })
return true
// record via TCP // record via TCP
} else if _, ok := th["RTP/AVP/TCP"]; ok { } else if _, ok := th["RTP/AVP/TCP"]; ok {
@ -675,10 +599,11 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
return nil return nil
}() }()
if err != nil { if err != nil {
return nil, err c.writeResError(req, err)
return false
} }
return &rtsp.Response{ c.writeRes(&rtsp.Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Headers: map[string]string{
@ -690,23 +615,28 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
}, ";"), }, ";"),
"Session": "12345678", "Session": "12345678",
}, },
}, nil })
return true
} else { } else {
return nil, fmt.Errorf("transport header does not contain a valid protocol (RTP/AVP or RTP/AVP/TCP) (%s)", transportstr) c.writeResError(req, fmt.Errorf("transport header does not contain a valid protocol (RTP/AVP or RTP/AVP/TCP) (%s)", transportstr))
return false
} }
default: default:
return nil, fmt.Errorf("client is in state '%s'", c.state) c.writeResError(req, fmt.Errorf("client is in state '%s'", c.state))
return false
} }
case "PLAY": case "PLAY":
if c.state != "PRE_PLAY" { if c.state != "PRE_PLAY" {
return nil, fmt.Errorf("client is in state '%s'", c.state) c.writeResError(req, fmt.Errorf("client is in state '%s'", c.state))
return false
} }
if path != c.path { if path != c.path {
return nil, fmt.Errorf("path has changed") c.writeResError(req, fmt.Errorf("path has changed"))
return false
} }
err := func() error { err := func() error {
@ -725,25 +655,59 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
return nil return nil
}() }()
if err != nil { if err != nil {
return nil, err c.writeResError(req, err)
return false
} }
return &rtsp.Response{ // first write response, then set state
// otherwise, in case of TCP connections, RTP packets could be written
// before the response
c.writeRes(&rtsp.Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Headers: map[string]string{
"CSeq": cseq, "CSeq": cseq,
"Session": "12345678", "Session": "12345678",
}, },
}, errPlay })
c.log("is receiving on path '%s', %d %s via %s", c.path, len(c.streamTracks), func() string {
if len(c.streamTracks) == 1 {
return "track"
}
return "tracks"
}(), c.streamProtocol)
c.p.mutex.Lock()
c.state = "PLAY"
c.p.mutex.Unlock()
// when protocol is TCP, the RTSP connection becomes a RTP connection
// receive RTP feedback, do not parse it, wait until connection closes
if c.streamProtocol == _STREAM_PROTOCOL_TCP {
buf := make([]byte, 2048)
for {
_, err := c.rconn.Read(buf)
if err != nil {
if err != io.EOF {
c.log("ERR: %s", err)
}
return false
}
}
}
return true
case "PAUSE": case "PAUSE":
if c.state != "PLAY" { if c.state != "PLAY" {
return nil, fmt.Errorf("client is in state '%s'", c.state) c.writeResError(req, fmt.Errorf("client is in state '%s'", c.state))
return false
} }
if path != c.path { if path != c.path {
return nil, fmt.Errorf("path has changed") c.writeResError(req, fmt.Errorf("path has changed"))
return false
} }
c.log("paused") c.log("paused")
@ -752,22 +716,25 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
c.state = "PRE_PLAY" c.state = "PRE_PLAY"
c.p.mutex.Unlock() c.p.mutex.Unlock()
return &rtsp.Response{ c.writeRes(&rtsp.Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Headers: map[string]string{
"CSeq": cseq, "CSeq": cseq,
"Session": "12345678", "Session": "12345678",
}, },
}, nil })
return true
case "RECORD": case "RECORD":
if c.state != "PRE_RECORD" { if c.state != "PRE_RECORD" {
return nil, fmt.Errorf("client is in state '%s'", c.state) c.writeResError(req, fmt.Errorf("client is in state '%s'", c.state))
return false
} }
if path != c.path { if path != c.path {
return nil, fmt.Errorf("path has changed") c.writeResError(req, fmt.Errorf("path has changed"))
return false
} }
err := func() error { err := func() error {
@ -781,22 +748,66 @@ func (c *client) handleRequest(req *rtsp.Request) (*rtsp.Response, error) {
return nil return nil
}() }()
if err != nil { if err != nil {
return nil, err c.writeResError(req, err)
return false
} }
return &rtsp.Response{ c.writeRes(&rtsp.Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Headers: map[string]string{
"CSeq": cseq, "CSeq": cseq,
"Session": "12345678", "Session": "12345678",
}, },
}, errRecord })
c.p.mutex.Lock()
c.state = "RECORD"
c.p.mutex.Unlock()
c.log("is publishing on path '%s', %d %s via %s", c.path, len(c.streamTracks), func() string {
if len(c.streamTracks) == 1 {
return "track"
}
return "tracks"
}(), c.streamProtocol)
// when protocol is TCP, the RTSP connection becomes a RTP connection
// receive RTP data and parse it
if c.streamProtocol == _STREAM_PROTOCOL_TCP {
buf := make([]byte, 2048)
for {
channel, n, err := c.rconn.ReadInterleavedFrame(buf)
if err != nil {
if _, ok := err.(*net.OpError); ok {
} else if err == io.EOF {
} else {
c.log("ERR: %s", err)
}
return false
}
trackId, trackFlow := interleavedChannelToTrack(channel)
if trackId >= len(c.streamTracks) {
c.log("ERR: invalid track id '%d'", trackId)
return false
}
c.p.mutex.RLock()
c.p.forwardTrack(c.path, trackId, trackFlow, buf[:n])
c.p.mutex.RUnlock()
}
}
return true
case "TEARDOWN": case "TEARDOWN":
return nil, errTeardown // close connection silently
return false
default: default:
return nil, fmt.Errorf("unhandled method '%s'", req.Method) c.writeResError(req, fmt.Errorf("unhandled method '%s'", req.Method))
return false
} }
} }

Loading…
Cancel
Save