diff --git a/internal/conf/path.go b/internal/conf/path.go index d5b8d526..34c4fa5c 100644 --- a/internal/conf/path.go +++ b/internal/conf/path.go @@ -200,16 +200,11 @@ func (pconf *PathConf) check(conf *Conf, name string) error { return fmt.Errorf("a path with a regular expression (or path 'all') cannot have a HLS source. use another path") } - host, _, err := net.SplitHostPort(pconf.Source[len("udp://"):]) + _, _, err := net.SplitHostPort(pconf.Source[len("udp://"):]) if err != nil { return fmt.Errorf("'%s' is not a valid UDP URL", pconf.Source) } - ip := net.ParseIP(host) - if ip == nil { - return fmt.Errorf("'%s' is not a valid IP", host) - } - case strings.HasPrefix(pconf.Source, "srt://"): if pconf.Regexp != nil { return fmt.Errorf("a path with a regular expression (or path 'all') cannot have a SRT source. use another path") diff --git a/internal/core/udp_source.go b/internal/core/udp_source.go index 9dc211ca..9a81c36f 100644 --- a/internal/core/udp_source.go +++ b/internal/core/udp_source.go @@ -92,16 +92,18 @@ func (s *udpSource) run(ctx context.Context, cnf *conf.PathConf, _ chan *conf.Pa hostPort := cnf.Source[len("udp://"):] - pc, err := net.ListenPacket(restrictNetwork("udp", hostPort)) + addr, err := net.ResolveUDPAddr("udp", hostPort) if err != nil { return err } - defer pc.Close() - host, _, _ := net.SplitHostPort(hostPort) - ip := net.ParseIP(host) + pc, err := net.ListenPacket(restrictNetwork("udp", addr.String())) + if err != nil { + return err + } + defer pc.Close() - if ip.IsMulticast() { + if addr.IP.IsMulticast() { p := ipv4.NewPacketConn(pc) err = p.SetMulticastTTL(multicastTTL) @@ -109,7 +111,7 @@ func (s *udpSource) run(ctx context.Context, cnf *conf.PathConf, _ chan *conf.Pa return err } - err = joinMulticastGroupOnAtLeastOneInterface(p, ip) + err = joinMulticastGroupOnAtLeastOneInterface(p, addr.IP) if err != nil { return err } diff --git a/internal/core/udp_source_test.go b/internal/core/udp_source_test.go new file mode 100644 index 00000000..4b3d8e2f --- /dev/null +++ b/internal/core/udp_source_test.go @@ -0,0 +1,86 @@ +package core + +import ( + "bufio" + "net" + "testing" + "time" + + "github.com/bluenviron/gortsplib/v3" + "github.com/bluenviron/gortsplib/v3/pkg/formats" + "github.com/bluenviron/gortsplib/v3/pkg/url" + "github.com/bluenviron/mediacommon/pkg/formats/mpegts" + "github.com/pion/rtp" + "github.com/stretchr/testify/require" +) + +func TestUDPSource(t *testing.T) { + p, ok := newInstance("paths:\n" + + " proxied:\n" + + " source: udp://localhost:9999\n" + + " sourceOnDemand: yes\n") + require.Equal(t, true, ok) + defer p.Close() + + c := gortsplib.Client{} + + u, err := url.Parse("rtsp://127.0.0.1:8554/proxied") + require.NoError(t, err) + + err = c.Start(u.Scheme, u.Host) + require.NoError(t, err) + defer c.Close() + + connected := make(chan struct{}) + received := make(chan struct{}) + + go func() { + time.Sleep(200 * time.Millisecond) + + conn, err := net.Dial("udp", "localhost:9999") + require.NoError(t, err) + defer conn.Close() + + track := &mpegts.Track{ + Codec: &mpegts.CodecH264{}, + } + + bw := bufio.NewWriter(conn) + w := mpegts.NewWriter(bw, []*mpegts.Track{track}) + require.NoError(t, err) + + err = w.WriteH26x(track, 0, 0, true, [][]byte{ + { // IDR + 0x05, 1, + }, + }) + require.NoError(t, err) + bw.Flush() + + <-connected + + err = w.WriteH26x(track, 0, 0, true, [][]byte{{5, 2}}) + require.NoError(t, err) + bw.Flush() + }() + + medias, baseURL, _, err := c.Describe(u) + require.NoError(t, err) + + var forma *formats.H264 + medi := medias.FindFormat(&forma) + + _, err = c.Setup(medi, baseURL, 0, 0) + require.NoError(t, err) + + c.OnPacketRTP(medi, forma, func(pkt *rtp.Packet) { + require.Equal(t, []byte{5, 1}, pkt.Payload) + close(received) + }) + + _, err = c.Play(nil) + require.NoError(t, err) + + close(connected) + <-received +}