commit e9952937dc8ddaf9cc5b6de9bbdf26956d80b36b Author: halwu(吴浩麟) Date: Sun May 28 16:11:47 2017 +0800 Initial commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..5c48649 --- /dev/null +++ b/README.md @@ -0,0 +1,15 @@ +# AV streaming server + +## Feature +- write in pure golang, can run in any platform +- for live streaming +- support `RTMP` and `FLV` `HLS` over HTTP + +## Use +1. run `git clone ` +2. run `go run main.go` to start livego server +3. push `RTMP` stream to `rtmp://localhost/live/movie`, eg use `ffmpeg -re -i demo.flv -c copy -f flv rtmp://localhost/live/movie` +4. play stream use [VLC](http://www.videolan.org/vlc/index.html) or other players + - play `RTMP` from `rtmp://localhost/live/movie` + - play `FLV` from `http://127.0.0.1:8081/live/movie.flv` + - play `HLS` from `http://127.0.0.1:8080/live/movie.m3u8` diff --git a/av/av.go b/av/av.go new file mode 100755 index 0000000..faeb926 --- /dev/null +++ b/av/av.go @@ -0,0 +1,152 @@ +package av + +import "io" +import "fmt" + +const ( + TAG_AUDIO = 8 + TAG_VIDEO = 9 + TAG_SCRIPTDATAAMF0 = 18 + TAG_SCRIPTDATAAMF3 = 0xf +) + +const ( + MetadatAMF0 = 0x12 + MetadataAMF3 = 0xf +) + +const ( + SOUND_MP3 = 2 + SOUND_NELLYMOSER_16KHZ_MONO = 4 + SOUND_NELLYMOSER_8KHZ_MONO = 5 + SOUND_NELLYMOSER = 6 + SOUND_ALAW = 7 + SOUND_MULAW = 8 + SOUND_AAC = 10 + SOUND_SPEEX = 11 + + SOUND_5_5Khz = 0 + SOUND_11Khz = 1 + SOUND_22Khz = 2 + SOUND_44Khz = 3 + + SOUND_8BIT = 0 + SOUND_16BIT = 1 + + SOUND_MONO = 0 + SOUND_STEREO = 1 + + AAC_SEQHDR = 0 + AAC_RAW = 1 +) + +const ( + AVC_SEQHDR = 0 + AVC_NALU = 1 + AVC_EOS = 2 + + FRAME_KEY = 1 + FRAME_INTER = 2 + + VIDEO_H264 = 7 +) + +var ( + PUBLISH = "publish" + PLAY = "play" +) + +// Header can be converted to AudioHeaderInfo or VideoHeaderInfo +type Packet struct { + IsAudio bool + IsVideo bool + IsMetadata bool + TimeStamp uint32 // dts + Header PacketHeader + Data []byte +} + +type PacketHeader interface { +} + +type AudioPacketHeader interface { + PacketHeader + SoundFormat() uint8 + AACPacketType() uint8 +} + +type VideoPacketHeader interface { + PacketHeader + IsKeyFrame() bool + IsSeq() bool + CodecID() uint8 + CompositionTime() int32 +} + +type Demuxer interface { + Demux(*Packet) (ret *Packet, err error) +} + +type Muxer interface { + Mux(*Packet, io.Writer) error +} + +type SampleRater interface { + SampleRate() (int, error) +} + +type CodecParser interface { + SampleRater + Parse(*Packet, io.Writer) error +} + +type GetWriter interface { + GetWriter(Info) WriteCloser +} + +type Handler interface { + HandleReader(ReadCloser) + HandleWriter(WriteCloser) +} + +type Alive interface { + Alive() bool +} + +type Closer interface { + Info() Info + Close(error) +} + +type CalcTime interface { + CalcBaseTimestamp() +} + +type Info struct { + Key string + URL string + UID string + Inter bool +} + +func (self Info) IsInterval() bool { + return self.Inter +} + +func (i Info) String() string { + return fmt.Sprintf("", + i.Key, i.URL, i.UID, i.Inter) +} + +type ReadCloser interface { + Closer + Alive + Read(*Packet) error +} + +type WriteCloser interface { + Closer + Alive + CalcTime + Write(Packet) error +} diff --git a/av/rwbase.go b/av/rwbase.go new file mode 100755 index 0000000..6385dd8 --- /dev/null +++ b/av/rwbase.go @@ -0,0 +1,53 @@ +package av + +import "time" +import "sync" + +type RWBaser struct { + lock sync.Mutex + timeout time.Duration + PreTime time.Time + BaseTimestamp uint32 + LastVideoTimestamp uint32 + LastAudioTimestamp uint32 +} + +func NewRWBaser(duration time.Duration) RWBaser { + return RWBaser{ + timeout: duration, + PreTime: time.Now(), + } +} + +func (self *RWBaser) BaseTimeStamp() uint32 { + return self.BaseTimestamp +} + +func (self *RWBaser) CalcBaseTimestamp() { + if self.LastAudioTimestamp > self.LastVideoTimestamp { + self.BaseTimestamp = self.LastAudioTimestamp + } else { + self.BaseTimestamp = self.LastVideoTimestamp + } +} + +func (self *RWBaser) RecTimeStamp(timestamp, typeID uint32) { + if typeID == TAG_VIDEO { + self.LastVideoTimestamp = timestamp + } else if typeID == TAG_AUDIO { + self.LastAudioTimestamp = timestamp + } +} + +func (self *RWBaser) SetPreTime() { + self.lock.Lock() + self.PreTime = time.Now() + self.lock.Unlock() +} + +func (self *RWBaser) Alive() bool { + self.lock.Lock() + b := !(time.Now().Sub(self.PreTime) >= self.timeout) + self.lock.Unlock() + return b +} diff --git a/container/flv/demuxer.go b/container/flv/demuxer.go new file mode 100755 index 0000000..730ed35 --- /dev/null +++ b/container/flv/demuxer.go @@ -0,0 +1,44 @@ +package flv + +import ( + "errors" + "github.com/gwuhaolin/livego/av" +) + +var ( + ErrAvcEndSEQ = errors.New("avc end sequence") +) + +type Demuxer struct { +} + +func NewDemuxer() *Demuxer { + return &Demuxer{} +} + +func (self *Demuxer) DemuxH(p *av.Packet) error { + var tag Tag + _, err := tag.ParseMeidaTagHeader(p.Data, p.IsVideo) + if err != nil { + return err + } + p.Header = &tag + + return nil +} + +func (self *Demuxer) Demux(p *av.Packet) error { + var tag Tag + n, err := tag.ParseMeidaTagHeader(p.Data, p.IsVideo) + if err != nil { + return err + } + if tag.CodecID() == av.VIDEO_H264 && + p.Data[0] == 0x17 && p.Data[1] == 0x02 { + return ErrAvcEndSEQ + } + p.Header = &tag + p.Data = p.Data[n:] + + return nil +} diff --git a/container/flv/muxer.go b/container/flv/muxer.go new file mode 100755 index 0000000..dadb70d --- /dev/null +++ b/container/flv/muxer.go @@ -0,0 +1,141 @@ +package flv + +import ( + "strings" + "time" + "flag" + "os" + "log" + "github.com/gwuhaolin/livego/utils/uid" + "github.com/gwuhaolin/livego/protocol/amf" + "github.com/gwuhaolin/livego/av" + "github.com/gwuhaolin/livego/utils/pio" +) + +var ( + flvHeader = []byte{0x46, 0x4c, 0x56, 0x01, 0x05, 0x00, 0x00, 0x00, 0x09} +) + +var ( + flvFile = flag.String("filFile", "./out.flv", "output flv file name") +) + +func NewFlv(handler av.Handler, info av.Info) { + patths := strings.SplitN(info.Key, "/", 2) + + if len(patths) != 2 { + log.Println("invalid info") + return + } + + w, err := os.OpenFile(*flvFile, os.O_CREATE|os.O_RDWR, 0755) + if err != nil { + log.Println("open file error: ", err) + } + + writer := NewFLVWriter(patths[0], patths[1], info.URL, w) + + handler.HandleWriter(writer) + + writer.Wait() + // close flv file + log.Println("close flv file") + writer.ctx.Close() +} + +const ( + headerLen = 11 +) + +type FLVWriter struct { + Uid string + av.RWBaser + app, title, url string + buf []byte + closed chan struct{} + ctx *os.File +} + +func NewFLVWriter(app, title, url string, ctx *os.File) *FLVWriter { + ret := &FLVWriter{ + Uid: uid.NEWID(), + app: app, + title: title, + url: url, + ctx: ctx, + RWBaser: av.NewRWBaser(time.Second * 10), + closed: make(chan struct{}), + buf: make([]byte, headerLen), + } + + ret.ctx.Write(flvHeader) + pio.PutI32BE(ret.buf[:4], 0) + ret.ctx.Write(ret.buf[:4]) + + return ret +} + +func (self *FLVWriter) Write(p av.Packet) error { + self.RWBaser.SetPreTime() + h := self.buf[:headerLen] + typeID := av.TAG_VIDEO + if !p.IsVideo { + if p.IsMetadata { + var err error + typeID = av.TAG_SCRIPTDATAAMF0 + p.Data, err = amf.MetaDataReform(p.Data, amf.DEL) + if err != nil { + return err + } + } else { + typeID = av.TAG_AUDIO + } + } + dataLen := len(p.Data) + timestamp := p.TimeStamp + timestamp += self.BaseTimeStamp() + self.RWBaser.RecTimeStamp(timestamp, uint32(typeID)) + + preDataLen := dataLen + headerLen + timestampbase := timestamp & 0xffffff + timestampExt := timestamp >> 24 & 0xff + + pio.PutU8(h[0:1], uint8(typeID)) + pio.PutI24BE(h[1:4], int32(dataLen)) + pio.PutI24BE(h[4:7], int32(timestampbase)) + pio.PutU8(h[7:8], uint8(timestampExt)) + + if _, err := self.ctx.Write(h); err != nil { + return err + } + + if _, err := self.ctx.Write(p.Data); err != nil { + return err + } + + pio.PutI32BE(h[:4], int32(preDataLen)) + if _, err := self.ctx.Write(h[:4]); err != nil { + return err + } + + return nil +} + +func (self *FLVWriter) Wait() { + select { + case <-self.closed: + return + } +} + +func (self *FLVWriter) Close(error) { + self.ctx.Close() + close(self.closed) +} + +func (self *FLVWriter) Info() (ret av.Info) { + ret.UID = self.Uid + ret.URL = self.url + ret.Key = self.app + "/" + self.title + return +} diff --git a/container/flv/tag.go b/container/flv/tag.go new file mode 100755 index 0000000..8f5fa76 --- /dev/null +++ b/container/flv/tag.go @@ -0,0 +1,180 @@ +package flv + +import ( + "fmt" + "github.com/gwuhaolin/livego/av" +) + +type flvTag struct { + fType uint8 + dataSize uint32 + timeStamp uint32 + streamID uint32 // always 0 +} + +type mediaTag struct { + /* + SoundFormat: UB[4] + 0 = Linear PCM, platform endian + 1 = ADPCM + 2 = MP3 + 3 = Linear PCM, little endian + 4 = Nellymoser 16-kHz mono + 5 = Nellymoser 8-kHz mono + 6 = Nellymoser + 7 = G.711 A-law logarithmic PCM + 8 = G.711 mu-law logarithmic PCM + 9 = reserved + 10 = AAC + 11 = Speex + 14 = MP3 8-Khz + 15 = Device-specific sound + Formats 7, 8, 14, and 15 are reserved for internal use + AAC is supported in Flash Player 9,0,115,0 and higher. + Speex is supported in Flash Player 10 and higher. + */ + soundFormat uint8 + + /* + SoundRate: UB[2] + Sampling rate + 0 = 5.5-kHz For AAC: always 3 + 1 = 11-kHz + 2 = 22-kHz + 3 = 44-kHz + */ + soundRate uint8 + + /* + SoundSize: UB[1] + 0 = snd8Bit + 1 = snd16Bit + Size of each sample. + This parameter only pertains to uncompressed formats. + Compressed formats always decode to 16 bits internally + */ + soundSize uint8 + + /* + SoundType: UB[1] + 0 = sndMono + 1 = sndStereo + Mono or stereo sound For Nellymoser: always 0 + For AAC: always 1 + */ + soundType uint8 + + /* + 0: AAC sequence header + 1: AAC raw + */ + aacPacketType uint8 + + /* + 1: keyframe (for AVC, a seekable frame) + 2: inter frame (for AVC, a non- seekable frame) + 3: disposable inter frame (H.263 only) + 4: generated keyframe (reserved for server use only) + 5: video info/command frame + */ + frameType uint8 + + /* + 1: JPEG (currently unused) + 2: Sorenson H.263 + 3: Screen video + 4: On2 VP6 + 5: On2 VP6 with alpha channel + 6: Screen video version 2 + 7: AVC + */ + codecID uint8 + + /* + 0: AVC sequence header + 1: AVC NALU + 2: AVC end of sequence (lower level NALU sequence ender is not required or supported) + */ + avcPacketType uint8 + + compositionTime int32 +} + +type Tag struct { + flvt flvTag + mediat mediaTag +} + +func (self *Tag) SoundFormat() uint8 { + return self.mediat.soundFormat +} + +func (self *Tag) AACPacketType() uint8 { + return self.mediat.aacPacketType +} + +func (self *Tag) IsKeyFrame() bool { + return self.mediat.frameType == av.FRAME_KEY +} + +func (self *Tag) IsSeq() bool { + return self.mediat.frameType == av.FRAME_KEY && + self.mediat.avcPacketType == av.AVC_SEQHDR +} + +func (self *Tag) CodecID() uint8 { + return self.mediat.codecID +} + +func (self *Tag) CompositionTime() int32 { + return self.mediat.compositionTime +} + +// ParseMeidaTagHeader, parse video, audio, tag header +func (self *Tag) ParseMeidaTagHeader(b []byte, isVideo bool) (n int, err error) { + switch isVideo { + case false: + n, err = self.parseAudioHeader(b) + case true: + n, err = self.parseVideoHeader(b) + } + return +} + +func (self *Tag) parseAudioHeader(b []byte) (n int, err error) { + if len(b) < n+1 { + err = fmt.Errorf("invalid audiodata len=%d", len(b)) + return + } + flags := b[0] + self.mediat.soundFormat = flags >> 4 + self.mediat.soundRate = (flags >> 2) & 0x3 + self.mediat.soundSize = (flags >> 1) & 0x1 + self.mediat.soundType = flags & 0x1 + n++ + switch self.mediat.soundFormat { + case av.SOUND_AAC: + self.mediat.aacPacketType = b[1] + n++ + } + return +} + +func (self *Tag) parseVideoHeader(b []byte) (n int, err error) { + if len(b) < n+5 { + err = fmt.Errorf("invalid videodata len=%d", len(b)) + return + } + flags := b[0] + self.mediat.frameType = flags >> 4 + self.mediat.codecID = flags & 0xf + n++ + if self.mediat.frameType == av.FRAME_INTER || self.mediat.frameType == av.FRAME_KEY { + self.mediat.avcPacketType = b[1] + for i := 2; i < 5; i++ { + self.mediat.compositionTime = self.mediat.compositionTime<<8 + int32(b[i]) + } + n += 4 + } + return +} diff --git a/container/mp4/muxer.go b/container/mp4/muxer.go new file mode 100755 index 0000000..c39d11c --- /dev/null +++ b/container/mp4/muxer.go @@ -0,0 +1 @@ +package mp4 diff --git a/container/ts/crc32.go b/container/ts/crc32.go new file mode 100755 index 0000000..ec8c480 --- /dev/null +++ b/container/ts/crc32.go @@ -0,0 +1,78 @@ +package ts + +func GenCrc32(src []byte) uint32 { + crcTable := []uint32{ + 0x00000000, 0x04c11db7, 0x09823b6e, 0x0d4326d9, + 0x130476dc, 0x17c56b6b, 0x1a864db2, 0x1e475005, + 0x2608edb8, 0x22c9f00f, 0x2f8ad6d6, 0x2b4bcb61, + 0x350c9b64, 0x31cd86d3, 0x3c8ea00a, 0x384fbdbd, + 0x4c11db70, 0x48d0c6c7, 0x4593e01e, 0x4152fda9, + 0x5f15adac, 0x5bd4b01b, 0x569796c2, 0x52568b75, + 0x6a1936c8, 0x6ed82b7f, 0x639b0da6, 0x675a1011, + 0x791d4014, 0x7ddc5da3, 0x709f7b7a, 0x745e66cd, + 0x9823b6e0, 0x9ce2ab57, 0x91a18d8e, 0x95609039, + 0x8b27c03c, 0x8fe6dd8b, 0x82a5fb52, 0x8664e6e5, + 0xbe2b5b58, 0xbaea46ef, 0xb7a96036, 0xb3687d81, + 0xad2f2d84, 0xa9ee3033, 0xa4ad16ea, 0xa06c0b5d, + 0xd4326d90, 0xd0f37027, 0xddb056fe, 0xd9714b49, + 0xc7361b4c, 0xc3f706fb, 0xceb42022, 0xca753d95, + 0xf23a8028, 0xf6fb9d9f, 0xfbb8bb46, 0xff79a6f1, + 0xe13ef6f4, 0xe5ffeb43, 0xe8bccd9a, 0xec7dd02d, + 0x34867077, 0x30476dc0, 0x3d044b19, 0x39c556ae, + 0x278206ab, 0x23431b1c, 0x2e003dc5, 0x2ac12072, + 0x128e9dcf, 0x164f8078, 0x1b0ca6a1, 0x1fcdbb16, + 0x018aeb13, 0x054bf6a4, 0x0808d07d, 0x0cc9cdca, + 0x7897ab07, 0x7c56b6b0, 0x71159069, 0x75d48dde, + 0x6b93dddb, 0x6f52c06c, 0x6211e6b5, 0x66d0fb02, + 0x5e9f46bf, 0x5a5e5b08, 0x571d7dd1, 0x53dc6066, + 0x4d9b3063, 0x495a2dd4, 0x44190b0d, 0x40d816ba, + 0xaca5c697, 0xa864db20, 0xa527fdf9, 0xa1e6e04e, + 0xbfa1b04b, 0xbb60adfc, 0xb6238b25, 0xb2e29692, + 0x8aad2b2f, 0x8e6c3698, 0x832f1041, 0x87ee0df6, + 0x99a95df3, 0x9d684044, 0x902b669d, 0x94ea7b2a, + 0xe0b41de7, 0xe4750050, 0xe9362689, 0xedf73b3e, + 0xf3b06b3b, 0xf771768c, 0xfa325055, 0xfef34de2, + 0xc6bcf05f, 0xc27dede8, 0xcf3ecb31, 0xcbffd686, + 0xd5b88683, 0xd1799b34, 0xdc3abded, 0xd8fba05a, + 0x690ce0ee, 0x6dcdfd59, 0x608edb80, 0x644fc637, + 0x7a089632, 0x7ec98b85, 0x738aad5c, 0x774bb0eb, + 0x4f040d56, 0x4bc510e1, 0x46863638, 0x42472b8f, + 0x5c007b8a, 0x58c1663d, 0x558240e4, 0x51435d53, + 0x251d3b9e, 0x21dc2629, 0x2c9f00f0, 0x285e1d47, + 0x36194d42, 0x32d850f5, 0x3f9b762c, 0x3b5a6b9b, + 0x0315d626, 0x07d4cb91, 0x0a97ed48, 0x0e56f0ff, + 0x1011a0fa, 0x14d0bd4d, 0x19939b94, 0x1d528623, + 0xf12f560e, 0xf5ee4bb9, 0xf8ad6d60, 0xfc6c70d7, + 0xe22b20d2, 0xe6ea3d65, 0xeba91bbc, 0xef68060b, + 0xd727bbb6, 0xd3e6a601, 0xdea580d8, 0xda649d6f, + 0xc423cd6a, 0xc0e2d0dd, 0xcda1f604, 0xc960ebb3, + 0xbd3e8d7e, 0xb9ff90c9, 0xb4bcb610, 0xb07daba7, + 0xae3afba2, 0xaafbe615, 0xa7b8c0cc, 0xa379dd7b, + 0x9b3660c6, 0x9ff77d71, 0x92b45ba8, 0x9675461f, + 0x8832161a, 0x8cf30bad, 0x81b02d74, 0x857130c3, + 0x5d8a9099, 0x594b8d2e, 0x5408abf7, 0x50c9b640, + 0x4e8ee645, 0x4a4ffbf2, 0x470cdd2b, 0x43cdc09c, + 0x7b827d21, 0x7f436096, 0x7200464f, 0x76c15bf8, + 0x68860bfd, 0x6c47164a, 0x61043093, 0x65c52d24, + 0x119b4be9, 0x155a565e, 0x18197087, 0x1cd86d30, + 0x029f3d35, 0x065e2082, 0x0b1d065b, 0x0fdc1bec, + 0x3793a651, 0x3352bbe6, 0x3e119d3f, 0x3ad08088, + 0x2497d08d, 0x2056cd3a, 0x2d15ebe3, 0x29d4f654, + 0xc5a92679, 0xc1683bce, 0xcc2b1d17, 0xc8ea00a0, + 0xd6ad50a5, 0xd26c4d12, 0xdf2f6bcb, 0xdbee767c, + 0xe3a1cbc1, 0xe760d676, 0xea23f0af, 0xeee2ed18, + 0xf0a5bd1d, 0xf464a0aa, 0xf9278673, 0xfde69bc4, + 0x89b8fd09, 0x8d79e0be, 0x803ac667, 0x84fbdbd0, + 0x9abc8bd5, 0x9e7d9662, 0x933eb0bb, 0x97ffad0c, + 0xafb010b1, 0xab710d06, 0xa6322bdf, 0xa2f33668, + 0xbcb4666d, 0xb8757bda, 0xb5365d03, 0xb1f740b4} + + j := byte(0) + crc32 := uint32(0xFFFFFFFF) + for i := 0; i < len(src); i++ { + j = ((byte(crc32>>24) ^ src[i]) & 0xff) + crc32 = uint32(uint32(crc32<<8) ^ uint32(crcTable[j])) + } + + return crc32 +} diff --git a/container/ts/muxer.go b/container/ts/muxer.go new file mode 100755 index 0000000..30a0e38 --- /dev/null +++ b/container/ts/muxer.go @@ -0,0 +1,365 @@ +package ts + +import ( + "io" + "github.com/gwuhaolin/livego/av" +) + +const ( + tsDefaultDataLen = 184 + tsPacketLen = 188 + h264DefaultHZ = 90 +) + +const ( + videoPID = 0x100 + audioPID = 0x101 + videoSID = 0xe0 + audioSID = 0xc0 +) + +type Muxer struct { + videoCc byte + audioCc byte + patCc byte + pmtCc byte + pat [tsPacketLen]byte + pmt [tsPacketLen]byte + tsPacket [tsPacketLen]byte +} + +func NewMuxer() *Muxer { + return &Muxer{} +} + +func (self *Muxer) Mux(p *av.Packet, w io.Writer) error { + first := true + wBytes := 0 + pesIndex := 0 + tmpLen := byte(0) + dataLen := byte(0) + + var pes pesHeader + dts := int64(p.TimeStamp) * int64(h264DefaultHZ) + pts := dts + pid := audioPID + var videoH av.VideoPacketHeader + if p.IsVideo { + pid = videoPID + videoH, _ = p.Header.(av.VideoPacketHeader) + pts = dts + int64(videoH.CompositionTime())*int64(h264DefaultHZ) + } + err := pes.packet(p, pts, dts) + if err != nil { + return err + } + pesHeaderLen := pes.len + packetBytesLen := len(p.Data) + int(pesHeaderLen) + + for { + if packetBytesLen <= 0 { + break + } + if p.IsVideo { + self.videoCc++ + if self.videoCc > 0xf { + self.videoCc = 0 + } + } else { + self.audioCc++ + if self.audioCc > 0xf { + self.audioCc = 0 + } + } + + i := byte(0) + + //sync byte + self.tsPacket[i] = 0x47 + i++ + + //error indicator, unit start indicator,ts priority,pid + self.tsPacket[i] = byte(pid >> 8) //pid high 5 bits + if first { + self.tsPacket[i] = self.tsPacket[i] | 0x40 //unit start indicator + } + i++ + + //pid low 8 bits + self.tsPacket[i] = byte(pid) + i++ + + //scram control, adaptation control, counter + if p.IsVideo { + self.tsPacket[i] = 0x10 | byte(self.videoCc&0x0f) + } else { + self.tsPacket[i] = 0x10 | byte(self.audioCc&0x0f) + } + i++ + + //关键帧需要加pcr + if first && p.IsVideo && videoH.IsKeyFrame() { + self.tsPacket[3] |= 0x20 + self.tsPacket[i] = 7 + i++ + self.tsPacket[i] = 0x50 + i++ + self.writePcr(self.tsPacket[0:], i, dts) + i += 6 + } + + //frame data + if packetBytesLen >= tsDefaultDataLen { + dataLen = tsDefaultDataLen + if first { + dataLen -= (i - 4) + } + } else { + self.tsPacket[3] |= 0x20 //have adaptation + remainBytes := byte(0) + dataLen = byte(packetBytesLen) + if first { + remainBytes = tsDefaultDataLen - dataLen - (i - 4) + } else { + remainBytes = tsDefaultDataLen - dataLen + } + self.adaptationBufInit(self.tsPacket[i:], byte(remainBytes)) + i += remainBytes + } + if first && i < tsPacketLen && pesHeaderLen > 0 { + tmpLen = tsPacketLen - i + if pesHeaderLen <= tmpLen { + tmpLen = pesHeaderLen + } + copy(self.tsPacket[i:], pes.data[pesIndex:pesIndex+int(tmpLen)]) + i += tmpLen + packetBytesLen -= int(tmpLen) + dataLen -= tmpLen + pesHeaderLen -= tmpLen + pesIndex += int(tmpLen) + } + + if i < tsPacketLen { + tmpLen = tsPacketLen - i + if tmpLen <= dataLen { + dataLen = tmpLen + } + copy(self.tsPacket[i:], p.Data[wBytes:wBytes+int(dataLen)]) + wBytes += int(dataLen) + packetBytesLen -= int(dataLen) + } + if w != nil { + if _, err := w.Write(self.tsPacket[0:]); err != nil { + return err + } + } + first = false + } + + return nil +} + +//PAT return pat data +func (self *Muxer) PAT() []byte { + i := 0 + remainByte := 0 + tsHeader := []byte{0x47, 0x40, 0x00, 0x10, 0x00} + patHeader := []byte{0x00, 0xb0, 0x0d, 0x00, 0x01, 0xc1, 0x00, 0x00, 0x00, 0x01, 0xf0, 0x01} + + if self.patCc > 0xf { + self.patCc = 0 + } + tsHeader[3] |= self.patCc & 0x0f + self.patCc++ + + copy(self.pat[i:], tsHeader) + i += len(tsHeader) + + copy(self.pat[i:], patHeader) + i += len(patHeader) + + crc32Value := GenCrc32(patHeader) + self.pat[i] = byte(crc32Value >> 24) + i++ + self.pat[i] = byte(crc32Value >> 16) + i++ + self.pat[i] = byte(crc32Value >> 8) + i++ + self.pat[i] = byte(crc32Value) + i++ + + remainByte = int(tsPacketLen - i) + for j := 0; j < remainByte; j++ { + self.pat[i+j] = 0xff + } + + return self.pat[0:] +} + +// PMT return pmt data +func (self *Muxer) PMT(soundFormat byte, hasVideo bool) []byte { + i := int(0) + j := int(0) + var progInfo []byte + remainBytes := int(0) + tsHeader := []byte{0x47, 0x50, 0x01, 0x10, 0x00} + pmtHeader := []byte{0x02, 0xb0, 0xff, 0x00, 0x01, 0xc1, 0x00, 0x00, 0xe1, 0x00, 0xf0, 0x00} + if !hasVideo { + pmtHeader[9] = 0x01 + progInfo = []byte{0x0f, 0xe1, 0x01, 0xf0, 0x00} + } else { + progInfo = []byte{0x1b, 0xe1, 0x00, 0xf0, 0x00, //h264 or h265* + 0x0f, 0xe1, 0x01, 0xf0, 0x00, //mp3 or aac + } + } + pmtHeader[2] = byte(len(progInfo) + 9 + 4) + + if self.pmtCc > 0xf { + self.pmtCc = 0 + } + tsHeader[3] |= self.pmtCc & 0x0f + self.pmtCc++ + + if soundFormat == 2 || + soundFormat == 14 { + if hasVideo { + progInfo[5] = 0x4 + } else { + progInfo[0] = 0x4 + } + } + + copy(self.pmt[i:], tsHeader) + i += len(tsHeader) + + copy(self.pmt[i:], pmtHeader) + i += len(pmtHeader) + + copy(self.pmt[i:], progInfo[0:]) + i += len(progInfo) + + crc32Value := GenCrc32(self.pmt[5: 5+len(pmtHeader)+len(progInfo)]) + self.pmt[i] = byte(crc32Value >> 24) + i++ + self.pmt[i] = byte(crc32Value >> 16) + i++ + self.pmt[i] = byte(crc32Value >> 8) + i++ + self.pmt[i] = byte(crc32Value) + i++ + + remainBytes = int(tsPacketLen - i) + for j = 0; j < remainBytes; j++ { + self.pmt[i+j] = 0xff + } + + return self.pmt[0:] +} + +func (self *Muxer) adaptationBufInit(src []byte, remainBytes byte) { + src[0] = byte(remainBytes - 1) + if remainBytes == 1 { + } else { + src[1] = 0x00 + for i := 2; i < len(src); i++ { + src[i] = 0xff + } + } + return +} + +func (self *Muxer) writePcr(b []byte, i byte, pcr int64) error { + b[i] = byte(pcr >> 25) + i++ + b[i] = byte((pcr >> 17) & 0xff) + i++ + b[i] = byte((pcr >> 9) & 0xff) + i++ + b[i] = byte((pcr >> 1) & 0xff) + i++ + b[i] = byte(((pcr & 0x1) << 7) | 0x7e) + i++ + b[i] = 0x00 + + return nil +} + +type pesHeader struct { + len byte + data [tsPacketLen]byte +} + +//pesPacket return pes packet +func (self *pesHeader) packet(p *av.Packet, pts, dts int64) error { + //PES header + i := 0 + self.data[i] = 0x00 + i++ + self.data[i] = 0x00 + i++ + self.data[i] = 0x01 + i++ + + sid := audioSID + if p.IsVideo { + sid = videoSID + } + self.data[i] = byte(sid) + i++ + + flag := 0x80 + ptslen := 5 + dtslen := ptslen + headerSize := ptslen + if p.IsVideo && pts != dts { + flag |= 0x40 + headerSize += 5 //add dts + } + size := len(p.Data) + headerSize + 3 + if size > 0xffff { + size = 0 + } + self.data[i] = byte(size >> 8) + i++ + self.data[i] = byte(size) + i++ + + self.data[i] = 0x80 + i++ + self.data[i] = byte(flag) + i++ + self.data[i] = byte(headerSize) + i++ + + self.writeTs(self.data[0:], i, flag>>6, pts) + i += ptslen + if p.IsVideo && pts != dts { + self.writeTs(self.data[0:], i, 1, dts) + i += dtslen + } + + self.len = byte(i) + + return nil +} + +func (self *pesHeader) writeTs(src []byte, i int, fb int, ts int64) { + val := uint32(0) + if ts > 0x1ffffffff { + ts -= 0x1ffffffff + } + val = uint32(fb<<4) | ((uint32(ts>>30) & 0x07) << 1) | 1 + src[i] = byte(val) + i++ + + val = ((uint32(ts>>15) & 0x7fff) << 1) | 1 + src[i] = byte(val >> 8) + i++ + src[i] = byte(val) + i++ + + val = (uint32(ts&0x7fff) << 1) | 1 + src[i] = byte(val >> 8) + i++ + src[i] = byte(val) +} diff --git a/container/ts/muxer_test.go b/container/ts/muxer_test.go new file mode 100755 index 0000000..9d46f49 --- /dev/null +++ b/container/ts/muxer_test.go @@ -0,0 +1,51 @@ +package ts + +import ( + "testing" + "github.com/gwuhaolin/livego/av" + "github.com/stretchr/testify/assert" +) + +type TestWriter struct { + buf []byte + count int +} + +//Write write p to w.buf +func (w *TestWriter) Write(p []byte) (int, error) { + w.count++ + w.buf = p + return len(p), nil +} + +func TestTSEncoder(t *testing.T) { + at := assert.New(t) + m := NewMuxer() + + w := &TestWriter{} + data := []byte{0xaf, 0x01, 0x21, 0x19, 0xd3, 0x40, 0x7d, 0x0b, 0x6d, 0x44, 0xae, 0x81, + 0x08, 0x00, 0x89, 0xa0, 0x3e, 0x85, 0xb6, 0x92, 0x57, 0x04, 0x80, 0x00, 0x5b, 0xb7, + 0x78, 0x00, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x38, 0x30, 0x00, 0x06, 0x00, 0x38, + } + p := av.Packet{ + IsVideo: false, + Data: data, + } + err := m.Mux(&p, w) + at.Equal(err, nil) + at.Equal(w.count, 1) + at.Equal(w.buf, []byte{0x47, 0x41, 0x01, 0x31, 0x81, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x01, 0xc0, 0x00, 0x30, + 0x80, 0x80, 0x05, 0x21, 0x00, 0x01, 0x00, 0x01, 0xaf, 0x01, 0x21, 0x19, 0xd3, 0x40, 0x7d, + 0x0b, 0x6d, 0x44, 0xae, 0x81, 0x08, 0x00, 0x89, 0xa0, 0x3e, 0x85, 0xb6, 0x92, 0x57, 0x04, + 0x80, 0x00, 0x5b, 0xb7, 0x78, 0x00, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x38, 0x30, 0x00, + 0x06, 0x00, 0x38}) +} diff --git a/main.go b/main.go new file mode 100755 index 0000000..8839ef5 --- /dev/null +++ b/main.go @@ -0,0 +1,172 @@ +package main + +import ( + "flag" + "net" + "os" + "os/signal" + "syscall" + "time" + "github.com/gwuhaolin/livego/protocol/rtmp" + "github.com/gwuhaolin/livego/protocol/hls" + "github.com/gwuhaolin/livego/protocol/httpflv" + "github.com/gwuhaolin/livego/protocol/httpopera" + "path/filepath" + "strings" + "io/ioutil" + "strconv" + "log" +) + +var ( + rtmpAddr = flag.String("rtmpAddr", ":1935", "The rtmp server address to bind.") + flvAddr = flag.String("flvAddr", ":8081", "the http-flv server address to bind.") + hlsAddr = flag.String("hlsAddr", ":8080", "the hls server address to bind.") + operaAddr = flag.String("operaAddr", "", "the http operation or config address to bind: 8082.") + CurDir string // save pid +) + +func getParentDirectory(dirctory string) string { + return substr(dirctory, 0, strings.LastIndex(dirctory, "/")) +} + +func getCurrentDirectory() string { + dir, err := filepath.Abs(filepath.Dir(os.Args[0])) + if err != nil { + log.Fatal(err) + } + return strings.Replace(dir, "\\", "/", -1) +} + +func substr(s string, pos, length int) string { + runes := []rune(s) + l := pos + length + if l > len(runes) { + l = len(runes) + } + return string(runes[pos:l]) +} + +func SavePid() error { + pidFilename := CurDir + "/pid/" + filepath.Base(os.Args[0]) + ".pid" + pid := os.Getpid() + return ioutil.WriteFile(pidFilename, []byte(strconv.Itoa(pid)), 0755) +} + +func init() { + CurDir = getParentDirectory(getCurrentDirectory()) + flag.Usage = func() { + flag.PrintDefaults() + } + flag.Parse() +} + +func catchSignal() { + sig := make(chan os.Signal) + signal.Notify(sig, syscall.SIGSTOP, syscall.SIGTERM) + <-sig + log.Println("recieved signal!") +} + +func startHls() *hls.Server { + hlsListen, err := net.Listen("tcp", *hlsAddr) + if err != nil { + log.Fatal(err) + } + + hlsServer := hls.NewServer() + go func() { + defer func() { + if r := recover(); r != nil { + log.Println("hls server panic: ", r) + } + }() + hlsServer.Serve(hlsListen) + }() + return hlsServer +} + +func startRtmp(stream *rtmp.RtmpStream, hlsServer *hls.Server) { + rtmplisten, err := net.Listen("tcp", *rtmpAddr) + if err != nil { + log.Fatal(err) + } + + rtmpServer := rtmp.NewRtmpServer(stream, hlsServer) + go func() { + defer func() { + if r := recover(); r != nil { + log.Println("hls server panic: ", r) + } + }() + rtmpServer.Serve(rtmplisten) + }() +} + +func startHTTPFlv(stream *rtmp.RtmpStream) { + flvListen, err := net.Listen("tcp", *flvAddr) + if err != nil { + log.Fatal(err) + } + + hdlServer := httpflv.NewServer(stream) + go func() { + defer func() { + if r := recover(); r != nil { + log.Println("hls server panic: ", r) + } + }() + hdlServer.Serve(flvListen) + }() +} + +func startHTTPOpera(stream *rtmp.RtmpStream) { + if *operaAddr != "" { + opListen, err := net.Listen("tcp", *operaAddr) + if err != nil { + log.Fatal(err) + } + opServer := httpopera.NewServer(stream) + go func() { + defer func() { + if r := recover(); r != nil { + log.Println("hls server panic: ", r) + } + }() + opServer.Serve(opListen) + }() + } +} + +func startLog() { + log.Println("RTMP Listen On", *rtmpAddr) + log.Println("HLS Listen On", *hlsAddr) + log.Println("HTTP-FLV Listen On", *flvAddr) + if *operaAddr != "" { + log.Println("HTTP-Operation Listen On", *operaAddr) + } + SavePid() +} + +func main() { + defer func() { + if r := recover(); r != nil { + log.Println("main panic: ", r) + time.Sleep(1 * time.Second) + } + }() + + stream := rtmp.NewRtmpStream() + // hls + h := startHls() + // rtmp + startRtmp(stream, h) + // http-flv + startHTTPFlv(stream) + // http-opera + startHTTPOpera(stream) + // my log + startLog() + // block + catchSignal() +} diff --git a/parser/aac/parser.go b/parser/aac/parser.go new file mode 100755 index 0000000..22fdf8a --- /dev/null +++ b/parser/aac/parser.go @@ -0,0 +1,113 @@ +package aac + +import ( + "errors" + "io" + "github.com/gwuhaolin/livego/av" +) + +type mpegExtension struct { + objectType byte + sampleRate byte +} + +type mpegCfgInfo struct { + objectType byte + sampleRate byte + channel byte + sbr byte + ps byte + frameLen byte + exceptionLogTs int64 + extension *mpegExtension +} + +var aacRates = []int{96000, 88200, 64000, 48000, 44100, 32000, 24000, 22050, 16000, 12000, 11025, 8000, 7350} + +var ( + specificBufInvalid = errors.New("audio mpegspecific error") + audioBufInvalid = errors.New("audiodata invalid") +) + +const ( + adtsHeaderLen = 7 +) + +type Parser struct { + gettedSpecific bool + adtsHeader []byte + cfgInfo *mpegCfgInfo +} + +func NewParser() *Parser { + return &Parser{ + gettedSpecific: false, + cfgInfo: &mpegCfgInfo{}, + adtsHeader: make([]byte, adtsHeaderLen), + } +} + +func (self *Parser) specificInfo(src []byte) error { + if len(src) < 2 { + return specificBufInvalid + } + self.gettedSpecific = true + self.cfgInfo.objectType = (src[0] >> 3) & 0xff + self.cfgInfo.sampleRate = ((src[0] & 0x07) << 1) | src[1]>>7 + self.cfgInfo.channel = (src[1] >> 3) & 0x0f + return nil +} + +func (self *Parser) adts(src []byte, w io.Writer) error { + if len(src) <= 0 || !self.gettedSpecific { + return audioBufInvalid + } + + frameLen := uint16(len(src)) + 7 + + //first write adts header + self.adtsHeader[0] = 0xff + self.adtsHeader[1] = 0xf1 + + self.adtsHeader[2] &= 0x00 + self.adtsHeader[2] = self.adtsHeader[2] | (self.cfgInfo.objectType-1)<<6 + self.adtsHeader[2] = self.adtsHeader[2] | (self.cfgInfo.sampleRate)<<2 + + self.adtsHeader[3] &= 0x00 + self.adtsHeader[3] = self.adtsHeader[3] | (self.cfgInfo.channel<<2)<<4 + self.adtsHeader[3] = self.adtsHeader[3] | byte((frameLen<<3)>>14) + + self.adtsHeader[4] &= 0x00 + self.adtsHeader[4] = self.adtsHeader[4] | byte((frameLen<<5)>>8) + + self.adtsHeader[5] &= 0x00 + self.adtsHeader[5] = self.adtsHeader[5] | byte(((frameLen<<13)>>13)<<5) + self.adtsHeader[5] = self.adtsHeader[5] | (0x7C<<1)>>3 + self.adtsHeader[6] = 0xfc + + if _, err := w.Write(self.adtsHeader[0:]); err != nil { + return err + } + if _, err := w.Write(src); err != nil { + return err + } + return nil +} + +func (self *Parser) SampleRate() int { + rate := 44100 + if self.cfgInfo.sampleRate <= byte(len(aacRates)-1) { + rate = aacRates[self.cfgInfo.sampleRate] + } + return rate +} + +func (self *Parser) Parse(b []byte, packetType uint8, w io.Writer) (err error) { + switch packetType { + case av.AAC_SEQHDR: + err = self.specificInfo(b) + case av.AAC_RAW: + err = self.adts(b, w) + } + return +} diff --git a/parser/h264/parser.go b/parser/h264/parser.go new file mode 100755 index 0000000..d659f8a --- /dev/null +++ b/parser/h264/parser.go @@ -0,0 +1,232 @@ +package h264 + +import ( + "bytes" + "errors" + "io" +) + +const ( + i_frame byte = 0 + p_frame byte = 1 + b_frame byte = 2 +) + +const ( + nalu_type_not_define byte = 0 + nalu_type_slice byte = 1 //slice_layer_without_partioning_rbsp() sliceheader + nalu_type_dpa byte = 2 // slice_data_partition_a_layer_rbsp( ), slice_header + nalu_type_dpb byte = 3 // slice_data_partition_b_layer_rbsp( ) + nalu_type_dpc byte = 4 // slice_data_partition_c_layer_rbsp( ) + nalu_type_idr byte = 5 // slice_layer_without_partitioning_rbsp( ),sliceheader + nalu_type_sei byte = 6 //sei_rbsp( ) + nalu_type_sps byte = 7 //seq_parameter_set_rbsp( ) + nalu_type_pps byte = 8 //pic_parameter_set_rbsp( ) + nalu_type_aud byte = 9 // access_unit_delimiter_rbsp( ) + nalu_type_eoesq byte = 10 //end_of_seq_rbsp( ) + nalu_type_eostream byte = 11 //end_of_stream_rbsp( ) + nalu_type_filler byte = 12 //filler_data_rbsp( ) +) + +const ( + naluBytesLen int = 4 + maxSpsPpsLen int = 2 * 1024 +) + +var ( + decDataNil = errors.New("dec buf is nil") + spsDataError = errors.New("sps data error") + ppsHeaderError = errors.New("pps header error") + ppsDataError = errors.New("pps data error") + naluHeaderInvalid = errors.New("nalu header invalid") + videoDataInvalid = errors.New("video data not match") + dataSizeNotMatch = errors.New("data size not match") + naluBodyLenError = errors.New("nalu body len error") +) + +var startCode = []byte{0x00, 0x00, 0x00, 0x01} +var naluAud = []byte{0x00, 0x00, 0x00, 0x01, 0x09, 0xf0} + +type Parser struct { + frameType byte + specificInfo []byte + pps *bytes.Buffer +} + +type sequenceHeader struct { + configVersion byte //8bits + avcProfileIndication byte //8bits + profileCompatility byte //8bits + avcLevelIndication byte //8bits + reserved1 byte //6bits + naluLen byte //2bits + reserved2 byte //3bits + spsNum byte //5bits + ppsNum byte //8bits + spsLen int + ppsLen int +} + +func NewParser() *Parser { + return &Parser{ + pps: bytes.NewBuffer(make([]byte, maxSpsPpsLen)), + } +} + +//return value 1:sps, value2 :pps +func (self *Parser) parseSpecificInfo(src []byte) error { + if len(src) < 9 { + return decDataNil + } + sps := []byte{} + pps := []byte{} + + var seq sequenceHeader + seq.configVersion = src[0] + seq.avcProfileIndication = src[1] + seq.profileCompatility = src[2] + seq.avcLevelIndication = src[3] + seq.reserved1 = src[4] & 0xfc + seq.naluLen = src[4]&0x03 + 1 + seq.reserved2 = src[5] >> 5 + + //get sps + seq.spsNum = src[5] & 0x1f + seq.spsLen = int(src[6])<<8 | int(src[7]) + + if len(src[8:]) < seq.spsLen || seq.spsLen <= 0 { + return spsDataError + } + sps = append(sps, startCode...) + sps = append(sps, src[8:(8 + seq.spsLen)]...) + + //get pps + tmpBuf := src[(8 + seq.spsLen):] + if len(tmpBuf) < 4 { + return ppsHeaderError + } + seq.ppsNum = tmpBuf[0] + seq.ppsLen = int(0)<<16 | int(tmpBuf[1])<<8 | int(tmpBuf[2]) + if len(tmpBuf[3:]) < seq.ppsLen || seq.ppsLen <= 0 { + return ppsDataError + } + + pps = append(pps, startCode...) + pps = append(pps, tmpBuf[3:]...) + + self.specificInfo = append(self.specificInfo, sps...) + self.specificInfo = append(self.specificInfo, pps...) + + return nil +} + +func (self *Parser) isNaluHeader(src []byte) bool { + if len(src) < naluBytesLen { + return false + } + return src[0] == 0x00 && + src[1] == 0x00 && + src[2] == 0x00 && + src[3] == 0x01 +} + +func (self *Parser) naluSize(src []byte) (int, error) { + if len(src) < naluBytesLen { + return 0, errors.New("nalusizedata invalid") + } + buf := src[:naluBytesLen] + size := int(0) + for i := 0; i < len(buf); i++ { + size = size<<8 + int(buf[i]) + } + return size, nil +} + +func (self *Parser) getAnnexbH264(src []byte, w io.Writer) error { + dataSize := len(src) + if dataSize < naluBytesLen { + return videoDataInvalid + } + self.pps.Reset() + _, err := w.Write(naluAud) + if err != nil { + return err + } + + index := 0 + nalLen := 0 + hasSpsPps := false + hasWriteSpsPps := false + + for dataSize > 0 { + nalLen, err = self.naluSize(src[index:]) + if err != nil { + return dataSizeNotMatch + } + index += naluBytesLen + dataSize -= naluBytesLen + if dataSize >= nalLen && len(src[index:]) >= nalLen && nalLen > 0 { + nalType := src[index] & 0x1f + switch nalType { + case nalu_type_aud: + case nalu_type_idr: + if !hasWriteSpsPps { + hasWriteSpsPps = true + if !hasSpsPps { + if _, err := w.Write(self.specificInfo); err != nil { + return err + } + } else { + if _, err := w.Write(self.pps.Bytes()); err != nil { + return err + } + } + } + fallthrough + case nalu_type_slice: + fallthrough + case nalu_type_sei: + _, err := w.Write(startCode) + if err != nil { + return err + } + _, err = w.Write(src[index: index+nalLen]) + if err != nil { + return err + } + case nalu_type_sps: + fallthrough + case nalu_type_pps: + hasSpsPps = true + _, err := self.pps.Write(startCode) + if err != nil { + return err + } + _, err = self.pps.Write(src[index: index+nalLen]) + if err != nil { + return err + } + } + index += nalLen + dataSize -= nalLen + } else { + return naluBodyLenError + } + } + return nil +} + +func (self *Parser) Parse(b []byte, isSeq bool, w io.Writer) (err error) { + switch isSeq { + case true: + err = self.parseSpecificInfo(b) + case false: + // is annexb + if self.isNaluHeader(b) { + _, err = w.Write(b) + } else { + err = self.getAnnexbH264(b, w) + } + } + return +} diff --git a/parser/h264/parser_test.go b/parser/h264/parser_test.go new file mode 100755 index 0000000..4773168 --- /dev/null +++ b/parser/h264/parser_test.go @@ -0,0 +1,91 @@ +package h264 + +import ( + "bytes" + "errors" + "testing" + "github.com/stretchr/testify/assert" +) + +func TestH264SeqDemux(t *testing.T) { + at := assert.New(t) + seq := []byte{ + 0x01, 0x4d, 0x00, 0x1e, 0xff, 0xe1, 0x00, 0x17, 0x67, 0x4d, 0x00, + 0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28, 0x28, 0x2f, + 0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x01, 0x00, + 0x04, 0x68, 0xde, 0x31, 0x12, + } + d := NewParser() + w := bytes.NewBuffer(nil) + err := d.Parse(seq, true, w) + at.Equal(err, nil) + at.Equal(d.specificInfo, []byte{0x00, 0x00, 0x00, 0x01, 0x67, 0x4d, 0x00, + 0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28, 0x28, 0x2f, + 0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x00, 0x00, 0x00, 0x01, 0x68, 0xde, 0x31, 0x12}) +} + +func TestH264AnnexbDemux(t *testing.T) { + at := assert.New(t) + nalu := []byte{ + 0x00, 0x00, 0x00, 0x01, 0x67, 0x4d, 0x00, 0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28, + 0x28, 0x2f, 0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x00, 0x00, 0x00, 0x01, 0x68, + 0xde, 0x31, 0x12, 0x00, 0x00, 0x00, 0x01, 0x65, 0x23, + } + d := NewParser() + w := bytes.NewBuffer(nil) + err := d.Parse(nalu, false, w) + at.Equal(err, nil) + at.Equal(w.Len(), 41) +} + +func TestH264NalueSizeException(t *testing.T) { + at := assert.New(t) + nalu := []byte{ + 0x00, 0x00, 0x10, + } + d := NewParser() + w := bytes.NewBuffer(nil) + err := d.Parse(nalu, false, w) + at.Equal(err, errors.New("video data not match")) +} + +func TestH264Mp4Demux(t *testing.T) { + at := assert.New(t) + nalu := []byte{ + 0x00, 0x00, 0x00, 0x17, 0x67, 0x4d, 0x00, 0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28, + 0x28, 0x2f, 0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x00, 0x00, 0x00, 0x04, 0x68, + 0xde, 0x31, 0x12, 0x00, 0x00, 0x00, 0x02, 0x65, 0x23, + } + d := NewParser() + w := bytes.NewBuffer(nil) + err := d.Parse(nalu, false, w) + at.Equal(err, nil) + at.Equal(w.Len(), 47) + at.Equal(w.Bytes(), []byte{0x00, 0x00, 0x00, 0x01, 0x09, 0xf0, 0x00, 0x00, 0x00, 0x01, 0x67, 0x4d, 0x00, 0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28, + 0x28, 0x2f, 0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x00, 0x00, 0x00, 0x01, 0x68, + 0xde, 0x31, 0x12, 0x00, 0x00, 0x00, 0x01, 0x65, 0x23}) +} + +func TestH264Mp4DemuxException1(t *testing.T) { + at := assert.New(t) + nalu := []byte{ + 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, + } + d := NewParser() + w := bytes.NewBuffer(nil) + + err := d.Parse(nalu, false, w) + at.Equal(err, naluBodyLenError) +} + +func TestH264Mp4DemuxException2(t *testing.T) { + at := assert.New(t) + nalu := []byte{ + 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x17, 0x67, 0x4d, 0x00, 0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28, + 0x28, 0x2f, 0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x00, 0x00, 0x00, + } + d := NewParser() + w := bytes.NewBuffer(nil) + err := d.Parse(nalu, false, w) + at.Equal(err, naluBodyLenError) +} diff --git a/parser/mp3/parser.go b/parser/mp3/parser.go new file mode 100755 index 0000000..d27dc91 --- /dev/null +++ b/parser/mp3/parser.go @@ -0,0 +1,41 @@ +package mp3 + +import "errors" + +type Parser struct { + samplingFrequency int +} + +func NewParser() *Parser { + return &Parser{} +} + +// sampling_frequency - indicates the sampling frequency, according to the following table. +// '00' 44.1 kHz +// '01' 48 kHz +// '10' 32 kHz +// '11' reserved +var mp3Rates = []int{44100, 48000, 32000} +var ( + errMp3DataInvalid = errors.New("mp3data invalid") + errIndexInvalid = errors.New("invalid rate index") +) + +func (self *Parser) Parse(src []byte) error { + if len(src) < 3 { + return errMp3DataInvalid + } + index := (src[2] >> 2) & 0x3 + if index <= byte(len(mp3Rates)-1) { + self.samplingFrequency = mp3Rates[index] + return nil + } + return errIndexInvalid +} + +func (self *Parser) SampleRate() int { + if self.samplingFrequency == 0 { + self.samplingFrequency = 44100 + } + return self.samplingFrequency +} diff --git a/parser/parser.go b/parser/parser.go new file mode 100755 index 0000000..42fc602 --- /dev/null +++ b/parser/parser.go @@ -0,0 +1,68 @@ +package parser + +import ( + "errors" + "io" + "github.com/gwuhaolin/livego/parser/mp3" + "github.com/gwuhaolin/livego/parser/aac" + "github.com/gwuhaolin/livego/av" + "github.com/gwuhaolin/livego/parser/h264" +) + +var ( + errNoAudio = errors.New("demuxer no audio") +) + +type CodecParser struct { + aac *aac.Parser + mp3 *mp3.Parser + h264 *h264.Parser +} + +func NewCodecParser() *CodecParser { + return &CodecParser{} +} + +func (self *CodecParser) SampleRate() (int, error) { + if self.aac == nil && self.mp3 == nil { + return 0, errNoAudio + } + if self.aac != nil { + return self.aac.SampleRate(), nil + } + return self.mp3.SampleRate(), nil +} + +func (self *CodecParser) Parse(p *av.Packet, w io.Writer) (err error) { + + switch p.IsVideo { + case true: + f, ok := p.Header.(av.VideoPacketHeader) + if ok { + if f.CodecID() == av.VIDEO_H264 { + if self.h264 == nil { + self.h264 = h264.NewParser() + } + err = self.h264.Parse(p.Data, f.IsSeq(), w) + } + } + case false: + f, ok := p.Header.(av.AudioPacketHeader) + if ok { + switch f.SoundFormat() { + case av.SOUND_AAC: + if self.aac == nil { + self.aac = aac.NewParser() + } + err = self.aac.Parse(p.Data, f.AACPacketType(), w) + case av.SOUND_MP3: + if self.mp3 == nil { + self.mp3 = mp3.NewParser() + } + err = self.mp3.Parse(p.Data) + } + } + + } + return +} diff --git a/protocol/amf/amf.go b/protocol/amf/amf.go new file mode 100755 index 0000000..e17fd47 --- /dev/null +++ b/protocol/amf/amf.go @@ -0,0 +1,50 @@ +package amf + +import ( + "errors" + "fmt" + "io" +) + +func (d *Decoder) DecodeBatch(r io.Reader, ver Version) (ret []interface{}, err error) { + var v interface{} + for { + v, err = d.Decode(r, ver) + if err != nil { + break + } + ret = append(ret, v) + } + return +} + +func (d *Decoder) Decode(r io.Reader, ver Version) (interface{}, error) { + switch ver { + case 0: + return d.DecodeAmf0(r) + case 3: + return d.DecodeAmf3(r) + } + + return nil, errors.New(fmt.Sprintf("decode amf: unsupported version %d", ver)) +} + +func (e *Encoder) EncodeBatch(w io.Writer, ver Version, val ...interface{}) (int, error) { + for _, v := range val { + if _, err := e.Encode(w, v, ver); err != nil { + return 0, err + } + } + return 0, nil +} + +func (e *Encoder) Encode(w io.Writer, val interface{}, ver Version) (int, error) { + switch ver { + case AMF0: + return e.EncodeAmf0(w, val) + case AMF3: + return e.EncodeAmf3(w, val) + } + + return 0, Error("encode amf: unsupported version %d", ver) +} diff --git a/protocol/amf/amf_test.go b/protocol/amf/amf_test.go new file mode 100755 index 0000000..9796d67 --- /dev/null +++ b/protocol/amf/amf_test.go @@ -0,0 +1,206 @@ +package amf + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "testing" + "time" +) + +func EncodeAndDecode(val interface{}, ver Version) (result interface{}, err error) { + enc := new(Encoder) + dec := new(Decoder) + + buf := new(bytes.Buffer) + + _, err = enc.Encode(buf, val, ver) + if err != nil { + return nil, errors.New(fmt.Sprintf("error in encode: %s", err)) + } + + result, err = dec.Decode(buf, ver) + if err != nil { + return nil, errors.New(fmt.Sprintf("error in decode: %s", err)) + } + + return +} + +func Compare(val interface{}, ver Version, name string, t *testing.T) { + result, err := EncodeAndDecode(val, ver) + if err != nil { + t.Errorf("%s: %s", name, err) + } + + if !reflect.DeepEqual(val, result) { + val_v := reflect.ValueOf(val) + result_v := reflect.ValueOf(result) + + t.Errorf("%s: comparison failed between %+v (%s) and %+v (%s)", name, val, val_v.Type(), result, result_v.Type()) + + Dump("expected", val) + Dump("got", result) + } + + // if val != result { + // t.Errorf("%s: comparison failed between %+v and %+v", name, val, result) + // } +} + +func TestAmf0Number(t *testing.T) { + Compare(float64(3.14159), 0, "amf0 number float", t) + Compare(float64(124567890), 0, "amf0 number high", t) + Compare(float64(-34.2), 0, "amf0 number negative", t) +} + +func TestAmf0String(t *testing.T) { + Compare("a pup!", 0, "amf0 string simple", t) + Compare("日本語", 0, "amf0 string utf8", t) +} + +func TestAmf0Boolean(t *testing.T) { + Compare(true, 0, "amf0 boolean true", t) + Compare(false, 0, "amf0 boolean false", t) +} + +func TestAmf0Null(t *testing.T) { + Compare(nil, 0, "amf0 boolean nil", t) +} + +func TestAmf0Object(t *testing.T) { + obj := make(Object) + obj["dog"] = "alfie" + obj["coffee"] = true + obj["drugs"] = false + obj["pi"] = 3.14159 + + res, err := EncodeAndDecode(obj, 0) + if err != nil { + t.Errorf("amf0 object: %s", err) + } + + result, ok := res.(Object) + if ok != true { + t.Errorf("amf0 object conversion failed") + } + + if result["dog"] != "alfie" { + t.Errorf("amf0 object string: comparison failed") + } + + if result["coffee"] != true { + t.Errorf("amf0 object true: comparison failed") + } + + if result["drugs"] != false { + t.Errorf("amf0 object false: comparison failed") + } + + if result["pi"] != float64(3.14159) { + t.Errorf("amf0 object float: comparison failed") + } +} + +func TestAmf0Array(t *testing.T) { + arr := [5]float64{1, 2, 3, 4, 5} + + res, err := EncodeAndDecode(arr, 0) + if err != nil { + t.Error("amf0 object: %s", err) + } + + result, ok := res.(Array) + if ok != true { + t.Errorf("amf0 array conversion failed") + } + + for i := 0; i < len(arr); i++ { + if arr[i] != result[i] { + t.Errorf("amf0 array %d comparison failed: %v / %v", i, arr[i], result[i]) + } + } +} + +func TestAmf3Integer(t *testing.T) { + Compare(int32(0), 3, "amf3 integer zero", t) + Compare(int32(1245), 3, "amf3 integer low", t) + Compare(int32(123456), 3, "amf3 integer high", t) +} + +func TestAmf3Double(t *testing.T) { + Compare(float64(3.14159), 3, "amf3 double float", t) + Compare(float64(1234567890), 3, "amf3 double high", t) + Compare(float64(-12345), 3, "amf3 double negative", t) +} + +func TestAmf3String(t *testing.T) { + Compare("a pup!", 0, "amf0 string simple", t) + Compare("日本語", 0, "amf0 string utf8", t) +} + +func TestAmf3Boolean(t *testing.T) { + Compare(true, 3, "amf3 boolean true", t) + Compare(false, 3, "amf3 boolean false", t) +} + +func TestAmf3Null(t *testing.T) { + Compare(nil, 3, "amf3 boolean nil", t) +} + +func TestAmf3Date(t *testing.T) { + t1 := time.Unix(time.Now().Unix(), 0).UTC() // nanoseconds discarded + t2 := time.Date(1983, 9, 4, 12, 4, 8, 0, time.UTC) + + Compare(t1, 3, "amf3 date now", t) + Compare(t2, 3, "amf3 date earlier", t) +} + +func TestAmf3Array(t *testing.T) { + obj := make(Object) + obj["key"] = "val" + + var arr Array + arr = append(arr, "amf") + arr = append(arr, float64(2)) + arr = append(arr, -34.95) + arr = append(arr, true) + arr = append(arr, false) + + res, err := EncodeAndDecode(arr, 3) + if err != nil { + t.Error("amf3 object: %s", err) + } + + result, ok := res.(Array) + if ok != true { + t.Errorf("amf3 array conversion failed: %+v", res) + } + + for i := 0; i < len(arr); i++ { + if arr[i] != result[i] { + t.Errorf("amf3 array %d comparison failed: %v / %v", i, arr[i], result[i]) + } + } +} + +func TestAmf3ByteArray(t *testing.T) { + enc := new(Encoder) + dec := new(Decoder) + + buf := new(bytes.Buffer) + + expect := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x00} + + enc.EncodeAmf3ByteArray(buf, expect, true) + + result, err := dec.DecodeAmf3ByteArray(buf, true) + if err != nil { + t.Errorf("err: %s", err) + } + + if bytes.Compare(result, expect) != 0 { + t.Errorf("expected: %+v, got %+v", expect, buf) + } +} diff --git a/protocol/amf/const.go b/protocol/amf/const.go new file mode 100755 index 0000000..02321ef --- /dev/null +++ b/protocol/amf/const.go @@ -0,0 +1,105 @@ +package amf + +import ( + "io" +) + +const ( + AMF0 = 0x00 + AMF3 = 0x03 +) + +const ( + AMF0_NUMBER_MARKER = 0x00 + AMF0_BOOLEAN_MARKER = 0x01 + AMF0_STRING_MARKER = 0x02 + AMF0_OBJECT_MARKER = 0x03 + AMF0_MOVIECLIP_MARKER = 0x04 + AMF0_NULL_MARKER = 0x05 + AMF0_UNDEFINED_MARKER = 0x06 + AMF0_REFERENCE_MARKER = 0x07 + AMF0_ECMA_ARRAY_MARKER = 0x08 + AMF0_OBJECT_END_MARKER = 0x09 + AMF0_STRICT_ARRAY_MARKER = 0x0a + AMF0_DATE_MARKER = 0x0b + AMF0_LONG_STRING_MARKER = 0x0c + AMF0_UNSUPPORTED_MARKER = 0x0d + AMF0_RECORDSET_MARKER = 0x0e + AMF0_XML_DOCUMENT_MARKER = 0x0f + AMF0_TYPED_OBJECT_MARKER = 0x10 + AMF0_ACMPLUS_OBJECT_MARKER = 0x11 +) + +const ( + AMF0_BOOLEAN_FALSE = 0x00 + AMF0_BOOLEAN_TRUE = 0x01 + AMF0_STRING_MAX = 65535 + AMF3_INTEGER_MAX = 536870911 +) + +const ( + AMF3_UNDEFINED_MARKER = 0x00 + AMF3_NULL_MARKER = 0x01 + AMF3_FALSE_MARKER = 0x02 + AMF3_TRUE_MARKER = 0x03 + AMF3_INTEGER_MARKER = 0x04 + AMF3_DOUBLE_MARKER = 0x05 + AMF3_STRING_MARKER = 0x06 + AMF3_XMLDOC_MARKER = 0x07 + AMF3_DATE_MARKER = 0x08 + AMF3_ARRAY_MARKER = 0x09 + AMF3_OBJECT_MARKER = 0x0a + AMF3_XMLSTRING_MARKER = 0x0b + AMF3_BYTEARRAY_MARKER = 0x0c +) + +type ExternalHandler func(*Decoder, io.Reader) (interface{}, error) + +type Decoder struct { + refCache []interface{} + stringRefs []string + objectRefs []interface{} + traitRefs []Trait + externalHandlers map[string]ExternalHandler +} + +func NewDecoder() *Decoder { + return &Decoder{ + externalHandlers: make(map[string]ExternalHandler), + } +} + +func (d *Decoder) RegisterExternalHandler(name string, f ExternalHandler) { + d.externalHandlers[name] = f +} + +type Encoder struct { +} + +type Version uint8 + +type Array []interface{} +type Object map[string]interface{} + +type TypedObject struct { + Type string + Object Object +} + +type Trait struct { + Type string + Externalizable bool + Dynamic bool + Properties []string +} + +func NewTrait() *Trait { + return &Trait{} +} + +func NewTypedObject() *TypedObject { + return &TypedObject{ + Type: "", + Object: make(Object), + } +} diff --git a/protocol/amf/decoder_amf0.go b/protocol/amf/decoder_amf0.go new file mode 100755 index 0000000..2de5e21 --- /dev/null +++ b/protocol/amf/decoder_amf0.go @@ -0,0 +1,335 @@ +package amf + +import ( + "encoding/binary" + "io" +) + +// amf0 polymorphic router +func (d *Decoder) DecodeAmf0(r io.Reader) (interface{}, error) { + marker, err := ReadMarker(r) + if err != nil { + return nil, err + } + + switch marker { + case AMF0_NUMBER_MARKER: + return d.DecodeAmf0Number(r, false) + case AMF0_BOOLEAN_MARKER: + return d.DecodeAmf0Boolean(r, false) + case AMF0_STRING_MARKER: + return d.DecodeAmf0String(r, false) + case AMF0_OBJECT_MARKER: + return d.DecodeAmf0Object(r, false) + case AMF0_MOVIECLIP_MARKER: + return nil, Error("decode amf0: unsupported type movieclip") + case AMF0_NULL_MARKER: + return d.DecodeAmf0Null(r, false) + case AMF0_UNDEFINED_MARKER: + return d.DecodeAmf0Undefined(r, false) + case AMF0_REFERENCE_MARKER: + return nil, Error("decode amf0: unsupported type reference") + case AMF0_ECMA_ARRAY_MARKER: + return d.DecodeAmf0EcmaArray(r, false) + case AMF0_STRICT_ARRAY_MARKER: + return d.DecodeAmf0StrictArray(r, false) + case AMF0_DATE_MARKER: + return d.DecodeAmf0Date(r, false) + case AMF0_LONG_STRING_MARKER: + return d.DecodeAmf0LongString(r, false) + case AMF0_UNSUPPORTED_MARKER: + return d.DecodeAmf0Unsupported(r, false) + case AMF0_RECORDSET_MARKER: + return nil, Error("decode amf0: unsupported type recordset") + case AMF0_XML_DOCUMENT_MARKER: + return d.DecodeAmf0XmlDocument(r, false) + case AMF0_TYPED_OBJECT_MARKER: + return d.DecodeAmf0TypedObject(r, false) + case AMF0_ACMPLUS_OBJECT_MARKER: + return d.DecodeAmf3(r) + } + + return nil, Error("decode amf0: unsupported type %d", marker) +} + +// marker: 1 byte 0x00 +// format: 8 byte big endian float64 +func (d *Decoder) DecodeAmf0Number(r io.Reader, decodeMarker bool) (result float64, err error) { + if err = AssertMarker(r, decodeMarker, AMF0_NUMBER_MARKER); err != nil { + return + } + + err = binary.Read(r, binary.BigEndian, &result) + if err != nil { + return float64(0), Error("amf0 decode: unable to read number: %s", err) + } + + return +} + +// marker: 1 byte 0x01 +// format: 1 byte, 0x00 = false, 0x01 = true +func (d *Decoder) DecodeAmf0Boolean(r io.Reader, decodeMarker bool) (result bool, err error) { + if err = AssertMarker(r, decodeMarker, AMF0_BOOLEAN_MARKER); err != nil { + return + } + + var b byte + if b, err = ReadByte(r); err != nil { + return + } + + if b == AMF0_BOOLEAN_FALSE { + return false, nil + } else if b == AMF0_BOOLEAN_TRUE { + return true, nil + } + + return false, Error("decode amf0: unexpected value %v for boolean", b) +} + +// marker: 1 byte 0x02 +// format: +// - 2 byte big endian uint16 header to determine size +// - n (size) byte utf8 string +func (d *Decoder) DecodeAmf0String(r io.Reader, decodeMarker bool) (result string, err error) { + if err = AssertMarker(r, decodeMarker, AMF0_STRING_MARKER); err != nil { + return + } + + var length uint16 + err = binary.Read(r, binary.BigEndian, &length) + if err != nil { + return "", Error("decode amf0: unable to decode string length: %s", err) + } + + var bytes = make([]byte, length) + if bytes, err = ReadBytes(r, int(length)); err != nil { + return "", Error("decode amf0: unable to decode string value: %s", err) + } + + return string(bytes), nil +} + +// marker: 1 byte 0x03 +// format: +// - loop encoded string followed by encoded value +// - terminated with empty string followed by 1 byte 0x09 +func (d *Decoder) DecodeAmf0Object(r io.Reader, decodeMarker bool) (Object, error) { + if err := AssertMarker(r, decodeMarker, AMF0_OBJECT_MARKER); err != nil { + return nil, err + } + + result := make(Object) + d.refCache = append(d.refCache, result) + + for { + key, err := d.DecodeAmf0String(r, false) + if err != nil { + return nil, err + } + + if key == "" { + if err = AssertMarker(r, true, AMF0_OBJECT_END_MARKER); err != nil { + return nil, Error("decode amf0: expected object end marker: %s", err) + } + + break + } + + value, err := d.DecodeAmf0(r) + if err != nil { + return nil, Error("decode amf0: unable to decode object value: %s", err) + } + + result[key] = value + } + + return result, nil + +} + +// marker: 1 byte 0x05 +// no additional data +func (d *Decoder) DecodeAmf0Null(r io.Reader, decodeMarker bool) (result interface{}, err error) { + err = AssertMarker(r, decodeMarker, AMF0_NULL_MARKER) + return +} + +// marker: 1 byte 0x06 +// no additional data +func (d *Decoder) DecodeAmf0Undefined(r io.Reader, decodeMarker bool) (result interface{}, err error) { + err = AssertMarker(r, decodeMarker, AMF0_UNDEFINED_MARKER) + return +} + +// marker: 1 byte 0x07 +// format: 2 byte big endian uint16 +/* +func (d *Decoder) DecodeAmf0Reference(r io.Reader, decodeMarker bool) (interface{}, error) { + if err := AssertMarker(r, decodeMarker, AMF0_REFERENCE_MARKER); err != nil { + return nil, err + } + + var err error + var ref uint16 + + err = binary.Read(r, binary.BigEndian, &ref) + if err != nil { + return nil, Error("decode amf0: unable to decode reference id: %s", err) + } + + if int(ref) > len(d.refCache) { + return nil, Error("decode amf0: bad reference %d (current length %d)", ref, len(d.refCache)) + } + + result := d.refCache[ref] + + return result, nil +} +*/ + +// marker: 1 byte 0x08 +// format: +// - 4 byte big endian uint32 with length of associative array +// - normal object format: +// - loop encoded string followed by encoded value +// - terminated with empty string followed by 1 byte 0x09 +func (d *Decoder) DecodeAmf0EcmaArray(r io.Reader, decodeMarker bool) (Object, error) { + if err := AssertMarker(r, decodeMarker, AMF0_ECMA_ARRAY_MARKER); err != nil { + return nil, err + } + + var length uint32 + err := binary.Read(r, binary.BigEndian, &length) + + result, err := d.DecodeAmf0Object(r, false) + if err != nil { + return nil, Error("decode amf0: unable to decode ecma array object: %s", err) + } + + return result, nil +} + +// marker: 1 byte 0x0a +// format: +// - 4 byte big endian uint32 to determine length of associative array +// - n (length) encoded values +func (d *Decoder) DecodeAmf0StrictArray(r io.Reader, decodeMarker bool) (result Array, err error) { + if err := AssertMarker(r, decodeMarker, AMF0_STRICT_ARRAY_MARKER); err != nil { + return nil, err + } + + var length uint32 + err = binary.Read(r, binary.BigEndian, &length) + if err != nil { + return nil, Error("decode amf0: unable to decode strict array length: %s", err) + } + + d.refCache = append(d.refCache, result) + + for i := uint32(0); i < length; i++ { + tmp, err := d.DecodeAmf0(r) + if err != nil { + return nil, Error("decode amf0: unable to decode strict array object: %s", err) + } + result = append(result, tmp) + } + + return result, nil +} + +// marker: 1 byte 0x0b +// format: +// - normal number format: +// - 8 byte big endian float64 +// - 2 byte unused +func (d *Decoder) DecodeAmf0Date(r io.Reader, decodeMarker bool) (result float64, err error) { + if err = AssertMarker(r, decodeMarker, AMF0_DATE_MARKER); err != nil { + return + } + + if result, err = d.DecodeAmf0Number(r, false); err != nil { + return float64(0), Error("decode amf0: unable to decode float in date: %s", err) + } + + if _, err = ReadBytes(r, 2); err != nil { + return float64(0), Error("decode amf0: unable to read 2 trail bytes in date: %s", err) + } + + return +} + +// marker: 1 byte 0x0c +// format: +// - 4 byte big endian uint32 header to determine size +// - n (size) byte utf8 string +func (d *Decoder) DecodeAmf0LongString(r io.Reader, decodeMarker bool) (result string, err error) { + if err = AssertMarker(r, decodeMarker, AMF0_LONG_STRING_MARKER); err != nil { + return + } + + var length uint32 + err = binary.Read(r, binary.BigEndian, &length) + if err != nil { + return "", Error("decode amf0: unable to decode long string length: %s", err) + } + + var bytes = make([]byte, length) + if bytes, err = ReadBytes(r, int(length)); err != nil { + return "", Error("decode amf0: unable to decode long string value: %s", err) + } + + return string(bytes), nil +} + +// marker: 1 byte 0x0d +// no additional data +func (d *Decoder) DecodeAmf0Unsupported(r io.Reader, decodeMarker bool) (result interface{}, err error) { + err = AssertMarker(r, decodeMarker, AMF0_UNSUPPORTED_MARKER) + return +} + +// marker: 1 byte 0x0f +// format: +// - normal long string format +// - 4 byte big endian uint32 header to determine size +// - n (size) byte utf8 string +func (d *Decoder) DecodeAmf0XmlDocument(r io.Reader, decodeMarker bool) (result string, err error) { + if err = AssertMarker(r, decodeMarker, AMF0_XML_DOCUMENT_MARKER); err != nil { + return + } + + return d.DecodeAmf0LongString(r, false) +} + +// marker: 1 byte 0x10 +// format: +// - normal string format: +// - 2 byte big endian uint16 header to determine size +// - n (size) byte utf8 string +// - normal object format: +// - loop encoded string followed by encoded value +// - terminated with empty string followed by 1 byte 0x09 +func (d *Decoder) DecodeAmf0TypedObject(r io.Reader, decodeMarker bool) (TypedObject, error) { + result := *new(TypedObject) + + err := AssertMarker(r, decodeMarker, AMF0_TYPED_OBJECT_MARKER) + if err != nil { + return result, err + } + + d.refCache = append(d.refCache, result) + + result.Type, err = d.DecodeAmf0String(r, false) + if err != nil { + return result, Error("decode amf0: typed object unable to determine type: %s", err) + } + + result.Object, err = d.DecodeAmf0Object(r, false) + if err != nil { + return result, Error("decode amf0: typed object unable to determine object: %s", err) + } + + return result, nil +} diff --git a/protocol/amf/decoder_amf0_test.go b/protocol/amf/decoder_amf0_test.go new file mode 100755 index 0000000..2ae6223 --- /dev/null +++ b/protocol/amf/decoder_amf0_test.go @@ -0,0 +1,588 @@ +package amf + +import ( + "bytes" + "testing" +) + +func TestDecodeAmf0Number(t *testing.T) { + buf := bytes.NewReader([]byte{0x00, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33}) + expect := float64(1.2) + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test number interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0Number(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test number interface without marker + buf.Seek(1, 0) + got, err = dec.DecodeAmf0Number(buf, false) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeAmf0BooleanTrue(t *testing.T) { + buf := bytes.NewReader([]byte{0x01, 0x01}) + expect := true + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test boolean interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0Boolean(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test boolean interface without marker + buf.Seek(1, 0) + got, err = dec.DecodeAmf0Boolean(buf, false) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeAmf0BooleanFalse(t *testing.T) { + buf := bytes.NewReader([]byte{0x01, 0x00}) + expect := false + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test boolean interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0Boolean(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test boolean interface without marker + buf.Seek(1, 0) + got, err = dec.DecodeAmf0Boolean(buf, false) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeAmf0String(t *testing.T) { + buf := bytes.NewReader([]byte{0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f}) + expect := "foo" + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test string interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0String(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test string interface without marker + buf.Seek(1, 0) + got, err = dec.DecodeAmf0String(buf, false) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeAmf0Object(t *testing.T) { + buf := bytes.NewReader([]byte{0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09}) + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + obj, ok := got.(Object) + if ok != true { + t.Errorf("expected result to cast to object") + } + if obj["foo"] != "bar" { + t.Errorf("expected {'foo'='bar'}, got %v", obj) + } + + // Test object interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0Object(buf, true) + if err != nil { + t.Errorf("%s", err) + } + obj, ok = got.(Object) + if ok != true { + t.Errorf("expected result to cast to object") + } + if obj["foo"] != "bar" { + t.Errorf("expected {'foo'='bar'}, got %v", obj) + } + + // Test object interface without marker + buf.Seek(1, 0) + got, err = dec.DecodeAmf0Object(buf, false) + if err != nil { + t.Errorf("%s", err) + } + obj, ok = got.(Object) + if ok != true { + t.Errorf("expected result to cast to object") + } + if obj["foo"] != "bar" { + t.Errorf("expected {'foo'='bar'}, got %v", obj) + } +} + +func TestDecodeAmf0Null(t *testing.T) { + buf := bytes.NewReader([]byte{0x05}) + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + if got != nil { + t.Errorf("expect nil got %v", got) + } + + // Test null interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0Null(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if got != nil { + t.Errorf("expect nil got %v", got) + } +} + +func TestDecodeAmf0Undefined(t *testing.T) { + buf := bytes.NewReader([]byte{0x06}) + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + if got != nil { + t.Errorf("expect nil got %v", got) + } + + // Test undefined interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0Undefined(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if got != nil { + t.Errorf("expect nil got %v", got) + } +} + +/* +func TestDecodeReference(t *testing.T) { + buf := bytes.NewReader([]byte{0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x00, 0x00, 0x00, 0x00, 0x09}) + + dec := &Decoder{} + + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + obj, ok := got.(Object) + if ok != true { + t.Errorf("expected result to cast to object") + } + + _, ok2 := obj["foo"].(Object) + if ok2 != true { + t.Errorf("expected foo value to cast to object") + } +} +*/ + +func TestDecodeAmf0EcmaArray(t *testing.T) { + buf := bytes.NewReader([]byte{0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09}) + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + obj, ok := got.(Object) + if ok != true { + t.Errorf("expected result to cast to object") + } + if obj["foo"] != "bar" { + t.Errorf("expected {'foo'='bar'}, got %v", obj) + } + + // Test ecma array interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0EcmaArray(buf, true) + if err != nil { + t.Errorf("%s", err) + } + obj, ok = got.(Object) + if ok != true { + t.Errorf("expected result to cast to object") + } + if obj["foo"] != "bar" { + t.Errorf("expected {'foo'='bar'}, got %v", obj) + } + + // Test ecma array interface without marker + buf.Seek(1, 0) + got, err = dec.DecodeAmf0EcmaArray(buf, false) + if err != nil { + t.Errorf("%s", err) + } + obj, ok = got.(Object) + if ok != true { + t.Errorf("expected result to cast to ecma array") + } + if obj["foo"] != "bar" { + t.Errorf("expected {'foo'='bar'}, got %v", obj) + } +} + +func TestDecodeAmf0StrictArray(t *testing.T) { + buf := bytes.NewReader([]byte{0x0a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x40, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x05}) + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + arr, ok := got.(Array) + if ok != true { + t.Errorf("expected result to cast to strict array") + } + if arr[0] != float64(5) { + t.Errorf("expected array[0] to be 5, got %v", arr[0]) + } + if arr[1] != "foo" { + t.Errorf("expected array[1] to be 'foo', got %v", arr[1]) + } + if arr[2] != nil { + t.Errorf("expected array[2] to be nil, got %v", arr[2]) + } + + // Test strict array interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0StrictArray(buf, true) + if err != nil { + t.Errorf("%s", err) + } + arr, ok = got.(Array) + if ok != true { + t.Errorf("expected result to cast to strict array") + } + if arr[0] != float64(5) { + t.Errorf("expected array[0] to be 5, got %v", arr[0]) + } + if arr[1] != "foo" { + t.Errorf("expected array[1] to be 'foo', got %v", arr[1]) + } + if arr[2] != nil { + t.Errorf("expected array[2] to be nil, got %v", arr[2]) + } + + // Test strict array interface without marker + buf.Seek(1, 0) + got, err = dec.DecodeAmf0StrictArray(buf, false) + if err != nil { + t.Errorf("%s", err) + } + arr, ok = got.(Array) + if ok != true { + t.Errorf("expected result to cast to strict array") + } + if arr[0] != float64(5) { + t.Errorf("expected array[0] to be 5, got %v", arr[0]) + } + if arr[1] != "foo" { + t.Errorf("expected array[1] to be 'foo', got %v", arr[1]) + } + if arr[2] != nil { + t.Errorf("expected array[2] to be nil, got %v", arr[2]) + } +} + +func TestDecodeAmf0Date(t *testing.T) { + buf := bytes.NewReader([]byte{0x0b, 0x40, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + expect := float64(5) + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test date interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0Date(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test date interface without marker + buf.Seek(1, 0) + got, err = dec.DecodeAmf0Date(buf, false) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeAmf0LongString(t *testing.T) { + buf := bytes.NewReader([]byte{0x0c, 0x00, 0x00, 0x00, 0x03, 0x66, 0x6f, 0x6f}) + expect := "foo" + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test long string interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0LongString(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test long string interface without marker + buf.Seek(1, 0) + got, err = dec.DecodeAmf0LongString(buf, false) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeAmf0Unsupported(t *testing.T) { + buf := bytes.NewReader([]byte{0x0d}) + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + if got != nil { + t.Errorf("expect nil got %v", got) + } + + // Test unsupported interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0Unsupported(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if got != nil { + t.Errorf("expect nil got %v", got) + } +} + +func TestDecodeAmf0XmlDocument(t *testing.T) { + buf := bytes.NewReader([]byte{0x0f, 0x00, 0x00, 0x00, 0x03, 0x66, 0x6f, 0x6f}) + expect := "foo" + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test long string interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0XmlDocument(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + // Test long string interface without marker + buf.Seek(1, 0) + got, err = dec.DecodeAmf0XmlDocument(buf, false) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeAmf0TypedObject(t *testing.T) { + + buf := bytes.NewReader([]byte{ + 0x10, 0x00, 0x0F, 'o', 'r', 'g', + '.', 'a', 'm', 'f', '.', 'A', + 'S', 'C', 'l', 'a', 's', 's', + 0x00, 0x03, 'b', 'a', 'z', 0x05, + 0x00, 0x03, 'f', 'o', 'o', 0x02, + 0x00, 0x03, 'b', 'a', 'r', 0x00, + 0x00, 0x09, + }) + + dec := &Decoder{} + + // Test main interface + got, err := dec.DecodeAmf0(buf) + if err != nil { + t.Errorf("%s", err) + } + tobj, ok := got.(TypedObject) + if ok != true { + t.Errorf("expected result to cast to typed object, got %+v", tobj) + } + if tobj.Type != "org.amf.ASClass" { + t.Errorf("expected typed object type to be 'class', got %v", tobj.Type) + } + if tobj.Object["foo"] != "bar" { + t.Errorf("expected typed object object foo to eql bar, got %v", tobj.Object["foo"]) + } + if tobj.Object["baz"] != nil { + t.Errorf("expected typed object object baz to nil, got %v", tobj.Object["baz"]) + } + + // Test typed object interface with marker + buf.Seek(0, 0) + got, err = dec.DecodeAmf0TypedObject(buf, true) + if err != nil { + t.Errorf("%s", err) + } + tobj, ok = got.(TypedObject) + if ok != true { + t.Errorf("expected result to cast to typed object, got %+v", tobj) + } + if tobj.Type != "org.amf.ASClass" { + t.Errorf("expected typed object type to be 'class', got %v", tobj.Type) + } + if tobj.Object["foo"] != "bar" { + t.Errorf("expected typed object object foo to eql bar, got %v", tobj.Object["foo"]) + } + if tobj.Object["baz"] != nil { + t.Errorf("expected typed object object baz to nil, got %v", tobj.Object["baz"]) + } + + // Test typed object interface without marker + buf.Seek(1, 0) + got, err = dec.DecodeAmf0TypedObject(buf, false) + if err != nil { + t.Errorf("%s", err) + } + tobj, ok = got.(TypedObject) + if ok != true { + t.Errorf("expected result to cast to typed object, got %+v", tobj) + } + if tobj.Type != "org.amf.ASClass" { + t.Errorf("expected typed object type to be 'class', got %v", tobj.Type) + } + if tobj.Object["foo"] != "bar" { + t.Errorf("expected typed object object foo to eql bar, got %v", tobj.Object["foo"]) + } + if tobj.Object["baz"] != nil { + t.Errorf("expected typed object object baz to nil, got %v", tobj.Object["baz"]) + } +} diff --git a/protocol/amf/decoder_amf3.go b/protocol/amf/decoder_amf3.go new file mode 100755 index 0000000..7c8260a --- /dev/null +++ b/protocol/amf/decoder_amf3.go @@ -0,0 +1,496 @@ +package amf + +import ( + "encoding/binary" + "io" + "time" +) + +// amf3 polymorphic router +func (d *Decoder) DecodeAmf3(r io.Reader) (interface{}, error) { + marker, err := ReadMarker(r) + if err != nil { + return nil, err + } + + switch marker { + case AMF3_UNDEFINED_MARKER: + return d.DecodeAmf3Undefined(r, false) + case AMF3_NULL_MARKER: + return d.DecodeAmf3Null(r, false) + case AMF3_FALSE_MARKER: + return d.DecodeAmf3False(r, false) + case AMF3_TRUE_MARKER: + return d.DecodeAmf3True(r, false) + case AMF3_INTEGER_MARKER: + return d.DecodeAmf3Integer(r, false) + case AMF3_DOUBLE_MARKER: + return d.DecodeAmf3Double(r, false) + case AMF3_STRING_MARKER: + return d.DecodeAmf3String(r, false) + case AMF3_XMLDOC_MARKER: + return d.DecodeAmf3Xml(r, false) + case AMF3_DATE_MARKER: + return d.DecodeAmf3Date(r, false) + case AMF3_ARRAY_MARKER: + return d.DecodeAmf3Array(r, false) + case AMF3_OBJECT_MARKER: + return d.DecodeAmf3Object(r, false) + case AMF3_XMLSTRING_MARKER: + return d.DecodeAmf3Xml(r, false) + case AMF3_BYTEARRAY_MARKER: + return d.DecodeAmf3ByteArray(r, false) + } + + return nil, Error("decode amf3: unsupported type %d", marker) +} + +// marker: 1 byte 0x00 +// no additional data +func (d *Decoder) DecodeAmf3Undefined(r io.Reader, decodeMarker bool) (result interface{}, err error) { + err = AssertMarker(r, decodeMarker, AMF3_UNDEFINED_MARKER) + return +} + +// marker: 1 byte 0x01 +// no additional data +func (d *Decoder) DecodeAmf3Null(r io.Reader, decodeMarker bool) (result interface{}, err error) { + err = AssertMarker(r, decodeMarker, AMF3_NULL_MARKER) + return +} + +// marker: 1 byte 0x02 +// no additional data +func (d *Decoder) DecodeAmf3False(r io.Reader, decodeMarker bool) (result bool, err error) { + err = AssertMarker(r, decodeMarker, AMF3_FALSE_MARKER) + result = false + return +} + +// marker: 1 byte 0x03 +// no additional data +func (d *Decoder) DecodeAmf3True(r io.Reader, decodeMarker bool) (result bool, err error) { + err = AssertMarker(r, decodeMarker, AMF3_TRUE_MARKER) + result = true + return +} + +// marker: 1 byte 0x04 +func (d *Decoder) DecodeAmf3Integer(r io.Reader, decodeMarker bool) (result int32, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_INTEGER_MARKER); err != nil { + return + } + + var u29 uint32 + u29, err = d.decodeU29(r) + if err != nil { + return + } + + result = int32(u29) + if result > 0xfffffff { + result = int32(u29 - 0x20000000) + } + + return +} + +// marker: 1 byte 0x05 +func (d *Decoder) DecodeAmf3Double(r io.Reader, decodeMarker bool) (result float64, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_DOUBLE_MARKER); err != nil { + return + } + + err = binary.Read(r, binary.BigEndian, &result) + if err != nil { + return float64(0), Error("amf3 decode: unable to read double: %s", err) + } + + return +} + +// marker: 1 byte 0x06 +// format: +// - u29 reference int. if reference, no more data. if not reference, +// length value of bytes to read to complete string. +func (d *Decoder) DecodeAmf3String(r io.Reader, decodeMarker bool) (result string, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_STRING_MARKER); err != nil { + return + } + + var isRef bool + var refVal uint32 + isRef, refVal, err = d.decodeReferenceInt(r) + if err != nil { + return "", Error("amf3 decode: unable to decode string reference and length: %s", err) + } + + if isRef { + result = d.stringRefs[refVal] + return + } + + buf := make([]byte, refVal) + _, err = r.Read(buf) + if err != nil { + return "", Error("amf3 decode: unable to read string: %s", err) + } + + result = string(buf) + if result != "" { + d.stringRefs = append(d.stringRefs, result) + } + + return +} + +// marker: 1 byte 0x08 +// format: +// - u29 reference int, if reference, no more data +// - timestamp double +func (d *Decoder) DecodeAmf3Date(r io.Reader, decodeMarker bool) (result time.Time, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_DATE_MARKER); err != nil { + return + } + + var isRef bool + var refVal uint32 + isRef, refVal, err = d.decodeReferenceInt(r) + if err != nil { + return result, Error("amf3 decode: unable to decode date reference and length: %s", err) + } + + if isRef { + res, ok := d.objectRefs[refVal].(time.Time) + if ok != true { + return result, Error("amf3 decode: unable to extract time from date object references") + } + + return res, err + } + + var u64 float64 + err = binary.Read(r, binary.BigEndian, &u64) + if err != nil { + return result, Error("amf3 decode: unable to read double: %s", err) + } + + result = time.Unix(int64(u64/1000), 0).UTC() + + d.objectRefs = append(d.objectRefs, result) + + return +} + +// marker: 1 byte 0x09 +// format: +// - u29 reference int. if reference, no more data. +// - string representing associative array if present +// - n values (length of u29) +func (d *Decoder) DecodeAmf3Array(r io.Reader, decodeMarker bool) (result Array, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_ARRAY_MARKER); err != nil { + return + } + + var isRef bool + var refVal uint32 + isRef, refVal, err = d.decodeReferenceInt(r) + if err != nil { + return result, Error("amf3 decode: unable to decode array reference and length: %s", err) + } + + if isRef { + objRefId := refVal >> 1 + + res, ok := d.objectRefs[objRefId].(Array) + if ok != true { + return result, Error("amf3 decode: unable to extract array from object references") + } + + return res, err + } + + var key string + key, err = d.DecodeAmf3String(r, false) + if err != nil { + return result, Error("amf3 decode: unable to read key for array: %s", err) + } + + if key != "" { + return result, Error("amf3 decode: array key is not empty, can't handle associative array") + } + + for i := uint32(0); i < refVal; i++ { + tmp, err := d.DecodeAmf3(r) + if err != nil { + return result, Error("amf3 decode: array element could not be decoded: %s", err) + } + result = append(result, tmp) + } + + d.objectRefs = append(d.objectRefs, result) + + return +} + +// marker: 1 byte 0x09 +// format: oh dear god +func (d *Decoder) DecodeAmf3Object(r io.Reader, decodeMarker bool) (result interface{}, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_OBJECT_MARKER); err != nil { + return nil, err + } + + // decode the initial u29 + isRef, refVal, err := d.decodeReferenceInt(r) + if err != nil { + return nil, Error("amf3 decode: unable to decode object reference and length: %s", err) + } + + // if this is a object reference only, grab it and return it + if isRef { + objRefId := refVal >> 1 + + return d.objectRefs[objRefId], nil + } + + // each type has traits that are cached, if the peer sent a reference + // then we'll need to look it up and use it. + var trait Trait + + traitIsRef := (refVal & 0x01) == 0 + + if traitIsRef { + traitRef := refVal >> 1 + trait = d.traitRefs[traitRef] + + } else { + // build a new trait from what's left of the given u29 + trait = *NewTrait() + trait.Externalizable = (refVal & 0x02) != 0 + trait.Dynamic = (refVal & 0x04) != 0 + + var cls string + cls, err = d.DecodeAmf3String(r, false) + if err != nil { + return result, Error("amf3 decode: unable to read trait type for object: %s", err) + } + trait.Type = cls + + // traits have property keys, encoded as amf3 strings + propLength := refVal >> 3 + for i := uint32(0); i < propLength; i++ { + tmp, err := d.DecodeAmf3String(r, false) + if err != nil { + return result, Error("amf3 decode: unable to read trait property for object: %s", err) + } + trait.Properties = append(trait.Properties, tmp) + } + + d.traitRefs = append(d.traitRefs, trait) + } + + d.objectRefs = append(d.objectRefs, result) + + // objects can be externalizable, meaning that the system has no concrete understanding of + // their properties or how they are encoded. in that case, we need to find and delegate behavior + // to the right object. + if trait.Externalizable { + switch trait.Type { + case "DSA": // AsyncMessageExt + result, err = d.decodeAsyncMessageExt(r) + if err != nil { + return result, Error("amf3 decode: unable to decode dsa: %s", err) + } + case "DSK": // AcknowledgeMessageExt + result, err = d.decodeAcknowledgeMessageExt(r) + if err != nil { + return result, Error("amf3 decode: unable to decode dsk: %s", err) + } + case "flex.messaging.io.ArrayCollection": + result, err = d.decodeArrayCollection(r) + if err != nil { + return result, Error("amf3 decode: unable to decode ac: %s", err) + } + + // store an extra reference to array collection container + d.objectRefs = append(d.objectRefs, result) + + default: + fn, ok := d.externalHandlers[trait.Type] + if ok { + result, err = fn(d, r) + if err != nil { + return result, Error("amf3 decode: unable to call external decoder for type %s: %s", trait.Type, err) + } + } else { + return result, Error("amf3 decode: unable to decode external type %s, no handler", trait.Type) + } + } + + return result, err + } + + var key string + var val interface{} + var obj Object + + obj = make(Object) + + // non-externalizable objects have property keys in traits, iterate through them + // and add the read values to the object + for _, key = range trait.Properties { + val, err = d.DecodeAmf3(r) + if err != nil { + return result, Error("amf3 decode: unable to decode object property: %s", err) + } + + obj[key] = val + } + + // if an object is dynamic, it can have extra key/value data at the end. in this case, + // read keys until we get an empty one. + if trait.Dynamic { + for { + key, err = d.DecodeAmf3String(r, false) + if err != nil { + return result, Error("amf3 decode: unable to decode dynamic key: %s", err) + } + if key == "" { + break + } + val, err = d.DecodeAmf3(r) + if err != nil { + return result, Error("amf3 decode: unable to decode dynamic value: %s", err) + } + + obj[key] = val + } + } + + result = obj + + return +} + +// marker: 1 byte 0x07 or 0x0b +// format: +// - u29 reference int. if reference, no more data. if not reference, +// length value of bytes to read to complete string. +func (d *Decoder) DecodeAmf3Xml(r io.Reader, decodeMarker bool) (result string, err error) { + if decodeMarker { + var marker byte + marker, err = ReadMarker(r) + if err != nil { + return "", err + } + + if (marker != AMF3_XMLDOC_MARKER) && (marker != AMF3_XMLSTRING_MARKER) { + return "", Error("decode assert marker failed: expected %v or %v, got %v", AMF3_XMLDOC_MARKER, AMF3_XMLSTRING_MARKER, marker) + } + } + + var isRef bool + var refVal uint32 + isRef, refVal, err = d.decodeReferenceInt(r) + if err != nil { + return "", Error("amf3 decode: unable to decode xml reference and length: %s", err) + } + + if isRef { + var ok bool + buf := d.objectRefs[refVal] + result, ok = buf.(string) + if ok != true { + return "", Error("amf3 decode: cannot coerce object reference into xml string") + } + + return + } + + buf := make([]byte, refVal) + _, err = r.Read(buf) + if err != nil { + return "", Error("amf3 decode: unable to read xml string: %s", err) + } + + result = string(buf) + + if result != "" { + d.objectRefs = append(d.objectRefs, result) + } + + return +} + +// marker: 1 byte 0x0c +// format: +// - u29 reference int. if reference, no more data. if not reference, +// length value of bytes to read. +func (d *Decoder) DecodeAmf3ByteArray(r io.Reader, decodeMarker bool) (result []byte, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_BYTEARRAY_MARKER); err != nil { + return + } + + var isRef bool + var refVal uint32 + isRef, refVal, err = d.decodeReferenceInt(r) + if err != nil { + return result, Error("amf3 decode: unable to decode byte array reference and length: %s", err) + } + + if isRef { + var ok bool + result, ok = d.objectRefs[refVal].([]byte) + if ok != true { + return result, Error("amf3 decode: unable to convert object ref to bytes") + } + + return + } + + result = make([]byte, refVal) + _, err = r.Read(result) + if err != nil { + return result, Error("amf3 decode: unable to read bytearray: %s", err) + } + + d.objectRefs = append(d.objectRefs, result) + + return +} + +func (d *Decoder) decodeU29(r io.Reader) (result uint32, err error) { + var b byte + + for i := 0; i < 3; i++ { + b, err = ReadByte(r) + if err != nil { + return + } + result = (result << 7) + uint32(b&0x7F) + if (b & 0x80) == 0 { + return + } + } + + b, err = ReadByte(r) + if err != nil { + return + } + + result = ((result << 8) + uint32(b)) + + return +} + +func (d *Decoder) decodeReferenceInt(r io.Reader) (isRef bool, refVal uint32, err error) { + u29, err := d.decodeU29(r) + if err != nil { + return false, 0, Error("amf3 decode: unable to decode reference int: %s", err) + } + + isRef = u29&0x01 == 0 + refVal = u29 >> 1 + + return +} diff --git a/protocol/amf/decoder_amf3_external.go b/protocol/amf/decoder_amf3_external.go new file mode 100755 index 0000000..35f4b7d --- /dev/null +++ b/protocol/amf/decoder_amf3_external.go @@ -0,0 +1,127 @@ +package amf + +import ( + "fmt" + "io" + "math" +) + +// Abstract external boilerplate +func (d *Decoder) decodeAbstractMessage(r io.Reader) (result Object, err error) { + result = make(Object) + + if err = d.decodeExternal(r, &result, + []string{"body", "clientId", "destination", "headers", "messageId", "timeStamp", "timeToLive"}, + []string{"clientIdBytes", "messageIdBytes"}); err != nil { + return result, Error("unable to decode abstract external: %s", err) + } + + return +} + +// DSA +func (d *Decoder) decodeAsyncMessageExt(r io.Reader) (result Object, err error) { + return d.decodeAsyncMessage(r) +} +func (d *Decoder) decodeAsyncMessage(r io.Reader) (result Object, err error) { + result, err = d.decodeAbstractMessage(r) + if err != nil { + return result, Error("unable to decode abstract for async: %s", err) + } + + if err = d.decodeExternal(r, &result, []string{"correlationId", "correlationIdBytes"}); err != nil { + return result, Error("unable to decode async external: %s", err) + } + + return +} + +// DSK +func (d *Decoder) decodeAcknowledgeMessageExt(r io.Reader) (result Object, err error) { + return d.decodeAcknowledgeMessage(r) +} +func (d *Decoder) decodeAcknowledgeMessage(r io.Reader) (result Object, err error) { + result, err = d.decodeAsyncMessage(r) + if err != nil { + return result, Error("unable to decode async for ack: %s", err) + } + + if err = d.decodeExternal(r, &result); err != nil { + return result, Error("unable to decode ack external: %s", err) + } + + return +} + +// flex.messaging.io.ArrayCollection +func (d *Decoder) decodeArrayCollection(r io.Reader) (interface{}, error) { + result, err := d.DecodeAmf3(r) + if err != nil { + return result, Error("cannot decode child of array collection: %s", err) + } + + return result, nil +} + +func (d *Decoder) decodeExternal(r io.Reader, obj *Object, fieldSets ...[]string) (err error) { + var flagSet []uint8 + var reservedPosition uint8 + var fieldNames []string + + flagSet, err = readFlags(r) + if err != nil { + return Error("unable to read flags: %s", err) + } + + for i, flags := range flagSet { + if i < len(fieldSets) { + fieldNames = fieldSets[i] + } else { + fieldNames = []string{} + } + + reservedPosition = uint8(len(fieldNames)) + + for p, field := range fieldNames { + flagBit := uint8(math.Exp2(float64(p))) + if (flags & flagBit) != 0 { + tmp, err := d.DecodeAmf3(r) + if err != nil { + return Error("unable to decode external field %s %d %d (%#v): %s", field, i, p, flagSet, err) + } + (*obj)[field] = tmp + } + } + + if (flags >> reservedPosition) != 0 { + for j := reservedPosition; j < 6; j++ { + if ((flags >> j) & 0x01) != 0 { + field := fmt.Sprintf("extra_%d_%d", i, j) + tmp, err := d.DecodeAmf3(r) + if err != nil { + return Error("unable to decode post-external field %d %d (%#v): %s", i, j, flagSet, err) + } + (*obj)[field] = tmp + } + } + } + } + + return +} + +func readFlags(r io.Reader) (result []uint8, err error) { + for { + flag, err := ReadByte(r) + if err != nil { + return result, Error("unable to read flags: %s", err) + } + + result = append(result, flag) + if (flag & 0x80) == 0 { + break + } + } + + return +} diff --git a/protocol/amf/decoder_amf3_test.go b/protocol/amf/decoder_amf3_test.go new file mode 100755 index 0000000..7833054 --- /dev/null +++ b/protocol/amf/decoder_amf3_test.go @@ -0,0 +1,220 @@ +package amf + +import ( + "bytes" + "testing" +) + +type u29TestCase struct { + value uint32 + expect []byte +} + +var u29TestCases = []u29TestCase{ + {1, []byte{0x01}}, + {2, []byte{0x02}}, + {127, []byte{0x7F}}, + {128, []byte{0x81, 0x00}}, + {255, []byte{0x81, 0x7F}}, + {256, []byte{0x82, 0x00}}, + {0x3FFF, []byte{0xFF, 0x7F}}, + {0x4000, []byte{0x81, 0x80, 0x00}}, + {0x7FFF, []byte{0x81, 0xFF, 0x7F}}, + {0x8000, []byte{0x82, 0x80, 0x00}}, + {0x1FFFFF, []byte{0xFF, 0xFF, 0x7F}}, + {0x200000, []byte{0x80, 0xC0, 0x80, 0x00}}, + {0x3FFFFF, []byte{0x80, 0xFF, 0xFF, 0xFF}}, + {0x400000, []byte{0x81, 0x80, 0x80, 0x00}}, + {0x0FFFFFFF, []byte{0xBF, 0xFF, 0xFF, 0xFF}}, +} + +func TestDecodeAmf3Undefined(t *testing.T) { + buf := bytes.NewReader([]byte{0x00}) + + dec := new(Decoder) + + got, err := dec.DecodeAmf3(buf) + if err != nil { + t.Errorf("%s", err) + } + if got != nil { + t.Errorf("expect nil got %v", got) + } +} + +func TestDecodeAmf3Null(t *testing.T) { + buf := bytes.NewReader([]byte{0x01}) + + dec := new(Decoder) + + got, err := dec.DecodeAmf3(buf) + if err != nil { + t.Errorf("%s", err) + } + if got != nil { + t.Errorf("expect nil got %v", got) + } +} + +func TestDecodeAmf3False(t *testing.T) { + buf := bytes.NewReader([]byte{0x02}) + expect := false + + dec := new(Decoder) + + got, err := dec.DecodeAmf3(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeAmf3True(t *testing.T) { + buf := bytes.NewReader([]byte{0x03}) + expect := true + + dec := new(Decoder) + + got, err := dec.DecodeAmf3(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeU29(t *testing.T) { + dec := new(Decoder) + + for _, tc := range u29TestCases { + buf := bytes.NewBuffer(tc.expect) + n, err := dec.decodeU29(buf) + if err != nil { + t.Errorf("DecodeAmf3Integer error: %s", err) + } + if n != tc.value { + t.Errorf("DecodeAmf3Integer expect n %x got %x", tc.value, n) + } + } +} + +func TestDecodeAmf3Integer(t *testing.T) { + dec := new(Decoder) + + buf := bytes.NewReader([]byte{0x04, 0xFF, 0xFF, 0x7F}) + expect := int32(2097151) + + got, err := dec.DecodeAmf3(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + buf.Seek(0, 0) + got, err = dec.DecodeAmf3Integer(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } + + buf.Seek(1, 0) + got, err = dec.DecodeAmf3Integer(buf, false) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeAmf3Double(t *testing.T) { + buf := bytes.NewReader([]byte{0x05, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33}) + expect := float64(1.2) + + dec := new(Decoder) + + got, err := dec.DecodeAmf3(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeAmf3String(t *testing.T) { + buf := bytes.NewReader([]byte{0x06, 0x07, 'f', 'o', 'o'}) + expect := "foo" + + dec := new(Decoder) + + got, err := dec.DecodeAmf3(buf) + if err != nil { + t.Errorf("%s", err) + } + if expect != got { + t.Errorf("expect %v got %v", expect, got) + } +} + +func TestDecodeAmf3Array(t *testing.T) { + buf := bytes.NewReader([]byte{0x09, 0x13, 0x01, + 0x06, 0x03, '1', + 0x06, 0x03, '2', + 0x06, 0x03, '3', + 0x06, 0x03, '4', + 0x06, 0x03, '5', + 0x06, 0x03, '6', + 0x06, 0x03, '7', + 0x06, 0x03, '8', + 0x06, 0x03, '9', + }) + + dec := new(Decoder) + expect := []string{"1", "2", "3", "4", "5", "6", "7", "8", "9"} + got, err := dec.DecodeAmf3Array(buf, true) + if err != nil { + t.Errorf("err: %s", err) + } + + for i, v := range expect { + if got[i] != v { + t.Error("expected array element %d to be %v, got %v", i, v, got[i]) + } + } +} + +func TestDecodeAmf3Object(t *testing.T) { + buf := bytes.NewReader([]byte{ + 0x0a, 0x23, 0x1f, 'o', 'r', 'g', '.', 'a', + 'm', 'f', '.', 'A', 'S', 'C', 'l', 'a', + 's', 's', 0x07, 'b', 'a', 'z', 0x07, 'f', + 'o', 'o', 0x01, 0x06, 0x07, 'b', 'a', 'r', + }) + + dec := new(Decoder) + got, err := dec.DecodeAmf3(buf) + if err != nil { + t.Errorf("err: %s", err) + } + + to, ok := got.(Object) + if ok != true { + t.Error("unable to cast object as typed object") + } + + if to["foo"] != "bar" { + t.Error("expected foo to be bar, got: %+v", to["foo"]) + } + + if to["baz"] != nil { + t.Error("expected baz to be nil, got: %+v", to["baz"]) + } +} diff --git a/protocol/amf/encoder_amf0.go b/protocol/amf/encoder_amf0.go new file mode 100755 index 0000000..20b4c41 --- /dev/null +++ b/protocol/amf/encoder_amf0.go @@ -0,0 +1,308 @@ +package amf + +import ( + "encoding/binary" + "io" + "reflect" +) + +// amf0 polymorphic router +func (e *Encoder) EncodeAmf0(w io.Writer, val interface{}) (int, error) { + if val == nil { + return e.EncodeAmf0Null(w, true) + } + + v := reflect.ValueOf(val) + if !v.IsValid() { + return e.EncodeAmf0Null(w, true) + } + + switch v.Kind() { + case reflect.String: + str := v.String() + if len(str) <= AMF0_STRING_MAX { + return e.EncodeAmf0String(w, str, true) + } else { + return e.EncodeAmf0LongString(w, str, true) + } + case reflect.Bool: + return e.EncodeAmf0Boolean(w, v.Bool(), true) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return e.EncodeAmf0Number(w, float64(v.Int()), true) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return e.EncodeAmf0Number(w, float64(v.Uint()), true) + case reflect.Float32, reflect.Float64: + return e.EncodeAmf0Number(w, float64(v.Float()), true) + case reflect.Array, reflect.Slice: + length := v.Len() + arr := make(Array, length) + for i := 0; i < length; i++ { + arr[i] = v.Index(int(i)).Interface() + } + return e.EncodeAmf0StrictArray(w, arr, true) + case reflect.Map: + obj, ok := val.(Object) + if ok != true { + return 0, Error("encode amf0: unable to create object from map") + } + return e.EncodeAmf0Object(w, obj, true) + } + + if _, ok := val.(TypedObject); ok { + return 0, Error("encode amf0: unsupported type typed object") + } + + return 0, Error("encode amf0: unsupported type %s", v.Type()) +} + +// marker: 1 byte 0x00 +// format: 8 byte big endian float64 +func (e *Encoder) EncodeAmf0Number(w io.Writer, val float64, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF0_NUMBER_MARKER); err != nil { + return + } + n += 1 + } + + err = binary.Write(w, binary.BigEndian, &val) + if err != nil { + return + } + n += 8 + + return +} + +// marker: 1 byte 0x01 +// format: 1 byte, 0x00 = false, 0x01 = true +func (e *Encoder) EncodeAmf0Boolean(w io.Writer, val bool, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF0_BOOLEAN_MARKER); err != nil { + return + } + n += 1 + } + + var m int + buf := make([]byte, 1) + if val { + buf[0] = AMF0_BOOLEAN_TRUE + } else { + buf[0] = AMF0_BOOLEAN_FALSE + } + + m, err = w.Write(buf) + if err != nil { + return + } + n += m + + return +} + +// marker: 1 byte 0x02 +// format: +// - 2 byte big endian uint16 header to determine size +// - n (size) byte utf8 string +func (e *Encoder) EncodeAmf0String(w io.Writer, val string, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF0_STRING_MARKER); err != nil { + return + } + n += 1 + } + + var m int + length := uint16(len(val)) + err = binary.Write(w, binary.BigEndian, length) + if err != nil { + return n, Error("encode amf0: unable to encode string length: %s", err) + } + n += 2 + + m, err = w.Write([]byte(val)) + if err != nil { + return n, Error("encode amf0: unable to encode string value: %s", err) + } + n += m + + return +} + +// marker: 1 byte 0x03 +// format: +// - loop encoded string followed by encoded value +// - terminated with empty string followed by 1 byte 0x09 +func (e *Encoder) EncodeAmf0Object(w io.Writer, val Object, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF0_OBJECT_MARKER); err != nil { + return + } + n += 1 + } + + var m int + for k, v := range val { + m, err = e.EncodeAmf0String(w, k, false) + if err != nil { + return n, Error("encode amf0: unable to encode object key: %s", err) + } + n += m + + m, err = e.EncodeAmf0(w, v) + if err != nil { + return n, Error("encode amf0: unable to encode object value: %s", err) + } + n += m + } + + m, err = e.EncodeAmf0String(w, "", false) + if err != nil { + return n, Error("encode amf0: unable to encode object empty string: %s", err) + } + n += m + + err = WriteMarker(w, AMF0_OBJECT_END_MARKER) + if err != nil { + return n, Error("encode amf0: unable to object end marker: %s", err) + } + n += 1 + + return +} + +// marker: 1 byte 0x05 +// no additional data +func (e *Encoder) EncodeAmf0Null(w io.Writer, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF0_NULL_MARKER); err != nil { + return + } + n += 1 + } + + return +} + +// marker: 1 byte 0x06 +// no additional data +func (e *Encoder) EncodeAmf0Undefined(w io.Writer, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF0_UNDEFINED_MARKER); err != nil { + return + } + n += 1 + } + + return +} + +// marker: 1 byte 0x08 +// format: +// - 4 byte big endian uint32 with length of associative array +// - normal object format: +// - loop encoded string followed by encoded value +// - terminated with empty string followed by 1 byte 0x09 +func (e *Encoder) EncodeAmf0EcmaArray(w io.Writer, val Object, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF0_ECMA_ARRAY_MARKER); err != nil { + return + } + n += 1 + } + + var m int + length := uint32(len(val)) + err = binary.Write(w, binary.BigEndian, length) + if err != nil { + return n, Error("encode amf0: unable to encode ecma array length: %s", err) + } + n += 4 + + m, err = e.EncodeAmf0Object(w, val, false) + if err != nil { + return n, Error("encode amf0: unable to encode ecma array object: %s", err) + } + n += m + + return +} + +// marker: 1 byte 0x0a +// format: +// - 4 byte big endian uint32 to determine length of associative array +// - n (length) encoded values +func (e *Encoder) EncodeAmf0StrictArray(w io.Writer, val Array, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF0_STRICT_ARRAY_MARKER); err != nil { + return + } + n += 1 + } + + var m int + length := uint32(len(val)) + err = binary.Write(w, binary.BigEndian, length) + if err != nil { + return n, Error("encode amf0: unable to encode strict array length: %s", err) + } + n += 4 + + for _, v := range val { + m, err = e.EncodeAmf0(w, v) + if err != nil { + return n, Error("encode amf0: unable to encode strict array element: %s", err) + } + n += m + } + + return +} + +// marker: 1 byte 0x0c +// format: +// - 4 byte big endian uint32 header to determine size +// - n (size) byte utf8 string +func (e *Encoder) EncodeAmf0LongString(w io.Writer, val string, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF0_LONG_STRING_MARKER); err != nil { + return + } + n += 1 + } + + var m int + length := uint32(len(val)) + err = binary.Write(w, binary.BigEndian, length) + if err != nil { + return n, Error("encode amf0: unable to encode long string length: %s", err) + } + n += 4 + + m, err = w.Write([]byte(val)) + if err != nil { + return n, Error("encode amf0: unable to encode long string value: %s", err) + } + n += m + + return +} + +// marker: 1 byte 0x0d +// no additional data +func (e *Encoder) EncodeAmf0Unsupported(w io.Writer, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF0_UNSUPPORTED_MARKER); err != nil { + return + } + n += 1 + } + + return +} + +// marker: 1 byte 0x11 +func (e *Encoder) EncodeAmf0Amf3Marker(w io.Writer) error { + return WriteMarker(w, AMF0_ACMPLUS_OBJECT_MARKER) +} diff --git a/protocol/amf/encoder_amf0_test.go b/protocol/amf/encoder_amf0_test.go new file mode 100755 index 0000000..48ac277 --- /dev/null +++ b/protocol/amf/encoder_amf0_test.go @@ -0,0 +1,212 @@ +package amf + +import ( + "bytes" + "encoding/binary" + "testing" +) + +func TestEncodeAmf0Number(t *testing.T) { + buf := new(bytes.Buffer) + expect := []byte{0x00, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33} + + enc := new(Encoder) + + n, err := enc.EncodeAmf0(buf, float64(1.2)) + if err != nil { + t.Errorf("%s", err) + } + if n != 9 { + t.Errorf("expected to write 9 bytes, actual %d", n) + } + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf0BooleanTrue(t *testing.T) { + buf := new(bytes.Buffer) + expect := []byte{0x01, 0x01} + + enc := new(Encoder) + + n, err := enc.EncodeAmf0(buf, true) + if err != nil { + t.Errorf("%s", err) + } + if n != 2 { + t.Errorf("expected to write 2 bytes, actual %d", n) + } + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf0BooleanFalse(t *testing.T) { + buf := new(bytes.Buffer) + expect := []byte{0x01, 0x00} + + enc := new(Encoder) + + n, err := enc.EncodeAmf0(buf, false) + if err != nil { + t.Errorf("%s", err) + } + if n != 2 { + t.Errorf("expected to write 2 bytes, actual %d", n) + } + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf0String(t *testing.T) { + buf := new(bytes.Buffer) + expect := []byte{0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f} + + enc := new(Encoder) + + n, err := enc.EncodeAmf0(buf, "foo") + if err != nil { + t.Errorf("%s", err) + } + if n != 6 { + t.Errorf("expected to write 6 bytes, actual %d", n) + } + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf0Object(t *testing.T) { + buf := new(bytes.Buffer) + expect := []byte{0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09} + + enc := new(Encoder) + + obj := make(Object) + obj["foo"] = "bar" + + n, err := enc.EncodeAmf0(buf, obj) + if err != nil { + t.Errorf("%s", err) + } + if n != 15 { + t.Errorf("expected to write 15 bytes, actual %d", n) + } + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf0EcmaArray(t *testing.T) { + buf := new(bytes.Buffer) + expect := []byte{0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09} + + enc := new(Encoder) + + obj := make(Object) + obj["foo"] = "bar" + + _, err := enc.EncodeAmf0EcmaArray(buf, obj, true) + if err != nil { + t.Errorf("%s", err) + } + + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf0StrictArray(t *testing.T) { + buf := new(bytes.Buffer) + expect := []byte{0x0a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x40, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x05} + + enc := new(Encoder) + + arr := make(Array, 3) + arr[0] = float64(5) + arr[1] = "foo" + arr[2] = nil + + _, err := enc.EncodeAmf0StrictArray(buf, arr, true) + if err != nil { + t.Errorf("%s", err) + } + + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf0Null(t *testing.T) { + buf := new(bytes.Buffer) + expect := []byte{0x05} + + enc := new(Encoder) + + n, err := enc.EncodeAmf0(buf, nil) + if err != nil { + t.Errorf("%s", err) + } + if n != 1 { + t.Errorf("expected to write 1 byte, actual %d", n) + } + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf0LongString(t *testing.T) { + buf := new(bytes.Buffer) + + testBytes := []byte("12345678") + + tbuf := new(bytes.Buffer) + for i := 0; i < 65536; i++ { + tbuf.Write(testBytes) + } + + enc := new(Encoder) + + _, err := enc.EncodeAmf0(buf, string(tbuf.Bytes())) + if err != nil { + t.Errorf("%s", err) + } + + mbuf := make([]byte, 1) + _, err = buf.Read(mbuf) + if err != nil { + t.Errorf("error reading header") + } + + if mbuf[0] != 0x0c { + t.Errorf("marker mismatch") + } + + var length uint32 + err = binary.Read(buf, binary.BigEndian, &length) + if err != nil { + t.Errorf("error reading buffer") + } + if length != (65536 * 8) { + t.Errorf("expected length to be %d, got %d", (65536 * 8), length) + } + + tmpBuf := make([]byte, 8) + counter := 0 + for buf.Len() > 0 { + n, err := buf.Read(tmpBuf) + if err != nil { + t.Fatalf("test long string result check, read data(%d) error: %s, n: %d", counter, err, n) + } + if n != 8 { + t.Fatalf("test long string result check, read data(%d) n: %d", counter, n) + } + if !bytes.Equal(testBytes, tmpBuf) { + t.Fatalf("test long string result check, read data % x", tmpBuf) + } + + counter++ + } +} diff --git a/protocol/amf/encoder_amf3.go b/protocol/amf/encoder_amf3.go new file mode 100755 index 0000000..9565c46 --- /dev/null +++ b/protocol/amf/encoder_amf3.go @@ -0,0 +1,431 @@ +package amf + +import ( + "encoding/binary" + "io" + "reflect" + "sort" + "time" +) + +// amf3 polymorphic router + +func (e *Encoder) EncodeAmf3(w io.Writer, val interface{}) (int, error) { + if val == nil { + return e.EncodeAmf3Null(w, true) + } + + v := reflect.ValueOf(val) + if !v.IsValid() { + return e.EncodeAmf3Null(w, true) + } + + switch v.Kind() { + case reflect.String: + return e.EncodeAmf3String(w, v.String(), true) + case reflect.Bool: + if v.Bool() { + return e.EncodeAmf3True(w, true) + } else { + return e.EncodeAmf3False(w, true) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: + n := v.Int() + if n >= 0 && n <= AMF3_INTEGER_MAX { + return e.EncodeAmf3Integer(w, uint32(n), true) + } else { + return e.EncodeAmf3Double(w, float64(n), true) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + n := v.Uint() + if n <= AMF3_INTEGER_MAX { + return e.EncodeAmf3Integer(w, uint32(n), true) + } else { + return e.EncodeAmf3Double(w, float64(n), true) + } + case reflect.Int64: + return e.EncodeAmf3Double(w, float64(v.Int()), true) + case reflect.Uint64: + return e.EncodeAmf3Double(w, float64(v.Uint()), true) + case reflect.Float32, reflect.Float64: + return e.EncodeAmf3Double(w, float64(v.Float()), true) + case reflect.Array, reflect.Slice: + length := v.Len() + arr := make(Array, length) + for i := 0; i < length; i++ { + arr[i] = v.Index(int(i)).Interface() + } + return e.EncodeAmf3Array(w, arr, true) + case reflect.Map: + obj, ok := val.(Object) + if ok != true { + return 0, Error("encode amf3: unable to create object from map") + } + + to := *new(TypedObject) + to.Object = obj + + return e.EncodeAmf3Object(w, to, true) + } + + if tm, ok := val.(time.Time); ok { + return e.EncodeAmf3Date(w, tm, true) + } + + if to, ok := val.(TypedObject); ok { + return e.EncodeAmf3Object(w, to, true) + } + + return 0, Error("encode amf3: unsupported type %s", v.Type()) +} + +// marker: 1 byte 0x00 +// no additional data +func (e *Encoder) EncodeAmf3Undefined(w io.Writer, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF3_UNDEFINED_MARKER); err != nil { + return + } + n += 1 + } + + return +} + +// marker: 1 byte 0x01 +// no additional data +func (e *Encoder) EncodeAmf3Null(w io.Writer, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF3_NULL_MARKER); err != nil { + return + } + n += 1 + } + + return +} + +// marker: 1 byte 0x02 +// no additional data +func (e *Encoder) EncodeAmf3False(w io.Writer, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF3_FALSE_MARKER); err != nil { + return + } + n += 1 + } + + return +} + +// marker: 1 byte 0x03 +// no additional data +func (e *Encoder) EncodeAmf3True(w io.Writer, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF3_TRUE_MARKER); err != nil { + return + } + n += 1 + } + + return +} + +// marker: 1 byte 0x04 +func (e *Encoder) EncodeAmf3Integer(w io.Writer, val uint32, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF3_INTEGER_MARKER); err != nil { + return + } + n += 1 + } + + var m int + m, err = e.encodeAmf3Uint29(w, val) + if err != nil { + return + } + n += m + + return +} + +// marker: 1 byte 0x05 +func (e *Encoder) EncodeAmf3Double(w io.Writer, val float64, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF3_DOUBLE_MARKER); err != nil { + return + } + n += 1 + } + + err = binary.Write(w, binary.BigEndian, &val) + if err != nil { + return + } + n += 8 + + return +} + +// marker: 1 byte 0x06 +// format: +// - u29 reference int. if reference, no more data. if not reference, +// length value of bytes to read to complete string. +func (e *Encoder) EncodeAmf3String(w io.Writer, val string, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF3_STRING_MARKER); err != nil { + return + } + n += 1 + } + + var m int + + m, err = e.encodeAmf3Utf8(w, val) + if err != nil { + return + } + n += m + + return +} + +// marker: 1 byte 0x08 +// format: +// - u29 reference int, if reference, no more data +// - timestamp double +func (e *Encoder) EncodeAmf3Date(w io.Writer, val time.Time, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF3_DATE_MARKER); err != nil { + return + } + n += 1 + } + + if err = WriteMarker(w, 0x01); err != nil { + return n, Error("amf3 encode: cannot encode u29 for array: %s", err) + } + n += 1 + + u64 := float64(val.Unix()) * 1000.0 + err = binary.Write(w, binary.BigEndian, &u64) + if err != nil { + return n, Error("amf3 encode: unable to write date double: %s", err) + } + n += 8 + + return +} + +// marker: 1 byte 0x09 +// format: +// - u29 reference int. if reference, no more data. +// - string representing associative array if present +// - n values (length of u29) +func (e *Encoder) EncodeAmf3Array(w io.Writer, val Array, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF3_ARRAY_MARKER); err != nil { + return + } + n += 1 + } + + var m int + length := uint32(len(val)) + u29 := uint32(length<<1) | 0x01 + + m, err = e.encodeAmf3Uint29(w, u29) + if err != nil { + return n, Error("amf3 encode: cannot encode u29 for array: %s", err) + } + n += m + + m, err = e.encodeAmf3Utf8(w, "") + if err != nil { + return n, Error("amf3 encode: cannot encode empty string for array: %s", err) + } + n += m + + for _, v := range val { + m, err := e.EncodeAmf3(w, v) + if err != nil { + return n, Error("amf3 encode: cannot encode array element: %s", err) + } + n += m + } + + return +} + +// marker: 1 byte 0x0a +// format: ugh +func (e *Encoder) EncodeAmf3Object(w io.Writer, val TypedObject, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF3_OBJECT_MARKER); err != nil { + return + } + n += 1 + } + + m := 0 + + trait := *NewTrait() + trait.Type = val.Type + trait.Dynamic = false + trait.Externalizable = false + + for k, _ := range val.Object { + trait.Properties = append(trait.Properties, k) + } + + sort.Strings(trait.Properties) + + var u29 uint32 = 0x03 + if trait.Dynamic { + u29 |= 0x02 << 2 + } + + if trait.Externalizable { + u29 |= 0x01 << 2 + } + + u29 |= uint32(len(trait.Properties)) << 4 + + m, err = e.encodeAmf3Uint29(w, u29) + if err != nil { + return n, Error("amf3 encode: cannot encode trait header for object: %s", err) + } + n += m + + m, err = e.encodeAmf3Utf8(w, trait.Type) + if err != nil { + return n, Error("amf3 encode: cannot encode trait type for object: %s", err) + } + n += m + + for _, prop := range trait.Properties { + m, err = e.encodeAmf3Utf8(w, prop) + if err != nil { + return n, Error("amf3 encode: cannot encode trait property for object: %s", err) + } + n += m + } + + if trait.Externalizable { + return n, Error("amf3 encode: cannot encode externalizable object") + } + + for _, prop := range trait.Properties { + m, err = e.EncodeAmf3(w, val.Object[prop]) + if err != nil { + return n, Error("amf3 encode: cannot encode sealed object value: %s", err) + } + n += m + } + + if trait.Dynamic { + for k, v := range val.Object { + var foundProp bool = false + for _, prop := range trait.Properties { + if prop == k { + foundProp = true + break + } + } + + if foundProp != true { + m, err = e.encodeAmf3Utf8(w, k) + if err != nil { + return n, Error("amf3 encode: cannot encode dynamic object property key: %s", err) + } + n += m + + m, err = e.EncodeAmf3(w, v) + if err != nil { + return n, Error("amf3 encode: cannot encode dynamic object value: %s", err) + } + n += m + } + + m, err = e.encodeAmf3Utf8(w, "") + if err != nil { + return n, Error("amf3 encode: cannot encode dynamic object ending marker string: %s", err) + } + n += m + } + } + + return +} + +// marker: 1 byte 0x0c +// format: +// - u29 reference int. if reference, no more data. if not reference, +// length value of bytes to read . +func (e *Encoder) EncodeAmf3ByteArray(w io.Writer, val []byte, encodeMarker bool) (n int, err error) { + if encodeMarker { + if err = WriteMarker(w, AMF3_BYTEARRAY_MARKER); err != nil { + return + } + n += 1 + } + + var m int + + length := uint32(len(val)) + u29 := (length << 1) | 1 + + m, err = e.encodeAmf3Uint29(w, u29) + if err != nil { + return n, Error("amf3 encode: cannot encode u29 for bytearray: %s", err) + } + n += m + + m, err = w.Write(val) + if err != nil { + return n, Error("encode amf3: unable to encode bytearray value: %s", err) + } + n += m + + return +} + +func (e *Encoder) encodeAmf3Utf8(w io.Writer, val string) (n int, err error) { + length := uint32(len(val)) + u29 := uint32(length<<1) | 0x01 + + var m int + m, err = e.encodeAmf3Uint29(w, u29) + if err != nil { + return n, Error("amf3 encode: cannot encode u29 for string: %s", err) + } + n += m + + m, err = w.Write([]byte(val)) + if err != nil { + return n, Error("encode amf3: unable to encode string value: %s", err) + } + n += m + + return +} + +func (e *Encoder) encodeAmf3Uint29(w io.Writer, val uint32) (n int, err error) { + if val <= 0x0000007F { + err = WriteByte(w, byte(val)) + if err == nil { + n += 1 + } + } else if val <= 0x00003FFF { + n, err = w.Write([]byte{byte(val>>7 | 0x80), byte(val & 0x7F)}) + } else if val <= 0x001FFFFF { + n, err = w.Write([]byte{byte(val>>14 | 0x80), byte(val>>7&0x7F | 0x80), byte(val & 0x7F)}) + } else if val <= 0x1FFFFFFF { + n, err = w.Write([]byte{byte(val>>22 | 0x80), byte(val>>15&0x7F | 0x80), byte(val>>8&0x7F | 0x80), byte(val)}) + } else { + return n, Error("amf3 encode: cannot encode u29 with value %d (out of range)", val) + } + + return +} diff --git a/protocol/amf/encoder_amf3_test.go b/protocol/amf/encoder_amf3_test.go new file mode 100755 index 0000000..a866abb --- /dev/null +++ b/protocol/amf/encoder_amf3_test.go @@ -0,0 +1,199 @@ +package amf + +import ( + "bytes" + "testing" +) + +func TestEncodeAmf3EmptyString(t *testing.T) { + enc := new(Encoder) + + buf := new(bytes.Buffer) + expect := []byte{0x01} + + _, err := enc.EncodeAmf3String(buf, "", false) + if err != nil { + t.Errorf("%s", err) + } + + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf3Undefined(t *testing.T) { + enc := new(Encoder) + + buf := new(bytes.Buffer) + expect := []byte{0x00} + + _, err := enc.EncodeAmf3Undefined(buf, true) + if err != nil { + t.Errorf("%s", err) + } + + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf3Null(t *testing.T) { + enc := new(Encoder) + + buf := new(bytes.Buffer) + expect := []byte{0x01} + + _, err := enc.EncodeAmf3(buf, nil) + if err != nil { + t.Errorf("%s", err) + } + + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf3False(t *testing.T) { + enc := new(Encoder) + + buf := new(bytes.Buffer) + expect := []byte{0x02} + + _, err := enc.EncodeAmf3(buf, false) + if err != nil { + t.Errorf("%s", err) + } + + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf3True(t *testing.T) { + enc := new(Encoder) + + buf := new(bytes.Buffer) + expect := []byte{0x03} + + _, err := enc.EncodeAmf3(buf, true) + if err != nil { + t.Errorf("%s", err) + } + + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf3Integer(t *testing.T) { + enc := new(Encoder) + + for _, tc := range u29TestCases { + buf := new(bytes.Buffer) + _, err := enc.EncodeAmf3Integer(buf, tc.value, false) + if err != nil { + t.Errorf("EncodeAmf3Integer error: %s", err) + } + got := buf.Bytes() + if !bytes.Equal(tc.expect, got) { + t.Errorf("EncodeAmf3Integer expect n %x got %x", tc.value, got) + } + } + + buf := new(bytes.Buffer) + expect := []byte{0x04, 0x80, 0xFF, 0xFF, 0xFF} + + n, err := enc.EncodeAmf3(buf, uint32(4194303)) + if err != nil { + t.Errorf("%s", err) + } + if n != 5 { + t.Errorf("expected to write 5 bytes, actual %d", n) + } + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf3Double(t *testing.T) { + enc := new(Encoder) + + buf := new(bytes.Buffer) + expect := []byte{0x05, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33} + + _, err := enc.EncodeAmf3(buf, float64(1.2)) + if err != nil { + t.Errorf("%s", err) + } + + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf3String(t *testing.T) { + enc := new(Encoder) + + buf := new(bytes.Buffer) + expect := []byte{0x06, 0x07, 'f', 'o', 'o'} + + _, err := enc.EncodeAmf3(buf, "foo") + if err != nil { + t.Errorf("%s", err) + } + + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf3Array(t *testing.T) { + enc := new(Encoder) + buf := new(bytes.Buffer) + expect := []byte{0x09, 0x13, 0x01, + 0x06, 0x03, '1', + 0x06, 0x03, '2', + 0x06, 0x03, '3', + 0x06, 0x03, '4', + 0x06, 0x03, '5', + 0x06, 0x03, '6', + 0x06, 0x03, '7', + 0x06, 0x03, '8', + 0x06, 0x03, '9', + } + + arr := []string{"1", "2", "3", "4", "5", "6", "7", "8", "9"} + _, err := enc.EncodeAmf3(buf, arr) + if err != nil { + t.Errorf("err: %s", err) + } + + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes()) + } +} + +func TestEncodeAmf3Object(t *testing.T) { + enc := new(Encoder) + buf := new(bytes.Buffer) + expect := []byte{ + 0x0a, 0x23, 0x1f, 'o', 'r', 'g', '.', 'a', + 'm', 'f', '.', 'A', 'S', 'C', 'l', 'a', + 's', 's', 0x07, 'b', 'a', 'z', 0x07, 'f', + 'o', 'o', 0x01, 0x06, 0x07, 'b', 'a', 'r', + } + + to := *NewTypedObject() + to.Type = "org.amf.ASClass" + to.Object["foo"] = "bar" + to.Object["baz"] = nil + + _, err := enc.EncodeAmf3(buf, to) + if err != nil { + t.Errorf("err: %s", err) + } + + if bytes.Compare(buf.Bytes(), expect) != 0 { + t.Errorf("expected buffer:\n%#v\ngot:\n%#v", expect, buf.Bytes()) + } +} diff --git a/protocol/amf/metadata.go b/protocol/amf/metadata.go new file mode 100755 index 0000000..722984d --- /dev/null +++ b/protocol/amf/metadata.go @@ -0,0 +1,70 @@ +package amf + +import ( + "bytes" + "fmt" + "log" +) + +const ( + ADD = 0x0 + DEL = 0x3 +) + +const ( + SetDataFrame string = "@setDataFrame" + OnMetaData string = "onMetaData" +) + +var setFrameFrame []byte + +func init() { + b := bytes.NewBuffer(nil) + encoder := &Encoder{} + if _, err := encoder.Encode(b, SetDataFrame, AMF0); err != nil { + log.Fatal(err) + } + setFrameFrame = b.Bytes() +} + +func MetaDataReform(p []byte, flag uint8) ([]byte, error) { + r := bytes.NewReader(p) + decoder := &Decoder{} + switch flag { + case ADD: + v, err := decoder.Decode(r, AMF0) + if err != nil { + return nil, err + } + switch v.(type) { + case string: + vv := v.(string) + if vv != SetDataFrame { + tmplen := len(setFrameFrame) + b := make([]byte, tmplen+len(p)) + copy(b, setFrameFrame) + copy(b[tmplen:], p) + p = b + } + default: + return nil, fmt.Errorf("setFrameFrame error") + } + case DEL: + v, err := decoder.Decode(r, AMF0) + if err != nil { + return nil, err + } + switch v.(type) { + case string: + vv := v.(string) + if vv == SetDataFrame { + p = p[len(setFrameFrame):] + } + default: + return nil, fmt.Errorf("metadata error") + } + default: + return nil, fmt.Errorf("invalid flag:%d", flag) + } + return p, nil +} diff --git a/protocol/amf/util.go b/protocol/amf/util.go new file mode 100755 index 0000000..c94e2aa --- /dev/null +++ b/protocol/amf/util.go @@ -0,0 +1,92 @@ +package amf + +import ( + "encoding/json" + "errors" + "fmt" + "io" +) + +func DumpBytes(label string, buf []byte, size int) { + fmt.Printf("Dumping %s (%d bytes):\n", label, size) + for i := 0; i < size; i++ { + fmt.Printf("0x%02x ", buf[i]) + } + fmt.Printf("\n") +} + +func Dump(label string, val interface{}) error { + json, err := json.MarshalIndent(val, "", " ") + if err != nil { + return Error("Error dumping %s: %s", label, err) + } + + fmt.Printf("Dumping %s:\n%s\n", label, json) + return nil +} + +func Error(f string, v ...interface{}) error { + return errors.New(fmt.Sprintf(f, v...)) +} + +func WriteByte(w io.Writer, b byte) (err error) { + bytes := make([]byte, 1) + bytes[0] = b + + _, err = WriteBytes(w, bytes) + + return +} + +func WriteBytes(w io.Writer, bytes []byte) (int, error) { + return w.Write(bytes) +} + +func ReadByte(r io.Reader) (byte, error) { + bytes, err := ReadBytes(r, 1) + if err != nil { + return 0x00, err + } + + return bytes[0], nil +} + +func ReadBytes(r io.Reader, n int) ([]byte, error) { + bytes := make([]byte, n) + + m, err := r.Read(bytes) + if err != nil { + return bytes, err + } + + if m != n { + return bytes, fmt.Errorf("decode read bytes failed: expected %d got %d", m, n) + } + + return bytes, nil +} + +func WriteMarker(w io.Writer, m byte) error { + return WriteByte(w, m) +} + +func ReadMarker(r io.Reader) (byte, error) { + return ReadByte(r) +} + +func AssertMarker(r io.Reader, checkMarker bool, m byte) error { + if checkMarker == false { + return nil + } + + marker, err := ReadMarker(r) + if err != nil { + return err + } + + if marker != m { + return Error("decode assert marker failed: expected %v got %v", m, marker) + } + + return nil +} diff --git a/protocol/dash/dash.go b/protocol/dash/dash.go new file mode 100755 index 0000000..3f19715 --- /dev/null +++ b/protocol/dash/dash.go @@ -0,0 +1 @@ +package dash diff --git a/protocol/hls/align.go b/protocol/hls/align.go new file mode 100755 index 0000000..26506b2 --- /dev/null +++ b/protocol/hls/align.go @@ -0,0 +1,29 @@ +package hls + +const ( + syncms = 2 // ms +) + +type align struct { + frameNum uint64 + frameBase uint64 +} + +func (self *align) align(dts *uint64, inc uint32) { + aFrameDts := *dts + estPts := self.frameBase + self.frameNum*uint64(inc) + var dPts uint64 + if estPts >= aFrameDts { + dPts = estPts - aFrameDts + } else { + dPts = aFrameDts - estPts + } + + if dPts <= uint64(syncms)*h264_default_hz { + self.frameNum++ + *dts = estPts + return + } + self.frameNum = 1 + self.frameBase = aFrameDts +} diff --git a/protocol/hls/audio_cache.go b/protocol/hls/audio_cache.go new file mode 100755 index 0000000..88dbd2f --- /dev/null +++ b/protocol/hls/audio_cache.go @@ -0,0 +1,44 @@ +package hls + +import "bytes" + +const ( + cache_max_frames byte = 6 + audio_cache_len int = 10 * 1024 +) + +type audioCache struct { + soundFormat byte + num byte + offset int + pts uint64 + buf *bytes.Buffer +} + +func newAudioCache() *audioCache { + return &audioCache{ + buf: bytes.NewBuffer(make([]byte, audio_cache_len)), + } +} + +func (self *audioCache) Cache(src []byte, pts uint64) bool { + if self.num == 0 { + self.offset = 0 + self.pts = pts + self.buf.Reset() + } + self.buf.Write(src) + self.offset += len(src) + self.num++ + + return false +} + +func (self *audioCache) GetFrame() (int, uint64, []byte) { + self.num = 0 + return self.offset, self.pts, self.buf.Bytes() +} + +func (self *audioCache) CacheNum() byte { + return self.num +} diff --git a/protocol/hls/hls.go b/protocol/hls/hls.go new file mode 100755 index 0000000..70adafc --- /dev/null +++ b/protocol/hls/hls.go @@ -0,0 +1,413 @@ +package hls + +import ( + "bytes" + "errors" + "fmt" + "net" + "net/http" + "path" + "strconv" + "strings" + "time" + + "github.com/gwuhaolin/livego/utils/cmap" + "github.com/golang/glog" + "github.com/gwuhaolin/livego/av" + "github.com/gwuhaolin/livego/container/flv" + "github.com/gwuhaolin/livego/container/ts" + "github.com/gwuhaolin/livego/parser" + "log" +) + +const ( + duration = 3000 +) + +var ( + ErrNoPublisher = errors.New("No publisher") + ErrInvalidReq = errors.New("invalid req url path") + ErrNoSupportVideoCodec = errors.New("no support video codec") + ErrNoSupportAudioCodec = errors.New("no support audio codec") +) + +var crossdomainxml = []byte(` + + + +`) + +type Server struct { + l net.Listener + conns cmap.ConcurrentMap +} + +func NewServer() *Server { + ret := &Server{ + conns: cmap.New(), + } + go ret.checkStop() + return ret +} + +func (self *Server) Serve(l net.Listener) error { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + self.handle(w, r) + }) + self.l = l + http.Serve(l, mux) + return nil +} + +func (self *Server) GetWriter(info av.Info) av.WriteCloser { + var s *Source + ok := self.conns.Has(info.Key) + if !ok { + log.Println("new hls source") + s = NewSource(info) + self.conns.Set(info.Key, s) + } else { + v, _ := self.conns.Get(info.Key) + s = v.(*Source) + } + return s +} + +func (self *Server) getConn(key string) *Source { + v, ok := self.conns.Get(key) + if !ok { + return nil + } + return v.(*Source) +} + +func (self *Server) checkStop() { + for { + <-time.After(5 * time.Second) + for item := range self.conns.IterBuffered() { + v := item.Val.(*Source) + if !v.Alive() { + log.Println("check stop and remove: ", v.Info()) + self.conns.Remove(item.Key) + } + } + } +} + +func (self *Server) handle(w http.ResponseWriter, r *http.Request) { + if path.Base(r.URL.Path) == "crossdomain.xml" { + w.Header().Set("Content-Type", "application/xml") + w.Write(crossdomainxml) + return + } + switch path.Ext(r.URL.Path) { + case ".m3u8": + key, _ := self.parseM3u8(r.URL.Path) + conn := self.getConn(key) + if conn == nil { + http.Error(w, ErrNoPublisher.Error(), http.StatusForbidden) + return + } + tsCache := conn.GetCacheInc() + body, err := tsCache.GenM3U8PlayList() + if err != nil { + log.Println("GenM3U8PlayList error: ", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Content-Type", "application/x-mpegURL") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.Write(body) + case ".ts": + key, _ := self.parseTs(r.URL.Path) + conn := self.getConn(key) + if conn == nil { + http.Error(w, ErrNoPublisher.Error(), http.StatusForbidden) + return + } + tsCache := conn.GetCacheInc() + item, err := tsCache.GetItem(r.URL.Path) + if err != nil { + log.Println("GetItem error: ", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Content-Type", "video/mp2ts") + w.Header().Set("Content-Length", strconv.Itoa(len(item.Data))) + w.Write(item.Data) + } +} + +func (self *Server) parseM3u8(pathstr string) (key string, err error) { + pathstr = strings.TrimLeft(pathstr, "/") + key = strings.TrimRight(pathstr, path.Ext(pathstr)) + return +} + +func (self *Server) parseTs(pathstr string) (key string, err error) { + pathstr = strings.TrimLeft(pathstr, "/") + paths := strings.SplitN(pathstr, "/", 3) + if len(paths) != 3 { + err = fmt.Errorf("invalid path=%s", pathstr) + return + } + key = paths[0] + "/" + paths[1] + + return +} + +const ( + videoHZ = 90000 + aacSampleLen = 1024 + maxQueueNum = 512 + + h264_default_hz uint64 = 90 +) + +type Source struct { + av.RWBaser + seq int + info av.Info + bwriter *bytes.Buffer + btswriter *bytes.Buffer + demuxer *flv.Demuxer + muxer *ts.Muxer + pts, dts uint64 + stat *status + align *align + cache *audioCache + tsCache *TSCacheItem + tsparser *parser.CodecParser + closed bool + packetQueue chan av.Packet +} + +func NewSource(info av.Info) *Source { + info.Inter = true + s := &Source{ + info: info, + align: &align{}, + stat: newStatus(), + RWBaser: av.NewRWBaser(time.Second * 10), + cache: newAudioCache(), + demuxer: flv.NewDemuxer(), + muxer: ts.NewMuxer(), + tsCache: NewTSCacheItem(info.Key), + tsparser: parser.NewCodecParser(), + bwriter: bytes.NewBuffer(make([]byte, 100*1024)), + packetQueue: make(chan av.Packet, maxQueueNum), + } + go func() { + err := s.SendPacket() + if err != nil { + log.Println("send packet error: ", err) + s.closed = true + } + }() + return s +} + +func (self *Source) GetCacheInc() *TSCacheItem { + return self.tsCache +} + +func (self *Source) DropPacket(pktQue chan av.Packet, info av.Info) { + glog.Errorf("[%v] packet queue max!!!", info) + for i := 0; i < maxQueueNum-84; i++ { + tmpPkt, ok := <-pktQue + // try to don't drop audio + if ok && tmpPkt.IsAudio { + if len(pktQue) > maxQueueNum-2 { + <-pktQue + } else { + pktQue <- tmpPkt + } + } + + if ok && tmpPkt.IsVideo { + videoPkt, ok := tmpPkt.Header.(av.VideoPacketHeader) + // dont't drop sps config and dont't drop key frame + if ok && (videoPkt.IsSeq() || videoPkt.IsKeyFrame()) { + pktQue <- tmpPkt + } + if len(pktQue) > maxQueueNum-10 { + <-pktQue + } + } + + } + log.Println("packet queue len: ", len(pktQue)) +} + +func (self *Source) Write(p av.Packet) error { + self.SetPreTime() + if len(self.packetQueue) >= maxQueueNum-24 { + self.DropPacket(self.packetQueue, self.info) + } else { + self.packetQueue <- p + } + return nil +} + +func (self *Source) SendPacket() error { + defer func() { + glog.Infof("[%v] hls sender stop", self.info) + if r := recover(); r != nil { + log.Println("hls SendPacket panic: ", r) + } + }() + glog.Infof("[%v] hls sender start", self.info) + for { + if self.closed { + return errors.New("closed") + } + + p, ok := <-self.packetQueue + if ok { + if p.IsMetadata { + continue + } + + err := self.demuxer.Demux(&p) + if err == flv.ErrAvcEndSEQ { + log.Println(err) + continue + } else { + if err != nil { + log.Println(err) + return err + } + } + compositionTime, isSeq, err := self.parse(&p) + if err != nil { + log.Println(err) + } + if err != nil || isSeq { + continue + } + if self.btswriter != nil { + self.stat.update(p.IsVideo, p.TimeStamp) + self.calcPtsDts(p.IsVideo, p.TimeStamp, uint32(compositionTime)) + self.tsMux(&p) + } + } else { + return errors.New("closed") + } + } +} + +func (self *Source) Info() (ret av.Info) { + return self.info +} + +func (self *Source) cleanup() { + close(self.packetQueue) + self.bwriter = nil + self.btswriter = nil + self.cache = nil + self.tsCache = nil +} + +func (self *Source) Close(err error) { + log.Println("hls source closed: ", self.info) + if !self.closed { + self.cleanup() + } + self.closed = true +} + +func (self *Source) cut() { + newf := true + if self.btswriter == nil { + self.btswriter = bytes.NewBuffer(nil) + } else if self.btswriter != nil && self.stat.durationMs() >= duration { + self.flushAudio() + + self.seq++ + filename := fmt.Sprintf("/%s/%d.ts", self.info.Key, time.Now().Unix()) + item := NewTSItem(filename, int(self.stat.durationMs()), self.seq, self.btswriter.Bytes()) + self.tsCache.SetItem(filename, item) + + self.btswriter.Reset() + self.stat.resetAndNew() + } else { + newf = false + } + if newf { + self.btswriter.Write(self.muxer.PAT()) + self.btswriter.Write(self.muxer.PMT(av.SOUND_AAC, true)) + } +} + +func (self *Source) parse(p *av.Packet) (int32, bool, error) { + var compositionTime int32 + var ah av.AudioPacketHeader + var vh av.VideoPacketHeader + if p.IsVideo { + vh = p.Header.(av.VideoPacketHeader) + if vh.CodecID() != av.VIDEO_H264 { + return compositionTime, false, ErrNoSupportVideoCodec + } + compositionTime = vh.CompositionTime() + if vh.IsKeyFrame() && vh.IsSeq() { + return compositionTime, true, self.tsparser.Parse(p, self.bwriter) + } + } else { + ah = p.Header.(av.AudioPacketHeader) + if ah.SoundFormat() != av.SOUND_AAC { + return compositionTime, false, ErrNoSupportAudioCodec + } + if ah.AACPacketType() == av.AAC_SEQHDR { + return compositionTime, true, self.tsparser.Parse(p, self.bwriter) + } + } + self.bwriter.Reset() + if err := self.tsparser.Parse(p, self.bwriter); err != nil { + return compositionTime, false, err + } + p.Data = self.bwriter.Bytes() + + if p.IsVideo && vh.IsKeyFrame() { + self.cut() + } + return compositionTime, false, nil +} + +func (self *Source) calcPtsDts(isVideo bool, ts, compositionTs uint32) { + self.dts = uint64(ts) * h264_default_hz + if isVideo { + self.pts = self.dts + uint64(compositionTs)*h264_default_hz + } else { + sampleRate, _ := self.tsparser.SampleRate() + self.align.align(&self.dts, uint32(videoHZ*aacSampleLen/sampleRate)) + self.pts = self.dts + } +} +func (self *Source) flushAudio() error { + return self.muxAudio(1) +} + +func (self *Source) muxAudio(limit byte) error { + if self.cache.CacheNum() < limit { + return nil + } + var p av.Packet + _, pts, buf := self.cache.GetFrame() + p.Data = buf + p.TimeStamp = uint32(pts / h264_default_hz) + return self.muxer.Mux(&p, self.btswriter) +} + +func (self *Source) tsMux(p *av.Packet) error { + if p.IsVideo { + return self.muxer.Mux(p, self.btswriter) + } else { + self.cache.Cache(p.Data, self.pts) + return self.muxAudio(cache_max_frames) + } +} diff --git a/protocol/hls/status.go b/protocol/hls/status.go new file mode 100755 index 0000000..fee80ef --- /dev/null +++ b/protocol/hls/status.go @@ -0,0 +1,43 @@ +package hls + +import "time" + +type status struct { + hasVideo bool + seqId int64 + createdAt time.Time + segBeginAt time.Time + hasSetFirstTs bool + firstTimestamp int64 + lastTimestamp int64 +} + +func newStatus() *status { + return &status{ + seqId: 0, + hasSetFirstTs: false, + segBeginAt: time.Now(), + } +} + +func (t *status) update(isVideo bool, timestamp uint32) { + if isVideo { + t.hasVideo = true + } + if !t.hasSetFirstTs { + t.hasSetFirstTs = true + t.firstTimestamp = int64(timestamp) + } + t.lastTimestamp = int64(timestamp) +} + +func (t *status) resetAndNew() { + t.seqId++ + t.hasVideo = false + t.createdAt = time.Now() + t.hasSetFirstTs = false +} + +func (t *status) durationMs() int64 { + return t.lastTimestamp - t.firstTimestamp +} diff --git a/protocol/hls/ts_cache.go b/protocol/hls/ts_cache.go new file mode 100755 index 0000000..99ccd61 --- /dev/null +++ b/protocol/hls/ts_cache.go @@ -0,0 +1,127 @@ +package hls + +import ( + "bytes" + "container/list" + "errors" + "fmt" + "sync" +) + +type TSCache struct { + entrys map[string]*TSCacheItem +} + +func NewTSCache() *TSCache { + return &TSCache{ + entrys: make(map[string]*TSCacheItem), + } +} + +func (self *TSCache) Set(key string, e *TSCacheItem) { + v, ok := self.entrys[key] + if !ok { + self.entrys[key] = e + } + if v.ID() != e.ID() { + self.entrys[key] = e + } +} + +func (self *TSCache) Get(key string) *TSCacheItem { + v := self.entrys[key] + return v +} + +const ( + maxTSCacheNum = 3 +) + +var ( + ErrNoKey = errors.New("No key for cache") +) + +type TSCacheItem struct { + id string + num int + lock sync.RWMutex + ll *list.List + lm map[string]TSItem +} + +func NewTSCacheItem(id string) *TSCacheItem { + return &TSCacheItem{ + id: id, + ll: list.New(), + num: maxTSCacheNum, + lm: make(map[string]TSItem), + } +} + +func (self *TSCacheItem) ID() string { + return self.id +} + +// TODO: found data race, fix it +func (self *TSCacheItem) GenM3U8PlayList() ([]byte, error) { + var seq int + var getSeq bool + var maxDuration int + m3u8body := bytes.NewBuffer(nil) + for e := self.ll.Front(); e != nil; e = e.Next() { + key := e.Value.(string) + v, ok := self.lm[key] + if ok { + if v.Duration > maxDuration { + maxDuration = v.Duration + } + if !getSeq { + getSeq = true + seq = v.SeqNum + } + fmt.Fprintf(m3u8body, "#EXTINF:%.3f,\n%s\n", float64(v.Duration)/float64(1000), v.Name) + } + } + w := bytes.NewBuffer(nil) + fmt.Fprintf(w, + "#EXTM3U\n#EXT-X-VERSION:3\n#EXT-X-ALLOW-CACHE:NO\n#EXT-X-TARGETDURATION:%d\n#EXT-X-MEDIA-SEQUENCE:%d\n\n", + maxDuration/1000+1, seq) + w.Write(m3u8body.Bytes()) + return w.Bytes(), nil +} + +func (self *TSCacheItem) SetItem(key string, item TSItem) { + if self.ll.Len() == self.num { + e := self.ll.Front() + self.ll.Remove(e) + k := e.Value.(string) + delete(self.lm, k) + } + self.lm[key] = item + self.ll.PushBack(key) +} + +func (self *TSCacheItem) GetItem(key string) (TSItem, error) { + item, ok := self.lm[key] + if !ok { + return item, ErrNoKey + } + return item, nil +} + +type TSItem struct { + Name string + SeqNum int + Duration int + Data []byte +} + +func NewTSItem(name string, duration, seqNum int, b []byte) TSItem { + var item TSItem + item.Name = name + item.SeqNum = seqNum + item.Duration = duration + item.Data = make([]byte, len(b)) + copy(item.Data, b) + return item +} diff --git a/protocol/httpflv/http_flv.go b/protocol/httpflv/http_flv.go new file mode 100755 index 0000000..850b273 --- /dev/null +++ b/protocol/httpflv/http_flv.go @@ -0,0 +1,274 @@ +package httpflv + +import ( + "encoding/json" + "net" + "net/http" + "strings" + "time" + + "errors" + + "github.com/golang/glog" + "github.com/gwuhaolin/livego/utils/uid" + "github.com/gwuhaolin/livego/protocol/amf" + "github.com/gwuhaolin/livego/av" + "github.com/gwuhaolin/livego/utils/pio" + "log" + "github.com/gwuhaolin/livego/protocol/rtmp" +) + +type Server struct { + handler av.Handler +} + +type stream struct { + Key string `json:"key"` + Id string `json:"id"` +} + +type streams struct { + Publishers []stream `json:"publishers"` + Players []stream `json:"players"` +} + +func NewServer(h av.Handler) *Server { + return &Server{ + handler: h, + } +} + +func (self *Server) Serve(l net.Listener) error { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + self.handleConn(w, r) + }) + mux.HandleFunc("/streams", func(w http.ResponseWriter, r *http.Request) { + self.getStream(w, r) + }) + http.Serve(l, mux) + return nil +} + +func (s *Server) getStream(w http.ResponseWriter, r *http.Request) { + rtmpStream := s.handler.(*rtmp.RtmpStream) + if rtmpStream == nil { + return + } + msgs := new(streams) + for item := range rtmpStream.GetStreams().IterBuffered() { + if s, ok := item.Val.(*rtmp.Stream); ok { + if s.GetReader() != nil { + msg := stream{item.Key, s.GetReader().Info().UID} + msgs.Publishers = append(msgs.Publishers, msg) + } + } + } + + for item := range rtmpStream.GetStreams().IterBuffered() { + ws := item.Val.(*rtmp.Stream).GetWs() + for s := range ws.IterBuffered() { + if pw, ok := s.Val.(*rtmp.PackWriterCloser); ok { + if pw.GetWriter() != nil { + msg := stream{item.Key, pw.GetWriter().Info().UID} + msgs.Players = append(msgs.Players, msg) + } + } + } + } + resp, _ := json.Marshal(msgs) + w.Header().Set("Content-Type", "application/json") + w.Write(resp) + +} + +func (self *Server) handleConn(w http.ResponseWriter, r *http.Request) { + defer func() { + if r := recover(); r != nil { + log.Println("http flv handleConn panic: ", r) + } + }() + + url := r.URL.String() + u := r.URL.Path + if pos := strings.LastIndex(u, "."); pos < 0 || u[pos:] != ".flv" { + http.Error(w, "invalid path", http.StatusBadRequest) + return + } + path := strings.TrimSuffix(strings.TrimLeft(u, "/"), ".flv") + paths := strings.SplitN(path, "/", 2) + log.Println("url:", u, "path:", path, "paths:", paths) + + if len(paths) != 2 { + http.Error(w, "invalid path", http.StatusBadRequest) + return + } + + w.Header().Set("Access-Control-Allow-Origin", "*") + writer := NewFLVWriter(paths[0], paths[1], url, w) + + self.handler.HandleWriter(writer) + writer.Wait() +} + +const ( + headerLen = 11 + maxQueueNum = 1024 +) + +type FLVWriter struct { + Uid string + av.RWBaser + app, title, url string + buf []byte + closed bool + closedChan chan struct{} + ctx http.ResponseWriter + packetQueue chan av.Packet +} + +func NewFLVWriter(app, title, url string, ctx http.ResponseWriter) *FLVWriter { + ret := &FLVWriter{ + Uid: uid.NEWID(), + app: app, + title: title, + url: url, + ctx: ctx, + RWBaser: av.NewRWBaser(time.Second * 10), + closedChan: make(chan struct{}), + buf: make([]byte, headerLen), + packetQueue: make(chan av.Packet, maxQueueNum), + } + + ret.ctx.Write([]byte{0x46, 0x4c, 0x56, 0x01, 0x05, 0x00, 0x00, 0x00, 0x09}) + pio.PutI32BE(ret.buf[:4], 0) + ret.ctx.Write(ret.buf[:4]) + go func() { + err := ret.SendPacket() + if err != nil { + log.Println("SendPacket error:", err) + ret.closed = true + } + }() + return ret +} + +func (self *FLVWriter) DropPacket(pktQue chan av.Packet, info av.Info) { + glog.Errorf("[%v] packet queue max!!!", info) + for i := 0; i < maxQueueNum-84; i++ { + tmpPkt, ok := <-pktQue + if ok && tmpPkt.IsVideo { + videoPkt, ok := tmpPkt.Header.(av.VideoPacketHeader) + // dont't drop sps config and dont't drop key frame + if ok && (videoPkt.IsSeq() || videoPkt.IsKeyFrame()) { + log.Println("insert keyframe to queue") + pktQue <- tmpPkt + } + + if len(pktQue) > maxQueueNum-10 { + <-pktQue + } + // drop other packet + <-pktQue + } + // try to don't drop audio + if ok && tmpPkt.IsAudio { + log.Println("insert audio to queue") + pktQue <- tmpPkt + } + } + log.Println("packet queue len: ", len(pktQue)) +} + +func (self *FLVWriter) Write(p av.Packet) error { + if !self.closed { + if len(self.packetQueue) >= maxQueueNum-24 { + self.DropPacket(self.packetQueue, self.Info()) + } else { + self.packetQueue <- p + } + return nil + } else { + return errors.New("closed") + } + +} + +// func (self *FLVWriter) Write(p av.Packet) error { +func (self *FLVWriter) SendPacket() error { + for { + p, ok := <-self.packetQueue + if ok { + self.RWBaser.SetPreTime() + h := self.buf[:headerLen] + typeID := av.TAG_VIDEO + if !p.IsVideo { + if p.IsMetadata { + var err error + typeID = av.TAG_SCRIPTDATAAMF0 + p.Data, err = amf.MetaDataReform(p.Data, amf.DEL) + if err != nil { + return err + } + } else { + typeID = av.TAG_AUDIO + } + } + dataLen := len(p.Data) + timestamp := p.TimeStamp + timestamp += self.BaseTimeStamp() + self.RWBaser.RecTimeStamp(timestamp, uint32(typeID)) + + preDataLen := dataLen + headerLen + timestampbase := timestamp & 0xffffff + timestampExt := timestamp >> 24 & 0xff + + pio.PutU8(h[0:1], uint8(typeID)) + pio.PutI24BE(h[1:4], int32(dataLen)) + pio.PutI24BE(h[4:7], int32(timestampbase)) + pio.PutU8(h[7:8], uint8(timestampExt)) + + if _, err := self.ctx.Write(h); err != nil { + return err + } + + if _, err := self.ctx.Write(p.Data); err != nil { + return err + } + + pio.PutI32BE(h[:4], int32(preDataLen)) + if _, err := self.ctx.Write(h[:4]); err != nil { + return err + } + } else { + return errors.New("closed") + } + + } + + return nil +} + +func (self *FLVWriter) Wait() { + select { + case <-self.closedChan: + return + } +} + +func (self *FLVWriter) Close(error) { + log.Println("http flv closed") + if !self.closed { + close(self.packetQueue) + close(self.closedChan) + } + self.closed = true +} + +func (self *FLVWriter) Info() (ret av.Info) { + ret.UID = self.Uid + ret.URL = self.url + ret.Key = self.app + "/" + self.title + ret.Inter = true + return +} diff --git a/protocol/httpopera/http_opera.go b/protocol/httpopera/http_opera.go new file mode 100755 index 0000000..34e6b5b --- /dev/null +++ b/protocol/httpopera/http_opera.go @@ -0,0 +1,232 @@ +package httpopera + +import ( + "encoding/json" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + + "github.com/golang/glog" + "github.com/gwuhaolin/livego/utils/uid" + "github.com/gwuhaolin/livego/av" + "log" + "github.com/gwuhaolin/livego/protocol/rtmp" +) + +type Response struct { + w http.ResponseWriter + Status int `json:"status"` + Message string `json:"message"` +} + +func (r *Response) SendJson() (int, error) { + resp, _ := json.Marshal(r) + r.w.Header().Set("Content-Type", "application/json") + return r.w.Write(resp) +} + +type Operation struct { + Method string `json:"method"` + URL string `json:"url"` + Stop bool `json:"stop"` +} + +type OperationChange struct { + Method string `json:"method"` + SourceURL string `json:"source_url"` + TargetURL string `json:"target_url"` + Stop bool `json:"stop"` +} + +type Server struct { + handler av.Handler +} + +func NewServer(h av.Handler) *Server { + return &Server{ + handler: h, + } +} + +func (s *Server) Serve(l net.Listener) error { + mux := http.NewServeMux() + mux.HandleFunc("/rtmp/operation", func(w http.ResponseWriter, r *http.Request) { + s.handleOpera(w, r) + }) + // mux.HandleFunc("/rtmp/operation/change", func(w http.ResponseWriter, r *http.Request) { + // s.handleOperaChange(w, r) + // }) + http.Serve(l, mux) + return nil +} + +// handleOpera, 拉流和推流的http api +// @Path: /rtmp/operation +// @Method: POST +// @Param: json +// method string, "push" or "pull" +// url string +// stop bool + +// @Example, +// curl -v -H "Content-Type: application/json" -X POST --data \ +// '{"method":"pull","url":"rtmp://127.0.0.1:1935/live/test"}' \ +// http://localhost:8087/rtmp/operation +func (s *Server) handleOpera(w http.ResponseWriter, r *http.Request) { + rep := &Response{ + w: w, + } + + if r.Method != "POST" { + rep.Status = 14000 + rep.Message = "bad request method" + rep.SendJson() + return + } else { + result, err := ioutil.ReadAll(r.Body) + if err != nil { + rep.Status = 15000 + rep.Message = "read request body error" + rep.SendJson() + return + } + r.Body.Close() + glog.Infof("post body: %s\n", result) + + var op Operation + err = json.Unmarshal(result, &op) + if err != nil { + rep.Status = 12000 + rep.Message = "parse json body failed" + rep.SendJson() + return + } + + switch op.Method { + case "push": + s.Push(op.URL, op.Stop) + case "pull": + s.Pull(op.URL, op.Stop) + } + + rep.Status = 10000 + rep.Message = op.Method + " " + op.URL + " success" + rep.SendJson() + } +} + +func (s *Server) Push(uri string, stop bool) error { + rtmpClient := rtmp.NewRtmpClient(s.handler, nil) + return rtmpClient.Dial(uri, av.PUBLISH) + // return nil +} + +func (s *Server) Pull(uri string, stop bool) error { + rtmpClient := rtmp.NewRtmpClient(s.handler, nil) + return rtmpClient.Dial(uri, av.PLAY) + // return nil +} + +// TODO: +// handleOperaChange, 拉流和推流的http api,支持自定义路径 +// @Path: /rtmp/operation/change +// @Method: POST +// @Param: json +// method string, "push" or "pull" +// url string +// stop bool + +// @Example, +// curl -v -H "Content-Type: application/json" -X POST --data \ +// '{"method":"pull","url":"rtmp://127.0.0.1:1935/live/test"}' \ +// http://localhost:8087/rtmp/operation +// func (s *Server) handleOperaChange(w http.ResponseWriter, r *http.Request) { +// rep := &Response{ +// w: w, +// } + +// if r.Method != "POST" { +// rep.Status = 14000 +// rep.Message = "bad request method" +// rep.SendJson() +// return +// } else { +// result, err := ioutil.ReadAll(r.Body) +// if err != nil { +// rep.Status = 15000 +// rep.Message = "read request body error" +// rep.SendJson() +// return +// } +// r.Body.Close() +// glog.Infof("post body: %s\n", result) + +// var op OperationChange +// err = json.Unmarshal(result, &op) +// if err != nil { +// rep.Status = 12000 +// rep.Message = "parse json body failed" +// rep.SendJson() +// return +// } + +// switch op.Method { +// case "push": +// s.PushChange(op.SourceURL, op.TargetURL, op.Stop) + +// case "pull": +// s.PullChange(op.SourceURL, op.TargetURL, op.Stop) +// } + +// rep.Status = 10000 +// rep.Message = op.Method + " from" + op.SourceURL + "to " + op.TargetURL + " success" +// rep.SendJson() +// } +// } + +// pushChange suri to turi +// func (s *Server) PushChange(suri, turi string, stop bool) error { +// if !stop { +// sinfo := parseURL(suri) +// tinfo := parseURL(turi) +// rtmpClient := rtmp.NewRtmpClient(s.handler, nil) +// return rtmpClient.Dial(turi, av.PUBLISH) +// } else { +// sinfo := parseURL(suri) +// tinfo := parseURL(turi) +// s.delStream(sinfo.Key, true) +// return nil +// } +// return nil +// } + +// pullChange +// func (s *Server) PullChange(suri, turi string, stop bool) error { +// if !stop { +// rtmpStreams, ok := s.handler.(*rtmp.RtmpStream) +// if ok { +// streams := rtmpStreams.GetStreams() +// rtmpClient := rtmp.NewRtmpClient(s.handler, nil) +// return rtmpClient.Dial(turi, av.PLAY) +// } +// } else { +// info := parseURL(suri) +// s.delStream(info.Key, false) +// return nil +// } +// return nil +// } + +func parseURL(URL string) (ret av.Info) { + ret.UID = uid.NEWID() + ret.URL = URL + _url, err := url.Parse(URL) + if err != nil { + log.Println(err) + } + ret.Key = strings.TrimLeft(_url.Path, "/") + ret.Inter = true + return +} diff --git a/protocol/kcpts/kcp_ts.go b/protocol/kcpts/kcp_ts.go new file mode 100755 index 0000000..7a87ef8 --- /dev/null +++ b/protocol/kcpts/kcp_ts.go @@ -0,0 +1 @@ +package kcpts diff --git a/protocol/private/protocol.go b/protocol/private/protocol.go new file mode 100755 index 0000000..735e4dc --- /dev/null +++ b/protocol/private/protocol.go @@ -0,0 +1 @@ +package private diff --git a/protocol/rtmp/cache/cache.go b/protocol/rtmp/cache/cache.go new file mode 100755 index 0000000..bbe9731 --- /dev/null +++ b/protocol/rtmp/cache/cache.go @@ -0,0 +1,79 @@ +package cache + +import ( + "flag" + "github.com/gwuhaolin/livego/av" +) + +var ( + gopNum = flag.Int("gopNum", 1, "gop num") +) + +type Cache struct { + gop *GopCache + videoSeq *SpecialCache + audioSeq *SpecialCache + metadata *SpecialCache +} + +func NewCache() *Cache { + return &Cache{ + gop: NewGopCache(*gopNum), + videoSeq: NewSpecialCache(), + audioSeq: NewSpecialCache(), + metadata: NewSpecialCache(), + } +} + +func (self *Cache) Write(p av.Packet) { + if p.IsMetadata { + self.metadata.Write(p) + return + } else { + if !p.IsVideo { + ah, ok := p.Header.(av.AudioPacketHeader) + if ok { + if ah.SoundFormat() == av.SOUND_AAC && + ah.AACPacketType() == av.AAC_SEQHDR { + self.audioSeq.Write(p) + return + } else { + return + } + } + + } else { + vh, ok := p.Header.(av.VideoPacketHeader) + if ok { + if vh.IsSeq() { + self.videoSeq.Write(p) + return + } + } else { + return + } + + } + } + self.gop.Write(p) +} + +func (self *Cache) Send(w av.WriteCloser) error { + if err := self.metadata.Send(w); err != nil { + return err + } + + if err := self.videoSeq.Send(w); err != nil { + return err + } + + if err := self.audioSeq.Send(w); err != nil { + return err + } + + if err := self.gop.Send(w); err != nil { + return err + } + + return nil +} diff --git a/protocol/rtmp/cache/gop.go b/protocol/rtmp/cache/gop.go new file mode 100755 index 0000000..9526eaa --- /dev/null +++ b/protocol/rtmp/cache/gop.go @@ -0,0 +1,120 @@ +package cache + +import ( + "errors" + + "github.com/gwuhaolin/livego/av" +) + +var ( + maxGOPCap int = 1024 + ErrGopTooBig = errors.New("gop to big") +) + +type array struct { + index int + packets []av.Packet +} + +func newArray() *array { + ret := &array{ + index: 0, + packets: make([]av.Packet, 0, maxGOPCap), + } + return ret +} + +func (self *array) reset() { + self.index = 0 + self.packets = self.packets[:0] +} + +func (self *array) write(packet av.Packet) error { + if self.index >= maxGOPCap { + return ErrGopTooBig + } + self.packets = append(self.packets, packet) + self.index++ + return nil +} + +func (self *array) send(w av.WriteCloser) error { + var err error + for i := 0; i < self.index; i++ { + packet := self.packets[i] + if err = w.Write(packet); err != nil { + return err + } + } + return err +} + +type GopCache struct { + start bool + num int + count int + nextindex int + gops []*array +} + +func NewGopCache(num int) *GopCache { + return &GopCache{ + count: num, + gops: make([]*array, num), + } +} + +func (self *GopCache) writeToArray(chunk av.Packet, startNew bool) error { + var ginc *array + if startNew { + ginc = self.gops[self.nextindex] + if ginc == nil { + ginc = newArray() + self.num++ + self.gops[self.nextindex] = ginc + } else { + ginc.reset() + } + self.nextindex = (self.nextindex + 1) % self.count + } else { + ginc = self.gops[(self.nextindex+1)%self.count] + } + ginc.write(chunk) + + return nil +} + +func (self *GopCache) Write(p av.Packet) { + var ok bool + if p.IsVideo { + vh := p.Header.(av.VideoPacketHeader) + if vh.IsKeyFrame() && !vh.IsSeq() { + ok = true + } + } + if ok || self.start { + self.start = true + self.writeToArray(p, ok) + } +} + +func (self *GopCache) sendTo(w av.WriteCloser) error { + var err error + pos := (self.nextindex + 1) % self.count + for i := 0; i < self.num; i++ { + index := (pos - self.num + 1) + i + if index < 0 { + index += self.count + } + g := self.gops[index] + err = g.send(w) + if err != nil { + return err + } + } + return nil +} + +func (self *GopCache) Send(w av.WriteCloser) error { + return self.sendTo(w) +} diff --git a/protocol/rtmp/cache/special.go b/protocol/rtmp/cache/special.go new file mode 100755 index 0000000..13c1fa2 --- /dev/null +++ b/protocol/rtmp/cache/special.go @@ -0,0 +1,46 @@ +package cache + +import ( + "bytes" + "log" + + "github.com/gwuhaolin/livego/protocol/amf" + "github.com/gwuhaolin/livego/av" +) + +const ( + SetDataFrame string = "@setDataFrame" + OnMetaData string = "onMetaData" +) + +var setFrameFrame []byte + +func init() { + b := bytes.NewBuffer(nil) + encoder := &amf.Encoder{} + if _, err := encoder.Encode(b, SetDataFrame, amf.AMF0); err != nil { + log.Fatal(err) + } + setFrameFrame = b.Bytes() +} + +type SpecialCache struct { + full bool + p av.Packet +} + +func NewSpecialCache() *SpecialCache { + return &SpecialCache{} +} + +func (self *SpecialCache) Write(p av.Packet) { + self.p = p + self.full = true +} + +func (self *SpecialCache) Send(w av.WriteCloser) error { + if !self.full { + return nil + } + return w.Write(self.p) +} diff --git a/protocol/rtmp/core/chunk_stream.go b/protocol/rtmp/core/chunk_stream.go new file mode 100755 index 0000000..b135531 --- /dev/null +++ b/protocol/rtmp/core/chunk_stream.go @@ -0,0 +1,225 @@ +package core + +import ( + "encoding/binary" + "fmt" + + "github.com/gwuhaolin/livego/av" + "github.com/gwuhaolin/livego/utils/pool" +) + +type ChunkStream struct { + Format uint32 + CSID uint32 + Timestamp uint32 + Length uint32 + TypeID uint32 + StreamID uint32 + timeDelta uint32 + exted bool + index uint32 + remain uint32 + got bool + tmpFromat uint32 + Data []byte +} + +func (self *ChunkStream) full() bool { + return self.got +} + +func (self *ChunkStream) new(pool *pool.Pool) { + self.got = false + self.index = 0 + self.remain = self.Length + self.Data = pool.Get(int(self.Length)) +} + +func (self *ChunkStream) writeHeader(w *ReadWriter) error { + //Chunk Basic Header + h := self.Format << 6 + switch { + case self.CSID < 64: + h |= self.CSID + w.WriteUintBE(h, 1) + case self.CSID-64 < 256: + h |= 0 + w.WriteUintBE(h, 1) + w.WriteUintLE(self.CSID-64, 1) + case self.CSID-64 < 65536: + h |= 1 + w.WriteUintBE(h, 1) + w.WriteUintLE(self.CSID-64, 2) + } + //Chunk Message Header + ts := self.Timestamp + if self.Format == 3 { + goto END + } + if self.Timestamp > 0xffffff { + ts = 0xffffff + } + w.WriteUintBE(ts, 3) + if self.Format == 2 { + goto END + } + if self.Length > 0xffffff { + return fmt.Errorf("length=%d", self.Length) + } + w.WriteUintBE(self.Length, 3) + w.WriteUintBE(self.TypeID, 1) + if self.Format == 1 { + goto END + } + w.WriteUintLE(self.StreamID, 4) +END: +//Extended Timestamp + if ts >= 0xffffff { + w.WriteUintBE(self.Timestamp, 4) + } + return w.WriteError() +} + +func (self *ChunkStream) writeChunk(w *ReadWriter, chunkSize int) error { + if self.TypeID == av.TAG_AUDIO { + self.CSID = 4 + } else if self.TypeID == av.TAG_VIDEO || + self.TypeID == av.TAG_SCRIPTDATAAMF0 || + self.TypeID == av.TAG_SCRIPTDATAAMF3 { + self.CSID = 6 + } + + totalLen := uint32(0) + numChunks := (self.Length / uint32(chunkSize)) + for i := uint32(0); i <= numChunks; i++ { + if totalLen == self.Length { + break + } + if i == 0 { + self.Format = uint32(0) + } else { + self.Format = uint32(3) + } + if err := self.writeHeader(w); err != nil { + return err + } + inc := uint32(chunkSize) + start := uint32(i) * uint32(chunkSize) + if uint32(len(self.Data))-start <= inc { + inc = uint32(len(self.Data)) - start + } + totalLen += inc + end := start + inc + buf := self.Data[start:end] + if _, err := w.Write(buf); err != nil { + return err + } + } + + return nil + +} + +func (self *ChunkStream) readChunk(r *ReadWriter, chunkSize uint32, pool *pool.Pool) error { + if self.remain != 0 && self.tmpFromat != 3 { + return fmt.Errorf("inlaid remin = %d", self.remain) + } + switch self.CSID { + case 0: + id, _ := r.ReadUintLE(1) + self.CSID = id + 64 + case 1: + id, _ := r.ReadUintLE(2) + self.CSID = id + 64 + } + + switch self.tmpFromat { + case 0: + self.Format = self.tmpFromat + self.Timestamp, _ = r.ReadUintBE(3) + self.Length, _ = r.ReadUintBE(3) + self.TypeID, _ = r.ReadUintBE(1) + self.StreamID, _ = r.ReadUintLE(4) + if self.Timestamp == 0xffffff { + self.Timestamp, _ = r.ReadUintBE(4) + self.exted = true + } else { + self.exted = false + } + self.new(pool) + case 1: + self.Format = self.tmpFromat + timeStamp, _ := r.ReadUintBE(3) + self.Length, _ = r.ReadUintBE(3) + self.TypeID, _ = r.ReadUintBE(1) + if timeStamp == 0xffffff { + timeStamp, _ = r.ReadUintBE(4) + self.exted = true + } else { + self.exted = false + } + self.timeDelta = timeStamp + self.Timestamp += timeStamp + self.new(pool) + case 2: + self.Format = self.tmpFromat + timeStamp, _ := r.ReadUintBE(3) + if timeStamp == 0xffffff { + timeStamp, _ = r.ReadUintBE(4) + self.exted = true + } else { + self.exted = false + } + self.timeDelta = timeStamp + self.Timestamp += timeStamp + self.new(pool) + case 3: + if self.remain == 0 { + switch self.Format { + case 0: + if self.exted { + timestamp, _ := r.ReadUintBE(4) + self.Timestamp = timestamp + } + case 1, 2: + var timedet uint32 + if self.exted { + timedet, _ = r.ReadUintBE(4) + } else { + timedet = self.timeDelta + } + self.Timestamp += timedet + } + self.new(pool) + } else { + if self.exted { + b, err := r.Peek(4) + if err != nil { + return err + } + tmpts := binary.BigEndian.Uint32(b) + if tmpts == self.Timestamp { + r.Discard(4) + } + } + } + default: + return fmt.Errorf("invalid format=%d", self.Format) + } + size := int(self.remain) + if size > int(chunkSize) { + size = int(chunkSize) + } + + buf := self.Data[self.index: self.index+uint32(size)] + if _, err := r.Read(buf); err != nil { + return err + } + self.index += uint32(size) + self.remain -= uint32(size) + if self.remain == 0 { + self.got = true + } + + return r.readError +} diff --git a/protocol/rtmp/core/chunk_stream_test.go b/protocol/rtmp/core/chunk_stream_test.go new file mode 100755 index 0000000..d351dac --- /dev/null +++ b/protocol/rtmp/core/chunk_stream_test.go @@ -0,0 +1,97 @@ +package core + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/gwuhaolin/livego/utils/pool" +) + +func TestChunkRead1(t *testing.T) { + at := assert.New(t) + data := []byte{ + 0x06, 0x00, 0x00, 0x00, 0x00, 0x01, 0x33, 0x09, 0x01, 0x00, 0x00, 0x00, + } + data1 := make([]byte, 128) + data2 := make([]byte, 51) + data = append(data, data1...) + data = append(data, 0xc6) + data = append(data, data1...) + data = append(data, 0xc6) + data = append(data, data2...) + + rw := NewReadWriter(bytes.NewBuffer(data), 1024) + chunkinc := &ChunkStream{} + + for { + h, _ := rw.ReadUintBE(1) + chunkinc.tmpFromat = h >> 6 + chunkinc.CSID = h & 0x3f + chunkinc.readChunk(rw, 128, pool.NewPool()) + if chunkinc.remain == 0 { + break + } + } + + at.Equal(int(chunkinc.Length), 307) + at.Equal(int(chunkinc.TypeID), 9) + at.Equal(int(chunkinc.StreamID), 1) + at.Equal(len(chunkinc.Data), 307) + at.Equal(int(chunkinc.remain), 0) + + data = []byte{ + 0x06, 0xff, 0xff, 0xff, 0x00, 0x01, 0x33, 0x09, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, + } + data = append(data, data1...) + data = append(data, 0xc6) + data = append(data, []byte{0x00, 0x00, 0x00, 0x05}...) + data = append(data, data1...) + data = append(data, 0xc6) + data = append(data, data2...) + + rw = NewReadWriter(bytes.NewBuffer(data), 1024) + chunkinc = &ChunkStream{} + + h, _ := rw.ReadUintBE(1) + chunkinc.tmpFromat = h >> 6 + chunkinc.CSID = h & 0x3f + chunkinc.readChunk(rw, 128, pool.NewPool()) + + h, _ = rw.ReadUintBE(1) + chunkinc.tmpFromat = h >> 6 + chunkinc.CSID = h & 0x3f + chunkinc.readChunk(rw, 128, pool.NewPool()) + + h, _ = rw.ReadUintBE(1) + chunkinc.tmpFromat = h >> 6 + chunkinc.CSID = h & 0x3f + chunkinc.readChunk(rw, 128, pool.NewPool()) + + at.Equal(int(chunkinc.Length), 307) + at.Equal(int(chunkinc.TypeID), 9) + at.Equal(int(chunkinc.StreamID), 1) + at.Equal(len(chunkinc.Data), 307) + at.Equal(chunkinc.exted, true) + at.Equal(int(chunkinc.Timestamp), 5) + at.Equal(int(chunkinc.remain), 0) + +} + +func TestWriteChunk(t *testing.T) { + at := assert.New(t) + chunkinc := &ChunkStream{} + + chunkinc.Length = 307 + chunkinc.TypeID = 9 + chunkinc.CSID = 4 + chunkinc.Timestamp = 40 + chunkinc.Data = make([]byte, 307) + + bf := bytes.NewBuffer(nil) + w := NewReadWriter(bf, 1024) + err := chunkinc.writeChunk(w, 128) + w.Flush() + at.Equal(err, nil) + at.Equal(len(bf.Bytes()), 321) +} diff --git a/protocol/rtmp/core/conn.go b/protocol/rtmp/core/conn.go new file mode 100755 index 0000000..9c43683 --- /dev/null +++ b/protocol/rtmp/core/conn.go @@ -0,0 +1,207 @@ +package core + +import ( + "encoding/binary" + "net" + "time" + + "github.com/gwuhaolin/livego/utils/pool" + "github.com/gwuhaolin/livego/utils/pio" +) + +const ( + _ = iota + idSetChunkSize + idAbortMessage + idAck + idUserControlMessages + idWindowAckSize + idSetPeerBandwidth +) + +type Conn struct { + net.Conn + chunkSize uint32 + remoteChunkSize uint32 + windowAckSize uint32 + remoteWindowAckSize uint32 + received uint32 + ackReceived uint32 + rw *ReadWriter + pool *pool.Pool + chunks map[uint32]ChunkStream +} + +func NewConn(c net.Conn, bufferSize int) *Conn { + return &Conn{ + Conn: c, + chunkSize: 128, + remoteChunkSize: 128, + windowAckSize: 2500000, + remoteWindowAckSize: 2500000, + pool: pool.NewPool(), + rw: NewReadWriter(c, bufferSize), + chunks: make(map[uint32]ChunkStream), + } +} + +func (self *Conn) Read(c *ChunkStream) error { + for { + h, _ := self.rw.ReadUintBE(1) + // if err != nil { + // log.Println("read from conn error: ", err) + // return err + // } + format := h >> 6 + csid := h & 0x3f + cs, ok := self.chunks[csid] + if !ok { + cs = ChunkStream{} + self.chunks[csid] = cs + } + cs.tmpFromat = format + cs.CSID = csid + err := cs.readChunk(self.rw, self.remoteChunkSize, self.pool) + if err != nil { + return err + } + self.chunks[csid] = cs + if cs.full() { + *c = cs + break + } + } + + self.handleControlMsg(c) + + self.ack(c.Length) + + return nil +} + +func (self *Conn) Write(c *ChunkStream) error { + if c.TypeID == idSetChunkSize { + self.chunkSize = binary.BigEndian.Uint32(c.Data) + } + return c.writeChunk(self.rw, int(self.chunkSize)) +} + +func (self *Conn) Flush() error { + return self.rw.Flush() +} + +func (self *Conn) Close() error { + return self.Conn.Close() +} + +func (self *Conn) RemoteAddr() net.Addr { + return self.Conn.RemoteAddr() +} + +func (self *Conn) LocalAddr() net.Addr { + return self.Conn.LocalAddr() +} + +func (self *Conn) SetDeadline(t time.Time) error { + return self.Conn.SetDeadline(t) +} + +func (self *Conn) NewAck(size uint32) ChunkStream { + return initControlMsg(idAck, 4, size) +} + +func (self *Conn) NewSetChunkSize(size uint32) ChunkStream { + return initControlMsg(idSetChunkSize, 4, size) +} + +func (self *Conn) NewWindowAckSize(size uint32) ChunkStream { + return initControlMsg(idWindowAckSize, 4, size) +} + +func (self *Conn) NewSetPeerBandwidth(size uint32) ChunkStream { + ret := initControlMsg(idSetPeerBandwidth, 5, size) + ret.Data[4] = 2 + return ret +} + +func (self *Conn) handleControlMsg(c *ChunkStream) { + if c.TypeID == idSetChunkSize { + self.remoteChunkSize = binary.BigEndian.Uint32(c.Data) + } else if c.TypeID == idWindowAckSize { + self.remoteWindowAckSize = binary.BigEndian.Uint32(c.Data) + } +} + +func (self *Conn) ack(size uint32) { + self.received += uint32(size) + self.ackReceived += uint32(size) + if self.received >= 0xf0000000 { + self.received = 0 + } + if self.ackReceived >= self.remoteWindowAckSize { + cs := self.NewAck(self.ackReceived) + cs.writeChunk(self.rw, int(self.chunkSize)) + self.ackReceived = 0 + } +} + +func initControlMsg(id, size, value uint32) ChunkStream { + ret := ChunkStream{ + Format: 0, + CSID: 2, + TypeID: id, + StreamID: 0, + Length: size, + Data: make([]byte, size), + } + pio.PutU32BE(ret.Data[:size], value) + return ret +} + +const ( + streamBegin uint32 = 0 + streamEOF uint32 = 1 + streamDry uint32 = 2 + setBufferLen uint32 = 3 + streamIsRecorded uint32 = 4 + pingRequest uint32 = 6 + pingResponse uint32 = 7 +) + +/* + +------------------------------+------------------------- + | Event Type ( 2- bytes ) | Event Data + +------------------------------+------------------------- + Pay load for the ‘User Control Message’. +*/ +func (self *Conn) userControlMsg(eventType, buflen uint32) ChunkStream { + var ret ChunkStream + buflen += 2 + ret = ChunkStream{ + Format: 0, + CSID: 2, + TypeID: 4, + StreamID: 1, + Length: buflen, + Data: make([]byte, buflen), + } + ret.Data[0] = byte(eventType >> 8 & 0xff) + ret.Data[1] = byte(eventType & 0xff) + return ret +} + +func (self *Conn) SetBegin() { + ret := self.userControlMsg(streamBegin, 4) + for i := 0; i < 4; i++ { + ret.Data[2+i] = byte(1 >> uint32((3-i)*8) & 0xff) + } + self.Write(&ret) +} + +func (self *Conn) SetRecorded() { + ret := self.userControlMsg(streamIsRecorded, 4) + for i := 0; i < 4; i++ { + ret.Data[2+i] = byte(1 >> uint32((3-i)*8) & 0xff) + } + self.Write(&ret) +} diff --git a/protocol/rtmp/core/conn_client.go b/protocol/rtmp/core/conn_client.go new file mode 100755 index 0000000..d989517 --- /dev/null +++ b/protocol/rtmp/core/conn_client.go @@ -0,0 +1,287 @@ +package core + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/rand" + "net" + neturl "net/url" + "strings" + + "github.com/gwuhaolin/livego/protocol/amf" + "github.com/golang/glog" + "github.com/gwuhaolin/livego/av" + "log" +) + +var ( + respResult = "_result" + respError = "_error" + onStatus = "onStatus" + publishStart = "NetStream.Publish.Start" + playStart = "NetStream.Play.Start" + connectSuccess = "NetConnection.Connect.Success" +) + +var ( + ErrFail = errors.New("respone err") +) + +type ConnClient struct { + done bool + transID int + url string + tcurl string + app string + title string + query string + curcmdName string + streamid uint32 + conn *Conn + encoder *amf.Encoder + decoder *amf.Decoder + bytesw *bytes.Buffer +} + +func NewConnClient() *ConnClient { + return &ConnClient{ + transID: 1, + bytesw: bytes.NewBuffer(nil), + encoder: &amf.Encoder{}, + decoder: &amf.Decoder{}, + } +} + +func (self *ConnClient) readRespMsg() error { + var err error + var rc ChunkStream + for { + if err = self.conn.Read(&rc); err != nil { + return err + } + switch rc.TypeID { + case 20, 17: + r := bytes.NewReader(rc.Data) + vs, err := self.decoder.DecodeBatch(r, amf.AMF0) + if err != nil && err != io.EOF { + return err + } + for k, v := range vs { + switch v.(type) { + case string: + switch self.curcmdName { + case cmdConnect, cmdCreateStream: + if v.(string) != respResult { + return ErrFail + } + case cmdPublish: + if v.(string) != onStatus { + return ErrFail + } + } + case float64: + switch self.curcmdName { + case cmdConnect, cmdCreateStream: + id := int(v.(float64)) + if k == 1 { + if id != self.transID { + return ErrFail + } + } else if k == 3 { + self.streamid = uint32(id) + } + case cmdPublish: + if int(v.(float64)) != 0 { + return ErrFail + } + } + case amf.Object: + objmap := v.(amf.Object) + switch self.curcmdName { + case cmdConnect: + code, ok := objmap["code"] + if ok && code.(string) != connectSuccess { + return ErrFail + } + case cmdPublish: + code, ok := objmap["code"] + if ok && code.(string) != publishStart { + return ErrFail + } + } + } + } + return nil + } + } +} + +func (self *ConnClient) writeMsg(args ...interface{}) error { + self.bytesw.Reset() + for _, v := range args { + if _, err := self.encoder.Encode(self.bytesw, v, amf.AMF0); err != nil { + return err + } + } + msg := self.bytesw.Bytes() + c := ChunkStream{ + Format: 0, + CSID: 3, + Timestamp: 0, + TypeID: 20, + StreamID: self.streamid, + Length: uint32(len(msg)), + Data: msg, + } + self.conn.Write(&c) + return self.conn.Flush() +} + +func (self *ConnClient) writeConnectMsg() error { + event := make(amf.Object) + event["app"] = self.app + event["type"] = "nonprivate" + event["flashVer"] = "FMS.3.1" + event["tcUrl"] = self.tcurl + self.curcmdName = cmdConnect + + if err := self.writeMsg(cmdConnect, self.transID, event); err != nil { + return err + } + return self.readRespMsg() +} + +func (self *ConnClient) writeCreateStreamMsg() error { + self.transID++ + self.curcmdName = cmdCreateStream + if err := self.writeMsg(cmdCreateStream, self.transID, nil); err != nil { + return err + } + return self.readRespMsg() +} + +func (self *ConnClient) writePublishMsg() error { + self.transID++ + self.curcmdName = cmdPublish + if err := self.writeMsg(cmdPublish, self.transID, nil, self.title, publishLive); err != nil { + return err + } + return self.readRespMsg() +} + +func (self *ConnClient) writePlayMsg() error { + self.transID++ + self.curcmdName = cmdPlay + if err := self.writeMsg(cmdPlay, 0, nil, self.title); err != nil { + return err + } + return self.readRespMsg() +} + +func (self *ConnClient) Start(url string, method string) error { + u, err := neturl.Parse(url) + if err != nil { + return err + } + self.url = url + path := strings.TrimLeft(u.Path, "/") + ps := strings.SplitN(path, "/", 2) + if len(ps) != 2 { + return fmt.Errorf("u path err: %s", path) + } + self.app = ps[0] + self.title = ps[1] + self.query = u.RawQuery + self.tcurl = "rtmp://" + u.Host + "/" + self.app + port := ":1935" + host := u.Host + localIP := ":0" + var remoteIP string + if strings.Index(host, ":") != -1 { + host, port, err = net.SplitHostPort(host) + if err != nil { + return err + } + port = ":" + port + } + ips, err := net.LookupIP(host) + glog.Infof("ips: %v, host: %v", ips, host) + if err != nil { + log.Println(err) + return err + } + remoteIP = ips[rand.Intn(len(ips))].String() + if strings.Index(remoteIP, ":") == -1 { + remoteIP += port + } + + local, err := net.ResolveTCPAddr("tcp", localIP) + if err != nil { + log.Println(err) + return err + } + log.Println("remoteIP: ", remoteIP) + remote, err := net.ResolveTCPAddr("tcp", remoteIP) + if err != nil { + log.Println(err) + return err + } + conn, err := net.DialTCP("tcp", local, remote) + if err != nil { + log.Println(err) + return err + } + + log.Println("connection:", "local:", conn.LocalAddr(), "remote:", conn.RemoteAddr()) + + self.conn = NewConn(conn, 4*1024) + if err := self.conn.HandshakeClient(); err != nil { + return err + } + if err := self.writeConnectMsg(); err != nil { + return err + } + if err := self.writeCreateStreamMsg(); err != nil { + return err + } + if method == av.PUBLISH { + if err := self.writePublishMsg(); err != nil { + return err + } + } else if method == av.PLAY { + if err := self.writePlayMsg(); err != nil { + return err + } + } + + return nil +} + +func (self *ConnClient) Write(c ChunkStream) error { + if c.TypeID == av.TAG_SCRIPTDATAAMF0 || + c.TypeID == av.TAG_SCRIPTDATAAMF3 { + var err error + if c.Data, err = amf.MetaDataReform(c.Data, amf.ADD); err != nil { + return err + } + c.Length = uint32(len(c.Data)) + } + return self.conn.Write(&c) +} + +func (self *ConnClient) Read(c *ChunkStream) (err error) { + return self.conn.Read(c) +} + +func (self *ConnClient) GetInfo() (app string, name string, url string) { + app = self.app + name = self.title + url = self.url + return +} + +func (self *ConnClient) Close(err error) { + self.conn.Close() +} diff --git a/protocol/rtmp/core/conn_server.go b/protocol/rtmp/core/conn_server.go new file mode 100755 index 0000000..40a953c --- /dev/null +++ b/protocol/rtmp/core/conn_server.go @@ -0,0 +1,353 @@ +package core + +import ( + "bytes" + "errors" + "io" + + "github.com/gwuhaolin/livego/protocol/amf" + "github.com/gwuhaolin/livego/av" + "log" +) + +var ( + publishLive = "live" + publishRecord = "record" + publishAppend = "append" +) + +var ( + ErrReq = errors.New("req error") +) + +var ( + cmdConnect = "connect" + cmdFcpublish = "FCPublish" + cmdReleaseStream = "releaseStream" + cmdCreateStream = "createStream" + cmdPublish = "publish" + cmdFCUnpublish = "FCUnpublish" + cmdDeleteStream = "deleteStream" + cmdPlay = "play" +) + +type ConnectInfo struct { + App string `amf:"app" json:"app"` + Flashver string `amf:"flashVer" json:"flashVer"` + SwfUrl string `amf:"swfUrl" json:"swfUrl"` + TcUrl string `amf:"tcUrl" json:"tcUrl"` + Fpad bool `amf:"fpad" json:"fpad"` + AudioCodecs int `amf:"audioCodecs" json:"audioCodecs"` + VideoCodecs int `amf:"videoCodecs" json:"videoCodecs"` + VideoFunction int `amf:"videoFunction" json:"videoFunction"` + PageUrl string `amf:"pageUrl" json:"pageUrl"` + ObjectEncoding int `amf:"objectEncoding" json:"objectEncoding"` +} + +type ConnectResp struct { + FMSVer string `amf:"fmsVer"` + Capabilities int `amf:"capabilities"` +} + +type ConnectEvent struct { + Level string `amf:"level"` + Code string `amf:"code"` + Description string `amf:"description"` + ObjectEncoding int `amf:"objectEncoding"` +} + +type PublishInfo struct { + Name string + Type string +} + +type ConnServer struct { + done bool + streamID int + isPublisher bool + conn *Conn + transactionID int + ConnInfo ConnectInfo + PublishInfo PublishInfo + decoder *amf.Decoder + encoder *amf.Encoder + bytesw *bytes.Buffer +} + +func NewConnServer(conn *Conn) *ConnServer { + return &ConnServer{ + conn: conn, + streamID: 1, + bytesw: bytes.NewBuffer(nil), + decoder: &amf.Decoder{}, + encoder: &amf.Encoder{}, + } +} + +func (self *ConnServer) writeMsg(csid, streamID uint32, args ...interface{}) error { + self.bytesw.Reset() + for _, v := range args { + if _, err := self.encoder.Encode(self.bytesw, v, amf.AMF0); err != nil { + return err + } + } + msg := self.bytesw.Bytes() + c := ChunkStream{ + Format: 0, + CSID: csid, + Timestamp: 0, + TypeID: 20, + StreamID: streamID, + Length: uint32(len(msg)), + Data: msg, + } + self.conn.Write(&c) + return self.conn.Flush() +} + +func (self *ConnServer) connect(vs []interface{}) error { + for _, v := range vs { + switch v.(type) { + case string: + case float64: + id := int(v.(float64)) + if id != 1 { + return ErrReq + } + self.transactionID = id + case amf.Object: + obimap := v.(amf.Object) + if app, ok := obimap["app"]; ok { + self.ConnInfo.App = app.(string) + } + if flashVer, ok := obimap["flashVer"]; ok { + self.ConnInfo.Flashver = flashVer.(string) + } + if tcurl, ok := obimap["tcUrl"]; ok { + self.ConnInfo.TcUrl = tcurl.(string) + } + if encoding, ok := obimap["objectEncoding"]; ok { + self.ConnInfo.ObjectEncoding = int(encoding.(float64)) + } + } + } + return nil +} + +func (self *ConnServer) releaseStream(vs []interface{}) error { + return nil +} + +func (self *ConnServer) fcPublish(vs []interface{}) error { + return nil +} + +func (self *ConnServer) connectResp(cur *ChunkStream) error { + c := self.conn.NewWindowAckSize(2500000) + self.conn.Write(&c) + c = self.conn.NewSetPeerBandwidth(2500000) + self.conn.Write(&c) + c = self.conn.NewSetChunkSize(uint32(1024)) + self.conn.Write(&c) + + resp := make(amf.Object) + resp["fmsVer"] = "FMS/3,0,1,123" + resp["capabilities"] = 31 + + event := make(amf.Object) + event["level"] = "status" + event["code"] = "NetConnection.Connect.Success" + event["description"] = "Connection succeeded." + event["objectEncoding"] = self.ConnInfo.ObjectEncoding + return self.writeMsg(cur.CSID, cur.StreamID, "_result", self.transactionID, resp, event) +} + +func (self *ConnServer) createStream(vs []interface{}) error { + for _, v := range vs { + switch v.(type) { + case string: + case float64: + self.transactionID = int(v.(float64)) + case amf.Object: + } + } + return nil +} + +func (self *ConnServer) createStreamResp(cur *ChunkStream) error { + return self.writeMsg(cur.CSID, cur.StreamID, "_result", self.transactionID, nil, self.streamID) +} + +func (self *ConnServer) publishOrPlay(vs []interface{}) error { + for k, v := range vs { + switch v.(type) { + case string: + if k == 2 { + self.PublishInfo.Name = v.(string) + } else if k == 3 { + self.PublishInfo.Type = v.(string) + } + case float64: + id := int(v.(float64)) + self.transactionID = id + case amf.Object: + } + } + + return nil +} + +func (self *ConnServer) publishResp(cur *ChunkStream) error { + event := make(amf.Object) + event["level"] = "status" + event["code"] = "NetStream.Publish.Start" + event["description"] = "Start publising." + return self.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event) +} + +func (self *ConnServer) playResp(cur *ChunkStream) error { + self.conn.SetRecorded() + self.conn.SetBegin() + + event := make(amf.Object) + event["level"] = "status" + event["code"] = "NetStream.Play.Reset" + event["description"] = "Playing and resetting stream." + if err := self.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil { + return err + } + + event["level"] = "status" + event["code"] = "NetStream.Play.Start" + event["description"] = "Started playing stream." + if err := self.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil { + return err + } + + event["level"] = "status" + event["code"] = "NetStream.Data.Start" + event["description"] = "Started playing stream." + if err := self.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil { + return err + } + + event["level"] = "status" + event["code"] = "NetStream.Play.PublishNotify" + event["description"] = "Started playing notify." + if err := self.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil { + return err + } + return self.conn.Flush() +} + +func (self *ConnServer) handleCmdMsg(c *ChunkStream) error { + amfType := amf.AMF0 + if c.TypeID == 17 { + c.Data = c.Data[1:] + } + r := bytes.NewReader(c.Data) + vs, err := self.decoder.DecodeBatch(r, amf.Version(amfType)) + if err != nil && err != io.EOF { + return err + } + // glog.Infof("rtmp req: %#v", vs) + switch vs[0].(type) { + case string: + switch vs[0].(string) { + case cmdConnect: + if err = self.connect(vs[1:]); err != nil { + return err + } + if err = self.connectResp(c); err != nil { + return err + } + case cmdCreateStream: + if err = self.createStream(vs[1:]); err != nil { + return err + } + if err = self.createStreamResp(c); err != nil { + return err + } + case cmdPublish: + if err = self.publishOrPlay(vs[1:]); err != nil { + return err + } + if err = self.publishResp(c); err != nil { + return err + } + self.done = true + self.isPublisher = true + log.Println("handle publish req done") + case cmdPlay: + if err = self.publishOrPlay(vs[1:]); err != nil { + return err + } + if err = self.playResp(c); err != nil { + return err + } + self.done = true + self.isPublisher = false + log.Println("handle play req done") + case cmdFcpublish: + self.fcPublish(vs) + case cmdReleaseStream: + self.releaseStream(vs) + case cmdFCUnpublish: + case cmdDeleteStream: + default: + log.Println("no support command=", vs[0].(string)) + } + } + + return nil +} + +func (self *ConnServer) ReadMsg() error { + var c ChunkStream + for { + if err := self.conn.Read(&c); err != nil { + return err + } + switch c.TypeID { + case 20, 17: + if err := self.handleCmdMsg(&c); err != nil { + return err + } + } + if self.done { + break + } + } + return nil +} + +func (self *ConnServer) IsPublisher() bool { + return self.isPublisher +} + +func (self *ConnServer) Write(c ChunkStream) error { + if c.TypeID == av.TAG_SCRIPTDATAAMF0 || + c.TypeID == av.TAG_SCRIPTDATAAMF3 { + var err error + if c.Data, err = amf.MetaDataReform(c.Data, amf.DEL); err != nil { + return err + } + c.Length = uint32(len(c.Data)) + } + return self.conn.Write(&c) +} + +func (self *ConnServer) Read(c *ChunkStream) (err error) { + return self.conn.Read(c) +} + +func (self *ConnServer) GetInfo() (app string, name string, url string) { + app = self.ConnInfo.App + name = self.PublishInfo.Name + url = self.ConnInfo.TcUrl + "/" + self.PublishInfo.Name + return +} + +func (self *ConnServer) Close(err error) { + self.conn.Close() +} diff --git a/protocol/rtmp/core/conn_test.go b/protocol/rtmp/core/conn_test.go new file mode 100755 index 0000000..df12260 --- /dev/null +++ b/protocol/rtmp/core/conn_test.go @@ -0,0 +1,251 @@ +package core + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/gwuhaolin/livego/utils/pool" +) + +func TestConnReadNormal(t *testing.T) { + at := assert.New(t) + data := []byte{ + 0x06, 0x00, 0x00, 0x00, 0x00, 0x01, 0x33, 0x09, 0x01, 0x00, 0x00, 0x00, + } + data1 := make([]byte, 128) + data2 := make([]byte, 51) + data = append(data, data1...) + data = append(data, 0xc6) + data = append(data, data1...) + data = append(data, 0xc6) + data = append(data, data2...) + conn := &Conn{ + pool: pool.NewPool(), + rw: NewReadWriter(bytes.NewBuffer(data), 1024), + remoteChunkSize: 128, + windowAckSize: 2500000, + remoteWindowAckSize: 2500000, + chunks: make(map[uint32]ChunkStream), + } + var c ChunkStream + err := conn.Read(&c) + at.Equal(err, nil) + at.Equal(int(c.CSID), 6) + at.Equal(int(c.Length), 307) + at.Equal(int(c.TypeID), 9) +} + +//交叉读音视频数据 +func TestConnCrossReading(t *testing.T) { + at := assert.New(t) + data1 := make([]byte, 128) + data2 := make([]byte, 51) + + videoData := []byte{ + 0x06, 0x00, 0x00, 0x00, 0x00, 0x01, 0x33, 0x09, 0x01, 0x00, 0x00, 0x00, + } + audioData := []byte{ + 0x04, 0x00, 0x00, 0x00, 0x00, 0x01, 0x33, 0x08, 0x01, 0x00, 0x00, 0x00, + } + //video 1 + videoData = append(videoData, data1...) + //video 2 + videoData = append(videoData, 0xc6) + videoData = append(videoData, data1...) + //audio 1 + videoData = append(videoData, audioData...) + videoData = append(videoData, data1...) + //audio 2 + videoData = append(videoData, 0xc4) + videoData = append(videoData, data1...) + //video 3 + videoData = append(videoData, 0xc6) + videoData = append(videoData, data2...) + //audio 3 + videoData = append(videoData, 0xc4) + videoData = append(videoData, data2...) + + conn := &Conn{ + pool: pool.NewPool(), + rw: NewReadWriter(bytes.NewBuffer(videoData), 1024), + remoteChunkSize: 128, + windowAckSize: 2500000, + remoteWindowAckSize: 2500000, + chunks: make(map[uint32]ChunkStream), + } + var c ChunkStream + //video 1 + err := conn.Read(&c) + at.Equal(err, nil) + at.Equal(int(c.TypeID), 9) + at.Equal(len(c.Data), 307) + + //audio2 + err = conn.Read(&c) + at.Equal(err, nil) + at.Equal(int(c.TypeID), 8) + at.Equal(len(c.Data), 307) + + err = conn.Read(&c) + at.Equal(err, io.EOF) +} + +func TestSetChunksizeForWrite(t *testing.T) { + at := assert.New(t) + chunk := ChunkStream{ + Format: 0, + CSID: 2, + Timestamp: 0, + Length: 4, + StreamID: 1, + TypeID: idSetChunkSize, + Data: []byte{0x00, 0x00, 0x00, 0x96}, + } + buf := bytes.NewBuffer(nil) + rw := NewReadWriter(buf, 1024) + conn := &Conn{ + pool: pool.NewPool(), + rw: rw, + chunkSize: 128, + remoteChunkSize: 128, + windowAckSize: 2500000, + remoteWindowAckSize: 2500000, + chunks: make(map[uint32]ChunkStream), + } + + audio := ChunkStream{ + Format: 0, + CSID: 4, + Timestamp: 40, + Length: 133, + StreamID: 1, + TypeID: 0x8, + } + audio.Data = make([]byte, 133) + audio.Data = audio.Data[:133] + audio.Data[0] = 0xff + audio.Data[128] = 0xff + err := conn.Write(&audio) + at.Equal(err, nil) + conn.Flush() + at.Equal(len(buf.Bytes()), 146) + + buf.Reset() + err = conn.Write(&chunk) + at.Equal(err, nil) + conn.Flush() + + buf.Reset() + err = conn.Write(&audio) + at.Equal(err, nil) + conn.Flush() + at.Equal(len(buf.Bytes()), 145) +} + +func TestSetChunksize(t *testing.T) { + at := assert.New(t) + data := []byte{ + 0x06, 0x00, 0x00, 0x00, 0x00, 0x01, 0x33, 0x09, 0x01, 0x00, 0x00, 0x00, + } + data1 := make([]byte, 128) + data2 := make([]byte, 51) + data = append(data, data1...) + data = append(data, 0xc6) + data = append(data, data1...) + data = append(data, 0xc6) + data = append(data, data2...) + rw := NewReadWriter(bytes.NewBuffer(data), 1024) + conn := &Conn{ + pool: pool.NewPool(), + rw: rw, + chunkSize: 128, + remoteChunkSize: 128, + windowAckSize: 2500000, + remoteWindowAckSize: 2500000, + chunks: make(map[uint32]ChunkStream), + } + + var c ChunkStream + err := conn.Read(&c) + at.Equal(err, nil) + at.Equal(int(c.TypeID), 9) + at.Equal(int(c.CSID), 6) + at.Equal(int(c.StreamID), 1) + at.Equal(len(c.Data), 307) + + //设置chunksize + chunkBuf := []byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x96} + conn.rw = NewReadWriter(bytes.NewBuffer(chunkBuf), 1024) + err = conn.Read(&c) + at.Equal(err, nil) + + data = data[:12] + data[7] = 0x8 + data1 = make([]byte, 150) + data2 = make([]byte, 7) + data = append(data, data1...) + data = append(data, 0xc6) + data = append(data, data1...) + data = append(data, 0xc6) + data = append(data, data2...) + + conn.rw = NewReadWriter(bytes.NewBuffer(data), 1024) + err = conn.Read(&c) + at.Equal(err, nil) + at.Equal(len(c.Data), 307) + + err = conn.Read(&c) + at.Equal(err, io.EOF) +} + +func TestConnWrite(t *testing.T) { + at := assert.New(t) + wr := bytes.NewBuffer(nil) + readWriter := NewReadWriter(wr, 128) + conn := &Conn{ + pool: pool.NewPool(), + rw: readWriter, + chunkSize: 128, + remoteChunkSize: 128, + windowAckSize: 2500000, + remoteWindowAckSize: 2500000, + chunks: make(map[uint32]ChunkStream), + } + + c1 := ChunkStream{ + Length: 3, + TypeID: 8, + CSID: 3, + Timestamp: 40, + Data: []byte{0x01, 0x02, 0x03}, + } + err := conn.Write(&c1) + at.Equal(err, nil) + conn.Flush() + at.Equal(wr.Bytes(), []byte{0x4, 0x0, 0x0, 0x28, 0x0, 0x0, 0x3, 0x8, 0x0, 0x0, 0x0, 0x0, 0x1, 0x2, 0x3}) + + //for type 1 + wr.Reset() + c1 = ChunkStream{ + Length: 4, + TypeID: 8, + CSID: 3, + Timestamp: 80, + Data: []byte{0x01, 0x02, 0x03, 0x4}, + } + err = conn.Write(&c1) + at.Equal(err, nil) + conn.Flush() + at.Equal(wr.Bytes(), []byte{0x4, 0x0, 0x0, 0x50, 0x0, 0x0, 0x4, 0x8, 0x0, 0x0, 0x0, 0x0, 0x1, 0x2, 0x3, 0x4}) + + //for type 2 + wr.Reset() + c1.Timestamp = 160 + err = conn.Write(&c1) + at.Equal(err, nil) + conn.Flush() + at.Equal(wr.Bytes(), []byte{0x4, 0x0, 0x0, 0xa0, 0x0, 0x0, 0x4, 0x8, 0x0, 0x0, 0x0, 0x0, 0x1, 0x2, 0x3, 0x4}) +} diff --git a/protocol/rtmp/core/handshake.go b/protocol/rtmp/core/handshake.go new file mode 100755 index 0000000..fa90b6b --- /dev/null +++ b/protocol/rtmp/core/handshake.go @@ -0,0 +1,207 @@ +package core + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + + "time" + + "github.com/gwuhaolin/livego/utils/pio" +) + +var ( + timeout = 5 * time.Second +) + +var ( + hsClientFullKey = []byte{ + 'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ', + 'F', 'l', 'a', 's', 'h', ' ', 'P', 'l', 'a', 'y', 'e', 'r', ' ', + '0', '0', '1', + 0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1, + 0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB, + 0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE, + } + hsServerFullKey = []byte{ + 'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ', + 'F', 'l', 'a', 's', 'h', ' ', 'M', 'e', 'd', 'i', 'a', ' ', + 'S', 'e', 'r', 'v', 'e', 'r', ' ', + '0', '0', '1', + 0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1, + 0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB, + 0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE, + } + hsClientPartialKey = hsClientFullKey[:30] + hsServerPartialKey = hsServerFullKey[:36] +) + +func hsMakeDigest(key []byte, src []byte, gap int) (dst []byte) { + h := hmac.New(sha256.New, key) + if gap <= 0 { + h.Write(src) + } else { + h.Write(src[:gap]) + h.Write(src[gap+32:]) + } + return h.Sum(nil) +} + +func hsCalcDigestPos(p []byte, base int) (pos int) { + for i := 0; i < 4; i++ { + pos += int(p[base+i]) + } + pos = (pos % 728) + base + 4 + return +} + +func hsFindDigest(p []byte, key []byte, base int) int { + gap := hsCalcDigestPos(p, base) + digest := hsMakeDigest(key, p, gap) + if bytes.Compare(p[gap:gap+32], digest) != 0 { + return -1 + } + return gap +} + +func hsParse1(p []byte, peerkey []byte, key []byte) (ok bool, digest []byte) { + var pos int + if pos = hsFindDigest(p, peerkey, 772); pos == -1 { + if pos = hsFindDigest(p, peerkey, 8); pos == -1 { + return + } + } + ok = true + digest = hsMakeDigest(key, p[pos:pos+32], -1) + return +} + +func hsCreate01(p []byte, time uint32, ver uint32, key []byte) { + p[0] = 3 + p1 := p[1:] + rand.Read(p1[8:]) + pio.PutU32BE(p1[0:4], time) + pio.PutU32BE(p1[4:8], ver) + gap := hsCalcDigestPos(p1, 8) + digest := hsMakeDigest(key, p1, gap) + copy(p1[gap:], digest) +} + +func hsCreate2(p []byte, key []byte) { + rand.Read(p) + gap := len(p) - 32 + digest := hsMakeDigest(key, p, gap) + copy(p[gap:], digest) +} + +func (self *Conn) HandshakeClient() (err error) { + var random [(1 + 1536*2) * 2]byte + + C0C1C2 := random[:1536*2+1] + C0 := C0C1C2[:1] + C0C1 := C0C1C2[:1536+1] + C2 := C0C1C2[1536+1:] + + S0S1S2 := random[1536*2+1:] + + C0[0] = 3 + // > C0C1 + self.Conn.SetDeadline(time.Now().Add(timeout)) + if _, err = self.rw.Write(C0C1); err != nil { + return + } + self.Conn.SetDeadline(time.Now().Add(timeout)) + if err = self.rw.Flush(); err != nil { + return + } + + // < S0S1S2 + self.Conn.SetDeadline(time.Now().Add(timeout)) + if _, err = io.ReadFull(self.rw, S0S1S2); err != nil { + return + } + + S1 := S0S1S2[1: 1536+1] + if ver := pio.U32BE(S1[4:8]); ver != 0 { + C2 = S1 + } else { + C2 = S1 + } + + // > C2 + self.Conn.SetDeadline(time.Now().Add(timeout)) + if _, err = self.rw.Write(C2); err != nil { + return + } + self.Conn.SetDeadline(time.Time{}) + return +} + +func (self *Conn) HandshakeServer() (err error) { + var random [(1 + 1536*2) * 2]byte + + C0C1C2 := random[:1536*2+1] + C0 := C0C1C2[:1] + C1 := C0C1C2[1: 1536+1] + C0C1 := C0C1C2[:1536+1] + C2 := C0C1C2[1536+1:] + + S0S1S2 := random[1536*2+1:] + S0 := S0S1S2[:1] + S1 := S0S1S2[1: 1536+1] + S0S1 := S0S1S2[:1536+1] + S2 := S0S1S2[1536+1:] + + // < C0C1 + self.Conn.SetDeadline(time.Now().Add(timeout)) + if _, err = io.ReadFull(self.rw, C0C1); err != nil { + return + } + self.Conn.SetDeadline(time.Now().Add(timeout)) + if C0[0] != 3 { + err = fmt.Errorf("rtmp: handshake version=%d invalid", C0[0]) + return + } + + S0[0] = 3 + + clitime := pio.U32BE(C1[0:4]) + srvtime := clitime + srvver := uint32(0x0d0e0a0d) + cliver := pio.U32BE(C1[4:8]) + + if cliver != 0 { + var ok bool + var digest []byte + if ok, digest = hsParse1(C1, hsClientPartialKey, hsServerFullKey); !ok { + err = fmt.Errorf("rtmp: handshake server: C1 invalid") + return + } + hsCreate01(S0S1, srvtime, srvver, hsServerPartialKey) + hsCreate2(S2, digest) + } else { + copy(S1, C2) + copy(S2, C1) + } + + // > S0S1S2 + self.Conn.SetDeadline(time.Now().Add(timeout)) + if _, err = self.rw.Write(S0S1S2); err != nil { + return + } + self.Conn.SetDeadline(time.Now().Add(timeout)) + if err = self.rw.Flush(); err != nil { + return + } + + // < C2 + self.Conn.SetDeadline(time.Now().Add(timeout)) + if _, err = io.ReadFull(self.rw, C2); err != nil { + return + } + self.Conn.SetDeadline(time.Time{}) + return +} diff --git a/protocol/rtmp/core/read_writer.go b/protocol/rtmp/core/read_writer.go new file mode 100755 index 0000000..669ddf1 --- /dev/null +++ b/protocol/rtmp/core/read_writer.go @@ -0,0 +1,114 @@ +package core + +import ( + "bufio" + "io" +) + +type ReadWriter struct { + *bufio.ReadWriter + readError error + writeError error +} + +func NewReadWriter(rw io.ReadWriter, bufSize int) *ReadWriter { + return &ReadWriter{ + ReadWriter: bufio.NewReadWriter(bufio.NewReaderSize(rw, bufSize), bufio.NewWriterSize(rw, bufSize)), + } +} + +func (rw *ReadWriter) Read(p []byte) (int, error) { + if rw.readError != nil { + return 0, rw.readError + } + n, err := io.ReadAtLeast(rw.ReadWriter, p, len(p)) + rw.readError = err + return n, err +} + +func (rw *ReadWriter) ReadError() error { + return rw.readError +} + +func (rw *ReadWriter) ReadUintBE(n int) (uint32, error) { + if rw.readError != nil { + return 0, rw.readError + } + ret := uint32(0) + for i := 0; i < n; i++ { + b, err := rw.ReadByte() + if err != nil { + rw.readError = err + return 0, err + } + ret = ret<<8 + uint32(b) + } + return ret, nil +} + +func (rw *ReadWriter) ReadUintLE(n int) (uint32, error) { + if rw.readError != nil { + return 0, rw.readError + } + ret := uint32(0) + for i := 0; i < n; i++ { + b, err := rw.ReadByte() + if err != nil { + rw.readError = err + return 0, err + } + ret += uint32(b) << uint32(i*8) + } + return ret, nil +} + +func (rw *ReadWriter) Flush() error { + if rw.writeError != nil { + return rw.writeError + } + + if rw.ReadWriter.Writer.Buffered() == 0 { + return nil + } + return rw.ReadWriter.Flush() +} + +func (rw *ReadWriter) Write(p []byte) (int, error) { + if rw.writeError != nil { + return 0, rw.writeError + } + return rw.ReadWriter.Write(p) +} + +func (rw *ReadWriter) WriteError() error { + return rw.writeError +} + +func (rw *ReadWriter) WriteUintBE(v uint32, n int) error { + if rw.writeError != nil { + return rw.writeError + } + for i := 0; i < n; i++ { + b := byte(v>>uint32((n-i-1)<<3)) & 0xff + if err := rw.WriteByte(b); err != nil { + rw.writeError = err + return err + } + } + return nil +} + +func (rw *ReadWriter) WriteUintLE(v uint32, n int) error { + if rw.writeError != nil { + return rw.writeError + } + for i := 0; i < n; i++ { + b := byte(v) & 0xff + if err := rw.WriteByte(b); err != nil { + rw.writeError = err + return err + } + v = v >> 8 + } + return nil +} diff --git a/protocol/rtmp/core/read_writer_test.go b/protocol/rtmp/core/read_writer_test.go new file mode 100755 index 0000000..97555ce --- /dev/null +++ b/protocol/rtmp/core/read_writer_test.go @@ -0,0 +1,136 @@ +package core + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReader(t *testing.T) { + at := assert.New(t) + buf := bytes.NewBufferString("abc") + r := NewReadWriter(buf, 1024) + b := make([]byte, 3) + n, err := r.Read(b) + at.Equal(err, nil) + at.Equal(r.ReadError(), nil) + at.Equal(n, 3) + n, err = r.Read(b) + at.Equal(err, io.EOF) + at.Equal(r.ReadError(), io.EOF) + buf.WriteString("123") + n, err = r.Read(b) + at.Equal(err, io.EOF) + at.Equal(r.ReadError(), io.EOF) + at.Equal(n, 0) +} + +func TestReaderUintBE(t *testing.T) { + at := assert.New(t) + type Test struct { + i int + value uint32 + bytes []byte + } + tests := []Test{ + {1, 0x01, []byte{0x01}}, + {2, 0x0102, []byte{0x01, 0x02}}, + {3, 0x010203, []byte{0x01, 0x02, 0x03}}, + {4, 0x01020304, []byte{0x01, 0x02, 0x03, 0x04}}, + } + for _, test := range tests { + buf := bytes.NewBuffer(test.bytes) + r := NewReadWriter(buf, 1024) + n, err := r.ReadUintBE(test.i) + at.Equal(err, nil, "test %d", test.i) + at.Equal(n, test.value, "test %d", test.i) + } +} + +func TestReaderUintLE(t *testing.T) { + at := assert.New(t) + type Test struct { + i int + value uint32 + bytes []byte + } + tests := []Test{ + {1, 0x01, []byte{0x01}}, + {2, 0x0102, []byte{0x02, 0x01}}, + {3, 0x010203, []byte{0x03, 0x02, 0x01}}, + {4, 0x01020304, []byte{0x04, 0x03, 0x02, 0x01}}, + } + for _, test := range tests { + buf := bytes.NewBuffer(test.bytes) + r := NewReadWriter(buf, 1024) + n, err := r.ReadUintLE(test.i) + at.Equal(err, nil, "test %d", test.i) + at.Equal(n, test.value, "test %d", test.i) + } +} + +func TestWriter(t *testing.T) { + at := assert.New(t) + buf := bytes.NewBuffer(nil) + w := NewReadWriter(buf, 1024) + b := []byte{1, 2, 3} + n, err := w.Write(b) + at.Equal(err, nil) + at.Equal(w.WriteError(), nil) + at.Equal(n, 3) + w.writeError = io.EOF + n, err = w.Write(b) + at.Equal(err, io.EOF) + at.Equal(w.WriteError(), io.EOF) + at.Equal(n, 0) +} + +func TestWriteUintBE(t *testing.T) { + at := assert.New(t) + type Test struct { + i int + value uint32 + bytes []byte + } + tests := []Test{ + {1, 0x01, []byte{0x01}}, + {2, 0x0102, []byte{0x01, 0x02}}, + {3, 0x010203, []byte{0x01, 0x02, 0x03}}, + {4, 0x01020304, []byte{0x01, 0x02, 0x03, 0x04}}, + } + for _, test := range tests { + buf := bytes.NewBuffer(nil) + r := NewReadWriter(buf, 1024) + err := r.WriteUintBE(test.value, test.i) + at.Equal(err, nil, "test %d", test.i) + err = r.Flush() + at.Equal(err, nil, "test %d", test.i) + at.Equal(buf.Bytes(), test.bytes, "test %d", test.i) + } +} + +func TestWriteUintLE(t *testing.T) { + at := assert.New(t) + type Test struct { + i int + value uint32 + bytes []byte + } + tests := []Test{ + {1, 0x01, []byte{0x01}}, + {2, 0x0102, []byte{0x02, 0x01}}, + {3, 0x010203, []byte{0x03, 0x02, 0x01}}, + {4, 0x01020304, []byte{0x04, 0x03, 0x02, 0x01}}, + } + for _, test := range tests { + buf := bytes.NewBuffer(nil) + r := NewReadWriter(buf, 1024) + err := r.WriteUintLE(test.value, test.i) + at.Equal(err, nil, "test %d", test.i) + err = r.Flush() + at.Equal(err, nil, "test %d", test.i) + at.Equal(buf.Bytes(), test.bytes, "test %d", test.i) + } +} diff --git a/protocol/rtmp/rtmp.go b/protocol/rtmp/rtmp.go new file mode 100755 index 0000000..7cdd677 --- /dev/null +++ b/protocol/rtmp/rtmp.go @@ -0,0 +1,341 @@ +package rtmp + +import ( + "net" + "time" + + "net/url" + + "strings" + + "errors" + + "flag" + + "github.com/gwuhaolin/livego/av" + "github.com/gwuhaolin/livego/utils/uid" + "github.com/gwuhaolin/livego/container/flv" + "github.com/golang/glog" + "github.com/gwuhaolin/livego/protocol/rtmp/core" + "log" +) + +const ( + maxQueueNum = 1024 +) + +var ( + readTimeout = flag.Int("readTimeout", 10, "read time out") + writeTimeout = flag.Int("writeTimeout", 10, "write time out") +) + +type Client struct { + handler av.Handler + getter av.GetWriter +} + +func NewRtmpClient(h av.Handler, getter av.GetWriter) *Client { + return &Client{ + handler: h, + getter: getter, + } +} + +func (self *Client) Dial(url string, method string) error { + connClient := core.NewConnClient() + if err := connClient.Start(url, method); err != nil { + return err + } + if method == av.PUBLISH { + writer := NewVirWriter(connClient) + self.handler.HandleWriter(writer) + } else if method == av.PLAY { + reader := NewVirReader(connClient) + self.handler.HandleReader(reader) + if self.getter != nil { + writer := self.getter.GetWriter(reader.Info()) + self.handler.HandleWriter(writer) + } + } + return nil +} + +func (self *Client) GetHandle() av.Handler { + return self.handler +} + +type Server struct { + handler av.Handler + getter av.GetWriter +} + +func NewRtmpServer(h av.Handler, getter av.GetWriter) *Server { + return &Server{ + handler: h, + getter: getter, + } +} + +func (self *Server) Serve(listener net.Listener) (err error) { + defer func() { + if r := recover(); r != nil { + log.Println("rtmp serve panic: ", r) + } + }() + + for { + var netconn net.Conn + netconn, err = listener.Accept() + if err != nil { + return + } + conn := core.NewConn(netconn, 4*1024) + log.Println("new client, connect remote:", conn.RemoteAddr().String(), + "local:", conn.LocalAddr().String()) + go self.handleConn(conn) + } +} + +func (self *Server) handleConn(conn *core.Conn) error { + if err := conn.HandshakeServer(); err != nil { + conn.Close() + log.Println("handleConn HandshakeServer err:", err) + return err + } + connServer := core.NewConnServer(conn) + + if err := connServer.ReadMsg(); err != nil { + conn.Close() + log.Println("handleConn read msg err:", err) + return err + } + if connServer.IsPublisher() { + reader := NewVirReader(connServer) + self.handler.HandleReader(reader) + glog.Infof("new publisher: %+v", reader.Info()) + + if self.getter != nil { + writer := self.getter.GetWriter(reader.Info()) + self.handler.HandleWriter(writer) + } + } else { + writer := NewVirWriter(connServer) + glog.Infof("new player: %+v", writer.Info()) + self.handler.HandleWriter(writer) + } + + return nil +} + +type GetInFo interface { + GetInfo() (string, string, string) +} + +type StreamReadWriteCloser interface { + GetInFo + Close(error) + Write(core.ChunkStream) error + Read(c *core.ChunkStream) error +} + +type VirWriter struct { + Uid string + closed bool + av.RWBaser + conn StreamReadWriteCloser + packetQueue chan av.Packet +} + +func NewVirWriter(conn StreamReadWriteCloser) *VirWriter { + ret := &VirWriter{ + Uid: uid.NEWID(), + conn: conn, + RWBaser: av.NewRWBaser(time.Second * time.Duration(*writeTimeout)), + packetQueue: make(chan av.Packet, maxQueueNum), + } + go ret.Check() + go func() { + err := ret.SendPacket() + if err != nil { + log.Println(err) + } + }() + return ret +} + +func (self *VirWriter) Check() { + var c core.ChunkStream + for { + if err := self.conn.Read(&c); err != nil { + self.Close(err) + return + } + } +} + +func (self *VirWriter) DropPacket(pktQue chan av.Packet, info av.Info) { + glog.Errorf("[%v] packet queue max!!!", info) + for i := 0; i < maxQueueNum-84; i++ { + tmpPkt, ok := <-pktQue + // try to don't drop audio + if ok && tmpPkt.IsAudio { + if len(pktQue) > maxQueueNum-2 { + log.Println("drop audio pkt") + <-pktQue + } else { + pktQue <- tmpPkt + } + + } + + if ok && tmpPkt.IsVideo { + videoPkt, ok := tmpPkt.Header.(av.VideoPacketHeader) + // dont't drop sps config and dont't drop key frame + if ok && (videoPkt.IsSeq() || videoPkt.IsKeyFrame()) { + pktQue <- tmpPkt + } + if len(pktQue) > maxQueueNum-10 { + log.Println("drop video pkt") + <-pktQue + } + } + + } + log.Println("packet queue len: ", len(pktQue)) +} + +// +func (self *VirWriter) Write(p av.Packet) error { + if !self.closed { + if len(self.packetQueue) >= maxQueueNum-24 { + self.DropPacket(self.packetQueue, self.Info()) + } else { + self.packetQueue <- p + } + return nil + } else { + return errors.New("closed") + } +} + +func (self *VirWriter) SendPacket() error { + var cs core.ChunkStream + for { + p, ok := <-self.packetQueue + if ok { + cs.Data = p.Data + cs.Length = uint32(len(p.Data)) + cs.StreamID = 1 + cs.Timestamp = p.TimeStamp + cs.Timestamp += self.BaseTimeStamp() + + if p.IsVideo { + cs.TypeID = av.TAG_VIDEO + } else { + if p.IsMetadata { + cs.TypeID = av.TAG_SCRIPTDATAAMF0 + } else { + cs.TypeID = av.TAG_AUDIO + } + } + + self.SetPreTime() + self.RecTimeStamp(cs.Timestamp, cs.TypeID) + err := self.conn.Write(cs) + if err != nil { + self.closed = true + return err + } + } else { + return errors.New("closed") + } + + } + return nil +} + +func (self *VirWriter) Info() (ret av.Info) { + ret.UID = self.Uid + _, _, URL := self.conn.GetInfo() + ret.URL = URL + _url, err := url.Parse(URL) + if err != nil { + log.Println(err) + } + ret.Key = strings.TrimLeft(_url.Path, "/") + ret.Inter = true + return +} + +func (self *VirWriter) Close(err error) { + log.Println("player ", self.Info(), "closed: "+err.Error()) + if !self.closed { + close(self.packetQueue) + } + self.closed = true + self.conn.Close(err) +} + +type VirReader struct { + Uid string + av.RWBaser + demuxer *flv.Demuxer + conn StreamReadWriteCloser +} + +func NewVirReader(conn StreamReadWriteCloser) *VirReader { + return &VirReader{ + Uid: uid.NEWID(), + conn: conn, + RWBaser: av.NewRWBaser(time.Second * time.Duration(*writeTimeout)), + demuxer: flv.NewDemuxer(), + } +} + +func (self *VirReader) Read(p *av.Packet) (err error) { + defer func() { + if r := recover(); r != nil { + log.Println("rtmp read packet panic: ", r) + } + }() + + self.SetPreTime() + var cs core.ChunkStream + for { + err = self.conn.Read(&cs) + if err != nil { + return err + } + if cs.TypeID == av.TAG_AUDIO || + cs.TypeID == av.TAG_VIDEO || + cs.TypeID == av.TAG_SCRIPTDATAAMF0 || + cs.TypeID == av.TAG_SCRIPTDATAAMF3 { + break + } + } + + p.IsAudio = cs.TypeID == av.TAG_AUDIO + p.IsVideo = cs.TypeID == av.TAG_VIDEO + p.IsMetadata = (cs.TypeID == av.TAG_SCRIPTDATAAMF0 || cs.TypeID == av.TAG_SCRIPTDATAAMF3) + p.Data = cs.Data + p.TimeStamp = cs.Timestamp + self.demuxer.DemuxH(p) + return err +} + +func (self *VirReader) Info() (ret av.Info) { + ret.UID = self.Uid + _, _, URL := self.conn.GetInfo() + ret.URL = URL + _url, err := url.Parse(URL) + if err != nil { + log.Println(err) + } + ret.Key = strings.TrimLeft(_url.Path, "/") + return +} + +func (self *VirReader) Close(err error) { + log.Println("publisher ", self.Info(), "closed: "+err.Error()) + self.conn.Close(err) +} diff --git a/protocol/rtmp/stream.go b/protocol/rtmp/stream.go new file mode 100755 index 0000000..dce71dd --- /dev/null +++ b/protocol/rtmp/stream.go @@ -0,0 +1,228 @@ +package rtmp + +import ( + "errors" + "time" + + "github.com/gwuhaolin/livego/utils/cmap" + "github.com/golang/glog" + "github.com/gwuhaolin/livego/av" + "github.com/gwuhaolin/livego/protocol/rtmp/cache" +) + +var ( + EmptyID = "" +) + +type RtmpStream struct { + streams cmap.ConcurrentMap +} + +func NewRtmpStream() *RtmpStream { + ret := &RtmpStream{ + streams: cmap.New(), + } + go ret.CheckAlive() + return ret +} + +func (rs *RtmpStream) HandleReader(r av.ReadCloser) { + info := r.Info() + var s *Stream + ok := rs.streams.Has(info.Key) + if !ok { + s = NewStream() + rs.streams.Set(info.Key, s) + } else { + s.TransStop() + id := s.ID() + if id != EmptyID && id != info.UID { + ns := NewStream() + s.Copy(ns) + s = ns + rs.streams.Set(info.Key, ns) + } + } + s.AddReader(r) +} + +func (rs *RtmpStream) HandleWriter(w av.WriteCloser) { + info := w.Info() + var s *Stream + ok := rs.streams.Has(info.Key) + if !ok { + s = NewStream() + rs.streams.Set(info.Key, s) + } else { + item, ok := rs.streams.Get(info.Key) + if ok { + s = item.(*Stream) + s.AddWriter(w) + } + } + +} + +func (rs *RtmpStream) GetStreams() cmap.ConcurrentMap { + return rs.streams +} + +func (rs *RtmpStream) CheckAlive() { + for { + <-time.After(5 * time.Second) + for item := range rs.streams.IterBuffered() { + v := item.Val.(*Stream) + if v.CheckAlive() == 0 { + rs.streams.Remove(item.Key) + } + } + } +} + +type Stream struct { + isStart bool + cache *cache.Cache + r av.ReadCloser + ws cmap.ConcurrentMap +} + +type PackWriterCloser struct { + init bool + w av.WriteCloser +} + +func (p *PackWriterCloser) GetWriter() av.WriteCloser { + return p.w +} + +func NewStream() *Stream { + return &Stream{ + cache: cache.NewCache(), + ws: cmap.New(), + } +} + +func (s *Stream) ID() string { + if s.r != nil { + return s.r.Info().UID + } + return EmptyID +} + +func (s *Stream) GetReader() av.ReadCloser { + return s.r +} + +func (s *Stream) GetWs() cmap.ConcurrentMap { + return s.ws +} + +func (s *Stream) Copy(dst *Stream) { + for item := range s.ws.IterBuffered() { + v := item.Val.(*PackWriterCloser) + s.ws.Remove(item.Key) + v.w.CalcBaseTimestamp() + dst.AddWriter(v.w) + } +} + +func (s *Stream) AddReader(r av.ReadCloser) { + s.r = r + go s.TransStart() +} + +func (s *Stream) AddWriter(w av.WriteCloser) { + info := w.Info() + pw := &PackWriterCloser{w: w} + s.ws.Set(info.UID, pw) +} + +func (s *Stream) TransStart() { + // debug mode don't use it + // defer func() { + // if r := recover(); r != nil { + // log.Println("rtmp TransStart panic: ", r) + // } + // }() + + s.isStart = true + var p av.Packet + for { + if !s.isStart { + s.closeInter() + return + } + err := s.r.Read(&p) + if err != nil { + s.closeInter() + s.isStart = false + return + } + s.cache.Write(p) + + for item := range s.ws.IterBuffered() { + v := item.Val.(*PackWriterCloser) + if !v.init { + if err = s.cache.Send(v.w); err != nil { + glog.Errorf("[%s] send cache packet error: %v, remove", v.w.Info(), err) + s.ws.Remove(item.Key) + continue + } + v.init = true + } else { + if err = v.w.Write(p); err != nil { + glog.Errorf("[%s] write packet error: %v, remove", v.w.Info(), err) + s.ws.Remove(item.Key) + } + } + } + } +} + +func (s *Stream) TransStop() { + if s.isStart && s.r != nil { + s.r.Close(errors.New("stop old")) + } + s.isStart = false +} + +func (s *Stream) CheckAlive() (n int) { + if s.r != nil && s.isStart { + if s.r.Alive() { + n++ + } else { + s.r.Close(errors.New("read timeout")) + } + } + for item := range s.ws.IterBuffered() { + v := item.Val.(*PackWriterCloser) + if v.w != nil { + if !v.w.Alive() && s.isStart { + s.ws.Remove(item.Key) + v.w.Close(errors.New("write timeout")) + continue + } + n++ + } + + } + return +} + +func (s *Stream) closeInter() { + if s.r != nil { + glog.Infof("[%v] publisher closed", s.r.Info()) + } + + for item := range s.ws.IterBuffered() { + v := item.Val.(*PackWriterCloser) + if v.w != nil { + if v.w.Info().IsInterval() { + v.w.Close(errors.New("closed")) + s.ws.Remove(item.Key) + glog.Infof("[%v] player closed and remove\n", v.w.Info()) + } + } + + } +} diff --git a/protocol/rtp/rtp.go b/protocol/rtp/rtp.go new file mode 100755 index 0000000..4b24952 --- /dev/null +++ b/protocol/rtp/rtp.go @@ -0,0 +1 @@ +package rtp diff --git a/protocol/rtsp/protocol.go b/protocol/rtsp/protocol.go new file mode 100755 index 0000000..185db25 --- /dev/null +++ b/protocol/rtsp/protocol.go @@ -0,0 +1 @@ +package rtsp diff --git a/protocol/rtsp/rtsp.go b/protocol/rtsp/rtsp.go new file mode 100755 index 0000000..185db25 --- /dev/null +++ b/protocol/rtsp/rtsp.go @@ -0,0 +1 @@ +package rtsp diff --git a/protocol/webrtc/webrtc.go b/protocol/webrtc/webrtc.go new file mode 100755 index 0000000..efb4a82 --- /dev/null +++ b/protocol/webrtc/webrtc.go @@ -0,0 +1 @@ +package webrtc diff --git a/utils/cmap/cmap.go b/utils/cmap/cmap.go new file mode 100755 index 0000000..764a1df --- /dev/null +++ b/utils/cmap/cmap.go @@ -0,0 +1,301 @@ +package cmap + +import ( + "encoding/json" + "sync" +) + +var SHARD_COUNT = 32 + +// A "thread" safe map of type string:Anything. +// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards. +type ConcurrentMap []*ConcurrentMapShared + +// A "thread" safe string to anything map. +type ConcurrentMapShared struct { + items map[string]interface{} + sync.RWMutex // Read Write mutex, guards access to internal map. +} + +// Creates a new concurrent map. +func New() ConcurrentMap { + m := make(ConcurrentMap, SHARD_COUNT) + for i := 0; i < SHARD_COUNT; i++ { + m[i] = &ConcurrentMapShared{items: make(map[string]interface{})} + } + return m +} + +// Returns shard under given key +func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared { + return m[uint(fnv32(key))%uint(SHARD_COUNT)] +} + +func (m ConcurrentMap) MSet(data map[string]interface{}) { + for key, value := range data { + shard := m.GetShard(key) + shard.Lock() + shard.items[key] = value + shard.Unlock() + } +} + +// Sets the given value under the specified key. +func (m *ConcurrentMap) Set(key string, value interface{}) { + // Get map shard. + shard := m.GetShard(key) + shard.Lock() + shard.items[key] = value + shard.Unlock() +} + +// Callback to return new element to be inserted into the map +// It is called while lock is held, therefore it MUST NOT +// try to access other keys in same map, as it can lead to deadlock since +// Go sync.RWLock is not reentrant +type UpsertCb func(exist bool, valueInMap interface{}, newValue interface{}) interface{} + +// Insert or Update - updates existing element or inserts a new one using UpsertCb +func (m *ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res interface{}) { + shard := m.GetShard(key) + shard.Lock() + v, ok := shard.items[key] + res = cb(ok, v, value) + shard.items[key] = res + shard.Unlock() + return res +} + +// Sets the given value under the specified key if no value was associated with it. +func (m *ConcurrentMap) SetIfAbsent(key string, value interface{}) bool { + // Get map shard. + shard := m.GetShard(key) + shard.Lock() + _, ok := shard.items[key] + if !ok { + shard.items[key] = value + } + shard.Unlock() + return !ok +} + +// Retrieves an element from map under given key. +func (m ConcurrentMap) Get(key string) (interface{}, bool) { + // Get shard + shard := m.GetShard(key) + shard.RLock() + // Get item from shard. + val, ok := shard.items[key] + shard.RUnlock() + return val, ok +} + +// Returns the number of elements within the map. +func (m ConcurrentMap) Count() int { + count := 0 + for i := 0; i < SHARD_COUNT; i++ { + shard := m[i] + shard.RLock() + count += len(shard.items) + shard.RUnlock() + } + return count +} + +// Looks up an item under specified key +func (m *ConcurrentMap) Has(key string) bool { + // Get shard + shard := m.GetShard(key) + shard.RLock() + // See if element is within shard. + _, ok := shard.items[key] + shard.RUnlock() + return ok +} + +// Removes an element from the map. +func (m *ConcurrentMap) Remove(key string) { + // Try to get shard. + shard := m.GetShard(key) + shard.Lock() + delete(shard.items, key) + shard.Unlock() +} + +// Removes an element from the map and returns it +func (m *ConcurrentMap) Pop(key string) (v interface{}, exists bool) { + // Try to get shard. + shard := m.GetShard(key) + shard.Lock() + v, exists = shard.items[key] + delete(shard.items, key) + shard.Unlock() + return v, exists +} + +// Checks if map is empty. +func (m *ConcurrentMap) IsEmpty() bool { + return m.Count() == 0 +} + +// Used by the Iter & IterBuffered functions to wrap two variables together over a channel, +type Tuple struct { + Key string + Val interface{} +} + +// Returns an iterator which could be used in a for range loop. +// +// Deprecated: using IterBuffered() will get a better performence +func (m ConcurrentMap) Iter() <-chan Tuple { + ch := make(chan Tuple) + go func() { + wg := sync.WaitGroup{} + wg.Add(SHARD_COUNT) + // Foreach shard. + for _, shard := range m { + go func(shard *ConcurrentMapShared) { + // Foreach key, value pair. + shard.RLock() + for key, val := range shard.items { + ch <- Tuple{key, val} + } + shard.RUnlock() + wg.Done() + }(shard) + } + wg.Wait() + close(ch) + }() + return ch +} + +// Returns a buffered iterator which could be used in a for range loop. +func (m ConcurrentMap) IterBuffered() <-chan Tuple { + ch := make(chan Tuple, m.Count()) + go func() { + wg := sync.WaitGroup{} + wg.Add(SHARD_COUNT) + // Foreach shard. + for _, shard := range m { + go func(shard *ConcurrentMapShared) { + // Foreach key, value pair. + shard.RLock() + for key, val := range shard.items { + ch <- Tuple{key, val} + } + shard.RUnlock() + wg.Done() + }(shard) + } + wg.Wait() + close(ch) + }() + return ch +} + +// Returns all items as map[string]interface{} +func (m ConcurrentMap) Items() map[string]interface{} { + tmp := make(map[string]interface{}) + + // Insert items to temporary map. + for item := range m.IterBuffered() { + tmp[item.Key] = item.Val + } + + return tmp +} + +// Iterator callback,called for every key,value found in +// maps. RLock is held for all calls for a given shard +// therefore callback sess consistent view of a shard, +// but not across the shards +type IterCb func(key string, v interface{}) + +// Callback based iterator, cheapest way to read +// all elements in a map. +func (m *ConcurrentMap) IterCb(fn IterCb) { + for idx := range *m { + shard := (*m)[idx] + shard.RLock() + for key, value := range shard.items { + fn(key, value) + } + shard.RUnlock() + } +} + +// Return all keys as []string +func (m ConcurrentMap) Keys() []string { + count := m.Count() + ch := make(chan string, count) + go func() { + // Foreach shard. + wg := sync.WaitGroup{} + wg.Add(SHARD_COUNT) + for _, shard := range m { + go func(shard *ConcurrentMapShared) { + // Foreach key, value pair. + shard.RLock() + for key := range shard.items { + ch <- key + } + shard.RUnlock() + wg.Done() + }(shard) + } + wg.Wait() + close(ch) + }() + + // Generate keys + keys := make([]string, 0, count) + for k := range ch { + keys = append(keys, k) + } + return keys +} + +//Reviles ConcurrentMap "private" variables to json marshal. +func (m ConcurrentMap) MarshalJSON() ([]byte, error) { + // Create a temporary map, which will hold all item spread across shards. + tmp := make(map[string]interface{}) + + // Insert items to temporary map. + for item := range m.IterBuffered() { + tmp[item.Key] = item.Val + } + return json.Marshal(tmp) +} + +func fnv32(key string) uint32 { + hash := uint32(2166136261) + const prime32 = uint32(16777619) + for i := 0; i < len(key); i++ { + hash *= prime32 + hash ^= uint32(key[i]) + } + return hash +} + +// Concurrent map uses Interface{} as its value, therefor JSON Unmarshal +// will probably won't know which to type to unmarshal into, in such case +// we'll end up with a value of type map[string]interface{}, In most cases this isn't +// out value type, this is why we've decided to remove this functionality. + +// func (m *ConcurrentMap) UnmarshalJSON(b []byte) (err error) { +// // Reverse process of Marshal. + +// tmp := make(map[string]interface{}) + +// // Unmarshal into a single map. +// if err := json.Unmarshal(b, &tmp); err != nil { +// return nil +// } + +// // foreach key,value pair in temporary map insert into our concurrent map. +// for key, val := range tmp { +// m.Set(key, val) +// } +// return nil +// } diff --git a/utils/pio/pio.go b/utils/pio/pio.go new file mode 100755 index 0000000..4a73aa9 --- /dev/null +++ b/utils/pio/pio.go @@ -0,0 +1,3 @@ +package pio + +var RecommendBufioSize = 1024 * 64 diff --git a/utils/pio/reader.go b/utils/pio/reader.go new file mode 100755 index 0000000..c28a119 --- /dev/null +++ b/utils/pio/reader.go @@ -0,0 +1,121 @@ +package pio + +func U8(b []byte) (i uint8) { + return b[0] +} + +func U16BE(b []byte) (i uint16) { + i = uint16(b[0]) + i <<= 8 + i |= uint16(b[1]) + return +} + +func I16BE(b []byte) (i int16) { + i = int16(b[0]) + i <<= 8 + i |= int16(b[1]) + return +} + +func I24BE(b []byte) (i int32) { + i = int32(int8(b[0])) + i <<= 8 + i |= int32(b[1]) + i <<= 8 + i |= int32(b[2]) + return +} + +func U24BE(b []byte) (i uint32) { + i = uint32(b[0]) + i <<= 8 + i |= uint32(b[1]) + i <<= 8 + i |= uint32(b[2]) + return +} + +func I32BE(b []byte) (i int32) { + i = int32(int8(b[0])) + i <<= 8 + i |= int32(b[1]) + i <<= 8 + i |= int32(b[2]) + i <<= 8 + i |= int32(b[3]) + return +} + +func U32LE(b []byte) (i uint32) { + i = uint32(b[3]) + i <<= 8 + i |= uint32(b[2]) + i <<= 8 + i |= uint32(b[1]) + i <<= 8 + i |= uint32(b[0]) + return +} + +func U32BE(b []byte) (i uint32) { + i = uint32(b[0]) + i <<= 8 + i |= uint32(b[1]) + i <<= 8 + i |= uint32(b[2]) + i <<= 8 + i |= uint32(b[3]) + return +} + +func U40BE(b []byte) (i uint64) { + i = uint64(b[0]) + i <<= 8 + i |= uint64(b[1]) + i <<= 8 + i |= uint64(b[2]) + i <<= 8 + i |= uint64(b[3]) + i <<= 8 + i |= uint64(b[4]) + return +} + +func U64BE(b []byte) (i uint64) { + i = uint64(b[0]) + i <<= 8 + i |= uint64(b[1]) + i <<= 8 + i |= uint64(b[2]) + i <<= 8 + i |= uint64(b[3]) + i <<= 8 + i |= uint64(b[4]) + i <<= 8 + i |= uint64(b[5]) + i <<= 8 + i |= uint64(b[6]) + i <<= 8 + i |= uint64(b[7]) + return +} + +func I64BE(b []byte) (i int64) { + i = int64(int8(b[0])) + i <<= 8 + i |= int64(b[1]) + i <<= 8 + i |= int64(b[2]) + i <<= 8 + i |= int64(b[3]) + i <<= 8 + i |= int64(b[4]) + i <<= 8 + i |= int64(b[5]) + i <<= 8 + i |= int64(b[6]) + i <<= 8 + i |= int64(b[7]) + return +} diff --git a/utils/pio/writer.go b/utils/pio/writer.go new file mode 100755 index 0000000..fdbb1b6 --- /dev/null +++ b/utils/pio/writer.go @@ -0,0 +1,87 @@ +package pio + +func PutU8(b []byte, v uint8) { + b[0] = v +} + +func PutI16BE(b []byte, v int16) { + b[0] = byte(v >> 8) + b[1] = byte(v) +} + +func PutU16BE(b []byte, v uint16) { + b[0] = byte(v >> 8) + b[1] = byte(v) +} + +func PutI24BE(b []byte, v int32) { + b[0] = byte(v >> 16) + b[1] = byte(v >> 8) + b[2] = byte(v) +} + +func PutU24BE(b []byte, v uint32) { + b[0] = byte(v >> 16) + b[1] = byte(v >> 8) + b[2] = byte(v) +} + +func PutI32BE(b []byte, v int32) { + b[0] = byte(v >> 24) + b[1] = byte(v >> 16) + b[2] = byte(v >> 8) + b[3] = byte(v) +} + +func PutU32BE(b []byte, v uint32) { + b[0] = byte(v >> 24) + b[1] = byte(v >> 16) + b[2] = byte(v >> 8) + b[3] = byte(v) +} + +func PutU32LE(b []byte, v uint32) { + b[3] = byte(v >> 24) + b[2] = byte(v >> 16) + b[1] = byte(v >> 8) + b[0] = byte(v) +} + +func PutU40BE(b []byte, v uint64) { + b[0] = byte(v >> 32) + b[1] = byte(v >> 24) + b[2] = byte(v >> 16) + b[3] = byte(v >> 8) + b[4] = byte(v) +} + +func PutU48BE(b []byte, v uint64) { + b[0] = byte(v >> 40) + b[1] = byte(v >> 32) + b[2] = byte(v >> 24) + b[3] = byte(v >> 16) + b[4] = byte(v >> 8) + b[5] = byte(v) +} + +func PutU64BE(b []byte, v uint64) { + b[0] = byte(v >> 56) + b[1] = byte(v >> 48) + b[2] = byte(v >> 40) + b[3] = byte(v >> 32) + b[4] = byte(v >> 24) + b[5] = byte(v >> 16) + b[6] = byte(v >> 8) + b[7] = byte(v) +} + +func PutI64BE(b []byte, v int64) { + b[0] = byte(v >> 56) + b[1] = byte(v >> 48) + b[2] = byte(v >> 40) + b[3] = byte(v >> 32) + b[4] = byte(v >> 24) + b[5] = byte(v >> 16) + b[6] = byte(v >> 8) + b[7] = byte(v) +} diff --git a/utils/pool/pool.go b/utils/pool/pool.go new file mode 100755 index 0000000..cd5c16d --- /dev/null +++ b/utils/pool/pool.go @@ -0,0 +1,24 @@ +package pool + +type Pool struct { + pos int + buf []byte +} + +const maxpoolsize = 500 * 1024 + +func (self *Pool) Get(size int) []byte { + if maxpoolsize-self.pos < size { + self.pos = 0 + self.buf = make([]byte, maxpoolsize) + } + b := self.buf[self.pos: self.pos+size] + self.pos += size + return b +} + +func NewPool() *Pool { + return &Pool{ + buf: make([]byte, maxpoolsize), + } +} diff --git a/utils/queue/queue.go b/utils/queue/queue.go new file mode 100755 index 0000000..fc82ff9 --- /dev/null +++ b/utils/queue/queue.go @@ -0,0 +1,72 @@ +package queue + +import ( + "sync" + + "github.com/gwuhaolin/livego/av" +) + +// Queue is a basic FIFO queue for Messages. +type Queue struct { + maxSize int + + list []*av.Packet + mutex sync.Mutex +} + +// NewQueue returns a new Queue. If maxSize is greater than zero the queue will +// not grow more than the defined size. +func NewQueue(maxSize int) *Queue { + return &Queue{ + maxSize: maxSize, + } +} + +// Push adds a message to the queue. +func (q *Queue) Push(msg *av.Packet) { + q.mutex.Lock() + defer q.mutex.Unlock() + + if len(q.list) == q.maxSize { + q.pop() + } + + q.list = append(q.list, msg) +} + +// Pop removes and returns a message from the queue in first to last order. +func (q *Queue) Pop() *av.Packet { + q.mutex.Lock() + defer q.mutex.Unlock() + + if len(q.list) == 0 { + return nil + } + + return q.pop() +} + +func (q *Queue) pop() *av.Packet { + x := len(q.list) - 1 + msg := q.list[x] + q.list = q.list[:x] + return msg +} + +// Len returns the length of the queue. +func (q *Queue) Len() int { + q.mutex.Lock() + defer q.mutex.Unlock() + + return len(q.list) +} + +// All returns and removes all messages from the queue. +func (q *Queue) All() []*av.Packet { + q.mutex.Lock() + defer q.mutex.Unlock() + + cache := q.list + q.list = nil + return cache +} diff --git a/utils/uid/uuid.go b/utils/uid/uuid.go new file mode 100755 index 0000000..af9fd57 --- /dev/null +++ b/utils/uid/uuid.go @@ -0,0 +1,450 @@ +package uid + +// Copyright (C) 2013-2015 by Maxim Bublis +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Package uuid provides implementation of Universally Unique Identifier (UUID). +// Supported versions are 1, 3, 4 and 5 (as specified in RFC 4122) and +// version 2 (as specified in DCE 1.1). + +import ( + "bytes" + "crypto/md5" + "crypto/rand" + "crypto/sha1" + "database/sql/driver" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "fmt" + "hash" + "net" + "os" + "sync" + "time" +) + +func NEWID() string { + id := NewV4() + b64 := base64.URLEncoding.EncodeToString(id.Bytes()[:12]) + return b64 +} + +// UUID layout variants. +const ( + VariantNCS = iota + VariantRFC4122 + VariantMicrosoft + VariantFuture +) + +// UUID DCE domains. +const ( + DomainPerson = iota + DomainGroup + DomainOrg +) + +// Difference in 100-nanosecond intervals between +// UUID epoch (October 15, 1582) and Unix epoch (January 1, 1970). +const epochStart = 122192928000000000 + +// Used in string method conversion +const dash byte = '-' + +// UUID v1/v2 storage. +var ( + storageMutex sync.Mutex + storageOnce sync.Once + epochFunc = unixTimeFunc + clockSequence uint16 + lastTime uint64 + hardwareAddr [6]byte + posixUID = uint32(os.Getuid()) + posixGID = uint32(os.Getgid()) +) + +// String parse helpers. +var ( + urnPrefix = []byte("urn:uuid:") + byteGroups = []int{8, 4, 4, 4, 12} +) + +func initClockSequence() { + buf := make([]byte, 2) + safeRandom(buf) + clockSequence = binary.BigEndian.Uint16(buf) +} + +func initHardwareAddr() { + interfaces, err := net.Interfaces() + if err == nil { + for _, iface := range interfaces { + if len(iface.HardwareAddr) >= 6 { + copy(hardwareAddr[:], iface.HardwareAddr) + return + } + } + } + + // Initialize hardwareAddr randomly in case + // of real network interfaces absence + safeRandom(hardwareAddr[:]) + + // Set multicast bit as recommended in RFC 4122 + hardwareAddr[0] |= 0x01 +} + +func initStorage() { + initClockSequence() + initHardwareAddr() +} + +func safeRandom(dest []byte) { + if _, err := rand.Read(dest); err != nil { + panic(err) + } +} + +// Returns difference in 100-nanosecond intervals between +// UUID epoch (October 15, 1582) and current time. +// This is default epoch calculation function. +func unixTimeFunc() uint64 { + return epochStart + uint64(time.Now().UnixNano()/100) +} + +// UUID representation compliant with specification +// described in RFC 4122. +type UUID [16]byte + +// The nil UUID is special form of UUID that is specified to have all +// 128 bits set to zero. +var Nil = UUID{} + +// Predefined namespace UUIDs. +var ( + NamespaceDNS, _ = FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8") + NamespaceURL, _ = FromString("6ba7b811-9dad-11d1-80b4-00c04fd430c8") + NamespaceOID, _ = FromString("6ba7b812-9dad-11d1-80b4-00c04fd430c8") + NamespaceX500, _ = FromString("6ba7b814-9dad-11d1-80b4-00c04fd430c8") +) + +// And returns result of binary AND of two UUIDs. +func And(u1 UUID, u2 UUID) UUID { + u := UUID{} + for i := 0; i < 16; i++ { + u[i] = u1[i] & u2[i] + } + return u +} + +// Or returns result of binary OR of two UUIDs. +func Or(u1 UUID, u2 UUID) UUID { + u := UUID{} + for i := 0; i < 16; i++ { + u[i] = u1[i] | u2[i] + } + return u +} + +// Equal returns true if u1 and u2 equals, otherwise returns false. +func Equal(u1 UUID, u2 UUID) bool { + return bytes.Equal(u1[:], u2[:]) +} + +// Version returns algorithm version used to generate UUID. +func (u UUID) Version() uint { + return uint(u[6] >> 4) +} + +// Variant returns UUID layout variant. +func (u UUID) Variant() uint { + switch { + case (u[8] & 0x80) == 0x00: + return VariantNCS + case (u[8]&0xc0)|0x80 == 0x80: + return VariantRFC4122 + case (u[8]&0xe0)|0xc0 == 0xc0: + return VariantMicrosoft + } + return VariantFuture +} + +// Bytes returns bytes slice representation of UUID. +func (u UUID) Bytes() []byte { + return u[:] +} + +// Returns canonical string representation of UUID: +// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx. +func (u UUID) String() string { + buf := make([]byte, 36) + + hex.Encode(buf[0:8], u[0:4]) + buf[8] = dash + hex.Encode(buf[9:13], u[4:6]) + buf[13] = dash + hex.Encode(buf[14:18], u[6:8]) + buf[18] = dash + hex.Encode(buf[19:23], u[8:10]) + buf[23] = dash + hex.Encode(buf[24:], u[10:]) + + return string(buf) +} + +// SetVersion sets version bits. +func (u *UUID) SetVersion(v byte) { + u[6] = (u[6] & 0x0f) | (v << 4) +} + +// SetVariant sets variant bits as described in RFC 4122. +func (u *UUID) SetVariant() { + u[8] = (u[8] & 0xbf) | 0x80 +} + +// MarshalText implements the encoding.TextMarshaler interface. +// The encoding is the same as returned by String. +func (u UUID) MarshalText() (text []byte, err error) { + text = []byte(u.String()) + return +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// Following formats are supported: +// "6ba7b810-9dad-11d1-80b4-00c04fd430c8", +// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}", +// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8" +func (u *UUID) UnmarshalText(text []byte) (err error) { + if len(text) < 32 { + err = fmt.Errorf("uuid: UUID string too short: %s", text) + return + } + + t := text[:] + + if bytes.Equal(t[:9], urnPrefix) { + t = t[9:] + } else if t[0] == '{' { + t = t[1:] + } + + b := u[:] + + for _, byteGroup := range byteGroups { + if t[0] == '-' { + t = t[1:] + } + + if len(t) < byteGroup { + err = fmt.Errorf("uuid: UUID string too short: %s", text) + return + } + + _, err = hex.Decode(b[:byteGroup/2], t[:byteGroup]) + + if err != nil { + return + } + + t = t[byteGroup:] + b = b[byteGroup/2:] + } + + return +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +func (u UUID) MarshalBinary() (data []byte, err error) { + data = u.Bytes() + return +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +// It will return error if the slice isn't 16 bytes long. +func (u *UUID) UnmarshalBinary(data []byte) (err error) { + if len(data) != 16 { + err = fmt.Errorf("uuid: UUID must be exactly 16 bytes long, got %d bytes", len(data)) + return + } + copy(u[:], data) + + return +} + +// Value implements the driver.Valuer interface. +func (u UUID) Value() (driver.Value, error) { + return u.String(), nil +} + +// Scan implements the sql.Scanner interface. +// A 16-byte slice is handled by UnmarshalBinary, while +// a longer byte slice or a string is handled by UnmarshalText. +func (u *UUID) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + if len(src) == 16 { + return u.UnmarshalBinary(src) + } + return u.UnmarshalText(src) + + case string: + return u.UnmarshalText([]byte(src)) + } + + return fmt.Errorf("uuid: cannot convert %T to UUID", src) +} + +// FromBytes returns UUID converted from raw byte slice input. +// It will return error if the slice isn't 16 bytes long. +func FromBytes(input []byte) (u UUID, err error) { + err = u.UnmarshalBinary(input) + return +} + +// FromBytesOrNil returns UUID converted from raw byte slice input. +// Same behavior as FromBytes, but returns a Nil UUID on error. +func FromBytesOrNil(input []byte) UUID { + uuid, err := FromBytes(input) + if err != nil { + return Nil + } + return uuid +} + +// FromString returns UUID parsed from string input. +// Input is expected in a form accepted by UnmarshalText. +func FromString(input string) (u UUID, err error) { + err = u.UnmarshalText([]byte(input)) + return +} + +// FromStringOrNil returns UUID parsed from string input. +// Same behavior as FromString, but returns a Nil UUID on error. +func FromStringOrNil(input string) UUID { + uuid, err := FromString(input) + if err != nil { + return Nil + } + return uuid +} + +// Returns UUID v1/v2 storage state. +// Returns epoch timestamp, clock sequence, and hardware address. +func getStorage() (uint64, uint16, []byte) { + storageOnce.Do(initStorage) + + storageMutex.Lock() + defer storageMutex.Unlock() + + timeNow := epochFunc() + // Clock changed backwards since last UUID generation. + // Should increase clock sequence. + if timeNow <= lastTime { + clockSequence++ + } + lastTime = timeNow + + return timeNow, clockSequence, hardwareAddr[:] +} + +// NewV1 returns UUID based on current timestamp and MAC address. +func NewV1() UUID { + u := UUID{} + + timeNow, clockSeq, hardwareAddr := getStorage() + + binary.BigEndian.PutUint32(u[0:], uint32(timeNow)) + binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32)) + binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48)) + binary.BigEndian.PutUint16(u[8:], clockSeq) + + copy(u[10:], hardwareAddr) + + u.SetVersion(1) + u.SetVariant() + + return u +} + +// NewV2 returns DCE Security UUID based on POSIX UID/GID. +func NewV2(domain byte) UUID { + u := UUID{} + + timeNow, clockSeq, hardwareAddr := getStorage() + + switch domain { + case DomainPerson: + binary.BigEndian.PutUint32(u[0:], posixUID) + case DomainGroup: + binary.BigEndian.PutUint32(u[0:], posixGID) + } + + binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32)) + binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48)) + binary.BigEndian.PutUint16(u[8:], clockSeq) + u[9] = domain + + copy(u[10:], hardwareAddr) + + u.SetVersion(2) + u.SetVariant() + + return u +} + +// NewV3 returns UUID based on MD5 hash of namespace UUID and name. +func NewV3(ns UUID, name string) UUID { + u := newFromHash(md5.New(), ns, name) + u.SetVersion(3) + u.SetVariant() + + return u +} + +// NewV4 returns random generated UUID. +func NewV4() UUID { + u := UUID{} + safeRandom(u[:]) + u.SetVersion(4) + u.SetVariant() + + return u +} + +// NewV5 returns UUID based on SHA-1 hash of namespace UUID and name. +func NewV5(ns UUID, name string) UUID { + u := newFromHash(sha1.New(), ns, name) + u.SetVersion(5) + u.SetVariant() + + return u +} + +// Returns UUID based on hashing of namespace UUID and name. +func newFromHash(h hash.Hash, ns UUID, name string) UUID { + u := UUID{} + h.Write(ns[:]) + h.Write([]byte(name)) + copy(u[:], h.Sum(nil)) + + return u +}