diff --git a/internal/conf/authmethod.go b/internal/conf/authmethod.go index 49e9a98e..5866fcbc 100644 --- a/internal/conf/authmethod.go +++ b/internal/conf/authmethod.go @@ -8,34 +8,6 @@ import ( "github.com/aler9/gortsplib/pkg/headers" ) -func unmarshalStringSlice(b []byte) ([]string, error) { - var in interface{} - if err := json.Unmarshal(b, &in); err != nil { - return nil, err - } - - var slice []string - - switch it := in.(type) { - case string: // from environment variables - slice = strings.Split(it, ",") - - case []interface{}: // from yaml - for _, e := range it { - et, ok := e.(string) - if !ok { - return nil, fmt.Errorf("cannot unmarshal from %T", e) - } - slice = append(slice, et) - } - - default: - return nil, fmt.Errorf("cannot unmarshal from %T", in) - } - - return slice, nil -} - // AuthMethods is the authMethods parameter. type AuthMethods []headers.AuthMethod @@ -58,12 +30,12 @@ func (d AuthMethods) MarshalJSON() ([]byte, error) { // UnmarshalJSON unmarshals a AuthMethods from JSON. func (d *AuthMethods) UnmarshalJSON(b []byte) error { - slice, err := unmarshalStringSlice(b) - if err != nil { + var in []string + if err := json.Unmarshal(b, &in); err != nil { return err } - for _, v := range slice { + for _, v := range in { switch v { case "basic": *d = append(*d, headers.AuthBasic) @@ -78,3 +50,8 @@ func (d *AuthMethods) UnmarshalJSON(b []byte) error { return nil } + +func (d *AuthMethods) unmarshalEnv(s string) error { + byts, _ := json.Marshal(strings.Split(s, ",")) + return d.UnmarshalJSON(byts) +} diff --git a/internal/conf/credential.go b/internal/conf/credential.go index 25a77556..f44af0c4 100644 --- a/internal/conf/credential.go +++ b/internal/conf/credential.go @@ -35,3 +35,7 @@ func (d *Credential) UnmarshalJSON(b []byte) error { *d = Credential(in) return nil } + +func (d *Credential) unmarshalEnv(s string) error { + return d.UnmarshalJSON([]byte(`"` + s + `"`)) +} diff --git a/internal/conf/encryption.go b/internal/conf/encryption.go index de6bf6d9..c474d17e 100644 --- a/internal/conf/encryption.go +++ b/internal/conf/encryption.go @@ -56,3 +56,7 @@ func (d *Encryption) UnmarshalJSON(b []byte) error { return nil } + +func (d *Encryption) unmarshalEnv(s string) error { + return d.UnmarshalJSON([]byte(`"` + s + `"`)) +} diff --git a/internal/conf/env.go b/internal/conf/env.go index 1ebaeafc..1d019100 100644 --- a/internal/conf/env.go +++ b/internal/conf/env.go @@ -1,7 +1,6 @@ package conf import ( - "encoding/json" "fmt" "os" "reflect" @@ -9,12 +8,16 @@ import ( "strings" ) +type envUnmarshaler interface { + unmarshalEnv(string) error +} + func loadEnvInternal(env map[string]string, prefix string, rv reflect.Value) error { rt := rv.Type() - if i, ok := rv.Addr().Interface().(json.Unmarshaler); ok { + if i, ok := rv.Addr().Interface().(envUnmarshaler); ok { if ev, ok := env[prefix]; ok { - err := i.UnmarshalJSON([]byte(`"` + ev + `"`)) + err := i.unmarshalEnv(ev) if err != nil { return fmt.Errorf("%s: %s", prefix, err) } diff --git a/internal/conf/ipsornets.go b/internal/conf/ipsornets.go index d1f83e74..cf5c5799 100644 --- a/internal/conf/ipsornets.go +++ b/internal/conf/ipsornets.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net" + "strings" ) // IPsOrNets is a parameter that acceps IPs or subnets. @@ -22,16 +23,16 @@ func (d IPsOrNets) MarshalJSON() ([]byte, error) { // UnmarshalJSON unmarshals a IPsOrNets from JSON. func (d *IPsOrNets) UnmarshalJSON(b []byte) error { - slice, err := unmarshalStringSlice(b) - if err != nil { + var in []string + if err := json.Unmarshal(b, &in); err != nil { return err } - if len(slice) == 0 { + if len(in) == 0 { return nil } - for _, t := range slice { + for _, t := range in { if _, ipnet, err := net.ParseCIDR(t); err == nil { *d = append(*d, ipnet) } else if ip := net.ParseIP(t); ip != nil { @@ -43,3 +44,8 @@ func (d *IPsOrNets) UnmarshalJSON(b []byte) error { return nil } + +func (d *IPsOrNets) unmarshalEnv(s string) error { + byts, _ := json.Marshal(strings.Split(s, ",")) + return d.UnmarshalJSON(byts) +} diff --git a/internal/conf/logdestination.go b/internal/conf/logdestination.go index 9eb8efa1..d324ef0a 100644 --- a/internal/conf/logdestination.go +++ b/internal/conf/logdestination.go @@ -3,6 +3,7 @@ package conf import ( "encoding/json" "fmt" + "strings" "github.com/aler9/rtsp-simple-server/internal/logger" ) @@ -38,14 +39,14 @@ func (d LogDestinations) MarshalJSON() ([]byte, error) { // UnmarshalJSON unmarshals a LogDestinations from JSON. func (d *LogDestinations) UnmarshalJSON(b []byte) error { - slice, err := unmarshalStringSlice(b) - if err != nil { + var in []string + if err := json.Unmarshal(b, &in); err != nil { return err } *d = make(LogDestinations) - for _, proto := range slice { + for _, proto := range in { switch proto { case "stdout": (*d)[logger.DestinationStdout] = struct{}{} @@ -63,3 +64,8 @@ func (d *LogDestinations) UnmarshalJSON(b []byte) error { return nil } + +func (d *LogDestinations) unmarshalEnv(s string) error { + byts, _ := json.Marshal(strings.Split(s, ",")) + return d.UnmarshalJSON(byts) +} diff --git a/internal/conf/loglevel.go b/internal/conf/loglevel.go index a29594ca..d942ccba 100644 --- a/internal/conf/loglevel.go +++ b/internal/conf/loglevel.go @@ -51,3 +51,7 @@ func (d *LogLevel) UnmarshalJSON(b []byte) error { return nil } + +func (d *LogLevel) unmarshalEnv(s string) error { + return d.UnmarshalJSON([]byte(`"` + s + `"`)) +} diff --git a/internal/conf/protocol.go b/internal/conf/protocol.go index 009e1180..c09df72c 100644 --- a/internal/conf/protocol.go +++ b/internal/conf/protocol.go @@ -3,6 +3,7 @@ package conf import ( "encoding/json" "fmt" + "strings" ) // Protocol is a RTSP stream protocol. @@ -46,14 +47,14 @@ func (d Protocols) MarshalJSON() ([]byte, error) { // UnmarshalJSON unmarshals a Protocols from JSON. func (d *Protocols) UnmarshalJSON(b []byte) error { - slice, err := unmarshalStringSlice(b) - if err != nil { + var in []string + if err := json.Unmarshal(b, &in); err != nil { return err } *d = make(Protocols) - for _, proto := range slice { + for _, proto := range in { switch proto { case "udp": (*d)[ProtocolUDP] = struct{}{} @@ -71,3 +72,8 @@ func (d *Protocols) UnmarshalJSON(b []byte) error { return nil } + +func (d *Protocols) unmarshalEnv(s string) error { + byts, _ := json.Marshal(strings.Split(s, ",")) + return d.UnmarshalJSON(byts) +} diff --git a/internal/conf/sourceprotocol.go b/internal/conf/sourceprotocol.go index 1fb47deb..3d95d6c4 100644 --- a/internal/conf/sourceprotocol.go +++ b/internal/conf/sourceprotocol.go @@ -62,3 +62,7 @@ func (d *SourceProtocol) UnmarshalJSON(b []byte) error { return nil } + +func (d *SourceProtocol) unmarshalEnv(s string) error { + return d.UnmarshalJSON([]byte(`"` + s + `"`)) +} diff --git a/internal/conf/stringduration.go b/internal/conf/stringduration.go index 69c45d3f..428b2158 100644 --- a/internal/conf/stringduration.go +++ b/internal/conf/stringduration.go @@ -29,3 +29,7 @@ func (d *StringDuration) UnmarshalJSON(b []byte) error { return nil } + +func (d *StringDuration) unmarshalEnv(s string) error { + return d.UnmarshalJSON([]byte(`"` + s + `"`)) +}