diff --git a/rtsp_client.go b/rtsp_client.go index e2f87985..872639b7 100644 --- a/rtsp_client.go +++ b/rtsp_client.go @@ -3,6 +3,7 @@ package main import ( "bufio" "encoding/binary" + "errors" "fmt" "io" "log" @@ -14,6 +15,12 @@ import ( "rtsp-server/rtsp" ) +var ( + errTeardown = errors.New("teardown") + errPlay = errors.New("play") + errRecord = errors.New("record") +) + type rtspClient struct { p *program nconn net.Conn @@ -92,336 +99,27 @@ func (c *rtspClient) run() { c.log(req.Method) - cseq, ok := req.Headers["CSeq"] - if !ok { - c.log("ERR: cseq missing") - return - } - - ur, err := url.Parse(req.Path) - if err != nil { - c.log("ERR: unable to parse path '%s'", req.Path) - return - } - - switch req.Method { - case "OPTIONS": - // do not check state, since OPTIONS can be requested - // in any state - - err = rconn.WriteResponse(&rtsp.Response{ - StatusCode: 200, - Status: "OK", - Headers: map[string]string{ - "CSeq": cseq, - "Public": strings.Join([]string{ - "DESCRIBE", - "ANNOUNCE", - "SETUP", - "PLAY", - "PAUSE", - "RECORD", - "TEARDOWN", - }, ", "), - }, - }) - if err != nil { - c.log("ERR: %s", err) - return - } - - case "DESCRIBE": - if c.state != "STARTING" { - c.log("ERR: client is in state '%s'", c.state) - return - } - - sdp, err := func() ([]byte, error) { - c.p.mutex.RLock() - defer c.p.mutex.RUnlock() - - if len(c.p.streamSdp) == 0 { - return nil, fmt.Errorf("no one is streaming") - } + res, err := c.handleRequest(req) - return c.p.streamSdp, nil - }() + switch err { + // normal response + case nil: + err = rconn.WriteResponse(res) if err != nil { c.log("ERR: %s", err) return } - err = rconn.WriteResponse(&rtsp.Response{ - StatusCode: 200, - Status: "OK", - Headers: map[string]string{ - "CSeq": cseq, - "Content-Base": ur.String(), - "Content-Type": "application/sdp", - }, - Content: sdp, - }) - if err != nil { - c.log("ERR: %s", err) - return - } - - case "ANNOUNCE": - if c.state != "STARTING" { - c.log("ERR: client is in state '%s'", c.state) - return - } - - ct, ok := req.Headers["Content-Type"] - if !ok { - c.log("ERR: Content-Type header missing") - return - } - - if ct != "application/sdp" { - c.log("ERR: unsupported Content-Type '%s'", ct) - return - } - - err := func() error { - c.p.mutex.Lock() - defer c.p.mutex.Unlock() - - if c.p.streamAuthor != nil { - return fmt.Errorf("another client is already streaming") - } - - c.p.streamAuthor = c - c.p.streamSdp = req.Content - return nil - }() - if err != nil { - c.log("ERR: %s", err) - return - } - - err = rconn.WriteResponse(&rtsp.Response{ - StatusCode: 200, - Status: "OK", - Headers: map[string]string{ - "CSeq": cseq, - }, - }) - if err != nil { - c.log("ERR: %s", err) - return - } - - c.p.mutex.Lock() - c.state = "ANNOUNCE" - c.p.mutex.Unlock() - - case "SETUP": - transport, ok := req.Headers["Transport"] - if !ok { - c.log("ERR: transport header missing") - return - } - - transports := make(map[string]struct{}) - for _, t := range strings.Split(transport, ";") { - transports[t] = struct{}{} - } - - if _, ok := transports["unicast"]; !ok { - c.log("ERR: transport header does not contain unicast") - return - } - - getPorts := func() (int, int) { - for t := range transports { - if !strings.HasPrefix(t, "client_port=") { - continue - } - t = t[len("client_port="):] - - ports := strings.Split(t, "-") - if len(ports) != 2 { - return 0, 0 - } - - port1, err := strconv.ParseInt(ports[0], 10, 64) - if err != nil { - return 0, 0 - } - - port2, err := strconv.ParseInt(ports[1], 10, 64) - if err != nil { - return 0, 0 - } - - return int(port1), int(port2) - } - return 0, 0 - } - - switch c.state { - // play - case "STARTING": - // UDP - if _, ok := transports["RTP/AVP"]; ok { - clientPort1, clientPort2 := getPorts() - if clientPort1 == 0 || clientPort2 == 0 { - c.log("ERR: transport header does not have valid client ports (%s)", transport) - return - } - - err = rconn.WriteResponse(&rtsp.Response{ - StatusCode: 200, - Status: "OK", - Headers: map[string]string{ - "CSeq": cseq, - "Transport": strings.Join([]string{ - "RTP/AVP", - "unicast", - fmt.Sprintf("client_port=%d-%d", clientPort1, clientPort2), - // use two fake server ports, since we do not want to receive feedback - // from the client - fmt.Sprintf("server_port=%d-%d", c.p.rtpPort+2, c.p.rtcpPort+2), - "ssrc=1234ABCD", - }, ";"), - "Session": "12345678", - }, - }) - if err != nil { - c.log("ERR: %s", err) - return - } - - c.p.mutex.Lock() - c.rtpProto = "udp" - c.rtpPort = clientPort1 - c.rtcpPort = clientPort2 - c.state = "PRE_PLAY" - c.p.mutex.Unlock() - - // TCP - } else if _, ok := transports["RTP/AVP/TCP"]; ok { - err = rconn.WriteResponse(&rtsp.Response{ - StatusCode: 200, - Status: "OK", - Headers: map[string]string{ - "CSeq": cseq, - "Transport": strings.Join([]string{ - "RTP/AVP/TCP", - "unicast", - "destination=127.0.0.1", - "source=127.0.0.1", - "interleaved=0-1", - }, ";"), - "Session": "12345678", - }, - }) - if err != nil { - c.log("ERR: %s", err) - return - } - - c.p.mutex.Lock() - c.rtpProto = "tcp" - c.state = "PRE_PLAY" - c.p.mutex.Unlock() - - } else { - c.log("ERR: transport header does not contain a valid protocol (RTP/AVP or RTP/AVP/TCP) (%s)", transport) - return - } - - // record - case "ANNOUNCE": - if _, ok := transports["mode=record"]; !ok { - c.log("ERR: transport header does not contain mode=record") - return - } - - if _, ok := transports["RTP/AVP/UDP"]; ok { - clientPort1, clientPort2 := getPorts() - if clientPort1 == 0 || clientPort2 == 0 { - c.log("ERR: transport header does not have valid client ports (%s)", transport) - return - } - - err = rconn.WriteResponse(&rtsp.Response{ - StatusCode: 200, - Status: "OK", - Headers: map[string]string{ - "CSeq": cseq, - "Transport": strings.Join([]string{ - "RTP/AVP", - "unicast", - fmt.Sprintf("client_port=%d-%d", clientPort1, clientPort2), - fmt.Sprintf("server_port=%d-%d", c.p.rtpPort, c.p.rtcpPort), - "ssrc=1234ABCD", - }, ";"), - "Session": "12345678", - }, - }) - if err != nil { - c.log("ERR: %s", err) - return - } - - c.p.mutex.Lock() - c.rtpProto = "udp" - c.rtpPort = clientPort1 - c.rtcpPort = clientPort2 - c.state = "PRE_RECORD" - c.p.mutex.Unlock() - - } else if _, ok := transports["RTP/AVP/TCP"]; ok { - err = rconn.WriteResponse(&rtsp.Response{ - StatusCode: 200, - Status: "OK", - Headers: map[string]string{ - "CSeq": cseq, - "Transport": strings.Join([]string{ - "RTP/AVP/TCP", - "unicast", - "destination=127.0.0.1", - "source=127.0.0.1", - }, ";"), - "Session": "12345678", - }, - }) - if err != nil { - c.log("ERR: %s", err) - return - } - - c.p.mutex.Lock() - c.rtpProto = "tcp" - c.state = "PRE_RECORD" - c.p.mutex.Unlock() - - } else { - c.log("ERR: transport header does not contain a valid protocol (RTP/AVP or RTP/AVP/TCP) (%s)", transport) - return - } - - default: - c.log("ERR: client is in state '%s'", c.state) - return - } - - case "PLAY": - if c.state != "PRE_PLAY" { - c.log("ERR: client is in state '%s'", c.state) - return - } + // TEARDOWN, close connection silently + case errTeardown: + return - err = rconn.WriteResponse(&rtsp.Response{ - StatusCode: 200, - Status: "OK", - Headers: map[string]string{ - "CSeq": cseq, - "Session": "12345678", - }, - }) + // PLAY: first write response, then set state + // otherwise, in case of TCP connections, RTP packets could be written + // before the response + // then switch to RTP if TCP + case errPlay: + err = rconn.WriteResponse(res) if err != nil { c.log("ERR: %s", err) return @@ -433,8 +131,8 @@ func (c *rtspClient) run() { c.state = "PLAY" c.p.mutex.Unlock() - // when rtp protocol is TCP, the RTSP connection becomes a RTP connection. - // receive RTP feedback, do not parse it, wait until connection closes. + // when rtp protocol is TCP, the RTSP connection becomes a RTP connection + // receive RTP feedback, do not parse it, wait until connection closes if c.rtpProto == "tcp" { buf := make([]byte, 1024) for { @@ -445,58 +143,22 @@ func (c *rtspClient) run() { } } - case "PAUSE": - if c.state != "PLAY" { - c.log("ERR: client is in state '%s'", c.state) - return - } - - c.log("paused receiving") - - c.p.mutex.Lock() - c.state = "PRE_PLAY" - c.p.mutex.Unlock() - - err = rconn.WriteResponse(&rtsp.Response{ - StatusCode: 200, - Status: "OK", - Headers: map[string]string{ - "CSeq": cseq, - "Session": "12345678", - }, - }) - if err != nil { - c.log("ERR: %s", err) - return - } - - case "RECORD": - if c.state != "PRE_RECORD" { - c.log("ERR: client is in state '%s'", c.state) - return - } - - err = rconn.WriteResponse(&rtsp.Response{ - StatusCode: 200, - Status: "OK", - Headers: map[string]string{ - "CSeq": cseq, - "Session": "12345678", - }, - }) + // RECORD: switch to RTP if TCP + case errRecord: + err = rconn.WriteResponse(res) if err != nil { c.log("ERR: %s", err) return } - c.log("is publishing (via %s)", c.rtpProto) - c.p.mutex.Lock() c.state = "RECORD" c.p.mutex.Unlock() - // when rtp protocol is TCP, the RTSP connection becomes a RTP connection. - // receive RTP feedback, do not parse it, wait until connection closes. + c.log("is publishing (via %s)", c.rtpProto) + + // when rtp protocol is TCP, the RTSP connection becomes a RTP connection + // receive RTP data and parse it if c.rtpProto == "tcp" { packet := make([]byte, 2048) bconn := bufio.NewReader(c.nconn) @@ -532,11 +194,354 @@ func (c *rtspClient) run() { } } - case "TEARDOWN": + // error: write and exit + default: + c.log("ERR: %s", err) + + if cseq, ok := req.Headers["cseq"]; ok { + rconn.WriteResponse(&rtsp.Response{ + StatusCode: 400, + Status: "Bad Request", + Headers: map[string]string{ + "CSeq": cseq, + }, + }) + } else { + rconn.WriteResponse(&rtsp.Response{ + StatusCode: 400, + Status: "Bad Request", + }) + } return + } + } +} + +func (c *rtspClient) handleRequest(req *rtsp.Request) (*rtsp.Response, error) { + cseq, ok := req.Headers["CSeq"] + if !ok { + return nil, fmt.Errorf("cseq missing") + } + + ur, err := url.Parse(req.Path) + if err != nil { + return nil, fmt.Errorf("unable to parse path '%s'", req.Path) + } + + switch req.Method { + case "OPTIONS": + // do not check state, since OPTIONS can be requested + // in any state + + return &rtsp.Response{ + StatusCode: 200, + Status: "OK", + Headers: map[string]string{ + "CSeq": cseq, + "Public": strings.Join([]string{ + "DESCRIBE", + "ANNOUNCE", + "SETUP", + "PLAY", + "PAUSE", + "RECORD", + "TEARDOWN", + }, ", "), + }, + }, nil + + case "DESCRIBE": + if c.state != "STARTING" { + return nil, fmt.Errorf("client is in state '%s'", c.state) + } + + sdp, err := func() ([]byte, error) { + c.p.mutex.RLock() + defer c.p.mutex.RUnlock() + + if len(c.p.streamSdp) == 0 { + return nil, fmt.Errorf("no one is streaming") + } + + return c.p.streamSdp, nil + }() + if err != nil { + return nil, err + } + + return &rtsp.Response{ + StatusCode: 200, + Status: "OK", + Headers: map[string]string{ + "CSeq": cseq, + "Content-Base": ur.String(), + "Content-Type": "application/sdp", + }, + Content: sdp, + }, nil + + case "ANNOUNCE": + if c.state != "STARTING" { + return nil, fmt.Errorf("client is in state '%s'", c.state) + } + + ct, ok := req.Headers["Content-Type"] + if !ok { + return nil, fmt.Errorf("Content-Type header missing") + } + + if ct != "application/sdp" { + return nil, fmt.Errorf("unsupported Content-Type '%s'", ct) + } + + err := func() error { + c.p.mutex.Lock() + defer c.p.mutex.Unlock() + + if c.p.streamAuthor != nil { + return fmt.Errorf("another client is already streaming") + } + + c.p.streamAuthor = c + c.p.streamSdp = req.Content + return nil + }() + if err != nil { + return nil, err + } + + c.p.mutex.Lock() + c.state = "ANNOUNCE" + c.p.mutex.Unlock() + + return &rtsp.Response{ + StatusCode: 200, + Status: "OK", + Headers: map[string]string{ + "CSeq": cseq, + }, + }, nil + + case "SETUP": + transport, ok := req.Headers["Transport"] + if !ok { + return nil, fmt.Errorf("transport header missing") + } + + transports := make(map[string]struct{}) + for _, t := range strings.Split(transport, ";") { + transports[t] = struct{}{} + } + + if _, ok := transports["unicast"]; !ok { + return nil, fmt.Errorf("transport header does not contain unicast") + } + + getPorts := func() (int, int) { + for t := range transports { + if !strings.HasPrefix(t, "client_port=") { + continue + } + t = t[len("client_port="):] + + ports := strings.Split(t, "-") + if len(ports) != 2 { + return 0, 0 + } + + port1, err := strconv.ParseInt(ports[0], 10, 64) + if err != nil { + return 0, 0 + } + + port2, err := strconv.ParseInt(ports[1], 10, 64) + if err != nil { + return 0, 0 + } + + return int(port1), int(port2) + } + return 0, 0 + } + + switch c.state { + // play + case "STARTING": + // UDP + if _, ok := transports["RTP/AVP"]; ok { + clientPort1, clientPort2 := getPorts() + if clientPort1 == 0 || clientPort2 == 0 { + return nil, fmt.Errorf("transport header does not have valid client ports (%s)", transport) + } + + c.p.mutex.Lock() + c.rtpProto = "udp" + c.rtpPort = clientPort1 + c.rtcpPort = clientPort2 + c.state = "PRE_PLAY" + c.p.mutex.Unlock() + + return &rtsp.Response{ + StatusCode: 200, + Status: "OK", + Headers: map[string]string{ + "CSeq": cseq, + "Transport": strings.Join([]string{ + "RTP/AVP", + "unicast", + fmt.Sprintf("client_port=%d-%d", clientPort1, clientPort2), + // use two fake server ports, since we do not want to receive feedback + // from the client + fmt.Sprintf("server_port=%d-%d", c.p.rtpPort+2, c.p.rtcpPort+2), + "ssrc=1234ABCD", + }, ";"), + "Session": "12345678", + }, + }, nil + + // TCP + } else if _, ok := transports["RTP/AVP/TCP"]; ok { + c.p.mutex.Lock() + c.rtpProto = "tcp" + c.state = "PRE_PLAY" + c.p.mutex.Unlock() + + return &rtsp.Response{ + StatusCode: 200, + Status: "OK", + Headers: map[string]string{ + "CSeq": cseq, + "Transport": strings.Join([]string{ + "RTP/AVP/TCP", + "unicast", + "destination=127.0.0.1", + "source=127.0.0.1", + "interleaved=0-1", + }, ";"), + "Session": "12345678", + }, + }, nil + + } else { + return nil, fmt.Errorf("transport header does not contain a valid protocol (RTP/AVP or RTP/AVP/TCP) (%s)", transport) + } + + // record + case "ANNOUNCE": + if _, ok := transports["mode=record"]; !ok { + return nil, fmt.Errorf("transport header does not contain mode=record") + } + + if _, ok := transports["RTP/AVP/UDP"]; ok { + clientPort1, clientPort2 := getPorts() + if clientPort1 == 0 || clientPort2 == 0 { + return nil, fmt.Errorf("transport header does not have valid client ports (%s)", transport) + } + + c.p.mutex.Lock() + c.rtpProto = "udp" + c.rtpPort = clientPort1 + c.rtcpPort = clientPort2 + c.state = "PRE_RECORD" + c.p.mutex.Unlock() + + return &rtsp.Response{ + StatusCode: 200, + Status: "OK", + Headers: map[string]string{ + "CSeq": cseq, + "Transport": strings.Join([]string{ + "RTP/AVP", + "unicast", + fmt.Sprintf("client_port=%d-%d", clientPort1, clientPort2), + fmt.Sprintf("server_port=%d-%d", c.p.rtpPort, c.p.rtcpPort), + "ssrc=1234ABCD", + }, ";"), + "Session": "12345678", + }, + }, nil + + } else if _, ok := transports["RTP/AVP/TCP"]; ok { + c.p.mutex.Lock() + c.rtpProto = "tcp" + c.state = "PRE_RECORD" + c.p.mutex.Unlock() + + return &rtsp.Response{ + StatusCode: 200, + Status: "OK", + Headers: map[string]string{ + "CSeq": cseq, + "Transport": strings.Join([]string{ + "RTP/AVP/TCP", + "unicast", + "destination=127.0.0.1", + "source=127.0.0.1", + }, ";"), + "Session": "12345678", + }, + }, nil + + } else { + return nil, fmt.Errorf("transport header does not contain a valid protocol (RTP/AVP or RTP/AVP/TCP) (%s)", transport) + } default: - c.log("ERR: method %s unhandled", req.Method) + return nil, fmt.Errorf("client is in state '%s'", c.state) + } + + case "PLAY": + if c.state != "PRE_PLAY" { + return nil, fmt.Errorf("client is in state '%s'", c.state) + } + + return &rtsp.Response{ + StatusCode: 200, + Status: "OK", + Headers: map[string]string{ + "CSeq": cseq, + "Session": "12345678", + }, + }, errPlay + + case "PAUSE": + if c.state != "PLAY" { + return nil, fmt.Errorf("client is in state '%s'", c.state) + } + + c.log("paused receiving") + + c.p.mutex.Lock() + c.state = "PRE_PLAY" + c.p.mutex.Unlock() + + return &rtsp.Response{ + StatusCode: 200, + Status: "OK", + Headers: map[string]string{ + "CSeq": cseq, + "Session": "12345678", + }, + }, nil + + case "RECORD": + if c.state != "PRE_RECORD" { + return nil, fmt.Errorf("client is in state '%s'", c.state) } + + return &rtsp.Response{ + StatusCode: 200, + Status: "OK", + Headers: map[string]string{ + "CSeq": cseq, + "Session": "12345678", + }, + }, errRecord + + case "TEARDOWN": + return nil, errTeardown + + default: + return nil, fmt.Errorf("unhandled method '%s'", req.Method) } }