Browse Source

Initial commit

pull/6/head
halwu(吴浩麟) 8 years ago
commit
e9952937dc
  1. 15
      README.md
  2. 152
      av/av.go
  3. 53
      av/rwbase.go
  4. 44
      container/flv/demuxer.go
  5. 141
      container/flv/muxer.go
  6. 180
      container/flv/tag.go
  7. 1
      container/mp4/muxer.go
  8. 78
      container/ts/crc32.go
  9. 365
      container/ts/muxer.go
  10. 51
      container/ts/muxer_test.go
  11. 172
      main.go
  12. 113
      parser/aac/parser.go
  13. 232
      parser/h264/parser.go
  14. 91
      parser/h264/parser_test.go
  15. 41
      parser/mp3/parser.go
  16. 68
      parser/parser.go
  17. 50
      protocol/amf/amf.go
  18. 206
      protocol/amf/amf_test.go
  19. 105
      protocol/amf/const.go
  20. 335
      protocol/amf/decoder_amf0.go
  21. 588
      protocol/amf/decoder_amf0_test.go
  22. 496
      protocol/amf/decoder_amf3.go
  23. 127
      protocol/amf/decoder_amf3_external.go
  24. 220
      protocol/amf/decoder_amf3_test.go
  25. 308
      protocol/amf/encoder_amf0.go
  26. 212
      protocol/amf/encoder_amf0_test.go
  27. 431
      protocol/amf/encoder_amf3.go
  28. 199
      protocol/amf/encoder_amf3_test.go
  29. 70
      protocol/amf/metadata.go
  30. 92
      protocol/amf/util.go
  31. 1
      protocol/dash/dash.go
  32. 29
      protocol/hls/align.go
  33. 44
      protocol/hls/audio_cache.go
  34. 413
      protocol/hls/hls.go
  35. 43
      protocol/hls/status.go
  36. 127
      protocol/hls/ts_cache.go
  37. 274
      protocol/httpflv/http_flv.go
  38. 232
      protocol/httpopera/http_opera.go
  39. 1
      protocol/kcpts/kcp_ts.go
  40. 1
      protocol/private/protocol.go
  41. 79
      protocol/rtmp/cache/cache.go
  42. 120
      protocol/rtmp/cache/gop.go
  43. 46
      protocol/rtmp/cache/special.go
  44. 225
      protocol/rtmp/core/chunk_stream.go
  45. 97
      protocol/rtmp/core/chunk_stream_test.go
  46. 207
      protocol/rtmp/core/conn.go
  47. 287
      protocol/rtmp/core/conn_client.go
  48. 353
      protocol/rtmp/core/conn_server.go
  49. 251
      protocol/rtmp/core/conn_test.go
  50. 207
      protocol/rtmp/core/handshake.go
  51. 114
      protocol/rtmp/core/read_writer.go
  52. 136
      protocol/rtmp/core/read_writer_test.go
  53. 341
      protocol/rtmp/rtmp.go
  54. 228
      protocol/rtmp/stream.go
  55. 1
      protocol/rtp/rtp.go
  56. 1
      protocol/rtsp/protocol.go
  57. 1
      protocol/rtsp/rtsp.go
  58. 1
      protocol/webrtc/webrtc.go
  59. 301
      utils/cmap/cmap.go
  60. 3
      utils/pio/pio.go
  61. 121
      utils/pio/reader.go
  62. 87
      utils/pio/writer.go
  63. 24
      utils/pool/pool.go
  64. 72
      utils/queue/queue.go
  65. 450
      utils/uid/uuid.go

15
README.md

@ -0,0 +1,15 @@ @@ -0,0 +1,15 @@
# AV streaming server
## Feature
- write in pure golang, can run in any platform
- for live streaming
- support `RTMP` and `FLV` `HLS` over HTTP
## Use
1. run `git clone `
2. run `go run main.go` to start livego server
3. push `RTMP` stream to `rtmp://localhost/live/movie`, eg use `ffmpeg -re -i demo.flv -c copy -f flv rtmp://localhost/live/movie`
4. play stream use [VLC](http://www.videolan.org/vlc/index.html) or other players
- play `RTMP` from `rtmp://localhost/live/movie`
- play `FLV` from `http://127.0.0.1:8081/live/movie.flv`
- play `HLS` from `http://127.0.0.1:8080/live/movie.m3u8`

152
av/av.go

@ -0,0 +1,152 @@ @@ -0,0 +1,152 @@
package av
import "io"
import "fmt"
const (
TAG_AUDIO = 8
TAG_VIDEO = 9
TAG_SCRIPTDATAAMF0 = 18
TAG_SCRIPTDATAAMF3 = 0xf
)
const (
MetadatAMF0 = 0x12
MetadataAMF3 = 0xf
)
const (
SOUND_MP3 = 2
SOUND_NELLYMOSER_16KHZ_MONO = 4
SOUND_NELLYMOSER_8KHZ_MONO = 5
SOUND_NELLYMOSER = 6
SOUND_ALAW = 7
SOUND_MULAW = 8
SOUND_AAC = 10
SOUND_SPEEX = 11
SOUND_5_5Khz = 0
SOUND_11Khz = 1
SOUND_22Khz = 2
SOUND_44Khz = 3
SOUND_8BIT = 0
SOUND_16BIT = 1
SOUND_MONO = 0
SOUND_STEREO = 1
AAC_SEQHDR = 0
AAC_RAW = 1
)
const (
AVC_SEQHDR = 0
AVC_NALU = 1
AVC_EOS = 2
FRAME_KEY = 1
FRAME_INTER = 2
VIDEO_H264 = 7
)
var (
PUBLISH = "publish"
PLAY = "play"
)
// Header can be converted to AudioHeaderInfo or VideoHeaderInfo
type Packet struct {
IsAudio bool
IsVideo bool
IsMetadata bool
TimeStamp uint32 // dts
Header PacketHeader
Data []byte
}
type PacketHeader interface {
}
type AudioPacketHeader interface {
PacketHeader
SoundFormat() uint8
AACPacketType() uint8
}
type VideoPacketHeader interface {
PacketHeader
IsKeyFrame() bool
IsSeq() bool
CodecID() uint8
CompositionTime() int32
}
type Demuxer interface {
Demux(*Packet) (ret *Packet, err error)
}
type Muxer interface {
Mux(*Packet, io.Writer) error
}
type SampleRater interface {
SampleRate() (int, error)
}
type CodecParser interface {
SampleRater
Parse(*Packet, io.Writer) error
}
type GetWriter interface {
GetWriter(Info) WriteCloser
}
type Handler interface {
HandleReader(ReadCloser)
HandleWriter(WriteCloser)
}
type Alive interface {
Alive() bool
}
type Closer interface {
Info() Info
Close(error)
}
type CalcTime interface {
CalcBaseTimestamp()
}
type Info struct {
Key string
URL string
UID string
Inter bool
}
func (self Info) IsInterval() bool {
return self.Inter
}
func (i Info) String() string {
return fmt.Sprintf("<key: %s, URL: %s, UID: %s, Inter: %v>",
i.Key, i.URL, i.UID, i.Inter)
}
type ReadCloser interface {
Closer
Alive
Read(*Packet) error
}
type WriteCloser interface {
Closer
Alive
CalcTime
Write(Packet) error
}

53
av/rwbase.go

@ -0,0 +1,53 @@ @@ -0,0 +1,53 @@
package av
import "time"
import "sync"
type RWBaser struct {
lock sync.Mutex
timeout time.Duration
PreTime time.Time
BaseTimestamp uint32
LastVideoTimestamp uint32
LastAudioTimestamp uint32
}
func NewRWBaser(duration time.Duration) RWBaser {
return RWBaser{
timeout: duration,
PreTime: time.Now(),
}
}
func (self *RWBaser) BaseTimeStamp() uint32 {
return self.BaseTimestamp
}
func (self *RWBaser) CalcBaseTimestamp() {
if self.LastAudioTimestamp > self.LastVideoTimestamp {
self.BaseTimestamp = self.LastAudioTimestamp
} else {
self.BaseTimestamp = self.LastVideoTimestamp
}
}
func (self *RWBaser) RecTimeStamp(timestamp, typeID uint32) {
if typeID == TAG_VIDEO {
self.LastVideoTimestamp = timestamp
} else if typeID == TAG_AUDIO {
self.LastAudioTimestamp = timestamp
}
}
func (self *RWBaser) SetPreTime() {
self.lock.Lock()
self.PreTime = time.Now()
self.lock.Unlock()
}
func (self *RWBaser) Alive() bool {
self.lock.Lock()
b := !(time.Now().Sub(self.PreTime) >= self.timeout)
self.lock.Unlock()
return b
}

44
container/flv/demuxer.go

@ -0,0 +1,44 @@ @@ -0,0 +1,44 @@
package flv
import (
"errors"
"github.com/gwuhaolin/livego/av"
)
var (
ErrAvcEndSEQ = errors.New("avc end sequence")
)
type Demuxer struct {
}
func NewDemuxer() *Demuxer {
return &Demuxer{}
}
func (self *Demuxer) DemuxH(p *av.Packet) error {
var tag Tag
_, err := tag.ParseMeidaTagHeader(p.Data, p.IsVideo)
if err != nil {
return err
}
p.Header = &tag
return nil
}
func (self *Demuxer) Demux(p *av.Packet) error {
var tag Tag
n, err := tag.ParseMeidaTagHeader(p.Data, p.IsVideo)
if err != nil {
return err
}
if tag.CodecID() == av.VIDEO_H264 &&
p.Data[0] == 0x17 && p.Data[1] == 0x02 {
return ErrAvcEndSEQ
}
p.Header = &tag
p.Data = p.Data[n:]
return nil
}

141
container/flv/muxer.go

@ -0,0 +1,141 @@ @@ -0,0 +1,141 @@
package flv
import (
"strings"
"time"
"flag"
"os"
"log"
"github.com/gwuhaolin/livego/utils/uid"
"github.com/gwuhaolin/livego/protocol/amf"
"github.com/gwuhaolin/livego/av"
"github.com/gwuhaolin/livego/utils/pio"
)
var (
flvHeader = []byte{0x46, 0x4c, 0x56, 0x01, 0x05, 0x00, 0x00, 0x00, 0x09}
)
var (
flvFile = flag.String("filFile", "./out.flv", "output flv file name")
)
func NewFlv(handler av.Handler, info av.Info) {
patths := strings.SplitN(info.Key, "/", 2)
if len(patths) != 2 {
log.Println("invalid info")
return
}
w, err := os.OpenFile(*flvFile, os.O_CREATE|os.O_RDWR, 0755)
if err != nil {
log.Println("open file error: ", err)
}
writer := NewFLVWriter(patths[0], patths[1], info.URL, w)
handler.HandleWriter(writer)
writer.Wait()
// close flv file
log.Println("close flv file")
writer.ctx.Close()
}
const (
headerLen = 11
)
type FLVWriter struct {
Uid string
av.RWBaser
app, title, url string
buf []byte
closed chan struct{}
ctx *os.File
}
func NewFLVWriter(app, title, url string, ctx *os.File) *FLVWriter {
ret := &FLVWriter{
Uid: uid.NEWID(),
app: app,
title: title,
url: url,
ctx: ctx,
RWBaser: av.NewRWBaser(time.Second * 10),
closed: make(chan struct{}),
buf: make([]byte, headerLen),
}
ret.ctx.Write(flvHeader)
pio.PutI32BE(ret.buf[:4], 0)
ret.ctx.Write(ret.buf[:4])
return ret
}
func (self *FLVWriter) Write(p av.Packet) error {
self.RWBaser.SetPreTime()
h := self.buf[:headerLen]
typeID := av.TAG_VIDEO
if !p.IsVideo {
if p.IsMetadata {
var err error
typeID = av.TAG_SCRIPTDATAAMF0
p.Data, err = amf.MetaDataReform(p.Data, amf.DEL)
if err != nil {
return err
}
} else {
typeID = av.TAG_AUDIO
}
}
dataLen := len(p.Data)
timestamp := p.TimeStamp
timestamp += self.BaseTimeStamp()
self.RWBaser.RecTimeStamp(timestamp, uint32(typeID))
preDataLen := dataLen + headerLen
timestampbase := timestamp & 0xffffff
timestampExt := timestamp >> 24 & 0xff
pio.PutU8(h[0:1], uint8(typeID))
pio.PutI24BE(h[1:4], int32(dataLen))
pio.PutI24BE(h[4:7], int32(timestampbase))
pio.PutU8(h[7:8], uint8(timestampExt))
if _, err := self.ctx.Write(h); err != nil {
return err
}
if _, err := self.ctx.Write(p.Data); err != nil {
return err
}
pio.PutI32BE(h[:4], int32(preDataLen))
if _, err := self.ctx.Write(h[:4]); err != nil {
return err
}
return nil
}
func (self *FLVWriter) Wait() {
select {
case <-self.closed:
return
}
}
func (self *FLVWriter) Close(error) {
self.ctx.Close()
close(self.closed)
}
func (self *FLVWriter) Info() (ret av.Info) {
ret.UID = self.Uid
ret.URL = self.url
ret.Key = self.app + "/" + self.title
return
}

180
container/flv/tag.go

@ -0,0 +1,180 @@ @@ -0,0 +1,180 @@
package flv
import (
"fmt"
"github.com/gwuhaolin/livego/av"
)
type flvTag struct {
fType uint8
dataSize uint32
timeStamp uint32
streamID uint32 // always 0
}
type mediaTag struct {
/*
SoundFormat: UB[4]
0 = Linear PCM, platform endian
1 = ADPCM
2 = MP3
3 = Linear PCM, little endian
4 = Nellymoser 16-kHz mono
5 = Nellymoser 8-kHz mono
6 = Nellymoser
7 = G.711 A-law logarithmic PCM
8 = G.711 mu-law logarithmic PCM
9 = reserved
10 = AAC
11 = Speex
14 = MP3 8-Khz
15 = Device-specific sound
Formats 7, 8, 14, and 15 are reserved for internal use
AAC is supported in Flash Player 9,0,115,0 and higher.
Speex is supported in Flash Player 10 and higher.
*/
soundFormat uint8
/*
SoundRate: UB[2]
Sampling rate
0 = 5.5-kHz For AAC: always 3
1 = 11-kHz
2 = 22-kHz
3 = 44-kHz
*/
soundRate uint8
/*
SoundSize: UB[1]
0 = snd8Bit
1 = snd16Bit
Size of each sample.
This parameter only pertains to uncompressed formats.
Compressed formats always decode to 16 bits internally
*/
soundSize uint8
/*
SoundType: UB[1]
0 = sndMono
1 = sndStereo
Mono or stereo sound For Nellymoser: always 0
For AAC: always 1
*/
soundType uint8
/*
0: AAC sequence header
1: AAC raw
*/
aacPacketType uint8
/*
1: keyframe (for AVC, a seekable frame)
2: inter frame (for AVC, a non- seekable frame)
3: disposable inter frame (H.263 only)
4: generated keyframe (reserved for server use only)
5: video info/command frame
*/
frameType uint8
/*
1: JPEG (currently unused)
2: Sorenson H.263
3: Screen video
4: On2 VP6
5: On2 VP6 with alpha channel
6: Screen video version 2
7: AVC
*/
codecID uint8
/*
0: AVC sequence header
1: AVC NALU
2: AVC end of sequence (lower level NALU sequence ender is not required or supported)
*/
avcPacketType uint8
compositionTime int32
}
type Tag struct {
flvt flvTag
mediat mediaTag
}
func (self *Tag) SoundFormat() uint8 {
return self.mediat.soundFormat
}
func (self *Tag) AACPacketType() uint8 {
return self.mediat.aacPacketType
}
func (self *Tag) IsKeyFrame() bool {
return self.mediat.frameType == av.FRAME_KEY
}
func (self *Tag) IsSeq() bool {
return self.mediat.frameType == av.FRAME_KEY &&
self.mediat.avcPacketType == av.AVC_SEQHDR
}
func (self *Tag) CodecID() uint8 {
return self.mediat.codecID
}
func (self *Tag) CompositionTime() int32 {
return self.mediat.compositionTime
}
// ParseMeidaTagHeader, parse video, audio, tag header
func (self *Tag) ParseMeidaTagHeader(b []byte, isVideo bool) (n int, err error) {
switch isVideo {
case false:
n, err = self.parseAudioHeader(b)
case true:
n, err = self.parseVideoHeader(b)
}
return
}
func (self *Tag) parseAudioHeader(b []byte) (n int, err error) {
if len(b) < n+1 {
err = fmt.Errorf("invalid audiodata len=%d", len(b))
return
}
flags := b[0]
self.mediat.soundFormat = flags >> 4
self.mediat.soundRate = (flags >> 2) & 0x3
self.mediat.soundSize = (flags >> 1) & 0x1
self.mediat.soundType = flags & 0x1
n++
switch self.mediat.soundFormat {
case av.SOUND_AAC:
self.mediat.aacPacketType = b[1]
n++
}
return
}
func (self *Tag) parseVideoHeader(b []byte) (n int, err error) {
if len(b) < n+5 {
err = fmt.Errorf("invalid videodata len=%d", len(b))
return
}
flags := b[0]
self.mediat.frameType = flags >> 4
self.mediat.codecID = flags & 0xf
n++
if self.mediat.frameType == av.FRAME_INTER || self.mediat.frameType == av.FRAME_KEY {
self.mediat.avcPacketType = b[1]
for i := 2; i < 5; i++ {
self.mediat.compositionTime = self.mediat.compositionTime<<8 + int32(b[i])
}
n += 4
}
return
}

1
container/mp4/muxer.go

@ -0,0 +1 @@ @@ -0,0 +1 @@
package mp4

78
container/ts/crc32.go

@ -0,0 +1,78 @@ @@ -0,0 +1,78 @@
package ts
func GenCrc32(src []byte) uint32 {
crcTable := []uint32{
0x00000000, 0x04c11db7, 0x09823b6e, 0x0d4326d9,
0x130476dc, 0x17c56b6b, 0x1a864db2, 0x1e475005,
0x2608edb8, 0x22c9f00f, 0x2f8ad6d6, 0x2b4bcb61,
0x350c9b64, 0x31cd86d3, 0x3c8ea00a, 0x384fbdbd,
0x4c11db70, 0x48d0c6c7, 0x4593e01e, 0x4152fda9,
0x5f15adac, 0x5bd4b01b, 0x569796c2, 0x52568b75,
0x6a1936c8, 0x6ed82b7f, 0x639b0da6, 0x675a1011,
0x791d4014, 0x7ddc5da3, 0x709f7b7a, 0x745e66cd,
0x9823b6e0, 0x9ce2ab57, 0x91a18d8e, 0x95609039,
0x8b27c03c, 0x8fe6dd8b, 0x82a5fb52, 0x8664e6e5,
0xbe2b5b58, 0xbaea46ef, 0xb7a96036, 0xb3687d81,
0xad2f2d84, 0xa9ee3033, 0xa4ad16ea, 0xa06c0b5d,
0xd4326d90, 0xd0f37027, 0xddb056fe, 0xd9714b49,
0xc7361b4c, 0xc3f706fb, 0xceb42022, 0xca753d95,
0xf23a8028, 0xf6fb9d9f, 0xfbb8bb46, 0xff79a6f1,
0xe13ef6f4, 0xe5ffeb43, 0xe8bccd9a, 0xec7dd02d,
0x34867077, 0x30476dc0, 0x3d044b19, 0x39c556ae,
0x278206ab, 0x23431b1c, 0x2e003dc5, 0x2ac12072,
0x128e9dcf, 0x164f8078, 0x1b0ca6a1, 0x1fcdbb16,
0x018aeb13, 0x054bf6a4, 0x0808d07d, 0x0cc9cdca,
0x7897ab07, 0x7c56b6b0, 0x71159069, 0x75d48dde,
0x6b93dddb, 0x6f52c06c, 0x6211e6b5, 0x66d0fb02,
0x5e9f46bf, 0x5a5e5b08, 0x571d7dd1, 0x53dc6066,
0x4d9b3063, 0x495a2dd4, 0x44190b0d, 0x40d816ba,
0xaca5c697, 0xa864db20, 0xa527fdf9, 0xa1e6e04e,
0xbfa1b04b, 0xbb60adfc, 0xb6238b25, 0xb2e29692,
0x8aad2b2f, 0x8e6c3698, 0x832f1041, 0x87ee0df6,
0x99a95df3, 0x9d684044, 0x902b669d, 0x94ea7b2a,
0xe0b41de7, 0xe4750050, 0xe9362689, 0xedf73b3e,
0xf3b06b3b, 0xf771768c, 0xfa325055, 0xfef34de2,
0xc6bcf05f, 0xc27dede8, 0xcf3ecb31, 0xcbffd686,
0xd5b88683, 0xd1799b34, 0xdc3abded, 0xd8fba05a,
0x690ce0ee, 0x6dcdfd59, 0x608edb80, 0x644fc637,
0x7a089632, 0x7ec98b85, 0x738aad5c, 0x774bb0eb,
0x4f040d56, 0x4bc510e1, 0x46863638, 0x42472b8f,
0x5c007b8a, 0x58c1663d, 0x558240e4, 0x51435d53,
0x251d3b9e, 0x21dc2629, 0x2c9f00f0, 0x285e1d47,
0x36194d42, 0x32d850f5, 0x3f9b762c, 0x3b5a6b9b,
0x0315d626, 0x07d4cb91, 0x0a97ed48, 0x0e56f0ff,
0x1011a0fa, 0x14d0bd4d, 0x19939b94, 0x1d528623,
0xf12f560e, 0xf5ee4bb9, 0xf8ad6d60, 0xfc6c70d7,
0xe22b20d2, 0xe6ea3d65, 0xeba91bbc, 0xef68060b,
0xd727bbb6, 0xd3e6a601, 0xdea580d8, 0xda649d6f,
0xc423cd6a, 0xc0e2d0dd, 0xcda1f604, 0xc960ebb3,
0xbd3e8d7e, 0xb9ff90c9, 0xb4bcb610, 0xb07daba7,
0xae3afba2, 0xaafbe615, 0xa7b8c0cc, 0xa379dd7b,
0x9b3660c6, 0x9ff77d71, 0x92b45ba8, 0x9675461f,
0x8832161a, 0x8cf30bad, 0x81b02d74, 0x857130c3,
0x5d8a9099, 0x594b8d2e, 0x5408abf7, 0x50c9b640,
0x4e8ee645, 0x4a4ffbf2, 0x470cdd2b, 0x43cdc09c,
0x7b827d21, 0x7f436096, 0x7200464f, 0x76c15bf8,
0x68860bfd, 0x6c47164a, 0x61043093, 0x65c52d24,
0x119b4be9, 0x155a565e, 0x18197087, 0x1cd86d30,
0x029f3d35, 0x065e2082, 0x0b1d065b, 0x0fdc1bec,
0x3793a651, 0x3352bbe6, 0x3e119d3f, 0x3ad08088,
0x2497d08d, 0x2056cd3a, 0x2d15ebe3, 0x29d4f654,
0xc5a92679, 0xc1683bce, 0xcc2b1d17, 0xc8ea00a0,
0xd6ad50a5, 0xd26c4d12, 0xdf2f6bcb, 0xdbee767c,
0xe3a1cbc1, 0xe760d676, 0xea23f0af, 0xeee2ed18,
0xf0a5bd1d, 0xf464a0aa, 0xf9278673, 0xfde69bc4,
0x89b8fd09, 0x8d79e0be, 0x803ac667, 0x84fbdbd0,
0x9abc8bd5, 0x9e7d9662, 0x933eb0bb, 0x97ffad0c,
0xafb010b1, 0xab710d06, 0xa6322bdf, 0xa2f33668,
0xbcb4666d, 0xb8757bda, 0xb5365d03, 0xb1f740b4}
j := byte(0)
crc32 := uint32(0xFFFFFFFF)
for i := 0; i < len(src); i++ {
j = ((byte(crc32>>24) ^ src[i]) & 0xff)
crc32 = uint32(uint32(crc32<<8) ^ uint32(crcTable[j]))
}
return crc32
}

365
container/ts/muxer.go

@ -0,0 +1,365 @@ @@ -0,0 +1,365 @@
package ts
import (
"io"
"github.com/gwuhaolin/livego/av"
)
const (
tsDefaultDataLen = 184
tsPacketLen = 188
h264DefaultHZ = 90
)
const (
videoPID = 0x100
audioPID = 0x101
videoSID = 0xe0
audioSID = 0xc0
)
type Muxer struct {
videoCc byte
audioCc byte
patCc byte
pmtCc byte
pat [tsPacketLen]byte
pmt [tsPacketLen]byte
tsPacket [tsPacketLen]byte
}
func NewMuxer() *Muxer {
return &Muxer{}
}
func (self *Muxer) Mux(p *av.Packet, w io.Writer) error {
first := true
wBytes := 0
pesIndex := 0
tmpLen := byte(0)
dataLen := byte(0)
var pes pesHeader
dts := int64(p.TimeStamp) * int64(h264DefaultHZ)
pts := dts
pid := audioPID
var videoH av.VideoPacketHeader
if p.IsVideo {
pid = videoPID
videoH, _ = p.Header.(av.VideoPacketHeader)
pts = dts + int64(videoH.CompositionTime())*int64(h264DefaultHZ)
}
err := pes.packet(p, pts, dts)
if err != nil {
return err
}
pesHeaderLen := pes.len
packetBytesLen := len(p.Data) + int(pesHeaderLen)
for {
if packetBytesLen <= 0 {
break
}
if p.IsVideo {
self.videoCc++
if self.videoCc > 0xf {
self.videoCc = 0
}
} else {
self.audioCc++
if self.audioCc > 0xf {
self.audioCc = 0
}
}
i := byte(0)
//sync byte
self.tsPacket[i] = 0x47
i++
//error indicator, unit start indicator,ts priority,pid
self.tsPacket[i] = byte(pid >> 8) //pid high 5 bits
if first {
self.tsPacket[i] = self.tsPacket[i] | 0x40 //unit start indicator
}
i++
//pid low 8 bits
self.tsPacket[i] = byte(pid)
i++
//scram control, adaptation control, counter
if p.IsVideo {
self.tsPacket[i] = 0x10 | byte(self.videoCc&0x0f)
} else {
self.tsPacket[i] = 0x10 | byte(self.audioCc&0x0f)
}
i++
//关键帧需要加pcr
if first && p.IsVideo && videoH.IsKeyFrame() {
self.tsPacket[3] |= 0x20
self.tsPacket[i] = 7
i++
self.tsPacket[i] = 0x50
i++
self.writePcr(self.tsPacket[0:], i, dts)
i += 6
}
//frame data
if packetBytesLen >= tsDefaultDataLen {
dataLen = tsDefaultDataLen
if first {
dataLen -= (i - 4)
}
} else {
self.tsPacket[3] |= 0x20 //have adaptation
remainBytes := byte(0)
dataLen = byte(packetBytesLen)
if first {
remainBytes = tsDefaultDataLen - dataLen - (i - 4)
} else {
remainBytes = tsDefaultDataLen - dataLen
}
self.adaptationBufInit(self.tsPacket[i:], byte(remainBytes))
i += remainBytes
}
if first && i < tsPacketLen && pesHeaderLen > 0 {
tmpLen = tsPacketLen - i
if pesHeaderLen <= tmpLen {
tmpLen = pesHeaderLen
}
copy(self.tsPacket[i:], pes.data[pesIndex:pesIndex+int(tmpLen)])
i += tmpLen
packetBytesLen -= int(tmpLen)
dataLen -= tmpLen
pesHeaderLen -= tmpLen
pesIndex += int(tmpLen)
}
if i < tsPacketLen {
tmpLen = tsPacketLen - i
if tmpLen <= dataLen {
dataLen = tmpLen
}
copy(self.tsPacket[i:], p.Data[wBytes:wBytes+int(dataLen)])
wBytes += int(dataLen)
packetBytesLen -= int(dataLen)
}
if w != nil {
if _, err := w.Write(self.tsPacket[0:]); err != nil {
return err
}
}
first = false
}
return nil
}
//PAT return pat data
func (self *Muxer) PAT() []byte {
i := 0
remainByte := 0
tsHeader := []byte{0x47, 0x40, 0x00, 0x10, 0x00}
patHeader := []byte{0x00, 0xb0, 0x0d, 0x00, 0x01, 0xc1, 0x00, 0x00, 0x00, 0x01, 0xf0, 0x01}
if self.patCc > 0xf {
self.patCc = 0
}
tsHeader[3] |= self.patCc & 0x0f
self.patCc++
copy(self.pat[i:], tsHeader)
i += len(tsHeader)
copy(self.pat[i:], patHeader)
i += len(patHeader)
crc32Value := GenCrc32(patHeader)
self.pat[i] = byte(crc32Value >> 24)
i++
self.pat[i] = byte(crc32Value >> 16)
i++
self.pat[i] = byte(crc32Value >> 8)
i++
self.pat[i] = byte(crc32Value)
i++
remainByte = int(tsPacketLen - i)
for j := 0; j < remainByte; j++ {
self.pat[i+j] = 0xff
}
return self.pat[0:]
}
// PMT return pmt data
func (self *Muxer) PMT(soundFormat byte, hasVideo bool) []byte {
i := int(0)
j := int(0)
var progInfo []byte
remainBytes := int(0)
tsHeader := []byte{0x47, 0x50, 0x01, 0x10, 0x00}
pmtHeader := []byte{0x02, 0xb0, 0xff, 0x00, 0x01, 0xc1, 0x00, 0x00, 0xe1, 0x00, 0xf0, 0x00}
if !hasVideo {
pmtHeader[9] = 0x01
progInfo = []byte{0x0f, 0xe1, 0x01, 0xf0, 0x00}
} else {
progInfo = []byte{0x1b, 0xe1, 0x00, 0xf0, 0x00, //h264 or h265*
0x0f, 0xe1, 0x01, 0xf0, 0x00, //mp3 or aac
}
}
pmtHeader[2] = byte(len(progInfo) + 9 + 4)
if self.pmtCc > 0xf {
self.pmtCc = 0
}
tsHeader[3] |= self.pmtCc & 0x0f
self.pmtCc++
if soundFormat == 2 ||
soundFormat == 14 {
if hasVideo {
progInfo[5] = 0x4
} else {
progInfo[0] = 0x4
}
}
copy(self.pmt[i:], tsHeader)
i += len(tsHeader)
copy(self.pmt[i:], pmtHeader)
i += len(pmtHeader)
copy(self.pmt[i:], progInfo[0:])
i += len(progInfo)
crc32Value := GenCrc32(self.pmt[5: 5+len(pmtHeader)+len(progInfo)])
self.pmt[i] = byte(crc32Value >> 24)
i++
self.pmt[i] = byte(crc32Value >> 16)
i++
self.pmt[i] = byte(crc32Value >> 8)
i++
self.pmt[i] = byte(crc32Value)
i++
remainBytes = int(tsPacketLen - i)
for j = 0; j < remainBytes; j++ {
self.pmt[i+j] = 0xff
}
return self.pmt[0:]
}
func (self *Muxer) adaptationBufInit(src []byte, remainBytes byte) {
src[0] = byte(remainBytes - 1)
if remainBytes == 1 {
} else {
src[1] = 0x00
for i := 2; i < len(src); i++ {
src[i] = 0xff
}
}
return
}
func (self *Muxer) writePcr(b []byte, i byte, pcr int64) error {
b[i] = byte(pcr >> 25)
i++
b[i] = byte((pcr >> 17) & 0xff)
i++
b[i] = byte((pcr >> 9) & 0xff)
i++
b[i] = byte((pcr >> 1) & 0xff)
i++
b[i] = byte(((pcr & 0x1) << 7) | 0x7e)
i++
b[i] = 0x00
return nil
}
type pesHeader struct {
len byte
data [tsPacketLen]byte
}
//pesPacket return pes packet
func (self *pesHeader) packet(p *av.Packet, pts, dts int64) error {
//PES header
i := 0
self.data[i] = 0x00
i++
self.data[i] = 0x00
i++
self.data[i] = 0x01
i++
sid := audioSID
if p.IsVideo {
sid = videoSID
}
self.data[i] = byte(sid)
i++
flag := 0x80
ptslen := 5
dtslen := ptslen
headerSize := ptslen
if p.IsVideo && pts != dts {
flag |= 0x40
headerSize += 5 //add dts
}
size := len(p.Data) + headerSize + 3
if size > 0xffff {
size = 0
}
self.data[i] = byte(size >> 8)
i++
self.data[i] = byte(size)
i++
self.data[i] = 0x80
i++
self.data[i] = byte(flag)
i++
self.data[i] = byte(headerSize)
i++
self.writeTs(self.data[0:], i, flag>>6, pts)
i += ptslen
if p.IsVideo && pts != dts {
self.writeTs(self.data[0:], i, 1, dts)
i += dtslen
}
self.len = byte(i)
return nil
}
func (self *pesHeader) writeTs(src []byte, i int, fb int, ts int64) {
val := uint32(0)
if ts > 0x1ffffffff {
ts -= 0x1ffffffff
}
val = uint32(fb<<4) | ((uint32(ts>>30) & 0x07) << 1) | 1
src[i] = byte(val)
i++
val = ((uint32(ts>>15) & 0x7fff) << 1) | 1
src[i] = byte(val >> 8)
i++
src[i] = byte(val)
i++
val = (uint32(ts&0x7fff) << 1) | 1
src[i] = byte(val >> 8)
i++
src[i] = byte(val)
}

51
container/ts/muxer_test.go

@ -0,0 +1,51 @@ @@ -0,0 +1,51 @@
package ts
import (
"testing"
"github.com/gwuhaolin/livego/av"
"github.com/stretchr/testify/assert"
)
type TestWriter struct {
buf []byte
count int
}
//Write write p to w.buf
func (w *TestWriter) Write(p []byte) (int, error) {
w.count++
w.buf = p
return len(p), nil
}
func TestTSEncoder(t *testing.T) {
at := assert.New(t)
m := NewMuxer()
w := &TestWriter{}
data := []byte{0xaf, 0x01, 0x21, 0x19, 0xd3, 0x40, 0x7d, 0x0b, 0x6d, 0x44, 0xae, 0x81,
0x08, 0x00, 0x89, 0xa0, 0x3e, 0x85, 0xb6, 0x92, 0x57, 0x04, 0x80, 0x00, 0x5b, 0xb7,
0x78, 0x00, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x38, 0x30, 0x00, 0x06, 0x00, 0x38,
}
p := av.Packet{
IsVideo: false,
Data: data,
}
err := m.Mux(&p, w)
at.Equal(err, nil)
at.Equal(w.count, 1)
at.Equal(w.buf, []byte{0x47, 0x41, 0x01, 0x31, 0x81, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x01, 0xc0, 0x00, 0x30,
0x80, 0x80, 0x05, 0x21, 0x00, 0x01, 0x00, 0x01, 0xaf, 0x01, 0x21, 0x19, 0xd3, 0x40, 0x7d,
0x0b, 0x6d, 0x44, 0xae, 0x81, 0x08, 0x00, 0x89, 0xa0, 0x3e, 0x85, 0xb6, 0x92, 0x57, 0x04,
0x80, 0x00, 0x5b, 0xb7, 0x78, 0x00, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x38, 0x30, 0x00,
0x06, 0x00, 0x38})
}

172
main.go

@ -0,0 +1,172 @@ @@ -0,0 +1,172 @@
package main
import (
"flag"
"net"
"os"
"os/signal"
"syscall"
"time"
"github.com/gwuhaolin/livego/protocol/rtmp"
"github.com/gwuhaolin/livego/protocol/hls"
"github.com/gwuhaolin/livego/protocol/httpflv"
"github.com/gwuhaolin/livego/protocol/httpopera"
"path/filepath"
"strings"
"io/ioutil"
"strconv"
"log"
)
var (
rtmpAddr = flag.String("rtmpAddr", ":1935", "The rtmp server address to bind.")
flvAddr = flag.String("flvAddr", ":8081", "the http-flv server address to bind.")
hlsAddr = flag.String("hlsAddr", ":8080", "the hls server address to bind.")
operaAddr = flag.String("operaAddr", "", "the http operation or config address to bind: 8082.")
CurDir string // save pid
)
func getParentDirectory(dirctory string) string {
return substr(dirctory, 0, strings.LastIndex(dirctory, "/"))
}
func getCurrentDirectory() string {
dir, err := filepath.Abs(filepath.Dir(os.Args[0]))
if err != nil {
log.Fatal(err)
}
return strings.Replace(dir, "\\", "/", -1)
}
func substr(s string, pos, length int) string {
runes := []rune(s)
l := pos + length
if l > len(runes) {
l = len(runes)
}
return string(runes[pos:l])
}
func SavePid() error {
pidFilename := CurDir + "/pid/" + filepath.Base(os.Args[0]) + ".pid"
pid := os.Getpid()
return ioutil.WriteFile(pidFilename, []byte(strconv.Itoa(pid)), 0755)
}
func init() {
CurDir = getParentDirectory(getCurrentDirectory())
flag.Usage = func() {
flag.PrintDefaults()
}
flag.Parse()
}
func catchSignal() {
sig := make(chan os.Signal)
signal.Notify(sig, syscall.SIGSTOP, syscall.SIGTERM)
<-sig
log.Println("recieved signal!")
}
func startHls() *hls.Server {
hlsListen, err := net.Listen("tcp", *hlsAddr)
if err != nil {
log.Fatal(err)
}
hlsServer := hls.NewServer()
go func() {
defer func() {
if r := recover(); r != nil {
log.Println("hls server panic: ", r)
}
}()
hlsServer.Serve(hlsListen)
}()
return hlsServer
}
func startRtmp(stream *rtmp.RtmpStream, hlsServer *hls.Server) {
rtmplisten, err := net.Listen("tcp", *rtmpAddr)
if err != nil {
log.Fatal(err)
}
rtmpServer := rtmp.NewRtmpServer(stream, hlsServer)
go func() {
defer func() {
if r := recover(); r != nil {
log.Println("hls server panic: ", r)
}
}()
rtmpServer.Serve(rtmplisten)
}()
}
func startHTTPFlv(stream *rtmp.RtmpStream) {
flvListen, err := net.Listen("tcp", *flvAddr)
if err != nil {
log.Fatal(err)
}
hdlServer := httpflv.NewServer(stream)
go func() {
defer func() {
if r := recover(); r != nil {
log.Println("hls server panic: ", r)
}
}()
hdlServer.Serve(flvListen)
}()
}
func startHTTPOpera(stream *rtmp.RtmpStream) {
if *operaAddr != "" {
opListen, err := net.Listen("tcp", *operaAddr)
if err != nil {
log.Fatal(err)
}
opServer := httpopera.NewServer(stream)
go func() {
defer func() {
if r := recover(); r != nil {
log.Println("hls server panic: ", r)
}
}()
opServer.Serve(opListen)
}()
}
}
func startLog() {
log.Println("RTMP Listen On", *rtmpAddr)
log.Println("HLS Listen On", *hlsAddr)
log.Println("HTTP-FLV Listen On", *flvAddr)
if *operaAddr != "" {
log.Println("HTTP-Operation Listen On", *operaAddr)
}
SavePid()
}
func main() {
defer func() {
if r := recover(); r != nil {
log.Println("main panic: ", r)
time.Sleep(1 * time.Second)
}
}()
stream := rtmp.NewRtmpStream()
// hls
h := startHls()
// rtmp
startRtmp(stream, h)
// http-flv
startHTTPFlv(stream)
// http-opera
startHTTPOpera(stream)
// my log
startLog()
// block
catchSignal()
}

113
parser/aac/parser.go

@ -0,0 +1,113 @@ @@ -0,0 +1,113 @@
package aac
import (
"errors"
"io"
"github.com/gwuhaolin/livego/av"
)
type mpegExtension struct {
objectType byte
sampleRate byte
}
type mpegCfgInfo struct {
objectType byte
sampleRate byte
channel byte
sbr byte
ps byte
frameLen byte
exceptionLogTs int64
extension *mpegExtension
}
var aacRates = []int{96000, 88200, 64000, 48000, 44100, 32000, 24000, 22050, 16000, 12000, 11025, 8000, 7350}
var (
specificBufInvalid = errors.New("audio mpegspecific error")
audioBufInvalid = errors.New("audiodata invalid")
)
const (
adtsHeaderLen = 7
)
type Parser struct {
gettedSpecific bool
adtsHeader []byte
cfgInfo *mpegCfgInfo
}
func NewParser() *Parser {
return &Parser{
gettedSpecific: false,
cfgInfo: &mpegCfgInfo{},
adtsHeader: make([]byte, adtsHeaderLen),
}
}
func (self *Parser) specificInfo(src []byte) error {
if len(src) < 2 {
return specificBufInvalid
}
self.gettedSpecific = true
self.cfgInfo.objectType = (src[0] >> 3) & 0xff
self.cfgInfo.sampleRate = ((src[0] & 0x07) << 1) | src[1]>>7
self.cfgInfo.channel = (src[1] >> 3) & 0x0f
return nil
}
func (self *Parser) adts(src []byte, w io.Writer) error {
if len(src) <= 0 || !self.gettedSpecific {
return audioBufInvalid
}
frameLen := uint16(len(src)) + 7
//first write adts header
self.adtsHeader[0] = 0xff
self.adtsHeader[1] = 0xf1
self.adtsHeader[2] &= 0x00
self.adtsHeader[2] = self.adtsHeader[2] | (self.cfgInfo.objectType-1)<<6
self.adtsHeader[2] = self.adtsHeader[2] | (self.cfgInfo.sampleRate)<<2
self.adtsHeader[3] &= 0x00
self.adtsHeader[3] = self.adtsHeader[3] | (self.cfgInfo.channel<<2)<<4
self.adtsHeader[3] = self.adtsHeader[3] | byte((frameLen<<3)>>14)
self.adtsHeader[4] &= 0x00
self.adtsHeader[4] = self.adtsHeader[4] | byte((frameLen<<5)>>8)
self.adtsHeader[5] &= 0x00
self.adtsHeader[5] = self.adtsHeader[5] | byte(((frameLen<<13)>>13)<<5)
self.adtsHeader[5] = self.adtsHeader[5] | (0x7C<<1)>>3
self.adtsHeader[6] = 0xfc
if _, err := w.Write(self.adtsHeader[0:]); err != nil {
return err
}
if _, err := w.Write(src); err != nil {
return err
}
return nil
}
func (self *Parser) SampleRate() int {
rate := 44100
if self.cfgInfo.sampleRate <= byte(len(aacRates)-1) {
rate = aacRates[self.cfgInfo.sampleRate]
}
return rate
}
func (self *Parser) Parse(b []byte, packetType uint8, w io.Writer) (err error) {
switch packetType {
case av.AAC_SEQHDR:
err = self.specificInfo(b)
case av.AAC_RAW:
err = self.adts(b, w)
}
return
}

232
parser/h264/parser.go

@ -0,0 +1,232 @@ @@ -0,0 +1,232 @@
package h264
import (
"bytes"
"errors"
"io"
)
const (
i_frame byte = 0
p_frame byte = 1
b_frame byte = 2
)
const (
nalu_type_not_define byte = 0
nalu_type_slice byte = 1 //slice_layer_without_partioning_rbsp() sliceheader
nalu_type_dpa byte = 2 // slice_data_partition_a_layer_rbsp( ), slice_header
nalu_type_dpb byte = 3 // slice_data_partition_b_layer_rbsp( )
nalu_type_dpc byte = 4 // slice_data_partition_c_layer_rbsp( )
nalu_type_idr byte = 5 // slice_layer_without_partitioning_rbsp( ),sliceheader
nalu_type_sei byte = 6 //sei_rbsp( )
nalu_type_sps byte = 7 //seq_parameter_set_rbsp( )
nalu_type_pps byte = 8 //pic_parameter_set_rbsp( )
nalu_type_aud byte = 9 // access_unit_delimiter_rbsp( )
nalu_type_eoesq byte = 10 //end_of_seq_rbsp( )
nalu_type_eostream byte = 11 //end_of_stream_rbsp( )
nalu_type_filler byte = 12 //filler_data_rbsp( )
)
const (
naluBytesLen int = 4
maxSpsPpsLen int = 2 * 1024
)
var (
decDataNil = errors.New("dec buf is nil")
spsDataError = errors.New("sps data error")
ppsHeaderError = errors.New("pps header error")
ppsDataError = errors.New("pps data error")
naluHeaderInvalid = errors.New("nalu header invalid")
videoDataInvalid = errors.New("video data not match")
dataSizeNotMatch = errors.New("data size not match")
naluBodyLenError = errors.New("nalu body len error")
)
var startCode = []byte{0x00, 0x00, 0x00, 0x01}
var naluAud = []byte{0x00, 0x00, 0x00, 0x01, 0x09, 0xf0}
type Parser struct {
frameType byte
specificInfo []byte
pps *bytes.Buffer
}
type sequenceHeader struct {
configVersion byte //8bits
avcProfileIndication byte //8bits
profileCompatility byte //8bits
avcLevelIndication byte //8bits
reserved1 byte //6bits
naluLen byte //2bits
reserved2 byte //3bits
spsNum byte //5bits
ppsNum byte //8bits
spsLen int
ppsLen int
}
func NewParser() *Parser {
return &Parser{
pps: bytes.NewBuffer(make([]byte, maxSpsPpsLen)),
}
}
//return value 1:sps, value2 :pps
func (self *Parser) parseSpecificInfo(src []byte) error {
if len(src) < 9 {
return decDataNil
}
sps := []byte{}
pps := []byte{}
var seq sequenceHeader
seq.configVersion = src[0]
seq.avcProfileIndication = src[1]
seq.profileCompatility = src[2]
seq.avcLevelIndication = src[3]
seq.reserved1 = src[4] & 0xfc
seq.naluLen = src[4]&0x03 + 1
seq.reserved2 = src[5] >> 5
//get sps
seq.spsNum = src[5] & 0x1f
seq.spsLen = int(src[6])<<8 | int(src[7])
if len(src[8:]) < seq.spsLen || seq.spsLen <= 0 {
return spsDataError
}
sps = append(sps, startCode...)
sps = append(sps, src[8:(8 + seq.spsLen)]...)
//get pps
tmpBuf := src[(8 + seq.spsLen):]
if len(tmpBuf) < 4 {
return ppsHeaderError
}
seq.ppsNum = tmpBuf[0]
seq.ppsLen = int(0)<<16 | int(tmpBuf[1])<<8 | int(tmpBuf[2])
if len(tmpBuf[3:]) < seq.ppsLen || seq.ppsLen <= 0 {
return ppsDataError
}
pps = append(pps, startCode...)
pps = append(pps, tmpBuf[3:]...)
self.specificInfo = append(self.specificInfo, sps...)
self.specificInfo = append(self.specificInfo, pps...)
return nil
}
func (self *Parser) isNaluHeader(src []byte) bool {
if len(src) < naluBytesLen {
return false
}
return src[0] == 0x00 &&
src[1] == 0x00 &&
src[2] == 0x00 &&
src[3] == 0x01
}
func (self *Parser) naluSize(src []byte) (int, error) {
if len(src) < naluBytesLen {
return 0, errors.New("nalusizedata invalid")
}
buf := src[:naluBytesLen]
size := int(0)
for i := 0; i < len(buf); i++ {
size = size<<8 + int(buf[i])
}
return size, nil
}
func (self *Parser) getAnnexbH264(src []byte, w io.Writer) error {
dataSize := len(src)
if dataSize < naluBytesLen {
return videoDataInvalid
}
self.pps.Reset()
_, err := w.Write(naluAud)
if err != nil {
return err
}
index := 0
nalLen := 0
hasSpsPps := false
hasWriteSpsPps := false
for dataSize > 0 {
nalLen, err = self.naluSize(src[index:])
if err != nil {
return dataSizeNotMatch
}
index += naluBytesLen
dataSize -= naluBytesLen
if dataSize >= nalLen && len(src[index:]) >= nalLen && nalLen > 0 {
nalType := src[index] & 0x1f
switch nalType {
case nalu_type_aud:
case nalu_type_idr:
if !hasWriteSpsPps {
hasWriteSpsPps = true
if !hasSpsPps {
if _, err := w.Write(self.specificInfo); err != nil {
return err
}
} else {
if _, err := w.Write(self.pps.Bytes()); err != nil {
return err
}
}
}
fallthrough
case nalu_type_slice:
fallthrough
case nalu_type_sei:
_, err := w.Write(startCode)
if err != nil {
return err
}
_, err = w.Write(src[index: index+nalLen])
if err != nil {
return err
}
case nalu_type_sps:
fallthrough
case nalu_type_pps:
hasSpsPps = true
_, err := self.pps.Write(startCode)
if err != nil {
return err
}
_, err = self.pps.Write(src[index: index+nalLen])
if err != nil {
return err
}
}
index += nalLen
dataSize -= nalLen
} else {
return naluBodyLenError
}
}
return nil
}
func (self *Parser) Parse(b []byte, isSeq bool, w io.Writer) (err error) {
switch isSeq {
case true:
err = self.parseSpecificInfo(b)
case false:
// is annexb
if self.isNaluHeader(b) {
_, err = w.Write(b)
} else {
err = self.getAnnexbH264(b, w)
}
}
return
}

91
parser/h264/parser_test.go

@ -0,0 +1,91 @@ @@ -0,0 +1,91 @@
package h264
import (
"bytes"
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestH264SeqDemux(t *testing.T) {
at := assert.New(t)
seq := []byte{
0x01, 0x4d, 0x00, 0x1e, 0xff, 0xe1, 0x00, 0x17, 0x67, 0x4d, 0x00,
0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28, 0x28, 0x2f,
0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x01, 0x00,
0x04, 0x68, 0xde, 0x31, 0x12,
}
d := NewParser()
w := bytes.NewBuffer(nil)
err := d.Parse(seq, true, w)
at.Equal(err, nil)
at.Equal(d.specificInfo, []byte{0x00, 0x00, 0x00, 0x01, 0x67, 0x4d, 0x00,
0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28, 0x28, 0x2f,
0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x00, 0x00, 0x00, 0x01, 0x68, 0xde, 0x31, 0x12})
}
func TestH264AnnexbDemux(t *testing.T) {
at := assert.New(t)
nalu := []byte{
0x00, 0x00, 0x00, 0x01, 0x67, 0x4d, 0x00, 0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28,
0x28, 0x2f, 0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x00, 0x00, 0x00, 0x01, 0x68,
0xde, 0x31, 0x12, 0x00, 0x00, 0x00, 0x01, 0x65, 0x23,
}
d := NewParser()
w := bytes.NewBuffer(nil)
err := d.Parse(nalu, false, w)
at.Equal(err, nil)
at.Equal(w.Len(), 41)
}
func TestH264NalueSizeException(t *testing.T) {
at := assert.New(t)
nalu := []byte{
0x00, 0x00, 0x10,
}
d := NewParser()
w := bytes.NewBuffer(nil)
err := d.Parse(nalu, false, w)
at.Equal(err, errors.New("video data not match"))
}
func TestH264Mp4Demux(t *testing.T) {
at := assert.New(t)
nalu := []byte{
0x00, 0x00, 0x00, 0x17, 0x67, 0x4d, 0x00, 0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28,
0x28, 0x2f, 0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x00, 0x00, 0x00, 0x04, 0x68,
0xde, 0x31, 0x12, 0x00, 0x00, 0x00, 0x02, 0x65, 0x23,
}
d := NewParser()
w := bytes.NewBuffer(nil)
err := d.Parse(nalu, false, w)
at.Equal(err, nil)
at.Equal(w.Len(), 47)
at.Equal(w.Bytes(), []byte{0x00, 0x00, 0x00, 0x01, 0x09, 0xf0, 0x00, 0x00, 0x00, 0x01, 0x67, 0x4d, 0x00, 0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28,
0x28, 0x2f, 0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x00, 0x00, 0x00, 0x01, 0x68,
0xde, 0x31, 0x12, 0x00, 0x00, 0x00, 0x01, 0x65, 0x23})
}
func TestH264Mp4DemuxException1(t *testing.T) {
at := assert.New(t)
nalu := []byte{
0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00,
}
d := NewParser()
w := bytes.NewBuffer(nil)
err := d.Parse(nalu, false, w)
at.Equal(err, naluBodyLenError)
}
func TestH264Mp4DemuxException2(t *testing.T) {
at := assert.New(t)
nalu := []byte{
0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x17, 0x67, 0x4d, 0x00, 0x1e, 0xab, 0x40, 0x5a, 0x12, 0x6c, 0x09, 0x28, 0x28,
0x28, 0x2f, 0x80, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x61, 0xa8, 0x4a, 0x00, 0x00, 0x00,
}
d := NewParser()
w := bytes.NewBuffer(nil)
err := d.Parse(nalu, false, w)
at.Equal(err, naluBodyLenError)
}

41
parser/mp3/parser.go

@ -0,0 +1,41 @@ @@ -0,0 +1,41 @@
package mp3
import "errors"
type Parser struct {
samplingFrequency int
}
func NewParser() *Parser {
return &Parser{}
}
// sampling_frequency - indicates the sampling frequency, according to the following table.
// '00' 44.1 kHz
// '01' 48 kHz
// '10' 32 kHz
// '11' reserved
var mp3Rates = []int{44100, 48000, 32000}
var (
errMp3DataInvalid = errors.New("mp3data invalid")
errIndexInvalid = errors.New("invalid rate index")
)
func (self *Parser) Parse(src []byte) error {
if len(src) < 3 {
return errMp3DataInvalid
}
index := (src[2] >> 2) & 0x3
if index <= byte(len(mp3Rates)-1) {
self.samplingFrequency = mp3Rates[index]
return nil
}
return errIndexInvalid
}
func (self *Parser) SampleRate() int {
if self.samplingFrequency == 0 {
self.samplingFrequency = 44100
}
return self.samplingFrequency
}

68
parser/parser.go

@ -0,0 +1,68 @@ @@ -0,0 +1,68 @@
package parser
import (
"errors"
"io"
"github.com/gwuhaolin/livego/parser/mp3"
"github.com/gwuhaolin/livego/parser/aac"
"github.com/gwuhaolin/livego/av"
"github.com/gwuhaolin/livego/parser/h264"
)
var (
errNoAudio = errors.New("demuxer no audio")
)
type CodecParser struct {
aac *aac.Parser
mp3 *mp3.Parser
h264 *h264.Parser
}
func NewCodecParser() *CodecParser {
return &CodecParser{}
}
func (self *CodecParser) SampleRate() (int, error) {
if self.aac == nil && self.mp3 == nil {
return 0, errNoAudio
}
if self.aac != nil {
return self.aac.SampleRate(), nil
}
return self.mp3.SampleRate(), nil
}
func (self *CodecParser) Parse(p *av.Packet, w io.Writer) (err error) {
switch p.IsVideo {
case true:
f, ok := p.Header.(av.VideoPacketHeader)
if ok {
if f.CodecID() == av.VIDEO_H264 {
if self.h264 == nil {
self.h264 = h264.NewParser()
}
err = self.h264.Parse(p.Data, f.IsSeq(), w)
}
}
case false:
f, ok := p.Header.(av.AudioPacketHeader)
if ok {
switch f.SoundFormat() {
case av.SOUND_AAC:
if self.aac == nil {
self.aac = aac.NewParser()
}
err = self.aac.Parse(p.Data, f.AACPacketType(), w)
case av.SOUND_MP3:
if self.mp3 == nil {
self.mp3 = mp3.NewParser()
}
err = self.mp3.Parse(p.Data)
}
}
}
return
}

50
protocol/amf/amf.go

@ -0,0 +1,50 @@ @@ -0,0 +1,50 @@
package amf
import (
"errors"
"fmt"
"io"
)
func (d *Decoder) DecodeBatch(r io.Reader, ver Version) (ret []interface{}, err error) {
var v interface{}
for {
v, err = d.Decode(r, ver)
if err != nil {
break
}
ret = append(ret, v)
}
return
}
func (d *Decoder) Decode(r io.Reader, ver Version) (interface{}, error) {
switch ver {
case 0:
return d.DecodeAmf0(r)
case 3:
return d.DecodeAmf3(r)
}
return nil, errors.New(fmt.Sprintf("decode amf: unsupported version %d", ver))
}
func (e *Encoder) EncodeBatch(w io.Writer, ver Version, val ...interface{}) (int, error) {
for _, v := range val {
if _, err := e.Encode(w, v, ver); err != nil {
return 0, err
}
}
return 0, nil
}
func (e *Encoder) Encode(w io.Writer, val interface{}, ver Version) (int, error) {
switch ver {
case AMF0:
return e.EncodeAmf0(w, val)
case AMF3:
return e.EncodeAmf3(w, val)
}
return 0, Error("encode amf: unsupported version %d", ver)
}

206
protocol/amf/amf_test.go

@ -0,0 +1,206 @@ @@ -0,0 +1,206 @@
package amf
import (
"bytes"
"errors"
"fmt"
"reflect"
"testing"
"time"
)
func EncodeAndDecode(val interface{}, ver Version) (result interface{}, err error) {
enc := new(Encoder)
dec := new(Decoder)
buf := new(bytes.Buffer)
_, err = enc.Encode(buf, val, ver)
if err != nil {
return nil, errors.New(fmt.Sprintf("error in encode: %s", err))
}
result, err = dec.Decode(buf, ver)
if err != nil {
return nil, errors.New(fmt.Sprintf("error in decode: %s", err))
}
return
}
func Compare(val interface{}, ver Version, name string, t *testing.T) {
result, err := EncodeAndDecode(val, ver)
if err != nil {
t.Errorf("%s: %s", name, err)
}
if !reflect.DeepEqual(val, result) {
val_v := reflect.ValueOf(val)
result_v := reflect.ValueOf(result)
t.Errorf("%s: comparison failed between %+v (%s) and %+v (%s)", name, val, val_v.Type(), result, result_v.Type())
Dump("expected", val)
Dump("got", result)
}
// if val != result {
// t.Errorf("%s: comparison failed between %+v and %+v", name, val, result)
// }
}
func TestAmf0Number(t *testing.T) {
Compare(float64(3.14159), 0, "amf0 number float", t)
Compare(float64(124567890), 0, "amf0 number high", t)
Compare(float64(-34.2), 0, "amf0 number negative", t)
}
func TestAmf0String(t *testing.T) {
Compare("a pup!", 0, "amf0 string simple", t)
Compare("日本語", 0, "amf0 string utf8", t)
}
func TestAmf0Boolean(t *testing.T) {
Compare(true, 0, "amf0 boolean true", t)
Compare(false, 0, "amf0 boolean false", t)
}
func TestAmf0Null(t *testing.T) {
Compare(nil, 0, "amf0 boolean nil", t)
}
func TestAmf0Object(t *testing.T) {
obj := make(Object)
obj["dog"] = "alfie"
obj["coffee"] = true
obj["drugs"] = false
obj["pi"] = 3.14159
res, err := EncodeAndDecode(obj, 0)
if err != nil {
t.Errorf("amf0 object: %s", err)
}
result, ok := res.(Object)
if ok != true {
t.Errorf("amf0 object conversion failed")
}
if result["dog"] != "alfie" {
t.Errorf("amf0 object string: comparison failed")
}
if result["coffee"] != true {
t.Errorf("amf0 object true: comparison failed")
}
if result["drugs"] != false {
t.Errorf("amf0 object false: comparison failed")
}
if result["pi"] != float64(3.14159) {
t.Errorf("amf0 object float: comparison failed")
}
}
func TestAmf0Array(t *testing.T) {
arr := [5]float64{1, 2, 3, 4, 5}
res, err := EncodeAndDecode(arr, 0)
if err != nil {
t.Error("amf0 object: %s", err)
}
result, ok := res.(Array)
if ok != true {
t.Errorf("amf0 array conversion failed")
}
for i := 0; i < len(arr); i++ {
if arr[i] != result[i] {
t.Errorf("amf0 array %d comparison failed: %v / %v", i, arr[i], result[i])
}
}
}
func TestAmf3Integer(t *testing.T) {
Compare(int32(0), 3, "amf3 integer zero", t)
Compare(int32(1245), 3, "amf3 integer low", t)
Compare(int32(123456), 3, "amf3 integer high", t)
}
func TestAmf3Double(t *testing.T) {
Compare(float64(3.14159), 3, "amf3 double float", t)
Compare(float64(1234567890), 3, "amf3 double high", t)
Compare(float64(-12345), 3, "amf3 double negative", t)
}
func TestAmf3String(t *testing.T) {
Compare("a pup!", 0, "amf0 string simple", t)
Compare("日本語", 0, "amf0 string utf8", t)
}
func TestAmf3Boolean(t *testing.T) {
Compare(true, 3, "amf3 boolean true", t)
Compare(false, 3, "amf3 boolean false", t)
}
func TestAmf3Null(t *testing.T) {
Compare(nil, 3, "amf3 boolean nil", t)
}
func TestAmf3Date(t *testing.T) {
t1 := time.Unix(time.Now().Unix(), 0).UTC() // nanoseconds discarded
t2 := time.Date(1983, 9, 4, 12, 4, 8, 0, time.UTC)
Compare(t1, 3, "amf3 date now", t)
Compare(t2, 3, "amf3 date earlier", t)
}
func TestAmf3Array(t *testing.T) {
obj := make(Object)
obj["key"] = "val"
var arr Array
arr = append(arr, "amf")
arr = append(arr, float64(2))
arr = append(arr, -34.95)
arr = append(arr, true)
arr = append(arr, false)
res, err := EncodeAndDecode(arr, 3)
if err != nil {
t.Error("amf3 object: %s", err)
}
result, ok := res.(Array)
if ok != true {
t.Errorf("amf3 array conversion failed: %+v", res)
}
for i := 0; i < len(arr); i++ {
if arr[i] != result[i] {
t.Errorf("amf3 array %d comparison failed: %v / %v", i, arr[i], result[i])
}
}
}
func TestAmf3ByteArray(t *testing.T) {
enc := new(Encoder)
dec := new(Decoder)
buf := new(bytes.Buffer)
expect := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x00}
enc.EncodeAmf3ByteArray(buf, expect, true)
result, err := dec.DecodeAmf3ByteArray(buf, true)
if err != nil {
t.Errorf("err: %s", err)
}
if bytes.Compare(result, expect) != 0 {
t.Errorf("expected: %+v, got %+v", expect, buf)
}
}

105
protocol/amf/const.go

@ -0,0 +1,105 @@ @@ -0,0 +1,105 @@
package amf
import (
"io"
)
const (
AMF0 = 0x00
AMF3 = 0x03
)
const (
AMF0_NUMBER_MARKER = 0x00
AMF0_BOOLEAN_MARKER = 0x01
AMF0_STRING_MARKER = 0x02
AMF0_OBJECT_MARKER = 0x03
AMF0_MOVIECLIP_MARKER = 0x04
AMF0_NULL_MARKER = 0x05
AMF0_UNDEFINED_MARKER = 0x06
AMF0_REFERENCE_MARKER = 0x07
AMF0_ECMA_ARRAY_MARKER = 0x08
AMF0_OBJECT_END_MARKER = 0x09
AMF0_STRICT_ARRAY_MARKER = 0x0a
AMF0_DATE_MARKER = 0x0b
AMF0_LONG_STRING_MARKER = 0x0c
AMF0_UNSUPPORTED_MARKER = 0x0d
AMF0_RECORDSET_MARKER = 0x0e
AMF0_XML_DOCUMENT_MARKER = 0x0f
AMF0_TYPED_OBJECT_MARKER = 0x10
AMF0_ACMPLUS_OBJECT_MARKER = 0x11
)
const (
AMF0_BOOLEAN_FALSE = 0x00
AMF0_BOOLEAN_TRUE = 0x01
AMF0_STRING_MAX = 65535
AMF3_INTEGER_MAX = 536870911
)
const (
AMF3_UNDEFINED_MARKER = 0x00
AMF3_NULL_MARKER = 0x01
AMF3_FALSE_MARKER = 0x02
AMF3_TRUE_MARKER = 0x03
AMF3_INTEGER_MARKER = 0x04
AMF3_DOUBLE_MARKER = 0x05
AMF3_STRING_MARKER = 0x06
AMF3_XMLDOC_MARKER = 0x07
AMF3_DATE_MARKER = 0x08
AMF3_ARRAY_MARKER = 0x09
AMF3_OBJECT_MARKER = 0x0a
AMF3_XMLSTRING_MARKER = 0x0b
AMF3_BYTEARRAY_MARKER = 0x0c
)
type ExternalHandler func(*Decoder, io.Reader) (interface{}, error)
type Decoder struct {
refCache []interface{}
stringRefs []string
objectRefs []interface{}
traitRefs []Trait
externalHandlers map[string]ExternalHandler
}
func NewDecoder() *Decoder {
return &Decoder{
externalHandlers: make(map[string]ExternalHandler),
}
}
func (d *Decoder) RegisterExternalHandler(name string, f ExternalHandler) {
d.externalHandlers[name] = f
}
type Encoder struct {
}
type Version uint8
type Array []interface{}
type Object map[string]interface{}
type TypedObject struct {
Type string
Object Object
}
type Trait struct {
Type string
Externalizable bool
Dynamic bool
Properties []string
}
func NewTrait() *Trait {
return &Trait{}
}
func NewTypedObject() *TypedObject {
return &TypedObject{
Type: "",
Object: make(Object),
}
}

335
protocol/amf/decoder_amf0.go

@ -0,0 +1,335 @@ @@ -0,0 +1,335 @@
package amf
import (
"encoding/binary"
"io"
)
// amf0 polymorphic router
func (d *Decoder) DecodeAmf0(r io.Reader) (interface{}, error) {
marker, err := ReadMarker(r)
if err != nil {
return nil, err
}
switch marker {
case AMF0_NUMBER_MARKER:
return d.DecodeAmf0Number(r, false)
case AMF0_BOOLEAN_MARKER:
return d.DecodeAmf0Boolean(r, false)
case AMF0_STRING_MARKER:
return d.DecodeAmf0String(r, false)
case AMF0_OBJECT_MARKER:
return d.DecodeAmf0Object(r, false)
case AMF0_MOVIECLIP_MARKER:
return nil, Error("decode amf0: unsupported type movieclip")
case AMF0_NULL_MARKER:
return d.DecodeAmf0Null(r, false)
case AMF0_UNDEFINED_MARKER:
return d.DecodeAmf0Undefined(r, false)
case AMF0_REFERENCE_MARKER:
return nil, Error("decode amf0: unsupported type reference")
case AMF0_ECMA_ARRAY_MARKER:
return d.DecodeAmf0EcmaArray(r, false)
case AMF0_STRICT_ARRAY_MARKER:
return d.DecodeAmf0StrictArray(r, false)
case AMF0_DATE_MARKER:
return d.DecodeAmf0Date(r, false)
case AMF0_LONG_STRING_MARKER:
return d.DecodeAmf0LongString(r, false)
case AMF0_UNSUPPORTED_MARKER:
return d.DecodeAmf0Unsupported(r, false)
case AMF0_RECORDSET_MARKER:
return nil, Error("decode amf0: unsupported type recordset")
case AMF0_XML_DOCUMENT_MARKER:
return d.DecodeAmf0XmlDocument(r, false)
case AMF0_TYPED_OBJECT_MARKER:
return d.DecodeAmf0TypedObject(r, false)
case AMF0_ACMPLUS_OBJECT_MARKER:
return d.DecodeAmf3(r)
}
return nil, Error("decode amf0: unsupported type %d", marker)
}
// marker: 1 byte 0x00
// format: 8 byte big endian float64
func (d *Decoder) DecodeAmf0Number(r io.Reader, decodeMarker bool) (result float64, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_NUMBER_MARKER); err != nil {
return
}
err = binary.Read(r, binary.BigEndian, &result)
if err != nil {
return float64(0), Error("amf0 decode: unable to read number: %s", err)
}
return
}
// marker: 1 byte 0x01
// format: 1 byte, 0x00 = false, 0x01 = true
func (d *Decoder) DecodeAmf0Boolean(r io.Reader, decodeMarker bool) (result bool, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_BOOLEAN_MARKER); err != nil {
return
}
var b byte
if b, err = ReadByte(r); err != nil {
return
}
if b == AMF0_BOOLEAN_FALSE {
return false, nil
} else if b == AMF0_BOOLEAN_TRUE {
return true, nil
}
return false, Error("decode amf0: unexpected value %v for boolean", b)
}
// marker: 1 byte 0x02
// format:
// - 2 byte big endian uint16 header to determine size
// - n (size) byte utf8 string
func (d *Decoder) DecodeAmf0String(r io.Reader, decodeMarker bool) (result string, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_STRING_MARKER); err != nil {
return
}
var length uint16
err = binary.Read(r, binary.BigEndian, &length)
if err != nil {
return "", Error("decode amf0: unable to decode string length: %s", err)
}
var bytes = make([]byte, length)
if bytes, err = ReadBytes(r, int(length)); err != nil {
return "", Error("decode amf0: unable to decode string value: %s", err)
}
return string(bytes), nil
}
// marker: 1 byte 0x03
// format:
// - loop encoded string followed by encoded value
// - terminated with empty string followed by 1 byte 0x09
func (d *Decoder) DecodeAmf0Object(r io.Reader, decodeMarker bool) (Object, error) {
if err := AssertMarker(r, decodeMarker, AMF0_OBJECT_MARKER); err != nil {
return nil, err
}
result := make(Object)
d.refCache = append(d.refCache, result)
for {
key, err := d.DecodeAmf0String(r, false)
if err != nil {
return nil, err
}
if key == "" {
if err = AssertMarker(r, true, AMF0_OBJECT_END_MARKER); err != nil {
return nil, Error("decode amf0: expected object end marker: %s", err)
}
break
}
value, err := d.DecodeAmf0(r)
if err != nil {
return nil, Error("decode amf0: unable to decode object value: %s", err)
}
result[key] = value
}
return result, nil
}
// marker: 1 byte 0x05
// no additional data
func (d *Decoder) DecodeAmf0Null(r io.Reader, decodeMarker bool) (result interface{}, err error) {
err = AssertMarker(r, decodeMarker, AMF0_NULL_MARKER)
return
}
// marker: 1 byte 0x06
// no additional data
func (d *Decoder) DecodeAmf0Undefined(r io.Reader, decodeMarker bool) (result interface{}, err error) {
err = AssertMarker(r, decodeMarker, AMF0_UNDEFINED_MARKER)
return
}
// marker: 1 byte 0x07
// format: 2 byte big endian uint16
/*
func (d *Decoder) DecodeAmf0Reference(r io.Reader, decodeMarker bool) (interface{}, error) {
if err := AssertMarker(r, decodeMarker, AMF0_REFERENCE_MARKER); err != nil {
return nil, err
}
var err error
var ref uint16
err = binary.Read(r, binary.BigEndian, &ref)
if err != nil {
return nil, Error("decode amf0: unable to decode reference id: %s", err)
}
if int(ref) > len(d.refCache) {
return nil, Error("decode amf0: bad reference %d (current length %d)", ref, len(d.refCache))
}
result := d.refCache[ref]
return result, nil
}
*/
// marker: 1 byte 0x08
// format:
// - 4 byte big endian uint32 with length of associative array
// - normal object format:
// - loop encoded string followed by encoded value
// - terminated with empty string followed by 1 byte 0x09
func (d *Decoder) DecodeAmf0EcmaArray(r io.Reader, decodeMarker bool) (Object, error) {
if err := AssertMarker(r, decodeMarker, AMF0_ECMA_ARRAY_MARKER); err != nil {
return nil, err
}
var length uint32
err := binary.Read(r, binary.BigEndian, &length)
result, err := d.DecodeAmf0Object(r, false)
if err != nil {
return nil, Error("decode amf0: unable to decode ecma array object: %s", err)
}
return result, nil
}
// marker: 1 byte 0x0a
// format:
// - 4 byte big endian uint32 to determine length of associative array
// - n (length) encoded values
func (d *Decoder) DecodeAmf0StrictArray(r io.Reader, decodeMarker bool) (result Array, err error) {
if err := AssertMarker(r, decodeMarker, AMF0_STRICT_ARRAY_MARKER); err != nil {
return nil, err
}
var length uint32
err = binary.Read(r, binary.BigEndian, &length)
if err != nil {
return nil, Error("decode amf0: unable to decode strict array length: %s", err)
}
d.refCache = append(d.refCache, result)
for i := uint32(0); i < length; i++ {
tmp, err := d.DecodeAmf0(r)
if err != nil {
return nil, Error("decode amf0: unable to decode strict array object: %s", err)
}
result = append(result, tmp)
}
return result, nil
}
// marker: 1 byte 0x0b
// format:
// - normal number format:
// - 8 byte big endian float64
// - 2 byte unused
func (d *Decoder) DecodeAmf0Date(r io.Reader, decodeMarker bool) (result float64, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_DATE_MARKER); err != nil {
return
}
if result, err = d.DecodeAmf0Number(r, false); err != nil {
return float64(0), Error("decode amf0: unable to decode float in date: %s", err)
}
if _, err = ReadBytes(r, 2); err != nil {
return float64(0), Error("decode amf0: unable to read 2 trail bytes in date: %s", err)
}
return
}
// marker: 1 byte 0x0c
// format:
// - 4 byte big endian uint32 header to determine size
// - n (size) byte utf8 string
func (d *Decoder) DecodeAmf0LongString(r io.Reader, decodeMarker bool) (result string, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_LONG_STRING_MARKER); err != nil {
return
}
var length uint32
err = binary.Read(r, binary.BigEndian, &length)
if err != nil {
return "", Error("decode amf0: unable to decode long string length: %s", err)
}
var bytes = make([]byte, length)
if bytes, err = ReadBytes(r, int(length)); err != nil {
return "", Error("decode amf0: unable to decode long string value: %s", err)
}
return string(bytes), nil
}
// marker: 1 byte 0x0d
// no additional data
func (d *Decoder) DecodeAmf0Unsupported(r io.Reader, decodeMarker bool) (result interface{}, err error) {
err = AssertMarker(r, decodeMarker, AMF0_UNSUPPORTED_MARKER)
return
}
// marker: 1 byte 0x0f
// format:
// - normal long string format
// - 4 byte big endian uint32 header to determine size
// - n (size) byte utf8 string
func (d *Decoder) DecodeAmf0XmlDocument(r io.Reader, decodeMarker bool) (result string, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_XML_DOCUMENT_MARKER); err != nil {
return
}
return d.DecodeAmf0LongString(r, false)
}
// marker: 1 byte 0x10
// format:
// - normal string format:
// - 2 byte big endian uint16 header to determine size
// - n (size) byte utf8 string
// - normal object format:
// - loop encoded string followed by encoded value
// - terminated with empty string followed by 1 byte 0x09
func (d *Decoder) DecodeAmf0TypedObject(r io.Reader, decodeMarker bool) (TypedObject, error) {
result := *new(TypedObject)
err := AssertMarker(r, decodeMarker, AMF0_TYPED_OBJECT_MARKER)
if err != nil {
return result, err
}
d.refCache = append(d.refCache, result)
result.Type, err = d.DecodeAmf0String(r, false)
if err != nil {
return result, Error("decode amf0: typed object unable to determine type: %s", err)
}
result.Object, err = d.DecodeAmf0Object(r, false)
if err != nil {
return result, Error("decode amf0: typed object unable to determine object: %s", err)
}
return result, nil
}

588
protocol/amf/decoder_amf0_test.go

@ -0,0 +1,588 @@ @@ -0,0 +1,588 @@
package amf
import (
"bytes"
"testing"
)
func TestDecodeAmf0Number(t *testing.T) {
buf := bytes.NewReader([]byte{0x00, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33})
expect := float64(1.2)
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test number interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Number(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test number interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0Number(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0BooleanTrue(t *testing.T) {
buf := bytes.NewReader([]byte{0x01, 0x01})
expect := true
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test boolean interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Boolean(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test boolean interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0Boolean(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0BooleanFalse(t *testing.T) {
buf := bytes.NewReader([]byte{0x01, 0x00})
expect := false
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test boolean interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Boolean(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test boolean interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0Boolean(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0String(t *testing.T) {
buf := bytes.NewReader([]byte{0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f})
expect := "foo"
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test string interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0String(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test string interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0String(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0Object(t *testing.T) {
buf := bytes.NewReader([]byte{0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
obj, ok := got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
// Test object interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Object(buf, true)
if err != nil {
t.Errorf("%s", err)
}
obj, ok = got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
// Test object interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0Object(buf, false)
if err != nil {
t.Errorf("%s", err)
}
obj, ok = got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
}
func TestDecodeAmf0Null(t *testing.T) {
buf := bytes.NewReader([]byte{0x05})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
// Test null interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Null(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
}
func TestDecodeAmf0Undefined(t *testing.T) {
buf := bytes.NewReader([]byte{0x06})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
// Test undefined interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Undefined(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
}
/*
func TestDecodeReference(t *testing.T) {
buf := bytes.NewReader([]byte{0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x00, 0x00, 0x00, 0x00, 0x09})
dec := &Decoder{}
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
obj, ok := got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
_, ok2 := obj["foo"].(Object)
if ok2 != true {
t.Errorf("expected foo value to cast to object")
}
}
*/
func TestDecodeAmf0EcmaArray(t *testing.T) {
buf := bytes.NewReader([]byte{0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
obj, ok := got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
// Test ecma array interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0EcmaArray(buf, true)
if err != nil {
t.Errorf("%s", err)
}
obj, ok = got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
// Test ecma array interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0EcmaArray(buf, false)
if err != nil {
t.Errorf("%s", err)
}
obj, ok = got.(Object)
if ok != true {
t.Errorf("expected result to cast to ecma array")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
}
func TestDecodeAmf0StrictArray(t *testing.T) {
buf := bytes.NewReader([]byte{0x0a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x40, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x05})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
arr, ok := got.(Array)
if ok != true {
t.Errorf("expected result to cast to strict array")
}
if arr[0] != float64(5) {
t.Errorf("expected array[0] to be 5, got %v", arr[0])
}
if arr[1] != "foo" {
t.Errorf("expected array[1] to be 'foo', got %v", arr[1])
}
if arr[2] != nil {
t.Errorf("expected array[2] to be nil, got %v", arr[2])
}
// Test strict array interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0StrictArray(buf, true)
if err != nil {
t.Errorf("%s", err)
}
arr, ok = got.(Array)
if ok != true {
t.Errorf("expected result to cast to strict array")
}
if arr[0] != float64(5) {
t.Errorf("expected array[0] to be 5, got %v", arr[0])
}
if arr[1] != "foo" {
t.Errorf("expected array[1] to be 'foo', got %v", arr[1])
}
if arr[2] != nil {
t.Errorf("expected array[2] to be nil, got %v", arr[2])
}
// Test strict array interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0StrictArray(buf, false)
if err != nil {
t.Errorf("%s", err)
}
arr, ok = got.(Array)
if ok != true {
t.Errorf("expected result to cast to strict array")
}
if arr[0] != float64(5) {
t.Errorf("expected array[0] to be 5, got %v", arr[0])
}
if arr[1] != "foo" {
t.Errorf("expected array[1] to be 'foo', got %v", arr[1])
}
if arr[2] != nil {
t.Errorf("expected array[2] to be nil, got %v", arr[2])
}
}
func TestDecodeAmf0Date(t *testing.T) {
buf := bytes.NewReader([]byte{0x0b, 0x40, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
expect := float64(5)
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test date interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Date(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test date interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0Date(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0LongString(t *testing.T) {
buf := bytes.NewReader([]byte{0x0c, 0x00, 0x00, 0x00, 0x03, 0x66, 0x6f, 0x6f})
expect := "foo"
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test long string interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0LongString(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test long string interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0LongString(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0Unsupported(t *testing.T) {
buf := bytes.NewReader([]byte{0x0d})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
// Test unsupported interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Unsupported(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
}
func TestDecodeAmf0XmlDocument(t *testing.T) {
buf := bytes.NewReader([]byte{0x0f, 0x00, 0x00, 0x00, 0x03, 0x66, 0x6f, 0x6f})
expect := "foo"
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test long string interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0XmlDocument(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test long string interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0XmlDocument(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0TypedObject(t *testing.T) {
buf := bytes.NewReader([]byte{
0x10, 0x00, 0x0F, 'o', 'r', 'g',
'.', 'a', 'm', 'f', '.', 'A',
'S', 'C', 'l', 'a', 's', 's',
0x00, 0x03, 'b', 'a', 'z', 0x05,
0x00, 0x03, 'f', 'o', 'o', 0x02,
0x00, 0x03, 'b', 'a', 'r', 0x00,
0x00, 0x09,
})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
tobj, ok := got.(TypedObject)
if ok != true {
t.Errorf("expected result to cast to typed object, got %+v", tobj)
}
if tobj.Type != "org.amf.ASClass" {
t.Errorf("expected typed object type to be 'class', got %v", tobj.Type)
}
if tobj.Object["foo"] != "bar" {
t.Errorf("expected typed object object foo to eql bar, got %v", tobj.Object["foo"])
}
if tobj.Object["baz"] != nil {
t.Errorf("expected typed object object baz to nil, got %v", tobj.Object["baz"])
}
// Test typed object interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0TypedObject(buf, true)
if err != nil {
t.Errorf("%s", err)
}
tobj, ok = got.(TypedObject)
if ok != true {
t.Errorf("expected result to cast to typed object, got %+v", tobj)
}
if tobj.Type != "org.amf.ASClass" {
t.Errorf("expected typed object type to be 'class', got %v", tobj.Type)
}
if tobj.Object["foo"] != "bar" {
t.Errorf("expected typed object object foo to eql bar, got %v", tobj.Object["foo"])
}
if tobj.Object["baz"] != nil {
t.Errorf("expected typed object object baz to nil, got %v", tobj.Object["baz"])
}
// Test typed object interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0TypedObject(buf, false)
if err != nil {
t.Errorf("%s", err)
}
tobj, ok = got.(TypedObject)
if ok != true {
t.Errorf("expected result to cast to typed object, got %+v", tobj)
}
if tobj.Type != "org.amf.ASClass" {
t.Errorf("expected typed object type to be 'class', got %v", tobj.Type)
}
if tobj.Object["foo"] != "bar" {
t.Errorf("expected typed object object foo to eql bar, got %v", tobj.Object["foo"])
}
if tobj.Object["baz"] != nil {
t.Errorf("expected typed object object baz to nil, got %v", tobj.Object["baz"])
}
}

496
protocol/amf/decoder_amf3.go

@ -0,0 +1,496 @@ @@ -0,0 +1,496 @@
package amf
import (
"encoding/binary"
"io"
"time"
)
// amf3 polymorphic router
func (d *Decoder) DecodeAmf3(r io.Reader) (interface{}, error) {
marker, err := ReadMarker(r)
if err != nil {
return nil, err
}
switch marker {
case AMF3_UNDEFINED_MARKER:
return d.DecodeAmf3Undefined(r, false)
case AMF3_NULL_MARKER:
return d.DecodeAmf3Null(r, false)
case AMF3_FALSE_MARKER:
return d.DecodeAmf3False(r, false)
case AMF3_TRUE_MARKER:
return d.DecodeAmf3True(r, false)
case AMF3_INTEGER_MARKER:
return d.DecodeAmf3Integer(r, false)
case AMF3_DOUBLE_MARKER:
return d.DecodeAmf3Double(r, false)
case AMF3_STRING_MARKER:
return d.DecodeAmf3String(r, false)
case AMF3_XMLDOC_MARKER:
return d.DecodeAmf3Xml(r, false)
case AMF3_DATE_MARKER:
return d.DecodeAmf3Date(r, false)
case AMF3_ARRAY_MARKER:
return d.DecodeAmf3Array(r, false)
case AMF3_OBJECT_MARKER:
return d.DecodeAmf3Object(r, false)
case AMF3_XMLSTRING_MARKER:
return d.DecodeAmf3Xml(r, false)
case AMF3_BYTEARRAY_MARKER:
return d.DecodeAmf3ByteArray(r, false)
}
return nil, Error("decode amf3: unsupported type %d", marker)
}
// marker: 1 byte 0x00
// no additional data
func (d *Decoder) DecodeAmf3Undefined(r io.Reader, decodeMarker bool) (result interface{}, err error) {
err = AssertMarker(r, decodeMarker, AMF3_UNDEFINED_MARKER)
return
}
// marker: 1 byte 0x01
// no additional data
func (d *Decoder) DecodeAmf3Null(r io.Reader, decodeMarker bool) (result interface{}, err error) {
err = AssertMarker(r, decodeMarker, AMF3_NULL_MARKER)
return
}
// marker: 1 byte 0x02
// no additional data
func (d *Decoder) DecodeAmf3False(r io.Reader, decodeMarker bool) (result bool, err error) {
err = AssertMarker(r, decodeMarker, AMF3_FALSE_MARKER)
result = false
return
}
// marker: 1 byte 0x03
// no additional data
func (d *Decoder) DecodeAmf3True(r io.Reader, decodeMarker bool) (result bool, err error) {
err = AssertMarker(r, decodeMarker, AMF3_TRUE_MARKER)
result = true
return
}
// marker: 1 byte 0x04
func (d *Decoder) DecodeAmf3Integer(r io.Reader, decodeMarker bool) (result int32, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_INTEGER_MARKER); err != nil {
return
}
var u29 uint32
u29, err = d.decodeU29(r)
if err != nil {
return
}
result = int32(u29)
if result > 0xfffffff {
result = int32(u29 - 0x20000000)
}
return
}
// marker: 1 byte 0x05
func (d *Decoder) DecodeAmf3Double(r io.Reader, decodeMarker bool) (result float64, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_DOUBLE_MARKER); err != nil {
return
}
err = binary.Read(r, binary.BigEndian, &result)
if err != nil {
return float64(0), Error("amf3 decode: unable to read double: %s", err)
}
return
}
// marker: 1 byte 0x06
// format:
// - u29 reference int. if reference, no more data. if not reference,
// length value of bytes to read to complete string.
func (d *Decoder) DecodeAmf3String(r io.Reader, decodeMarker bool) (result string, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_STRING_MARKER); err != nil {
return
}
var isRef bool
var refVal uint32
isRef, refVal, err = d.decodeReferenceInt(r)
if err != nil {
return "", Error("amf3 decode: unable to decode string reference and length: %s", err)
}
if isRef {
result = d.stringRefs[refVal]
return
}
buf := make([]byte, refVal)
_, err = r.Read(buf)
if err != nil {
return "", Error("amf3 decode: unable to read string: %s", err)
}
result = string(buf)
if result != "" {
d.stringRefs = append(d.stringRefs, result)
}
return
}
// marker: 1 byte 0x08
// format:
// - u29 reference int, if reference, no more data
// - timestamp double
func (d *Decoder) DecodeAmf3Date(r io.Reader, decodeMarker bool) (result time.Time, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_DATE_MARKER); err != nil {
return
}
var isRef bool
var refVal uint32
isRef, refVal, err = d.decodeReferenceInt(r)
if err != nil {
return result, Error("amf3 decode: unable to decode date reference and length: %s", err)
}
if isRef {
res, ok := d.objectRefs[refVal].(time.Time)
if ok != true {
return result, Error("amf3 decode: unable to extract time from date object references")
}
return res, err
}
var u64 float64
err = binary.Read(r, binary.BigEndian, &u64)
if err != nil {
return result, Error("amf3 decode: unable to read double: %s", err)
}
result = time.Unix(int64(u64/1000), 0).UTC()
d.objectRefs = append(d.objectRefs, result)
return
}
// marker: 1 byte 0x09
// format:
// - u29 reference int. if reference, no more data.
// - string representing associative array if present
// - n values (length of u29)
func (d *Decoder) DecodeAmf3Array(r io.Reader, decodeMarker bool) (result Array, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_ARRAY_MARKER); err != nil {
return
}
var isRef bool
var refVal uint32
isRef, refVal, err = d.decodeReferenceInt(r)
if err != nil {
return result, Error("amf3 decode: unable to decode array reference and length: %s", err)
}
if isRef {
objRefId := refVal >> 1
res, ok := d.objectRefs[objRefId].(Array)
if ok != true {
return result, Error("amf3 decode: unable to extract array from object references")
}
return res, err
}
var key string
key, err = d.DecodeAmf3String(r, false)
if err != nil {
return result, Error("amf3 decode: unable to read key for array: %s", err)
}
if key != "" {
return result, Error("amf3 decode: array key is not empty, can't handle associative array")
}
for i := uint32(0); i < refVal; i++ {
tmp, err := d.DecodeAmf3(r)
if err != nil {
return result, Error("amf3 decode: array element could not be decoded: %s", err)
}
result = append(result, tmp)
}
d.objectRefs = append(d.objectRefs, result)
return
}
// marker: 1 byte 0x09
// format: oh dear god
func (d *Decoder) DecodeAmf3Object(r io.Reader, decodeMarker bool) (result interface{}, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_OBJECT_MARKER); err != nil {
return nil, err
}
// decode the initial u29
isRef, refVal, err := d.decodeReferenceInt(r)
if err != nil {
return nil, Error("amf3 decode: unable to decode object reference and length: %s", err)
}
// if this is a object reference only, grab it and return it
if isRef {
objRefId := refVal >> 1
return d.objectRefs[objRefId], nil
}
// each type has traits that are cached, if the peer sent a reference
// then we'll need to look it up and use it.
var trait Trait
traitIsRef := (refVal & 0x01) == 0
if traitIsRef {
traitRef := refVal >> 1
trait = d.traitRefs[traitRef]
} else {
// build a new trait from what's left of the given u29
trait = *NewTrait()
trait.Externalizable = (refVal & 0x02) != 0
trait.Dynamic = (refVal & 0x04) != 0
var cls string
cls, err = d.DecodeAmf3String(r, false)
if err != nil {
return result, Error("amf3 decode: unable to read trait type for object: %s", err)
}
trait.Type = cls
// traits have property keys, encoded as amf3 strings
propLength := refVal >> 3
for i := uint32(0); i < propLength; i++ {
tmp, err := d.DecodeAmf3String(r, false)
if err != nil {
return result, Error("amf3 decode: unable to read trait property for object: %s", err)
}
trait.Properties = append(trait.Properties, tmp)
}
d.traitRefs = append(d.traitRefs, trait)
}
d.objectRefs = append(d.objectRefs, result)
// objects can be externalizable, meaning that the system has no concrete understanding of
// their properties or how they are encoded. in that case, we need to find and delegate behavior
// to the right object.
if trait.Externalizable {
switch trait.Type {
case "DSA": // AsyncMessageExt
result, err = d.decodeAsyncMessageExt(r)
if err != nil {
return result, Error("amf3 decode: unable to decode dsa: %s", err)
}
case "DSK": // AcknowledgeMessageExt
result, err = d.decodeAcknowledgeMessageExt(r)
if err != nil {
return result, Error("amf3 decode: unable to decode dsk: %s", err)
}
case "flex.messaging.io.ArrayCollection":
result, err = d.decodeArrayCollection(r)
if err != nil {
return result, Error("amf3 decode: unable to decode ac: %s", err)
}
// store an extra reference to array collection container
d.objectRefs = append(d.objectRefs, result)
default:
fn, ok := d.externalHandlers[trait.Type]
if ok {
result, err = fn(d, r)
if err != nil {
return result, Error("amf3 decode: unable to call external decoder for type %s: %s", trait.Type, err)
}
} else {
return result, Error("amf3 decode: unable to decode external type %s, no handler", trait.Type)
}
}
return result, err
}
var key string
var val interface{}
var obj Object
obj = make(Object)
// non-externalizable objects have property keys in traits, iterate through them
// and add the read values to the object
for _, key = range trait.Properties {
val, err = d.DecodeAmf3(r)
if err != nil {
return result, Error("amf3 decode: unable to decode object property: %s", err)
}
obj[key] = val
}
// if an object is dynamic, it can have extra key/value data at the end. in this case,
// read keys until we get an empty one.
if trait.Dynamic {
for {
key, err = d.DecodeAmf3String(r, false)
if err != nil {
return result, Error("amf3 decode: unable to decode dynamic key: %s", err)
}
if key == "" {
break
}
val, err = d.DecodeAmf3(r)
if err != nil {
return result, Error("amf3 decode: unable to decode dynamic value: %s", err)
}
obj[key] = val
}
}
result = obj
return
}
// marker: 1 byte 0x07 or 0x0b
// format:
// - u29 reference int. if reference, no more data. if not reference,
// length value of bytes to read to complete string.
func (d *Decoder) DecodeAmf3Xml(r io.Reader, decodeMarker bool) (result string, err error) {
if decodeMarker {
var marker byte
marker, err = ReadMarker(r)
if err != nil {
return "", err
}
if (marker != AMF3_XMLDOC_MARKER) && (marker != AMF3_XMLSTRING_MARKER) {
return "", Error("decode assert marker failed: expected %v or %v, got %v", AMF3_XMLDOC_MARKER, AMF3_XMLSTRING_MARKER, marker)
}
}
var isRef bool
var refVal uint32
isRef, refVal, err = d.decodeReferenceInt(r)
if err != nil {
return "", Error("amf3 decode: unable to decode xml reference and length: %s", err)
}
if isRef {
var ok bool
buf := d.objectRefs[refVal]
result, ok = buf.(string)
if ok != true {
return "", Error("amf3 decode: cannot coerce object reference into xml string")
}
return
}
buf := make([]byte, refVal)
_, err = r.Read(buf)
if err != nil {
return "", Error("amf3 decode: unable to read xml string: %s", err)
}
result = string(buf)
if result != "" {
d.objectRefs = append(d.objectRefs, result)
}
return
}
// marker: 1 byte 0x0c
// format:
// - u29 reference int. if reference, no more data. if not reference,
// length value of bytes to read.
func (d *Decoder) DecodeAmf3ByteArray(r io.Reader, decodeMarker bool) (result []byte, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_BYTEARRAY_MARKER); err != nil {
return
}
var isRef bool
var refVal uint32
isRef, refVal, err = d.decodeReferenceInt(r)
if err != nil {
return result, Error("amf3 decode: unable to decode byte array reference and length: %s", err)
}
if isRef {
var ok bool
result, ok = d.objectRefs[refVal].([]byte)
if ok != true {
return result, Error("amf3 decode: unable to convert object ref to bytes")
}
return
}
result = make([]byte, refVal)
_, err = r.Read(result)
if err != nil {
return result, Error("amf3 decode: unable to read bytearray: %s", err)
}
d.objectRefs = append(d.objectRefs, result)
return
}
func (d *Decoder) decodeU29(r io.Reader) (result uint32, err error) {
var b byte
for i := 0; i < 3; i++ {
b, err = ReadByte(r)
if err != nil {
return
}
result = (result << 7) + uint32(b&0x7F)
if (b & 0x80) == 0 {
return
}
}
b, err = ReadByte(r)
if err != nil {
return
}
result = ((result << 8) + uint32(b))
return
}
func (d *Decoder) decodeReferenceInt(r io.Reader) (isRef bool, refVal uint32, err error) {
u29, err := d.decodeU29(r)
if err != nil {
return false, 0, Error("amf3 decode: unable to decode reference int: %s", err)
}
isRef = u29&0x01 == 0
refVal = u29 >> 1
return
}

127
protocol/amf/decoder_amf3_external.go

@ -0,0 +1,127 @@ @@ -0,0 +1,127 @@
package amf
import (
"fmt"
"io"
"math"
)
// Abstract external boilerplate
func (d *Decoder) decodeAbstractMessage(r io.Reader) (result Object, err error) {
result = make(Object)
if err = d.decodeExternal(r, &result,
[]string{"body", "clientId", "destination", "headers", "messageId", "timeStamp", "timeToLive"},
[]string{"clientIdBytes", "messageIdBytes"}); err != nil {
return result, Error("unable to decode abstract external: %s", err)
}
return
}
// DSA
func (d *Decoder) decodeAsyncMessageExt(r io.Reader) (result Object, err error) {
return d.decodeAsyncMessage(r)
}
func (d *Decoder) decodeAsyncMessage(r io.Reader) (result Object, err error) {
result, err = d.decodeAbstractMessage(r)
if err != nil {
return result, Error("unable to decode abstract for async: %s", err)
}
if err = d.decodeExternal(r, &result, []string{"correlationId", "correlationIdBytes"}); err != nil {
return result, Error("unable to decode async external: %s", err)
}
return
}
// DSK
func (d *Decoder) decodeAcknowledgeMessageExt(r io.Reader) (result Object, err error) {
return d.decodeAcknowledgeMessage(r)
}
func (d *Decoder) decodeAcknowledgeMessage(r io.Reader) (result Object, err error) {
result, err = d.decodeAsyncMessage(r)
if err != nil {
return result, Error("unable to decode async for ack: %s", err)
}
if err = d.decodeExternal(r, &result); err != nil {
return result, Error("unable to decode ack external: %s", err)
}
return
}
// flex.messaging.io.ArrayCollection
func (d *Decoder) decodeArrayCollection(r io.Reader) (interface{}, error) {
result, err := d.DecodeAmf3(r)
if err != nil {
return result, Error("cannot decode child of array collection: %s", err)
}
return result, nil
}
func (d *Decoder) decodeExternal(r io.Reader, obj *Object, fieldSets ...[]string) (err error) {
var flagSet []uint8
var reservedPosition uint8
var fieldNames []string
flagSet, err = readFlags(r)
if err != nil {
return Error("unable to read flags: %s", err)
}
for i, flags := range flagSet {
if i < len(fieldSets) {
fieldNames = fieldSets[i]
} else {
fieldNames = []string{}
}
reservedPosition = uint8(len(fieldNames))
for p, field := range fieldNames {
flagBit := uint8(math.Exp2(float64(p)))
if (flags & flagBit) != 0 {
tmp, err := d.DecodeAmf3(r)
if err != nil {
return Error("unable to decode external field %s %d %d (%#v): %s", field, i, p, flagSet, err)
}
(*obj)[field] = tmp
}
}
if (flags >> reservedPosition) != 0 {
for j := reservedPosition; j < 6; j++ {
if ((flags >> j) & 0x01) != 0 {
field := fmt.Sprintf("extra_%d_%d", i, j)
tmp, err := d.DecodeAmf3(r)
if err != nil {
return Error("unable to decode post-external field %d %d (%#v): %s", i, j, flagSet, err)
}
(*obj)[field] = tmp
}
}
}
}
return
}
func readFlags(r io.Reader) (result []uint8, err error) {
for {
flag, err := ReadByte(r)
if err != nil {
return result, Error("unable to read flags: %s", err)
}
result = append(result, flag)
if (flag & 0x80) == 0 {
break
}
}
return
}

220
protocol/amf/decoder_amf3_test.go

@ -0,0 +1,220 @@ @@ -0,0 +1,220 @@
package amf
import (
"bytes"
"testing"
)
type u29TestCase struct {
value uint32
expect []byte
}
var u29TestCases = []u29TestCase{
{1, []byte{0x01}},
{2, []byte{0x02}},
{127, []byte{0x7F}},
{128, []byte{0x81, 0x00}},
{255, []byte{0x81, 0x7F}},
{256, []byte{0x82, 0x00}},
{0x3FFF, []byte{0xFF, 0x7F}},
{0x4000, []byte{0x81, 0x80, 0x00}},
{0x7FFF, []byte{0x81, 0xFF, 0x7F}},
{0x8000, []byte{0x82, 0x80, 0x00}},
{0x1FFFFF, []byte{0xFF, 0xFF, 0x7F}},
{0x200000, []byte{0x80, 0xC0, 0x80, 0x00}},
{0x3FFFFF, []byte{0x80, 0xFF, 0xFF, 0xFF}},
{0x400000, []byte{0x81, 0x80, 0x80, 0x00}},
{0x0FFFFFFF, []byte{0xBF, 0xFF, 0xFF, 0xFF}},
}
func TestDecodeAmf3Undefined(t *testing.T) {
buf := bytes.NewReader([]byte{0x00})
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
}
func TestDecodeAmf3Null(t *testing.T) {
buf := bytes.NewReader([]byte{0x01})
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
}
func TestDecodeAmf3False(t *testing.T) {
buf := bytes.NewReader([]byte{0x02})
expect := false
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf3True(t *testing.T) {
buf := bytes.NewReader([]byte{0x03})
expect := true
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeU29(t *testing.T) {
dec := new(Decoder)
for _, tc := range u29TestCases {
buf := bytes.NewBuffer(tc.expect)
n, err := dec.decodeU29(buf)
if err != nil {
t.Errorf("DecodeAmf3Integer error: %s", err)
}
if n != tc.value {
t.Errorf("DecodeAmf3Integer expect n %x got %x", tc.value, n)
}
}
}
func TestDecodeAmf3Integer(t *testing.T) {
dec := new(Decoder)
buf := bytes.NewReader([]byte{0x04, 0xFF, 0xFF, 0x7F})
expect := int32(2097151)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
buf.Seek(0, 0)
got, err = dec.DecodeAmf3Integer(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
buf.Seek(1, 0)
got, err = dec.DecodeAmf3Integer(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf3Double(t *testing.T) {
buf := bytes.NewReader([]byte{0x05, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33})
expect := float64(1.2)
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf3String(t *testing.T) {
buf := bytes.NewReader([]byte{0x06, 0x07, 'f', 'o', 'o'})
expect := "foo"
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf3Array(t *testing.T) {
buf := bytes.NewReader([]byte{0x09, 0x13, 0x01,
0x06, 0x03, '1',
0x06, 0x03, '2',
0x06, 0x03, '3',
0x06, 0x03, '4',
0x06, 0x03, '5',
0x06, 0x03, '6',
0x06, 0x03, '7',
0x06, 0x03, '8',
0x06, 0x03, '9',
})
dec := new(Decoder)
expect := []string{"1", "2", "3", "4", "5", "6", "7", "8", "9"}
got, err := dec.DecodeAmf3Array(buf, true)
if err != nil {
t.Errorf("err: %s", err)
}
for i, v := range expect {
if got[i] != v {
t.Error("expected array element %d to be %v, got %v", i, v, got[i])
}
}
}
func TestDecodeAmf3Object(t *testing.T) {
buf := bytes.NewReader([]byte{
0x0a, 0x23, 0x1f, 'o', 'r', 'g', '.', 'a',
'm', 'f', '.', 'A', 'S', 'C', 'l', 'a',
's', 's', 0x07, 'b', 'a', 'z', 0x07, 'f',
'o', 'o', 0x01, 0x06, 0x07, 'b', 'a', 'r',
})
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("err: %s", err)
}
to, ok := got.(Object)
if ok != true {
t.Error("unable to cast object as typed object")
}
if to["foo"] != "bar" {
t.Error("expected foo to be bar, got: %+v", to["foo"])
}
if to["baz"] != nil {
t.Error("expected baz to be nil, got: %+v", to["baz"])
}
}

308
protocol/amf/encoder_amf0.go

@ -0,0 +1,308 @@ @@ -0,0 +1,308 @@
package amf
import (
"encoding/binary"
"io"
"reflect"
)
// amf0 polymorphic router
func (e *Encoder) EncodeAmf0(w io.Writer, val interface{}) (int, error) {
if val == nil {
return e.EncodeAmf0Null(w, true)
}
v := reflect.ValueOf(val)
if !v.IsValid() {
return e.EncodeAmf0Null(w, true)
}
switch v.Kind() {
case reflect.String:
str := v.String()
if len(str) <= AMF0_STRING_MAX {
return e.EncodeAmf0String(w, str, true)
} else {
return e.EncodeAmf0LongString(w, str, true)
}
case reflect.Bool:
return e.EncodeAmf0Boolean(w, v.Bool(), true)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return e.EncodeAmf0Number(w, float64(v.Int()), true)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return e.EncodeAmf0Number(w, float64(v.Uint()), true)
case reflect.Float32, reflect.Float64:
return e.EncodeAmf0Number(w, float64(v.Float()), true)
case reflect.Array, reflect.Slice:
length := v.Len()
arr := make(Array, length)
for i := 0; i < length; i++ {
arr[i] = v.Index(int(i)).Interface()
}
return e.EncodeAmf0StrictArray(w, arr, true)
case reflect.Map:
obj, ok := val.(Object)
if ok != true {
return 0, Error("encode amf0: unable to create object from map")
}
return e.EncodeAmf0Object(w, obj, true)
}
if _, ok := val.(TypedObject); ok {
return 0, Error("encode amf0: unsupported type typed object")
}
return 0, Error("encode amf0: unsupported type %s", v.Type())
}
// marker: 1 byte 0x00
// format: 8 byte big endian float64
func (e *Encoder) EncodeAmf0Number(w io.Writer, val float64, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_NUMBER_MARKER); err != nil {
return
}
n += 1
}
err = binary.Write(w, binary.BigEndian, &val)
if err != nil {
return
}
n += 8
return
}
// marker: 1 byte 0x01
// format: 1 byte, 0x00 = false, 0x01 = true
func (e *Encoder) EncodeAmf0Boolean(w io.Writer, val bool, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_BOOLEAN_MARKER); err != nil {
return
}
n += 1
}
var m int
buf := make([]byte, 1)
if val {
buf[0] = AMF0_BOOLEAN_TRUE
} else {
buf[0] = AMF0_BOOLEAN_FALSE
}
m, err = w.Write(buf)
if err != nil {
return
}
n += m
return
}
// marker: 1 byte 0x02
// format:
// - 2 byte big endian uint16 header to determine size
// - n (size) byte utf8 string
func (e *Encoder) EncodeAmf0String(w io.Writer, val string, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_STRING_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint16(len(val))
err = binary.Write(w, binary.BigEndian, length)
if err != nil {
return n, Error("encode amf0: unable to encode string length: %s", err)
}
n += 2
m, err = w.Write([]byte(val))
if err != nil {
return n, Error("encode amf0: unable to encode string value: %s", err)
}
n += m
return
}
// marker: 1 byte 0x03
// format:
// - loop encoded string followed by encoded value
// - terminated with empty string followed by 1 byte 0x09
func (e *Encoder) EncodeAmf0Object(w io.Writer, val Object, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_OBJECT_MARKER); err != nil {
return
}
n += 1
}
var m int
for k, v := range val {
m, err = e.EncodeAmf0String(w, k, false)
if err != nil {
return n, Error("encode amf0: unable to encode object key: %s", err)
}
n += m
m, err = e.EncodeAmf0(w, v)
if err != nil {
return n, Error("encode amf0: unable to encode object value: %s", err)
}
n += m
}
m, err = e.EncodeAmf0String(w, "", false)
if err != nil {
return n, Error("encode amf0: unable to encode object empty string: %s", err)
}
n += m
err = WriteMarker(w, AMF0_OBJECT_END_MARKER)
if err != nil {
return n, Error("encode amf0: unable to object end marker: %s", err)
}
n += 1
return
}
// marker: 1 byte 0x05
// no additional data
func (e *Encoder) EncodeAmf0Null(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_NULL_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x06
// no additional data
func (e *Encoder) EncodeAmf0Undefined(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_UNDEFINED_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x08
// format:
// - 4 byte big endian uint32 with length of associative array
// - normal object format:
// - loop encoded string followed by encoded value
// - terminated with empty string followed by 1 byte 0x09
func (e *Encoder) EncodeAmf0EcmaArray(w io.Writer, val Object, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_ECMA_ARRAY_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint32(len(val))
err = binary.Write(w, binary.BigEndian, length)
if err != nil {
return n, Error("encode amf0: unable to encode ecma array length: %s", err)
}
n += 4
m, err = e.EncodeAmf0Object(w, val, false)
if err != nil {
return n, Error("encode amf0: unable to encode ecma array object: %s", err)
}
n += m
return
}
// marker: 1 byte 0x0a
// format:
// - 4 byte big endian uint32 to determine length of associative array
// - n (length) encoded values
func (e *Encoder) EncodeAmf0StrictArray(w io.Writer, val Array, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_STRICT_ARRAY_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint32(len(val))
err = binary.Write(w, binary.BigEndian, length)
if err != nil {
return n, Error("encode amf0: unable to encode strict array length: %s", err)
}
n += 4
for _, v := range val {
m, err = e.EncodeAmf0(w, v)
if err != nil {
return n, Error("encode amf0: unable to encode strict array element: %s", err)
}
n += m
}
return
}
// marker: 1 byte 0x0c
// format:
// - 4 byte big endian uint32 header to determine size
// - n (size) byte utf8 string
func (e *Encoder) EncodeAmf0LongString(w io.Writer, val string, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_LONG_STRING_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint32(len(val))
err = binary.Write(w, binary.BigEndian, length)
if err != nil {
return n, Error("encode amf0: unable to encode long string length: %s", err)
}
n += 4
m, err = w.Write([]byte(val))
if err != nil {
return n, Error("encode amf0: unable to encode long string value: %s", err)
}
n += m
return
}
// marker: 1 byte 0x0d
// no additional data
func (e *Encoder) EncodeAmf0Unsupported(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_UNSUPPORTED_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x11
func (e *Encoder) EncodeAmf0Amf3Marker(w io.Writer) error {
return WriteMarker(w, AMF0_ACMPLUS_OBJECT_MARKER)
}

212
protocol/amf/encoder_amf0_test.go

@ -0,0 +1,212 @@ @@ -0,0 +1,212 @@
package amf
import (
"bytes"
"encoding/binary"
"testing"
)
func TestEncodeAmf0Number(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x00, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33}
enc := new(Encoder)
n, err := enc.EncodeAmf0(buf, float64(1.2))
if err != nil {
t.Errorf("%s", err)
}
if n != 9 {
t.Errorf("expected to write 9 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0BooleanTrue(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x01, 0x01}
enc := new(Encoder)
n, err := enc.EncodeAmf0(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if n != 2 {
t.Errorf("expected to write 2 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0BooleanFalse(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x01, 0x00}
enc := new(Encoder)
n, err := enc.EncodeAmf0(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if n != 2 {
t.Errorf("expected to write 2 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0String(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f}
enc := new(Encoder)
n, err := enc.EncodeAmf0(buf, "foo")
if err != nil {
t.Errorf("%s", err)
}
if n != 6 {
t.Errorf("expected to write 6 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0Object(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09}
enc := new(Encoder)
obj := make(Object)
obj["foo"] = "bar"
n, err := enc.EncodeAmf0(buf, obj)
if err != nil {
t.Errorf("%s", err)
}
if n != 15 {
t.Errorf("expected to write 15 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0EcmaArray(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09}
enc := new(Encoder)
obj := make(Object)
obj["foo"] = "bar"
_, err := enc.EncodeAmf0EcmaArray(buf, obj, true)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0StrictArray(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x0a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x40, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x05}
enc := new(Encoder)
arr := make(Array, 3)
arr[0] = float64(5)
arr[1] = "foo"
arr[2] = nil
_, err := enc.EncodeAmf0StrictArray(buf, arr, true)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0Null(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x05}
enc := new(Encoder)
n, err := enc.EncodeAmf0(buf, nil)
if err != nil {
t.Errorf("%s", err)
}
if n != 1 {
t.Errorf("expected to write 1 byte, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0LongString(t *testing.T) {
buf := new(bytes.Buffer)
testBytes := []byte("12345678")
tbuf := new(bytes.Buffer)
for i := 0; i < 65536; i++ {
tbuf.Write(testBytes)
}
enc := new(Encoder)
_, err := enc.EncodeAmf0(buf, string(tbuf.Bytes()))
if err != nil {
t.Errorf("%s", err)
}
mbuf := make([]byte, 1)
_, err = buf.Read(mbuf)
if err != nil {
t.Errorf("error reading header")
}
if mbuf[0] != 0x0c {
t.Errorf("marker mismatch")
}
var length uint32
err = binary.Read(buf, binary.BigEndian, &length)
if err != nil {
t.Errorf("error reading buffer")
}
if length != (65536 * 8) {
t.Errorf("expected length to be %d, got %d", (65536 * 8), length)
}
tmpBuf := make([]byte, 8)
counter := 0
for buf.Len() > 0 {
n, err := buf.Read(tmpBuf)
if err != nil {
t.Fatalf("test long string result check, read data(%d) error: %s, n: %d", counter, err, n)
}
if n != 8 {
t.Fatalf("test long string result check, read data(%d) n: %d", counter, n)
}
if !bytes.Equal(testBytes, tmpBuf) {
t.Fatalf("test long string result check, read data % x", tmpBuf)
}
counter++
}
}

431
protocol/amf/encoder_amf3.go

@ -0,0 +1,431 @@ @@ -0,0 +1,431 @@
package amf
import (
"encoding/binary"
"io"
"reflect"
"sort"
"time"
)
// amf3 polymorphic router
func (e *Encoder) EncodeAmf3(w io.Writer, val interface{}) (int, error) {
if val == nil {
return e.EncodeAmf3Null(w, true)
}
v := reflect.ValueOf(val)
if !v.IsValid() {
return e.EncodeAmf3Null(w, true)
}
switch v.Kind() {
case reflect.String:
return e.EncodeAmf3String(w, v.String(), true)
case reflect.Bool:
if v.Bool() {
return e.EncodeAmf3True(w, true)
} else {
return e.EncodeAmf3False(w, true)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
n := v.Int()
if n >= 0 && n <= AMF3_INTEGER_MAX {
return e.EncodeAmf3Integer(w, uint32(n), true)
} else {
return e.EncodeAmf3Double(w, float64(n), true)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
n := v.Uint()
if n <= AMF3_INTEGER_MAX {
return e.EncodeAmf3Integer(w, uint32(n), true)
} else {
return e.EncodeAmf3Double(w, float64(n), true)
}
case reflect.Int64:
return e.EncodeAmf3Double(w, float64(v.Int()), true)
case reflect.Uint64:
return e.EncodeAmf3Double(w, float64(v.Uint()), true)
case reflect.Float32, reflect.Float64:
return e.EncodeAmf3Double(w, float64(v.Float()), true)
case reflect.Array, reflect.Slice:
length := v.Len()
arr := make(Array, length)
for i := 0; i < length; i++ {
arr[i] = v.Index(int(i)).Interface()
}
return e.EncodeAmf3Array(w, arr, true)
case reflect.Map:
obj, ok := val.(Object)
if ok != true {
return 0, Error("encode amf3: unable to create object from map")
}
to := *new(TypedObject)
to.Object = obj
return e.EncodeAmf3Object(w, to, true)
}
if tm, ok := val.(time.Time); ok {
return e.EncodeAmf3Date(w, tm, true)
}
if to, ok := val.(TypedObject); ok {
return e.EncodeAmf3Object(w, to, true)
}
return 0, Error("encode amf3: unsupported type %s", v.Type())
}
// marker: 1 byte 0x00
// no additional data
func (e *Encoder) EncodeAmf3Undefined(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_UNDEFINED_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x01
// no additional data
func (e *Encoder) EncodeAmf3Null(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_NULL_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x02
// no additional data
func (e *Encoder) EncodeAmf3False(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_FALSE_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x03
// no additional data
func (e *Encoder) EncodeAmf3True(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_TRUE_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x04
func (e *Encoder) EncodeAmf3Integer(w io.Writer, val uint32, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_INTEGER_MARKER); err != nil {
return
}
n += 1
}
var m int
m, err = e.encodeAmf3Uint29(w, val)
if err != nil {
return
}
n += m
return
}
// marker: 1 byte 0x05
func (e *Encoder) EncodeAmf3Double(w io.Writer, val float64, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_DOUBLE_MARKER); err != nil {
return
}
n += 1
}
err = binary.Write(w, binary.BigEndian, &val)
if err != nil {
return
}
n += 8
return
}
// marker: 1 byte 0x06
// format:
// - u29 reference int. if reference, no more data. if not reference,
// length value of bytes to read to complete string.
func (e *Encoder) EncodeAmf3String(w io.Writer, val string, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_STRING_MARKER); err != nil {
return
}
n += 1
}
var m int
m, err = e.encodeAmf3Utf8(w, val)
if err != nil {
return
}
n += m
return
}
// marker: 1 byte 0x08
// format:
// - u29 reference int, if reference, no more data
// - timestamp double
func (e *Encoder) EncodeAmf3Date(w io.Writer, val time.Time, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_DATE_MARKER); err != nil {
return
}
n += 1
}
if err = WriteMarker(w, 0x01); err != nil {
return n, Error("amf3 encode: cannot encode u29 for array: %s", err)
}
n += 1
u64 := float64(val.Unix()) * 1000.0
err = binary.Write(w, binary.BigEndian, &u64)
if err != nil {
return n, Error("amf3 encode: unable to write date double: %s", err)
}
n += 8
return
}
// marker: 1 byte 0x09
// format:
// - u29 reference int. if reference, no more data.
// - string representing associative array if present
// - n values (length of u29)
func (e *Encoder) EncodeAmf3Array(w io.Writer, val Array, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_ARRAY_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint32(len(val))
u29 := uint32(length<<1) | 0x01
m, err = e.encodeAmf3Uint29(w, u29)
if err != nil {
return n, Error("amf3 encode: cannot encode u29 for array: %s", err)
}
n += m
m, err = e.encodeAmf3Utf8(w, "")
if err != nil {
return n, Error("amf3 encode: cannot encode empty string for array: %s", err)
}
n += m
for _, v := range val {
m, err := e.EncodeAmf3(w, v)
if err != nil {
return n, Error("amf3 encode: cannot encode array element: %s", err)
}
n += m
}
return
}
// marker: 1 byte 0x0a
// format: ugh
func (e *Encoder) EncodeAmf3Object(w io.Writer, val TypedObject, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_OBJECT_MARKER); err != nil {
return
}
n += 1
}
m := 0
trait := *NewTrait()
trait.Type = val.Type
trait.Dynamic = false
trait.Externalizable = false
for k, _ := range val.Object {
trait.Properties = append(trait.Properties, k)
}
sort.Strings(trait.Properties)
var u29 uint32 = 0x03
if trait.Dynamic {
u29 |= 0x02 << 2
}
if trait.Externalizable {
u29 |= 0x01 << 2
}
u29 |= uint32(len(trait.Properties)) << 4
m, err = e.encodeAmf3Uint29(w, u29)
if err != nil {
return n, Error("amf3 encode: cannot encode trait header for object: %s", err)
}
n += m
m, err = e.encodeAmf3Utf8(w, trait.Type)
if err != nil {
return n, Error("amf3 encode: cannot encode trait type for object: %s", err)
}
n += m
for _, prop := range trait.Properties {
m, err = e.encodeAmf3Utf8(w, prop)
if err != nil {
return n, Error("amf3 encode: cannot encode trait property for object: %s", err)
}
n += m
}
if trait.Externalizable {
return n, Error("amf3 encode: cannot encode externalizable object")
}
for _, prop := range trait.Properties {
m, err = e.EncodeAmf3(w, val.Object[prop])
if err != nil {
return n, Error("amf3 encode: cannot encode sealed object value: %s", err)
}
n += m
}
if trait.Dynamic {
for k, v := range val.Object {
var foundProp bool = false
for _, prop := range trait.Properties {
if prop == k {
foundProp = true
break
}
}
if foundProp != true {
m, err = e.encodeAmf3Utf8(w, k)
if err != nil {
return n, Error("amf3 encode: cannot encode dynamic object property key: %s", err)
}
n += m
m, err = e.EncodeAmf3(w, v)
if err != nil {
return n, Error("amf3 encode: cannot encode dynamic object value: %s", err)
}
n += m
}
m, err = e.encodeAmf3Utf8(w, "")
if err != nil {
return n, Error("amf3 encode: cannot encode dynamic object ending marker string: %s", err)
}
n += m
}
}
return
}
// marker: 1 byte 0x0c
// format:
// - u29 reference int. if reference, no more data. if not reference,
// length value of bytes to read .
func (e *Encoder) EncodeAmf3ByteArray(w io.Writer, val []byte, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_BYTEARRAY_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint32(len(val))
u29 := (length << 1) | 1
m, err = e.encodeAmf3Uint29(w, u29)
if err != nil {
return n, Error("amf3 encode: cannot encode u29 for bytearray: %s", err)
}
n += m
m, err = w.Write(val)
if err != nil {
return n, Error("encode amf3: unable to encode bytearray value: %s", err)
}
n += m
return
}
func (e *Encoder) encodeAmf3Utf8(w io.Writer, val string) (n int, err error) {
length := uint32(len(val))
u29 := uint32(length<<1) | 0x01
var m int
m, err = e.encodeAmf3Uint29(w, u29)
if err != nil {
return n, Error("amf3 encode: cannot encode u29 for string: %s", err)
}
n += m
m, err = w.Write([]byte(val))
if err != nil {
return n, Error("encode amf3: unable to encode string value: %s", err)
}
n += m
return
}
func (e *Encoder) encodeAmf3Uint29(w io.Writer, val uint32) (n int, err error) {
if val <= 0x0000007F {
err = WriteByte(w, byte(val))
if err == nil {
n += 1
}
} else if val <= 0x00003FFF {
n, err = w.Write([]byte{byte(val>>7 | 0x80), byte(val & 0x7F)})
} else if val <= 0x001FFFFF {
n, err = w.Write([]byte{byte(val>>14 | 0x80), byte(val>>7&0x7F | 0x80), byte(val & 0x7F)})
} else if val <= 0x1FFFFFFF {
n, err = w.Write([]byte{byte(val>>22 | 0x80), byte(val>>15&0x7F | 0x80), byte(val>>8&0x7F | 0x80), byte(val)})
} else {
return n, Error("amf3 encode: cannot encode u29 with value %d (out of range)", val)
}
return
}

199
protocol/amf/encoder_amf3_test.go

@ -0,0 +1,199 @@ @@ -0,0 +1,199 @@
package amf
import (
"bytes"
"testing"
)
func TestEncodeAmf3EmptyString(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x01}
_, err := enc.EncodeAmf3String(buf, "", false)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Undefined(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x00}
_, err := enc.EncodeAmf3Undefined(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Null(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x01}
_, err := enc.EncodeAmf3(buf, nil)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3False(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x02}
_, err := enc.EncodeAmf3(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3True(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x03}
_, err := enc.EncodeAmf3(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Integer(t *testing.T) {
enc := new(Encoder)
for _, tc := range u29TestCases {
buf := new(bytes.Buffer)
_, err := enc.EncodeAmf3Integer(buf, tc.value, false)
if err != nil {
t.Errorf("EncodeAmf3Integer error: %s", err)
}
got := buf.Bytes()
if !bytes.Equal(tc.expect, got) {
t.Errorf("EncodeAmf3Integer expect n %x got %x", tc.value, got)
}
}
buf := new(bytes.Buffer)
expect := []byte{0x04, 0x80, 0xFF, 0xFF, 0xFF}
n, err := enc.EncodeAmf3(buf, uint32(4194303))
if err != nil {
t.Errorf("%s", err)
}
if n != 5 {
t.Errorf("expected to write 5 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Double(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x05, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33}
_, err := enc.EncodeAmf3(buf, float64(1.2))
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3String(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x06, 0x07, 'f', 'o', 'o'}
_, err := enc.EncodeAmf3(buf, "foo")
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Array(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x09, 0x13, 0x01,
0x06, 0x03, '1',
0x06, 0x03, '2',
0x06, 0x03, '3',
0x06, 0x03, '4',
0x06, 0x03, '5',
0x06, 0x03, '6',
0x06, 0x03, '7',
0x06, 0x03, '8',
0x06, 0x03, '9',
}
arr := []string{"1", "2", "3", "4", "5", "6", "7", "8", "9"}
_, err := enc.EncodeAmf3(buf, arr)
if err != nil {
t.Errorf("err: %s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Object(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{
0x0a, 0x23, 0x1f, 'o', 'r', 'g', '.', 'a',
'm', 'f', '.', 'A', 'S', 'C', 'l', 'a',
's', 's', 0x07, 'b', 'a', 'z', 0x07, 'f',
'o', 'o', 0x01, 0x06, 0x07, 'b', 'a', 'r',
}
to := *NewTypedObject()
to.Type = "org.amf.ASClass"
to.Object["foo"] = "bar"
to.Object["baz"] = nil
_, err := enc.EncodeAmf3(buf, to)
if err != nil {
t.Errorf("err: %s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer:\n%#v\ngot:\n%#v", expect, buf.Bytes())
}
}

70
protocol/amf/metadata.go

@ -0,0 +1,70 @@ @@ -0,0 +1,70 @@
package amf
import (
"bytes"
"fmt"
"log"
)
const (
ADD = 0x0
DEL = 0x3
)
const (
SetDataFrame string = "@setDataFrame"
OnMetaData string = "onMetaData"
)
var setFrameFrame []byte
func init() {
b := bytes.NewBuffer(nil)
encoder := &Encoder{}
if _, err := encoder.Encode(b, SetDataFrame, AMF0); err != nil {
log.Fatal(err)
}
setFrameFrame = b.Bytes()
}
func MetaDataReform(p []byte, flag uint8) ([]byte, error) {
r := bytes.NewReader(p)
decoder := &Decoder{}
switch flag {
case ADD:
v, err := decoder.Decode(r, AMF0)
if err != nil {
return nil, err
}
switch v.(type) {
case string:
vv := v.(string)
if vv != SetDataFrame {
tmplen := len(setFrameFrame)
b := make([]byte, tmplen+len(p))
copy(b, setFrameFrame)
copy(b[tmplen:], p)
p = b
}
default:
return nil, fmt.Errorf("setFrameFrame error")
}
case DEL:
v, err := decoder.Decode(r, AMF0)
if err != nil {
return nil, err
}
switch v.(type) {
case string:
vv := v.(string)
if vv == SetDataFrame {
p = p[len(setFrameFrame):]
}
default:
return nil, fmt.Errorf("metadata error")
}
default:
return nil, fmt.Errorf("invalid flag:%d", flag)
}
return p, nil
}

92
protocol/amf/util.go

@ -0,0 +1,92 @@ @@ -0,0 +1,92 @@
package amf
import (
"encoding/json"
"errors"
"fmt"
"io"
)
func DumpBytes(label string, buf []byte, size int) {
fmt.Printf("Dumping %s (%d bytes):\n", label, size)
for i := 0; i < size; i++ {
fmt.Printf("0x%02x ", buf[i])
}
fmt.Printf("\n")
}
func Dump(label string, val interface{}) error {
json, err := json.MarshalIndent(val, "", " ")
if err != nil {
return Error("Error dumping %s: %s", label, err)
}
fmt.Printf("Dumping %s:\n%s\n", label, json)
return nil
}
func Error(f string, v ...interface{}) error {
return errors.New(fmt.Sprintf(f, v...))
}
func WriteByte(w io.Writer, b byte) (err error) {
bytes := make([]byte, 1)
bytes[0] = b
_, err = WriteBytes(w, bytes)
return
}
func WriteBytes(w io.Writer, bytes []byte) (int, error) {
return w.Write(bytes)
}
func ReadByte(r io.Reader) (byte, error) {
bytes, err := ReadBytes(r, 1)
if err != nil {
return 0x00, err
}
return bytes[0], nil
}
func ReadBytes(r io.Reader, n int) ([]byte, error) {
bytes := make([]byte, n)
m, err := r.Read(bytes)
if err != nil {
return bytes, err
}
if m != n {
return bytes, fmt.Errorf("decode read bytes failed: expected %d got %d", m, n)
}
return bytes, nil
}
func WriteMarker(w io.Writer, m byte) error {
return WriteByte(w, m)
}
func ReadMarker(r io.Reader) (byte, error) {
return ReadByte(r)
}
func AssertMarker(r io.Reader, checkMarker bool, m byte) error {
if checkMarker == false {
return nil
}
marker, err := ReadMarker(r)
if err != nil {
return err
}
if marker != m {
return Error("decode assert marker failed: expected %v got %v", m, marker)
}
return nil
}

1
protocol/dash/dash.go

@ -0,0 +1 @@ @@ -0,0 +1 @@
package dash

29
protocol/hls/align.go

@ -0,0 +1,29 @@ @@ -0,0 +1,29 @@
package hls
const (
syncms = 2 // ms
)
type align struct {
frameNum uint64
frameBase uint64
}
func (self *align) align(dts *uint64, inc uint32) {
aFrameDts := *dts
estPts := self.frameBase + self.frameNum*uint64(inc)
var dPts uint64
if estPts >= aFrameDts {
dPts = estPts - aFrameDts
} else {
dPts = aFrameDts - estPts
}
if dPts <= uint64(syncms)*h264_default_hz {
self.frameNum++
*dts = estPts
return
}
self.frameNum = 1
self.frameBase = aFrameDts
}

44
protocol/hls/audio_cache.go

@ -0,0 +1,44 @@ @@ -0,0 +1,44 @@
package hls
import "bytes"
const (
cache_max_frames byte = 6
audio_cache_len int = 10 * 1024
)
type audioCache struct {
soundFormat byte
num byte
offset int
pts uint64
buf *bytes.Buffer
}
func newAudioCache() *audioCache {
return &audioCache{
buf: bytes.NewBuffer(make([]byte, audio_cache_len)),
}
}
func (self *audioCache) Cache(src []byte, pts uint64) bool {
if self.num == 0 {
self.offset = 0
self.pts = pts
self.buf.Reset()
}
self.buf.Write(src)
self.offset += len(src)
self.num++
return false
}
func (self *audioCache) GetFrame() (int, uint64, []byte) {
self.num = 0
return self.offset, self.pts, self.buf.Bytes()
}
func (self *audioCache) CacheNum() byte {
return self.num
}

413
protocol/hls/hls.go

@ -0,0 +1,413 @@ @@ -0,0 +1,413 @@
package hls
import (
"bytes"
"errors"
"fmt"
"net"
"net/http"
"path"
"strconv"
"strings"
"time"
"github.com/gwuhaolin/livego/utils/cmap"
"github.com/golang/glog"
"github.com/gwuhaolin/livego/av"
"github.com/gwuhaolin/livego/container/flv"
"github.com/gwuhaolin/livego/container/ts"
"github.com/gwuhaolin/livego/parser"
"log"
)
const (
duration = 3000
)
var (
ErrNoPublisher = errors.New("No publisher")
ErrInvalidReq = errors.New("invalid req url path")
ErrNoSupportVideoCodec = errors.New("no support video codec")
ErrNoSupportAudioCodec = errors.New("no support audio codec")
)
var crossdomainxml = []byte(`<?xml version="1.0" ?>
<cross-domain-policy>
<allow-access-from domain="*" />
<allow-http-request-headers-from domain="*" headers="*"/>
</cross-domain-policy>`)
type Server struct {
l net.Listener
conns cmap.ConcurrentMap
}
func NewServer() *Server {
ret := &Server{
conns: cmap.New(),
}
go ret.checkStop()
return ret
}
func (self *Server) Serve(l net.Listener) error {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
self.handle(w, r)
})
self.l = l
http.Serve(l, mux)
return nil
}
func (self *Server) GetWriter(info av.Info) av.WriteCloser {
var s *Source
ok := self.conns.Has(info.Key)
if !ok {
log.Println("new hls source")
s = NewSource(info)
self.conns.Set(info.Key, s)
} else {
v, _ := self.conns.Get(info.Key)
s = v.(*Source)
}
return s
}
func (self *Server) getConn(key string) *Source {
v, ok := self.conns.Get(key)
if !ok {
return nil
}
return v.(*Source)
}
func (self *Server) checkStop() {
for {
<-time.After(5 * time.Second)
for item := range self.conns.IterBuffered() {
v := item.Val.(*Source)
if !v.Alive() {
log.Println("check stop and remove: ", v.Info())
self.conns.Remove(item.Key)
}
}
}
}
func (self *Server) handle(w http.ResponseWriter, r *http.Request) {
if path.Base(r.URL.Path) == "crossdomain.xml" {
w.Header().Set("Content-Type", "application/xml")
w.Write(crossdomainxml)
return
}
switch path.Ext(r.URL.Path) {
case ".m3u8":
key, _ := self.parseM3u8(r.URL.Path)
conn := self.getConn(key)
if conn == nil {
http.Error(w, ErrNoPublisher.Error(), http.StatusForbidden)
return
}
tsCache := conn.GetCacheInc()
body, err := tsCache.GenM3U8PlayList()
if err != nil {
log.Println("GenM3U8PlayList error: ", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Content-Type", "application/x-mpegURL")
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
w.Write(body)
case ".ts":
key, _ := self.parseTs(r.URL.Path)
conn := self.getConn(key)
if conn == nil {
http.Error(w, ErrNoPublisher.Error(), http.StatusForbidden)
return
}
tsCache := conn.GetCacheInc()
item, err := tsCache.GetItem(r.URL.Path)
if err != nil {
log.Println("GetItem error: ", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Content-Type", "video/mp2ts")
w.Header().Set("Content-Length", strconv.Itoa(len(item.Data)))
w.Write(item.Data)
}
}
func (self *Server) parseM3u8(pathstr string) (key string, err error) {
pathstr = strings.TrimLeft(pathstr, "/")
key = strings.TrimRight(pathstr, path.Ext(pathstr))
return
}
func (self *Server) parseTs(pathstr string) (key string, err error) {
pathstr = strings.TrimLeft(pathstr, "/")
paths := strings.SplitN(pathstr, "/", 3)
if len(paths) != 3 {
err = fmt.Errorf("invalid path=%s", pathstr)
return
}
key = paths[0] + "/" + paths[1]
return
}
const (
videoHZ = 90000
aacSampleLen = 1024
maxQueueNum = 512
h264_default_hz uint64 = 90
)
type Source struct {
av.RWBaser
seq int
info av.Info
bwriter *bytes.Buffer
btswriter *bytes.Buffer
demuxer *flv.Demuxer
muxer *ts.Muxer
pts, dts uint64
stat *status
align *align
cache *audioCache
tsCache *TSCacheItem
tsparser *parser.CodecParser
closed bool
packetQueue chan av.Packet
}
func NewSource(info av.Info) *Source {
info.Inter = true
s := &Source{
info: info,
align: &align{},
stat: newStatus(),
RWBaser: av.NewRWBaser(time.Second * 10),
cache: newAudioCache(),
demuxer: flv.NewDemuxer(),
muxer: ts.NewMuxer(),
tsCache: NewTSCacheItem(info.Key),
tsparser: parser.NewCodecParser(),
bwriter: bytes.NewBuffer(make([]byte, 100*1024)),
packetQueue: make(chan av.Packet, maxQueueNum),
}
go func() {
err := s.SendPacket()
if err != nil {
log.Println("send packet error: ", err)
s.closed = true
}
}()
return s
}
func (self *Source) GetCacheInc() *TSCacheItem {
return self.tsCache
}
func (self *Source) DropPacket(pktQue chan av.Packet, info av.Info) {
glog.Errorf("[%v] packet queue max!!!", info)
for i := 0; i < maxQueueNum-84; i++ {
tmpPkt, ok := <-pktQue
// try to don't drop audio
if ok && tmpPkt.IsAudio {
if len(pktQue) > maxQueueNum-2 {
<-pktQue
} else {
pktQue <- tmpPkt
}
}
if ok && tmpPkt.IsVideo {
videoPkt, ok := tmpPkt.Header.(av.VideoPacketHeader)
// dont't drop sps config and dont't drop key frame
if ok && (videoPkt.IsSeq() || videoPkt.IsKeyFrame()) {
pktQue <- tmpPkt
}
if len(pktQue) > maxQueueNum-10 {
<-pktQue
}
}
}
log.Println("packet queue len: ", len(pktQue))
}
func (self *Source) Write(p av.Packet) error {
self.SetPreTime()
if len(self.packetQueue) >= maxQueueNum-24 {
self.DropPacket(self.packetQueue, self.info)
} else {
self.packetQueue <- p
}
return nil
}
func (self *Source) SendPacket() error {
defer func() {
glog.Infof("[%v] hls sender stop", self.info)
if r := recover(); r != nil {
log.Println("hls SendPacket panic: ", r)
}
}()
glog.Infof("[%v] hls sender start", self.info)
for {
if self.closed {
return errors.New("closed")
}
p, ok := <-self.packetQueue
if ok {
if p.IsMetadata {
continue
}
err := self.demuxer.Demux(&p)
if err == flv.ErrAvcEndSEQ {
log.Println(err)
continue
} else {
if err != nil {
log.Println(err)
return err
}
}
compositionTime, isSeq, err := self.parse(&p)
if err != nil {
log.Println(err)
}
if err != nil || isSeq {
continue
}
if self.btswriter != nil {
self.stat.update(p.IsVideo, p.TimeStamp)
self.calcPtsDts(p.IsVideo, p.TimeStamp, uint32(compositionTime))
self.tsMux(&p)
}
} else {
return errors.New("closed")
}
}
}
func (self *Source) Info() (ret av.Info) {
return self.info
}
func (self *Source) cleanup() {
close(self.packetQueue)
self.bwriter = nil
self.btswriter = nil
self.cache = nil
self.tsCache = nil
}
func (self *Source) Close(err error) {
log.Println("hls source closed: ", self.info)
if !self.closed {
self.cleanup()
}
self.closed = true
}
func (self *Source) cut() {
newf := true
if self.btswriter == nil {
self.btswriter = bytes.NewBuffer(nil)
} else if self.btswriter != nil && self.stat.durationMs() >= duration {
self.flushAudio()
self.seq++
filename := fmt.Sprintf("/%s/%d.ts", self.info.Key, time.Now().Unix())
item := NewTSItem(filename, int(self.stat.durationMs()), self.seq, self.btswriter.Bytes())
self.tsCache.SetItem(filename, item)
self.btswriter.Reset()
self.stat.resetAndNew()
} else {
newf = false
}
if newf {
self.btswriter.Write(self.muxer.PAT())
self.btswriter.Write(self.muxer.PMT(av.SOUND_AAC, true))
}
}
func (self *Source) parse(p *av.Packet) (int32, bool, error) {
var compositionTime int32
var ah av.AudioPacketHeader
var vh av.VideoPacketHeader
if p.IsVideo {
vh = p.Header.(av.VideoPacketHeader)
if vh.CodecID() != av.VIDEO_H264 {
return compositionTime, false, ErrNoSupportVideoCodec
}
compositionTime = vh.CompositionTime()
if vh.IsKeyFrame() && vh.IsSeq() {
return compositionTime, true, self.tsparser.Parse(p, self.bwriter)
}
} else {
ah = p.Header.(av.AudioPacketHeader)
if ah.SoundFormat() != av.SOUND_AAC {
return compositionTime, false, ErrNoSupportAudioCodec
}
if ah.AACPacketType() == av.AAC_SEQHDR {
return compositionTime, true, self.tsparser.Parse(p, self.bwriter)
}
}
self.bwriter.Reset()
if err := self.tsparser.Parse(p, self.bwriter); err != nil {
return compositionTime, false, err
}
p.Data = self.bwriter.Bytes()
if p.IsVideo && vh.IsKeyFrame() {
self.cut()
}
return compositionTime, false, nil
}
func (self *Source) calcPtsDts(isVideo bool, ts, compositionTs uint32) {
self.dts = uint64(ts) * h264_default_hz
if isVideo {
self.pts = self.dts + uint64(compositionTs)*h264_default_hz
} else {
sampleRate, _ := self.tsparser.SampleRate()
self.align.align(&self.dts, uint32(videoHZ*aacSampleLen/sampleRate))
self.pts = self.dts
}
}
func (self *Source) flushAudio() error {
return self.muxAudio(1)
}
func (self *Source) muxAudio(limit byte) error {
if self.cache.CacheNum() < limit {
return nil
}
var p av.Packet
_, pts, buf := self.cache.GetFrame()
p.Data = buf
p.TimeStamp = uint32(pts / h264_default_hz)
return self.muxer.Mux(&p, self.btswriter)
}
func (self *Source) tsMux(p *av.Packet) error {
if p.IsVideo {
return self.muxer.Mux(p, self.btswriter)
} else {
self.cache.Cache(p.Data, self.pts)
return self.muxAudio(cache_max_frames)
}
}

43
protocol/hls/status.go

@ -0,0 +1,43 @@ @@ -0,0 +1,43 @@
package hls
import "time"
type status struct {
hasVideo bool
seqId int64
createdAt time.Time
segBeginAt time.Time
hasSetFirstTs bool
firstTimestamp int64
lastTimestamp int64
}
func newStatus() *status {
return &status{
seqId: 0,
hasSetFirstTs: false,
segBeginAt: time.Now(),
}
}
func (t *status) update(isVideo bool, timestamp uint32) {
if isVideo {
t.hasVideo = true
}
if !t.hasSetFirstTs {
t.hasSetFirstTs = true
t.firstTimestamp = int64(timestamp)
}
t.lastTimestamp = int64(timestamp)
}
func (t *status) resetAndNew() {
t.seqId++
t.hasVideo = false
t.createdAt = time.Now()
t.hasSetFirstTs = false
}
func (t *status) durationMs() int64 {
return t.lastTimestamp - t.firstTimestamp
}

127
protocol/hls/ts_cache.go

@ -0,0 +1,127 @@ @@ -0,0 +1,127 @@
package hls
import (
"bytes"
"container/list"
"errors"
"fmt"
"sync"
)
type TSCache struct {
entrys map[string]*TSCacheItem
}
func NewTSCache() *TSCache {
return &TSCache{
entrys: make(map[string]*TSCacheItem),
}
}
func (self *TSCache) Set(key string, e *TSCacheItem) {
v, ok := self.entrys[key]
if !ok {
self.entrys[key] = e
}
if v.ID() != e.ID() {
self.entrys[key] = e
}
}
func (self *TSCache) Get(key string) *TSCacheItem {
v := self.entrys[key]
return v
}
const (
maxTSCacheNum = 3
)
var (
ErrNoKey = errors.New("No key for cache")
)
type TSCacheItem struct {
id string
num int
lock sync.RWMutex
ll *list.List
lm map[string]TSItem
}
func NewTSCacheItem(id string) *TSCacheItem {
return &TSCacheItem{
id: id,
ll: list.New(),
num: maxTSCacheNum,
lm: make(map[string]TSItem),
}
}
func (self *TSCacheItem) ID() string {
return self.id
}
// TODO: found data race, fix it
func (self *TSCacheItem) GenM3U8PlayList() ([]byte, error) {
var seq int
var getSeq bool
var maxDuration int
m3u8body := bytes.NewBuffer(nil)
for e := self.ll.Front(); e != nil; e = e.Next() {
key := e.Value.(string)
v, ok := self.lm[key]
if ok {
if v.Duration > maxDuration {
maxDuration = v.Duration
}
if !getSeq {
getSeq = true
seq = v.SeqNum
}
fmt.Fprintf(m3u8body, "#EXTINF:%.3f,\n%s\n", float64(v.Duration)/float64(1000), v.Name)
}
}
w := bytes.NewBuffer(nil)
fmt.Fprintf(w,
"#EXTM3U\n#EXT-X-VERSION:3\n#EXT-X-ALLOW-CACHE:NO\n#EXT-X-TARGETDURATION:%d\n#EXT-X-MEDIA-SEQUENCE:%d\n\n",
maxDuration/1000+1, seq)
w.Write(m3u8body.Bytes())
return w.Bytes(), nil
}
func (self *TSCacheItem) SetItem(key string, item TSItem) {
if self.ll.Len() == self.num {
e := self.ll.Front()
self.ll.Remove(e)
k := e.Value.(string)
delete(self.lm, k)
}
self.lm[key] = item
self.ll.PushBack(key)
}
func (self *TSCacheItem) GetItem(key string) (TSItem, error) {
item, ok := self.lm[key]
if !ok {
return item, ErrNoKey
}
return item, nil
}
type TSItem struct {
Name string
SeqNum int
Duration int
Data []byte
}
func NewTSItem(name string, duration, seqNum int, b []byte) TSItem {
var item TSItem
item.Name = name
item.SeqNum = seqNum
item.Duration = duration
item.Data = make([]byte, len(b))
copy(item.Data, b)
return item
}

274
protocol/httpflv/http_flv.go

@ -0,0 +1,274 @@ @@ -0,0 +1,274 @@
package httpflv
import (
"encoding/json"
"net"
"net/http"
"strings"
"time"
"errors"
"github.com/golang/glog"
"github.com/gwuhaolin/livego/utils/uid"
"github.com/gwuhaolin/livego/protocol/amf"
"github.com/gwuhaolin/livego/av"
"github.com/gwuhaolin/livego/utils/pio"
"log"
"github.com/gwuhaolin/livego/protocol/rtmp"
)
type Server struct {
handler av.Handler
}
type stream struct {
Key string `json:"key"`
Id string `json:"id"`
}
type streams struct {
Publishers []stream `json:"publishers"`
Players []stream `json:"players"`
}
func NewServer(h av.Handler) *Server {
return &Server{
handler: h,
}
}
func (self *Server) Serve(l net.Listener) error {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
self.handleConn(w, r)
})
mux.HandleFunc("/streams", func(w http.ResponseWriter, r *http.Request) {
self.getStream(w, r)
})
http.Serve(l, mux)
return nil
}
func (s *Server) getStream(w http.ResponseWriter, r *http.Request) {
rtmpStream := s.handler.(*rtmp.RtmpStream)
if rtmpStream == nil {
return
}
msgs := new(streams)
for item := range rtmpStream.GetStreams().IterBuffered() {
if s, ok := item.Val.(*rtmp.Stream); ok {
if s.GetReader() != nil {
msg := stream{item.Key, s.GetReader().Info().UID}
msgs.Publishers = append(msgs.Publishers, msg)
}
}
}
for item := range rtmpStream.GetStreams().IterBuffered() {
ws := item.Val.(*rtmp.Stream).GetWs()
for s := range ws.IterBuffered() {
if pw, ok := s.Val.(*rtmp.PackWriterCloser); ok {
if pw.GetWriter() != nil {
msg := stream{item.Key, pw.GetWriter().Info().UID}
msgs.Players = append(msgs.Players, msg)
}
}
}
}
resp, _ := json.Marshal(msgs)
w.Header().Set("Content-Type", "application/json")
w.Write(resp)
}
func (self *Server) handleConn(w http.ResponseWriter, r *http.Request) {
defer func() {
if r := recover(); r != nil {
log.Println("http flv handleConn panic: ", r)
}
}()
url := r.URL.String()
u := r.URL.Path
if pos := strings.LastIndex(u, "."); pos < 0 || u[pos:] != ".flv" {
http.Error(w, "invalid path", http.StatusBadRequest)
return
}
path := strings.TrimSuffix(strings.TrimLeft(u, "/"), ".flv")
paths := strings.SplitN(path, "/", 2)
log.Println("url:", u, "path:", path, "paths:", paths)
if len(paths) != 2 {
http.Error(w, "invalid path", http.StatusBadRequest)
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
writer := NewFLVWriter(paths[0], paths[1], url, w)
self.handler.HandleWriter(writer)
writer.Wait()
}
const (
headerLen = 11
maxQueueNum = 1024
)
type FLVWriter struct {
Uid string
av.RWBaser
app, title, url string
buf []byte
closed bool
closedChan chan struct{}
ctx http.ResponseWriter
packetQueue chan av.Packet
}
func NewFLVWriter(app, title, url string, ctx http.ResponseWriter) *FLVWriter {
ret := &FLVWriter{
Uid: uid.NEWID(),
app: app,
title: title,
url: url,
ctx: ctx,
RWBaser: av.NewRWBaser(time.Second * 10),
closedChan: make(chan struct{}),
buf: make([]byte, headerLen),
packetQueue: make(chan av.Packet, maxQueueNum),
}
ret.ctx.Write([]byte{0x46, 0x4c, 0x56, 0x01, 0x05, 0x00, 0x00, 0x00, 0x09})
pio.PutI32BE(ret.buf[:4], 0)
ret.ctx.Write(ret.buf[:4])
go func() {
err := ret.SendPacket()
if err != nil {
log.Println("SendPacket error:", err)
ret.closed = true
}
}()
return ret
}
func (self *FLVWriter) DropPacket(pktQue chan av.Packet, info av.Info) {
glog.Errorf("[%v] packet queue max!!!", info)
for i := 0; i < maxQueueNum-84; i++ {
tmpPkt, ok := <-pktQue
if ok && tmpPkt.IsVideo {
videoPkt, ok := tmpPkt.Header.(av.VideoPacketHeader)
// dont't drop sps config and dont't drop key frame
if ok && (videoPkt.IsSeq() || videoPkt.IsKeyFrame()) {
log.Println("insert keyframe to queue")
pktQue <- tmpPkt
}
if len(pktQue) > maxQueueNum-10 {
<-pktQue
}
// drop other packet
<-pktQue
}
// try to don't drop audio
if ok && tmpPkt.IsAudio {
log.Println("insert audio to queue")
pktQue <- tmpPkt
}
}
log.Println("packet queue len: ", len(pktQue))
}
func (self *FLVWriter) Write(p av.Packet) error {
if !self.closed {
if len(self.packetQueue) >= maxQueueNum-24 {
self.DropPacket(self.packetQueue, self.Info())
} else {
self.packetQueue <- p
}
return nil
} else {
return errors.New("closed")
}
}
// func (self *FLVWriter) Write(p av.Packet) error {
func (self *FLVWriter) SendPacket() error {
for {
p, ok := <-self.packetQueue
if ok {
self.RWBaser.SetPreTime()
h := self.buf[:headerLen]
typeID := av.TAG_VIDEO
if !p.IsVideo {
if p.IsMetadata {
var err error
typeID = av.TAG_SCRIPTDATAAMF0
p.Data, err = amf.MetaDataReform(p.Data, amf.DEL)
if err != nil {
return err
}
} else {
typeID = av.TAG_AUDIO
}
}
dataLen := len(p.Data)
timestamp := p.TimeStamp
timestamp += self.BaseTimeStamp()
self.RWBaser.RecTimeStamp(timestamp, uint32(typeID))
preDataLen := dataLen + headerLen
timestampbase := timestamp & 0xffffff
timestampExt := timestamp >> 24 & 0xff
pio.PutU8(h[0:1], uint8(typeID))
pio.PutI24BE(h[1:4], int32(dataLen))
pio.PutI24BE(h[4:7], int32(timestampbase))
pio.PutU8(h[7:8], uint8(timestampExt))
if _, err := self.ctx.Write(h); err != nil {
return err
}
if _, err := self.ctx.Write(p.Data); err != nil {
return err
}
pio.PutI32BE(h[:4], int32(preDataLen))
if _, err := self.ctx.Write(h[:4]); err != nil {
return err
}
} else {
return errors.New("closed")
}
}
return nil
}
func (self *FLVWriter) Wait() {
select {
case <-self.closedChan:
return
}
}
func (self *FLVWriter) Close(error) {
log.Println("http flv closed")
if !self.closed {
close(self.packetQueue)
close(self.closedChan)
}
self.closed = true
}
func (self *FLVWriter) Info() (ret av.Info) {
ret.UID = self.Uid
ret.URL = self.url
ret.Key = self.app + "/" + self.title
ret.Inter = true
return
}

232
protocol/httpopera/http_opera.go

@ -0,0 +1,232 @@ @@ -0,0 +1,232 @@
package httpopera
import (
"encoding/json"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"github.com/golang/glog"
"github.com/gwuhaolin/livego/utils/uid"
"github.com/gwuhaolin/livego/av"
"log"
"github.com/gwuhaolin/livego/protocol/rtmp"
)
type Response struct {
w http.ResponseWriter
Status int `json:"status"`
Message string `json:"message"`
}
func (r *Response) SendJson() (int, error) {
resp, _ := json.Marshal(r)
r.w.Header().Set("Content-Type", "application/json")
return r.w.Write(resp)
}
type Operation struct {
Method string `json:"method"`
URL string `json:"url"`
Stop bool `json:"stop"`
}
type OperationChange struct {
Method string `json:"method"`
SourceURL string `json:"source_url"`
TargetURL string `json:"target_url"`
Stop bool `json:"stop"`
}
type Server struct {
handler av.Handler
}
func NewServer(h av.Handler) *Server {
return &Server{
handler: h,
}
}
func (s *Server) Serve(l net.Listener) error {
mux := http.NewServeMux()
mux.HandleFunc("/rtmp/operation", func(w http.ResponseWriter, r *http.Request) {
s.handleOpera(w, r)
})
// mux.HandleFunc("/rtmp/operation/change", func(w http.ResponseWriter, r *http.Request) {
// s.handleOperaChange(w, r)
// })
http.Serve(l, mux)
return nil
}
// handleOpera, 拉流和推流的http api
// @Path: /rtmp/operation
// @Method: POST
// @Param: json
// method string, "push" or "pull"
// url string
// stop bool
// @Example,
// curl -v -H "Content-Type: application/json" -X POST --data \
// '{"method":"pull","url":"rtmp://127.0.0.1:1935/live/test"}' \
// http://localhost:8087/rtmp/operation
func (s *Server) handleOpera(w http.ResponseWriter, r *http.Request) {
rep := &Response{
w: w,
}
if r.Method != "POST" {
rep.Status = 14000
rep.Message = "bad request method"
rep.SendJson()
return
} else {
result, err := ioutil.ReadAll(r.Body)
if err != nil {
rep.Status = 15000
rep.Message = "read request body error"
rep.SendJson()
return
}
r.Body.Close()
glog.Infof("post body: %s\n", result)
var op Operation
err = json.Unmarshal(result, &op)
if err != nil {
rep.Status = 12000
rep.Message = "parse json body failed"
rep.SendJson()
return
}
switch op.Method {
case "push":
s.Push(op.URL, op.Stop)
case "pull":
s.Pull(op.URL, op.Stop)
}
rep.Status = 10000
rep.Message = op.Method + " " + op.URL + " success"
rep.SendJson()
}
}
func (s *Server) Push(uri string, stop bool) error {
rtmpClient := rtmp.NewRtmpClient(s.handler, nil)
return rtmpClient.Dial(uri, av.PUBLISH)
// return nil
}
func (s *Server) Pull(uri string, stop bool) error {
rtmpClient := rtmp.NewRtmpClient(s.handler, nil)
return rtmpClient.Dial(uri, av.PLAY)
// return nil
}
// TODO:
// handleOperaChange, 拉流和推流的http api,支持自定义路径
// @Path: /rtmp/operation/change
// @Method: POST
// @Param: json
// method string, "push" or "pull"
// url string
// stop bool
// @Example,
// curl -v -H "Content-Type: application/json" -X POST --data \
// '{"method":"pull","url":"rtmp://127.0.0.1:1935/live/test"}' \
// http://localhost:8087/rtmp/operation
// func (s *Server) handleOperaChange(w http.ResponseWriter, r *http.Request) {
// rep := &Response{
// w: w,
// }
// if r.Method != "POST" {
// rep.Status = 14000
// rep.Message = "bad request method"
// rep.SendJson()
// return
// } else {
// result, err := ioutil.ReadAll(r.Body)
// if err != nil {
// rep.Status = 15000
// rep.Message = "read request body error"
// rep.SendJson()
// return
// }
// r.Body.Close()
// glog.Infof("post body: %s\n", result)
// var op OperationChange
// err = json.Unmarshal(result, &op)
// if err != nil {
// rep.Status = 12000
// rep.Message = "parse json body failed"
// rep.SendJson()
// return
// }
// switch op.Method {
// case "push":
// s.PushChange(op.SourceURL, op.TargetURL, op.Stop)
// case "pull":
// s.PullChange(op.SourceURL, op.TargetURL, op.Stop)
// }
// rep.Status = 10000
// rep.Message = op.Method + " from" + op.SourceURL + "to " + op.TargetURL + " success"
// rep.SendJson()
// }
// }
// pushChange suri to turi
// func (s *Server) PushChange(suri, turi string, stop bool) error {
// if !stop {
// sinfo := parseURL(suri)
// tinfo := parseURL(turi)
// rtmpClient := rtmp.NewRtmpClient(s.handler, nil)
// return rtmpClient.Dial(turi, av.PUBLISH)
// } else {
// sinfo := parseURL(suri)
// tinfo := parseURL(turi)
// s.delStream(sinfo.Key, true)
// return nil
// }
// return nil
// }
// pullChange
// func (s *Server) PullChange(suri, turi string, stop bool) error {
// if !stop {
// rtmpStreams, ok := s.handler.(*rtmp.RtmpStream)
// if ok {
// streams := rtmpStreams.GetStreams()
// rtmpClient := rtmp.NewRtmpClient(s.handler, nil)
// return rtmpClient.Dial(turi, av.PLAY)
// }
// } else {
// info := parseURL(suri)
// s.delStream(info.Key, false)
// return nil
// }
// return nil
// }
func parseURL(URL string) (ret av.Info) {
ret.UID = uid.NEWID()
ret.URL = URL
_url, err := url.Parse(URL)
if err != nil {
log.Println(err)
}
ret.Key = strings.TrimLeft(_url.Path, "/")
ret.Inter = true
return
}

1
protocol/kcpts/kcp_ts.go

@ -0,0 +1 @@ @@ -0,0 +1 @@
package kcpts

1
protocol/private/protocol.go

@ -0,0 +1 @@ @@ -0,0 +1 @@
package private

79
protocol/rtmp/cache/cache.go vendored

@ -0,0 +1,79 @@ @@ -0,0 +1,79 @@
package cache
import (
"flag"
"github.com/gwuhaolin/livego/av"
)
var (
gopNum = flag.Int("gopNum", 1, "gop num")
)
type Cache struct {
gop *GopCache
videoSeq *SpecialCache
audioSeq *SpecialCache
metadata *SpecialCache
}
func NewCache() *Cache {
return &Cache{
gop: NewGopCache(*gopNum),
videoSeq: NewSpecialCache(),
audioSeq: NewSpecialCache(),
metadata: NewSpecialCache(),
}
}
func (self *Cache) Write(p av.Packet) {
if p.IsMetadata {
self.metadata.Write(p)
return
} else {
if !p.IsVideo {
ah, ok := p.Header.(av.AudioPacketHeader)
if ok {
if ah.SoundFormat() == av.SOUND_AAC &&
ah.AACPacketType() == av.AAC_SEQHDR {
self.audioSeq.Write(p)
return
} else {
return
}
}
} else {
vh, ok := p.Header.(av.VideoPacketHeader)
if ok {
if vh.IsSeq() {
self.videoSeq.Write(p)
return
}
} else {
return
}
}
}
self.gop.Write(p)
}
func (self *Cache) Send(w av.WriteCloser) error {
if err := self.metadata.Send(w); err != nil {
return err
}
if err := self.videoSeq.Send(w); err != nil {
return err
}
if err := self.audioSeq.Send(w); err != nil {
return err
}
if err := self.gop.Send(w); err != nil {
return err
}
return nil
}

120
protocol/rtmp/cache/gop.go vendored

@ -0,0 +1,120 @@ @@ -0,0 +1,120 @@
package cache
import (
"errors"
"github.com/gwuhaolin/livego/av"
)
var (
maxGOPCap int = 1024
ErrGopTooBig = errors.New("gop to big")
)
type array struct {
index int
packets []av.Packet
}
func newArray() *array {
ret := &array{
index: 0,
packets: make([]av.Packet, 0, maxGOPCap),
}
return ret
}
func (self *array) reset() {
self.index = 0
self.packets = self.packets[:0]
}
func (self *array) write(packet av.Packet) error {
if self.index >= maxGOPCap {
return ErrGopTooBig
}
self.packets = append(self.packets, packet)
self.index++
return nil
}
func (self *array) send(w av.WriteCloser) error {
var err error
for i := 0; i < self.index; i++ {
packet := self.packets[i]
if err = w.Write(packet); err != nil {
return err
}
}
return err
}
type GopCache struct {
start bool
num int
count int
nextindex int
gops []*array
}
func NewGopCache(num int) *GopCache {
return &GopCache{
count: num,
gops: make([]*array, num),
}
}
func (self *GopCache) writeToArray(chunk av.Packet, startNew bool) error {
var ginc *array
if startNew {
ginc = self.gops[self.nextindex]
if ginc == nil {
ginc = newArray()
self.num++
self.gops[self.nextindex] = ginc
} else {
ginc.reset()
}
self.nextindex = (self.nextindex + 1) % self.count
} else {
ginc = self.gops[(self.nextindex+1)%self.count]
}
ginc.write(chunk)
return nil
}
func (self *GopCache) Write(p av.Packet) {
var ok bool
if p.IsVideo {
vh := p.Header.(av.VideoPacketHeader)
if vh.IsKeyFrame() && !vh.IsSeq() {
ok = true
}
}
if ok || self.start {
self.start = true
self.writeToArray(p, ok)
}
}
func (self *GopCache) sendTo(w av.WriteCloser) error {
var err error
pos := (self.nextindex + 1) % self.count
for i := 0; i < self.num; i++ {
index := (pos - self.num + 1) + i
if index < 0 {
index += self.count
}
g := self.gops[index]
err = g.send(w)
if err != nil {
return err
}
}
return nil
}
func (self *GopCache) Send(w av.WriteCloser) error {
return self.sendTo(w)
}

46
protocol/rtmp/cache/special.go vendored

@ -0,0 +1,46 @@ @@ -0,0 +1,46 @@
package cache
import (
"bytes"
"log"
"github.com/gwuhaolin/livego/protocol/amf"
"github.com/gwuhaolin/livego/av"
)
const (
SetDataFrame string = "@setDataFrame"
OnMetaData string = "onMetaData"
)
var setFrameFrame []byte
func init() {
b := bytes.NewBuffer(nil)
encoder := &amf.Encoder{}
if _, err := encoder.Encode(b, SetDataFrame, amf.AMF0); err != nil {
log.Fatal(err)
}
setFrameFrame = b.Bytes()
}
type SpecialCache struct {
full bool
p av.Packet
}
func NewSpecialCache() *SpecialCache {
return &SpecialCache{}
}
func (self *SpecialCache) Write(p av.Packet) {
self.p = p
self.full = true
}
func (self *SpecialCache) Send(w av.WriteCloser) error {
if !self.full {
return nil
}
return w.Write(self.p)
}

225
protocol/rtmp/core/chunk_stream.go

@ -0,0 +1,225 @@ @@ -0,0 +1,225 @@
package core
import (
"encoding/binary"
"fmt"
"github.com/gwuhaolin/livego/av"
"github.com/gwuhaolin/livego/utils/pool"
)
type ChunkStream struct {
Format uint32
CSID uint32
Timestamp uint32
Length uint32
TypeID uint32
StreamID uint32
timeDelta uint32
exted bool
index uint32
remain uint32
got bool
tmpFromat uint32
Data []byte
}
func (self *ChunkStream) full() bool {
return self.got
}
func (self *ChunkStream) new(pool *pool.Pool) {
self.got = false
self.index = 0
self.remain = self.Length
self.Data = pool.Get(int(self.Length))
}
func (self *ChunkStream) writeHeader(w *ReadWriter) error {
//Chunk Basic Header
h := self.Format << 6
switch {
case self.CSID < 64:
h |= self.CSID
w.WriteUintBE(h, 1)
case self.CSID-64 < 256:
h |= 0
w.WriteUintBE(h, 1)
w.WriteUintLE(self.CSID-64, 1)
case self.CSID-64 < 65536:
h |= 1
w.WriteUintBE(h, 1)
w.WriteUintLE(self.CSID-64, 2)
}
//Chunk Message Header
ts := self.Timestamp
if self.Format == 3 {
goto END
}
if self.Timestamp > 0xffffff {
ts = 0xffffff
}
w.WriteUintBE(ts, 3)
if self.Format == 2 {
goto END
}
if self.Length > 0xffffff {
return fmt.Errorf("length=%d", self.Length)
}
w.WriteUintBE(self.Length, 3)
w.WriteUintBE(self.TypeID, 1)
if self.Format == 1 {
goto END
}
w.WriteUintLE(self.StreamID, 4)
END:
//Extended Timestamp
if ts >= 0xffffff {
w.WriteUintBE(self.Timestamp, 4)
}
return w.WriteError()
}
func (self *ChunkStream) writeChunk(w *ReadWriter, chunkSize int) error {
if self.TypeID == av.TAG_AUDIO {
self.CSID = 4
} else if self.TypeID == av.TAG_VIDEO ||
self.TypeID == av.TAG_SCRIPTDATAAMF0 ||
self.TypeID == av.TAG_SCRIPTDATAAMF3 {
self.CSID = 6
}
totalLen := uint32(0)
numChunks := (self.Length / uint32(chunkSize))
for i := uint32(0); i <= numChunks; i++ {
if totalLen == self.Length {
break
}
if i == 0 {
self.Format = uint32(0)
} else {
self.Format = uint32(3)
}
if err := self.writeHeader(w); err != nil {
return err
}
inc := uint32(chunkSize)
start := uint32(i) * uint32(chunkSize)
if uint32(len(self.Data))-start <= inc {
inc = uint32(len(self.Data)) - start
}
totalLen += inc
end := start + inc
buf := self.Data[start:end]
if _, err := w.Write(buf); err != nil {
return err
}
}
return nil
}
func (self *ChunkStream) readChunk(r *ReadWriter, chunkSize uint32, pool *pool.Pool) error {
if self.remain != 0 && self.tmpFromat != 3 {
return fmt.Errorf("inlaid remin = %d", self.remain)
}
switch self.CSID {
case 0:
id, _ := r.ReadUintLE(1)
self.CSID = id + 64
case 1:
id, _ := r.ReadUintLE(2)
self.CSID = id + 64
}
switch self.tmpFromat {
case 0:
self.Format = self.tmpFromat
self.Timestamp, _ = r.ReadUintBE(3)
self.Length, _ = r.ReadUintBE(3)
self.TypeID, _ = r.ReadUintBE(1)
self.StreamID, _ = r.ReadUintLE(4)
if self.Timestamp == 0xffffff {
self.Timestamp, _ = r.ReadUintBE(4)
self.exted = true
} else {
self.exted = false
}
self.new(pool)
case 1:
self.Format = self.tmpFromat
timeStamp, _ := r.ReadUintBE(3)
self.Length, _ = r.ReadUintBE(3)
self.TypeID, _ = r.ReadUintBE(1)
if timeStamp == 0xffffff {
timeStamp, _ = r.ReadUintBE(4)
self.exted = true
} else {
self.exted = false
}
self.timeDelta = timeStamp
self.Timestamp += timeStamp
self.new(pool)
case 2:
self.Format = self.tmpFromat
timeStamp, _ := r.ReadUintBE(3)
if timeStamp == 0xffffff {
timeStamp, _ = r.ReadUintBE(4)
self.exted = true
} else {
self.exted = false
}
self.timeDelta = timeStamp
self.Timestamp += timeStamp
self.new(pool)
case 3:
if self.remain == 0 {
switch self.Format {
case 0:
if self.exted {
timestamp, _ := r.ReadUintBE(4)
self.Timestamp = timestamp
}
case 1, 2:
var timedet uint32
if self.exted {
timedet, _ = r.ReadUintBE(4)
} else {
timedet = self.timeDelta
}
self.Timestamp += timedet
}
self.new(pool)
} else {
if self.exted {
b, err := r.Peek(4)
if err != nil {
return err
}
tmpts := binary.BigEndian.Uint32(b)
if tmpts == self.Timestamp {
r.Discard(4)
}
}
}
default:
return fmt.Errorf("invalid format=%d", self.Format)
}
size := int(self.remain)
if size > int(chunkSize) {
size = int(chunkSize)
}
buf := self.Data[self.index: self.index+uint32(size)]
if _, err := r.Read(buf); err != nil {
return err
}
self.index += uint32(size)
self.remain -= uint32(size)
if self.remain == 0 {
self.got = true
}
return r.readError
}

97
protocol/rtmp/core/chunk_stream_test.go

@ -0,0 +1,97 @@ @@ -0,0 +1,97 @@
package core
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"github.com/gwuhaolin/livego/utils/pool"
)
func TestChunkRead1(t *testing.T) {
at := assert.New(t)
data := []byte{
0x06, 0x00, 0x00, 0x00, 0x00, 0x01, 0x33, 0x09, 0x01, 0x00, 0x00, 0x00,
}
data1 := make([]byte, 128)
data2 := make([]byte, 51)
data = append(data, data1...)
data = append(data, 0xc6)
data = append(data, data1...)
data = append(data, 0xc6)
data = append(data, data2...)
rw := NewReadWriter(bytes.NewBuffer(data), 1024)
chunkinc := &ChunkStream{}
for {
h, _ := rw.ReadUintBE(1)
chunkinc.tmpFromat = h >> 6
chunkinc.CSID = h & 0x3f
chunkinc.readChunk(rw, 128, pool.NewPool())
if chunkinc.remain == 0 {
break
}
}
at.Equal(int(chunkinc.Length), 307)
at.Equal(int(chunkinc.TypeID), 9)
at.Equal(int(chunkinc.StreamID), 1)
at.Equal(len(chunkinc.Data), 307)
at.Equal(int(chunkinc.remain), 0)
data = []byte{
0x06, 0xff, 0xff, 0xff, 0x00, 0x01, 0x33, 0x09, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05,
}
data = append(data, data1...)
data = append(data, 0xc6)
data = append(data, []byte{0x00, 0x00, 0x00, 0x05}...)
data = append(data, data1...)
data = append(data, 0xc6)
data = append(data, data2...)
rw = NewReadWriter(bytes.NewBuffer(data), 1024)
chunkinc = &ChunkStream{}
h, _ := rw.ReadUintBE(1)
chunkinc.tmpFromat = h >> 6
chunkinc.CSID = h & 0x3f
chunkinc.readChunk(rw, 128, pool.NewPool())
h, _ = rw.ReadUintBE(1)
chunkinc.tmpFromat = h >> 6
chunkinc.CSID = h & 0x3f
chunkinc.readChunk(rw, 128, pool.NewPool())
h, _ = rw.ReadUintBE(1)
chunkinc.tmpFromat = h >> 6
chunkinc.CSID = h & 0x3f
chunkinc.readChunk(rw, 128, pool.NewPool())
at.Equal(int(chunkinc.Length), 307)
at.Equal(int(chunkinc.TypeID), 9)
at.Equal(int(chunkinc.StreamID), 1)
at.Equal(len(chunkinc.Data), 307)
at.Equal(chunkinc.exted, true)
at.Equal(int(chunkinc.Timestamp), 5)
at.Equal(int(chunkinc.remain), 0)
}
func TestWriteChunk(t *testing.T) {
at := assert.New(t)
chunkinc := &ChunkStream{}
chunkinc.Length = 307
chunkinc.TypeID = 9
chunkinc.CSID = 4
chunkinc.Timestamp = 40
chunkinc.Data = make([]byte, 307)
bf := bytes.NewBuffer(nil)
w := NewReadWriter(bf, 1024)
err := chunkinc.writeChunk(w, 128)
w.Flush()
at.Equal(err, nil)
at.Equal(len(bf.Bytes()), 321)
}

207
protocol/rtmp/core/conn.go

@ -0,0 +1,207 @@ @@ -0,0 +1,207 @@
package core
import (
"encoding/binary"
"net"
"time"
"github.com/gwuhaolin/livego/utils/pool"
"github.com/gwuhaolin/livego/utils/pio"
)
const (
_ = iota
idSetChunkSize
idAbortMessage
idAck
idUserControlMessages
idWindowAckSize
idSetPeerBandwidth
)
type Conn struct {
net.Conn
chunkSize uint32
remoteChunkSize uint32
windowAckSize uint32
remoteWindowAckSize uint32
received uint32
ackReceived uint32
rw *ReadWriter
pool *pool.Pool
chunks map[uint32]ChunkStream
}
func NewConn(c net.Conn, bufferSize int) *Conn {
return &Conn{
Conn: c,
chunkSize: 128,
remoteChunkSize: 128,
windowAckSize: 2500000,
remoteWindowAckSize: 2500000,
pool: pool.NewPool(),
rw: NewReadWriter(c, bufferSize),
chunks: make(map[uint32]ChunkStream),
}
}
func (self *Conn) Read(c *ChunkStream) error {
for {
h, _ := self.rw.ReadUintBE(1)
// if err != nil {
// log.Println("read from conn error: ", err)
// return err
// }
format := h >> 6
csid := h & 0x3f
cs, ok := self.chunks[csid]
if !ok {
cs = ChunkStream{}
self.chunks[csid] = cs
}
cs.tmpFromat = format
cs.CSID = csid
err := cs.readChunk(self.rw, self.remoteChunkSize, self.pool)
if err != nil {
return err
}
self.chunks[csid] = cs
if cs.full() {
*c = cs
break
}
}
self.handleControlMsg(c)
self.ack(c.Length)
return nil
}
func (self *Conn) Write(c *ChunkStream) error {
if c.TypeID == idSetChunkSize {
self.chunkSize = binary.BigEndian.Uint32(c.Data)
}
return c.writeChunk(self.rw, int(self.chunkSize))
}
func (self *Conn) Flush() error {
return self.rw.Flush()
}
func (self *Conn) Close() error {
return self.Conn.Close()
}
func (self *Conn) RemoteAddr() net.Addr {
return self.Conn.RemoteAddr()
}
func (self *Conn) LocalAddr() net.Addr {
return self.Conn.LocalAddr()
}
func (self *Conn) SetDeadline(t time.Time) error {
return self.Conn.SetDeadline(t)
}
func (self *Conn) NewAck(size uint32) ChunkStream {
return initControlMsg(idAck, 4, size)
}
func (self *Conn) NewSetChunkSize(size uint32) ChunkStream {
return initControlMsg(idSetChunkSize, 4, size)
}
func (self *Conn) NewWindowAckSize(size uint32) ChunkStream {
return initControlMsg(idWindowAckSize, 4, size)
}
func (self *Conn) NewSetPeerBandwidth(size uint32) ChunkStream {
ret := initControlMsg(idSetPeerBandwidth, 5, size)
ret.Data[4] = 2
return ret
}
func (self *Conn) handleControlMsg(c *ChunkStream) {
if c.TypeID == idSetChunkSize {
self.remoteChunkSize = binary.BigEndian.Uint32(c.Data)
} else if c.TypeID == idWindowAckSize {
self.remoteWindowAckSize = binary.BigEndian.Uint32(c.Data)
}
}
func (self *Conn) ack(size uint32) {
self.received += uint32(size)
self.ackReceived += uint32(size)
if self.received >= 0xf0000000 {
self.received = 0
}
if self.ackReceived >= self.remoteWindowAckSize {
cs := self.NewAck(self.ackReceived)
cs.writeChunk(self.rw, int(self.chunkSize))
self.ackReceived = 0
}
}
func initControlMsg(id, size, value uint32) ChunkStream {
ret := ChunkStream{
Format: 0,
CSID: 2,
TypeID: id,
StreamID: 0,
Length: size,
Data: make([]byte, size),
}
pio.PutU32BE(ret.Data[:size], value)
return ret
}
const (
streamBegin uint32 = 0
streamEOF uint32 = 1
streamDry uint32 = 2
setBufferLen uint32 = 3
streamIsRecorded uint32 = 4
pingRequest uint32 = 6
pingResponse uint32 = 7
)
/*
+------------------------------+-------------------------
| Event Type ( 2- bytes ) | Event Data
+------------------------------+-------------------------
Pay load for the User Control Message.
*/
func (self *Conn) userControlMsg(eventType, buflen uint32) ChunkStream {
var ret ChunkStream
buflen += 2
ret = ChunkStream{
Format: 0,
CSID: 2,
TypeID: 4,
StreamID: 1,
Length: buflen,
Data: make([]byte, buflen),
}
ret.Data[0] = byte(eventType >> 8 & 0xff)
ret.Data[1] = byte(eventType & 0xff)
return ret
}
func (self *Conn) SetBegin() {
ret := self.userControlMsg(streamBegin, 4)
for i := 0; i < 4; i++ {
ret.Data[2+i] = byte(1 >> uint32((3-i)*8) & 0xff)
}
self.Write(&ret)
}
func (self *Conn) SetRecorded() {
ret := self.userControlMsg(streamIsRecorded, 4)
for i := 0; i < 4; i++ {
ret.Data[2+i] = byte(1 >> uint32((3-i)*8) & 0xff)
}
self.Write(&ret)
}

287
protocol/rtmp/core/conn_client.go

@ -0,0 +1,287 @@ @@ -0,0 +1,287 @@
package core
import (
"bytes"
"errors"
"fmt"
"io"
"math/rand"
"net"
neturl "net/url"
"strings"
"github.com/gwuhaolin/livego/protocol/amf"
"github.com/golang/glog"
"github.com/gwuhaolin/livego/av"
"log"
)
var (
respResult = "_result"
respError = "_error"
onStatus = "onStatus"
publishStart = "NetStream.Publish.Start"
playStart = "NetStream.Play.Start"
connectSuccess = "NetConnection.Connect.Success"
)
var (
ErrFail = errors.New("respone err")
)
type ConnClient struct {
done bool
transID int
url string
tcurl string
app string
title string
query string
curcmdName string
streamid uint32
conn *Conn
encoder *amf.Encoder
decoder *amf.Decoder
bytesw *bytes.Buffer
}
func NewConnClient() *ConnClient {
return &ConnClient{
transID: 1,
bytesw: bytes.NewBuffer(nil),
encoder: &amf.Encoder{},
decoder: &amf.Decoder{},
}
}
func (self *ConnClient) readRespMsg() error {
var err error
var rc ChunkStream
for {
if err = self.conn.Read(&rc); err != nil {
return err
}
switch rc.TypeID {
case 20, 17:
r := bytes.NewReader(rc.Data)
vs, err := self.decoder.DecodeBatch(r, amf.AMF0)
if err != nil && err != io.EOF {
return err
}
for k, v := range vs {
switch v.(type) {
case string:
switch self.curcmdName {
case cmdConnect, cmdCreateStream:
if v.(string) != respResult {
return ErrFail
}
case cmdPublish:
if v.(string) != onStatus {
return ErrFail
}
}
case float64:
switch self.curcmdName {
case cmdConnect, cmdCreateStream:
id := int(v.(float64))
if k == 1 {
if id != self.transID {
return ErrFail
}
} else if k == 3 {
self.streamid = uint32(id)
}
case cmdPublish:
if int(v.(float64)) != 0 {
return ErrFail
}
}
case amf.Object:
objmap := v.(amf.Object)
switch self.curcmdName {
case cmdConnect:
code, ok := objmap["code"]
if ok && code.(string) != connectSuccess {
return ErrFail
}
case cmdPublish:
code, ok := objmap["code"]
if ok && code.(string) != publishStart {
return ErrFail
}
}
}
}
return nil
}
}
}
func (self *ConnClient) writeMsg(args ...interface{}) error {
self.bytesw.Reset()
for _, v := range args {
if _, err := self.encoder.Encode(self.bytesw, v, amf.AMF0); err != nil {
return err
}
}
msg := self.bytesw.Bytes()
c := ChunkStream{
Format: 0,
CSID: 3,
Timestamp: 0,
TypeID: 20,
StreamID: self.streamid,
Length: uint32(len(msg)),
Data: msg,
}
self.conn.Write(&c)
return self.conn.Flush()
}
func (self *ConnClient) writeConnectMsg() error {
event := make(amf.Object)
event["app"] = self.app
event["type"] = "nonprivate"
event["flashVer"] = "FMS.3.1"
event["tcUrl"] = self.tcurl
self.curcmdName = cmdConnect
if err := self.writeMsg(cmdConnect, self.transID, event); err != nil {
return err
}
return self.readRespMsg()
}
func (self *ConnClient) writeCreateStreamMsg() error {
self.transID++
self.curcmdName = cmdCreateStream
if err := self.writeMsg(cmdCreateStream, self.transID, nil); err != nil {
return err
}
return self.readRespMsg()
}
func (self *ConnClient) writePublishMsg() error {
self.transID++
self.curcmdName = cmdPublish
if err := self.writeMsg(cmdPublish, self.transID, nil, self.title, publishLive); err != nil {
return err
}
return self.readRespMsg()
}
func (self *ConnClient) writePlayMsg() error {
self.transID++
self.curcmdName = cmdPlay
if err := self.writeMsg(cmdPlay, 0, nil, self.title); err != nil {
return err
}
return self.readRespMsg()
}
func (self *ConnClient) Start(url string, method string) error {
u, err := neturl.Parse(url)
if err != nil {
return err
}
self.url = url
path := strings.TrimLeft(u.Path, "/")
ps := strings.SplitN(path, "/", 2)
if len(ps) != 2 {
return fmt.Errorf("u path err: %s", path)
}
self.app = ps[0]
self.title = ps[1]
self.query = u.RawQuery
self.tcurl = "rtmp://" + u.Host + "/" + self.app
port := ":1935"
host := u.Host
localIP := ":0"
var remoteIP string
if strings.Index(host, ":") != -1 {
host, port, err = net.SplitHostPort(host)
if err != nil {
return err
}
port = ":" + port
}
ips, err := net.LookupIP(host)
glog.Infof("ips: %v, host: %v", ips, host)
if err != nil {
log.Println(err)
return err
}
remoteIP = ips[rand.Intn(len(ips))].String()
if strings.Index(remoteIP, ":") == -1 {
remoteIP += port
}
local, err := net.ResolveTCPAddr("tcp", localIP)
if err != nil {
log.Println(err)
return err
}
log.Println("remoteIP: ", remoteIP)
remote, err := net.ResolveTCPAddr("tcp", remoteIP)
if err != nil {
log.Println(err)
return err
}
conn, err := net.DialTCP("tcp", local, remote)
if err != nil {
log.Println(err)
return err
}
log.Println("connection:", "local:", conn.LocalAddr(), "remote:", conn.RemoteAddr())
self.conn = NewConn(conn, 4*1024)
if err := self.conn.HandshakeClient(); err != nil {
return err
}
if err := self.writeConnectMsg(); err != nil {
return err
}
if err := self.writeCreateStreamMsg(); err != nil {
return err
}
if method == av.PUBLISH {
if err := self.writePublishMsg(); err != nil {
return err
}
} else if method == av.PLAY {
if err := self.writePlayMsg(); err != nil {
return err
}
}
return nil
}
func (self *ConnClient) Write(c ChunkStream) error {
if c.TypeID == av.TAG_SCRIPTDATAAMF0 ||
c.TypeID == av.TAG_SCRIPTDATAAMF3 {
var err error
if c.Data, err = amf.MetaDataReform(c.Data, amf.ADD); err != nil {
return err
}
c.Length = uint32(len(c.Data))
}
return self.conn.Write(&c)
}
func (self *ConnClient) Read(c *ChunkStream) (err error) {
return self.conn.Read(c)
}
func (self *ConnClient) GetInfo() (app string, name string, url string) {
app = self.app
name = self.title
url = self.url
return
}
func (self *ConnClient) Close(err error) {
self.conn.Close()
}

353
protocol/rtmp/core/conn_server.go

@ -0,0 +1,353 @@ @@ -0,0 +1,353 @@
package core
import (
"bytes"
"errors"
"io"
"github.com/gwuhaolin/livego/protocol/amf"
"github.com/gwuhaolin/livego/av"
"log"
)
var (
publishLive = "live"
publishRecord = "record"
publishAppend = "append"
)
var (
ErrReq = errors.New("req error")
)
var (
cmdConnect = "connect"
cmdFcpublish = "FCPublish"
cmdReleaseStream = "releaseStream"
cmdCreateStream = "createStream"
cmdPublish = "publish"
cmdFCUnpublish = "FCUnpublish"
cmdDeleteStream = "deleteStream"
cmdPlay = "play"
)
type ConnectInfo struct {
App string `amf:"app" json:"app"`
Flashver string `amf:"flashVer" json:"flashVer"`
SwfUrl string `amf:"swfUrl" json:"swfUrl"`
TcUrl string `amf:"tcUrl" json:"tcUrl"`
Fpad bool `amf:"fpad" json:"fpad"`
AudioCodecs int `amf:"audioCodecs" json:"audioCodecs"`
VideoCodecs int `amf:"videoCodecs" json:"videoCodecs"`
VideoFunction int `amf:"videoFunction" json:"videoFunction"`
PageUrl string `amf:"pageUrl" json:"pageUrl"`
ObjectEncoding int `amf:"objectEncoding" json:"objectEncoding"`
}
type ConnectResp struct {
FMSVer string `amf:"fmsVer"`
Capabilities int `amf:"capabilities"`
}
type ConnectEvent struct {
Level string `amf:"level"`
Code string `amf:"code"`
Description string `amf:"description"`
ObjectEncoding int `amf:"objectEncoding"`
}
type PublishInfo struct {
Name string
Type string
}
type ConnServer struct {
done bool
streamID int
isPublisher bool
conn *Conn
transactionID int
ConnInfo ConnectInfo
PublishInfo PublishInfo
decoder *amf.Decoder
encoder *amf.Encoder
bytesw *bytes.Buffer
}
func NewConnServer(conn *Conn) *ConnServer {
return &ConnServer{
conn: conn,
streamID: 1,
bytesw: bytes.NewBuffer(nil),
decoder: &amf.Decoder{},
encoder: &amf.Encoder{},
}
}
func (self *ConnServer) writeMsg(csid, streamID uint32, args ...interface{}) error {
self.bytesw.Reset()
for _, v := range args {
if _, err := self.encoder.Encode(self.bytesw, v, amf.AMF0); err != nil {
return err
}
}
msg := self.bytesw.Bytes()
c := ChunkStream{
Format: 0,
CSID: csid,
Timestamp: 0,
TypeID: 20,
StreamID: streamID,
Length: uint32(len(msg)),
Data: msg,
}
self.conn.Write(&c)
return self.conn.Flush()
}
func (self *ConnServer) connect(vs []interface{}) error {
for _, v := range vs {
switch v.(type) {
case string:
case float64:
id := int(v.(float64))
if id != 1 {
return ErrReq
}
self.transactionID = id
case amf.Object:
obimap := v.(amf.Object)
if app, ok := obimap["app"]; ok {
self.ConnInfo.App = app.(string)
}
if flashVer, ok := obimap["flashVer"]; ok {
self.ConnInfo.Flashver = flashVer.(string)
}
if tcurl, ok := obimap["tcUrl"]; ok {
self.ConnInfo.TcUrl = tcurl.(string)
}
if encoding, ok := obimap["objectEncoding"]; ok {
self.ConnInfo.ObjectEncoding = int(encoding.(float64))
}
}
}
return nil
}
func (self *ConnServer) releaseStream(vs []interface{}) error {
return nil
}
func (self *ConnServer) fcPublish(vs []interface{}) error {
return nil
}
func (self *ConnServer) connectResp(cur *ChunkStream) error {
c := self.conn.NewWindowAckSize(2500000)
self.conn.Write(&c)
c = self.conn.NewSetPeerBandwidth(2500000)
self.conn.Write(&c)
c = self.conn.NewSetChunkSize(uint32(1024))
self.conn.Write(&c)
resp := make(amf.Object)
resp["fmsVer"] = "FMS/3,0,1,123"
resp["capabilities"] = 31
event := make(amf.Object)
event["level"] = "status"
event["code"] = "NetConnection.Connect.Success"
event["description"] = "Connection succeeded."
event["objectEncoding"] = self.ConnInfo.ObjectEncoding
return self.writeMsg(cur.CSID, cur.StreamID, "_result", self.transactionID, resp, event)
}
func (self *ConnServer) createStream(vs []interface{}) error {
for _, v := range vs {
switch v.(type) {
case string:
case float64:
self.transactionID = int(v.(float64))
case amf.Object:
}
}
return nil
}
func (self *ConnServer) createStreamResp(cur *ChunkStream) error {
return self.writeMsg(cur.CSID, cur.StreamID, "_result", self.transactionID, nil, self.streamID)
}
func (self *ConnServer) publishOrPlay(vs []interface{}) error {
for k, v := range vs {
switch v.(type) {
case string:
if k == 2 {
self.PublishInfo.Name = v.(string)
} else if k == 3 {
self.PublishInfo.Type = v.(string)
}
case float64:
id := int(v.(float64))
self.transactionID = id
case amf.Object:
}
}
return nil
}
func (self *ConnServer) publishResp(cur *ChunkStream) error {
event := make(amf.Object)
event["level"] = "status"
event["code"] = "NetStream.Publish.Start"
event["description"] = "Start publising."
return self.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event)
}
func (self *ConnServer) playResp(cur *ChunkStream) error {
self.conn.SetRecorded()
self.conn.SetBegin()
event := make(amf.Object)
event["level"] = "status"
event["code"] = "NetStream.Play.Reset"
event["description"] = "Playing and resetting stream."
if err := self.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil {
return err
}
event["level"] = "status"
event["code"] = "NetStream.Play.Start"
event["description"] = "Started playing stream."
if err := self.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil {
return err
}
event["level"] = "status"
event["code"] = "NetStream.Data.Start"
event["description"] = "Started playing stream."
if err := self.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil {
return err
}
event["level"] = "status"
event["code"] = "NetStream.Play.PublishNotify"
event["description"] = "Started playing notify."
if err := self.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil {
return err
}
return self.conn.Flush()
}
func (self *ConnServer) handleCmdMsg(c *ChunkStream) error {
amfType := amf.AMF0
if c.TypeID == 17 {
c.Data = c.Data[1:]
}
r := bytes.NewReader(c.Data)
vs, err := self.decoder.DecodeBatch(r, amf.Version(amfType))
if err != nil && err != io.EOF {
return err
}
// glog.Infof("rtmp req: %#v", vs)
switch vs[0].(type) {
case string:
switch vs[0].(string) {
case cmdConnect:
if err = self.connect(vs[1:]); err != nil {
return err
}
if err = self.connectResp(c); err != nil {
return err
}
case cmdCreateStream:
if err = self.createStream(vs[1:]); err != nil {
return err
}
if err = self.createStreamResp(c); err != nil {
return err
}
case cmdPublish:
if err = self.publishOrPlay(vs[1:]); err != nil {
return err
}
if err = self.publishResp(c); err != nil {
return err
}
self.done = true
self.isPublisher = true
log.Println("handle publish req done")
case cmdPlay:
if err = self.publishOrPlay(vs[1:]); err != nil {
return err
}
if err = self.playResp(c); err != nil {
return err
}
self.done = true
self.isPublisher = false
log.Println("handle play req done")
case cmdFcpublish:
self.fcPublish(vs)
case cmdReleaseStream:
self.releaseStream(vs)
case cmdFCUnpublish:
case cmdDeleteStream:
default:
log.Println("no support command=", vs[0].(string))
}
}
return nil
}
func (self *ConnServer) ReadMsg() error {
var c ChunkStream
for {
if err := self.conn.Read(&c); err != nil {
return err
}
switch c.TypeID {
case 20, 17:
if err := self.handleCmdMsg(&c); err != nil {
return err
}
}
if self.done {
break
}
}
return nil
}
func (self *ConnServer) IsPublisher() bool {
return self.isPublisher
}
func (self *ConnServer) Write(c ChunkStream) error {
if c.TypeID == av.TAG_SCRIPTDATAAMF0 ||
c.TypeID == av.TAG_SCRIPTDATAAMF3 {
var err error
if c.Data, err = amf.MetaDataReform(c.Data, amf.DEL); err != nil {
return err
}
c.Length = uint32(len(c.Data))
}
return self.conn.Write(&c)
}
func (self *ConnServer) Read(c *ChunkStream) (err error) {
return self.conn.Read(c)
}
func (self *ConnServer) GetInfo() (app string, name string, url string) {
app = self.ConnInfo.App
name = self.PublishInfo.Name
url = self.ConnInfo.TcUrl + "/" + self.PublishInfo.Name
return
}
func (self *ConnServer) Close(err error) {
self.conn.Close()
}

251
protocol/rtmp/core/conn_test.go

@ -0,0 +1,251 @@ @@ -0,0 +1,251 @@
package core
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/gwuhaolin/livego/utils/pool"
)
func TestConnReadNormal(t *testing.T) {
at := assert.New(t)
data := []byte{
0x06, 0x00, 0x00, 0x00, 0x00, 0x01, 0x33, 0x09, 0x01, 0x00, 0x00, 0x00,
}
data1 := make([]byte, 128)
data2 := make([]byte, 51)
data = append(data, data1...)
data = append(data, 0xc6)
data = append(data, data1...)
data = append(data, 0xc6)
data = append(data, data2...)
conn := &Conn{
pool: pool.NewPool(),
rw: NewReadWriter(bytes.NewBuffer(data), 1024),
remoteChunkSize: 128,
windowAckSize: 2500000,
remoteWindowAckSize: 2500000,
chunks: make(map[uint32]ChunkStream),
}
var c ChunkStream
err := conn.Read(&c)
at.Equal(err, nil)
at.Equal(int(c.CSID), 6)
at.Equal(int(c.Length), 307)
at.Equal(int(c.TypeID), 9)
}
//交叉读音视频数据
func TestConnCrossReading(t *testing.T) {
at := assert.New(t)
data1 := make([]byte, 128)
data2 := make([]byte, 51)
videoData := []byte{
0x06, 0x00, 0x00, 0x00, 0x00, 0x01, 0x33, 0x09, 0x01, 0x00, 0x00, 0x00,
}
audioData := []byte{
0x04, 0x00, 0x00, 0x00, 0x00, 0x01, 0x33, 0x08, 0x01, 0x00, 0x00, 0x00,
}
//video 1
videoData = append(videoData, data1...)
//video 2
videoData = append(videoData, 0xc6)
videoData = append(videoData, data1...)
//audio 1
videoData = append(videoData, audioData...)
videoData = append(videoData, data1...)
//audio 2
videoData = append(videoData, 0xc4)
videoData = append(videoData, data1...)
//video 3
videoData = append(videoData, 0xc6)
videoData = append(videoData, data2...)
//audio 3
videoData = append(videoData, 0xc4)
videoData = append(videoData, data2...)
conn := &Conn{
pool: pool.NewPool(),
rw: NewReadWriter(bytes.NewBuffer(videoData), 1024),
remoteChunkSize: 128,
windowAckSize: 2500000,
remoteWindowAckSize: 2500000,
chunks: make(map[uint32]ChunkStream),
}
var c ChunkStream
//video 1
err := conn.Read(&c)
at.Equal(err, nil)
at.Equal(int(c.TypeID), 9)
at.Equal(len(c.Data), 307)
//audio2
err = conn.Read(&c)
at.Equal(err, nil)
at.Equal(int(c.TypeID), 8)
at.Equal(len(c.Data), 307)
err = conn.Read(&c)
at.Equal(err, io.EOF)
}
func TestSetChunksizeForWrite(t *testing.T) {
at := assert.New(t)
chunk := ChunkStream{
Format: 0,
CSID: 2,
Timestamp: 0,
Length: 4,
StreamID: 1,
TypeID: idSetChunkSize,
Data: []byte{0x00, 0x00, 0x00, 0x96},
}
buf := bytes.NewBuffer(nil)
rw := NewReadWriter(buf, 1024)
conn := &Conn{
pool: pool.NewPool(),
rw: rw,
chunkSize: 128,
remoteChunkSize: 128,
windowAckSize: 2500000,
remoteWindowAckSize: 2500000,
chunks: make(map[uint32]ChunkStream),
}
audio := ChunkStream{
Format: 0,
CSID: 4,
Timestamp: 40,
Length: 133,
StreamID: 1,
TypeID: 0x8,
}
audio.Data = make([]byte, 133)
audio.Data = audio.Data[:133]
audio.Data[0] = 0xff
audio.Data[128] = 0xff
err := conn.Write(&audio)
at.Equal(err, nil)
conn.Flush()
at.Equal(len(buf.Bytes()), 146)
buf.Reset()
err = conn.Write(&chunk)
at.Equal(err, nil)
conn.Flush()
buf.Reset()
err = conn.Write(&audio)
at.Equal(err, nil)
conn.Flush()
at.Equal(len(buf.Bytes()), 145)
}
func TestSetChunksize(t *testing.T) {
at := assert.New(t)
data := []byte{
0x06, 0x00, 0x00, 0x00, 0x00, 0x01, 0x33, 0x09, 0x01, 0x00, 0x00, 0x00,
}
data1 := make([]byte, 128)
data2 := make([]byte, 51)
data = append(data, data1...)
data = append(data, 0xc6)
data = append(data, data1...)
data = append(data, 0xc6)
data = append(data, data2...)
rw := NewReadWriter(bytes.NewBuffer(data), 1024)
conn := &Conn{
pool: pool.NewPool(),
rw: rw,
chunkSize: 128,
remoteChunkSize: 128,
windowAckSize: 2500000,
remoteWindowAckSize: 2500000,
chunks: make(map[uint32]ChunkStream),
}
var c ChunkStream
err := conn.Read(&c)
at.Equal(err, nil)
at.Equal(int(c.TypeID), 9)
at.Equal(int(c.CSID), 6)
at.Equal(int(c.StreamID), 1)
at.Equal(len(c.Data), 307)
//设置chunksize
chunkBuf := []byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x01, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x96}
conn.rw = NewReadWriter(bytes.NewBuffer(chunkBuf), 1024)
err = conn.Read(&c)
at.Equal(err, nil)
data = data[:12]
data[7] = 0x8
data1 = make([]byte, 150)
data2 = make([]byte, 7)
data = append(data, data1...)
data = append(data, 0xc6)
data = append(data, data1...)
data = append(data, 0xc6)
data = append(data, data2...)
conn.rw = NewReadWriter(bytes.NewBuffer(data), 1024)
err = conn.Read(&c)
at.Equal(err, nil)
at.Equal(len(c.Data), 307)
err = conn.Read(&c)
at.Equal(err, io.EOF)
}
func TestConnWrite(t *testing.T) {
at := assert.New(t)
wr := bytes.NewBuffer(nil)
readWriter := NewReadWriter(wr, 128)
conn := &Conn{
pool: pool.NewPool(),
rw: readWriter,
chunkSize: 128,
remoteChunkSize: 128,
windowAckSize: 2500000,
remoteWindowAckSize: 2500000,
chunks: make(map[uint32]ChunkStream),
}
c1 := ChunkStream{
Length: 3,
TypeID: 8,
CSID: 3,
Timestamp: 40,
Data: []byte{0x01, 0x02, 0x03},
}
err := conn.Write(&c1)
at.Equal(err, nil)
conn.Flush()
at.Equal(wr.Bytes(), []byte{0x4, 0x0, 0x0, 0x28, 0x0, 0x0, 0x3, 0x8, 0x0, 0x0, 0x0, 0x0, 0x1, 0x2, 0x3})
//for type 1
wr.Reset()
c1 = ChunkStream{
Length: 4,
TypeID: 8,
CSID: 3,
Timestamp: 80,
Data: []byte{0x01, 0x02, 0x03, 0x4},
}
err = conn.Write(&c1)
at.Equal(err, nil)
conn.Flush()
at.Equal(wr.Bytes(), []byte{0x4, 0x0, 0x0, 0x50, 0x0, 0x0, 0x4, 0x8, 0x0, 0x0, 0x0, 0x0, 0x1, 0x2, 0x3, 0x4})
//for type 2
wr.Reset()
c1.Timestamp = 160
err = conn.Write(&c1)
at.Equal(err, nil)
conn.Flush()
at.Equal(wr.Bytes(), []byte{0x4, 0x0, 0x0, 0xa0, 0x0, 0x0, 0x4, 0x8, 0x0, 0x0, 0x0, 0x0, 0x1, 0x2, 0x3, 0x4})
}

207
protocol/rtmp/core/handshake.go

@ -0,0 +1,207 @@ @@ -0,0 +1,207 @@
package core
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"time"
"github.com/gwuhaolin/livego/utils/pio"
)
var (
timeout = 5 * time.Second
)
var (
hsClientFullKey = []byte{
'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ',
'F', 'l', 'a', 's', 'h', ' ', 'P', 'l', 'a', 'y', 'e', 'r', ' ',
'0', '0', '1',
0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1,
0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB,
0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE,
}
hsServerFullKey = []byte{
'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ',
'F', 'l', 'a', 's', 'h', ' ', 'M', 'e', 'd', 'i', 'a', ' ',
'S', 'e', 'r', 'v', 'e', 'r', ' ',
'0', '0', '1',
0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1,
0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB,
0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE,
}
hsClientPartialKey = hsClientFullKey[:30]
hsServerPartialKey = hsServerFullKey[:36]
)
func hsMakeDigest(key []byte, src []byte, gap int) (dst []byte) {
h := hmac.New(sha256.New, key)
if gap <= 0 {
h.Write(src)
} else {
h.Write(src[:gap])
h.Write(src[gap+32:])
}
return h.Sum(nil)
}
func hsCalcDigestPos(p []byte, base int) (pos int) {
for i := 0; i < 4; i++ {
pos += int(p[base+i])
}
pos = (pos % 728) + base + 4
return
}
func hsFindDigest(p []byte, key []byte, base int) int {
gap := hsCalcDigestPos(p, base)
digest := hsMakeDigest(key, p, gap)
if bytes.Compare(p[gap:gap+32], digest) != 0 {
return -1
}
return gap
}
func hsParse1(p []byte, peerkey []byte, key []byte) (ok bool, digest []byte) {
var pos int
if pos = hsFindDigest(p, peerkey, 772); pos == -1 {
if pos = hsFindDigest(p, peerkey, 8); pos == -1 {
return
}
}
ok = true
digest = hsMakeDigest(key, p[pos:pos+32], -1)
return
}
func hsCreate01(p []byte, time uint32, ver uint32, key []byte) {
p[0] = 3
p1 := p[1:]
rand.Read(p1[8:])
pio.PutU32BE(p1[0:4], time)
pio.PutU32BE(p1[4:8], ver)
gap := hsCalcDigestPos(p1, 8)
digest := hsMakeDigest(key, p1, gap)
copy(p1[gap:], digest)
}
func hsCreate2(p []byte, key []byte) {
rand.Read(p)
gap := len(p) - 32
digest := hsMakeDigest(key, p, gap)
copy(p[gap:], digest)
}
func (self *Conn) HandshakeClient() (err error) {
var random [(1 + 1536*2) * 2]byte
C0C1C2 := random[:1536*2+1]
C0 := C0C1C2[:1]
C0C1 := C0C1C2[:1536+1]
C2 := C0C1C2[1536+1:]
S0S1S2 := random[1536*2+1:]
C0[0] = 3
// > C0C1
self.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = self.rw.Write(C0C1); err != nil {
return
}
self.Conn.SetDeadline(time.Now().Add(timeout))
if err = self.rw.Flush(); err != nil {
return
}
// < S0S1S2
self.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = io.ReadFull(self.rw, S0S1S2); err != nil {
return
}
S1 := S0S1S2[1: 1536+1]
if ver := pio.U32BE(S1[4:8]); ver != 0 {
C2 = S1
} else {
C2 = S1
}
// > C2
self.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = self.rw.Write(C2); err != nil {
return
}
self.Conn.SetDeadline(time.Time{})
return
}
func (self *Conn) HandshakeServer() (err error) {
var random [(1 + 1536*2) * 2]byte
C0C1C2 := random[:1536*2+1]
C0 := C0C1C2[:1]
C1 := C0C1C2[1: 1536+1]
C0C1 := C0C1C2[:1536+1]
C2 := C0C1C2[1536+1:]
S0S1S2 := random[1536*2+1:]
S0 := S0S1S2[:1]
S1 := S0S1S2[1: 1536+1]
S0S1 := S0S1S2[:1536+1]
S2 := S0S1S2[1536+1:]
// < C0C1
self.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = io.ReadFull(self.rw, C0C1); err != nil {
return
}
self.Conn.SetDeadline(time.Now().Add(timeout))
if C0[0] != 3 {
err = fmt.Errorf("rtmp: handshake version=%d invalid", C0[0])
return
}
S0[0] = 3
clitime := pio.U32BE(C1[0:4])
srvtime := clitime
srvver := uint32(0x0d0e0a0d)
cliver := pio.U32BE(C1[4:8])
if cliver != 0 {
var ok bool
var digest []byte
if ok, digest = hsParse1(C1, hsClientPartialKey, hsServerFullKey); !ok {
err = fmt.Errorf("rtmp: handshake server: C1 invalid")
return
}
hsCreate01(S0S1, srvtime, srvver, hsServerPartialKey)
hsCreate2(S2, digest)
} else {
copy(S1, C2)
copy(S2, C1)
}
// > S0S1S2
self.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = self.rw.Write(S0S1S2); err != nil {
return
}
self.Conn.SetDeadline(time.Now().Add(timeout))
if err = self.rw.Flush(); err != nil {
return
}
// < C2
self.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = io.ReadFull(self.rw, C2); err != nil {
return
}
self.Conn.SetDeadline(time.Time{})
return
}

114
protocol/rtmp/core/read_writer.go

@ -0,0 +1,114 @@ @@ -0,0 +1,114 @@
package core
import (
"bufio"
"io"
)
type ReadWriter struct {
*bufio.ReadWriter
readError error
writeError error
}
func NewReadWriter(rw io.ReadWriter, bufSize int) *ReadWriter {
return &ReadWriter{
ReadWriter: bufio.NewReadWriter(bufio.NewReaderSize(rw, bufSize), bufio.NewWriterSize(rw, bufSize)),
}
}
func (rw *ReadWriter) Read(p []byte) (int, error) {
if rw.readError != nil {
return 0, rw.readError
}
n, err := io.ReadAtLeast(rw.ReadWriter, p, len(p))
rw.readError = err
return n, err
}
func (rw *ReadWriter) ReadError() error {
return rw.readError
}
func (rw *ReadWriter) ReadUintBE(n int) (uint32, error) {
if rw.readError != nil {
return 0, rw.readError
}
ret := uint32(0)
for i := 0; i < n; i++ {
b, err := rw.ReadByte()
if err != nil {
rw.readError = err
return 0, err
}
ret = ret<<8 + uint32(b)
}
return ret, nil
}
func (rw *ReadWriter) ReadUintLE(n int) (uint32, error) {
if rw.readError != nil {
return 0, rw.readError
}
ret := uint32(0)
for i := 0; i < n; i++ {
b, err := rw.ReadByte()
if err != nil {
rw.readError = err
return 0, err
}
ret += uint32(b) << uint32(i*8)
}
return ret, nil
}
func (rw *ReadWriter) Flush() error {
if rw.writeError != nil {
return rw.writeError
}
if rw.ReadWriter.Writer.Buffered() == 0 {
return nil
}
return rw.ReadWriter.Flush()
}
func (rw *ReadWriter) Write(p []byte) (int, error) {
if rw.writeError != nil {
return 0, rw.writeError
}
return rw.ReadWriter.Write(p)
}
func (rw *ReadWriter) WriteError() error {
return rw.writeError
}
func (rw *ReadWriter) WriteUintBE(v uint32, n int) error {
if rw.writeError != nil {
return rw.writeError
}
for i := 0; i < n; i++ {
b := byte(v>>uint32((n-i-1)<<3)) & 0xff
if err := rw.WriteByte(b); err != nil {
rw.writeError = err
return err
}
}
return nil
}
func (rw *ReadWriter) WriteUintLE(v uint32, n int) error {
if rw.writeError != nil {
return rw.writeError
}
for i := 0; i < n; i++ {
b := byte(v) & 0xff
if err := rw.WriteByte(b); err != nil {
rw.writeError = err
return err
}
v = v >> 8
}
return nil
}

136
protocol/rtmp/core/read_writer_test.go

@ -0,0 +1,136 @@ @@ -0,0 +1,136 @@
package core
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/assert"
)
func TestReader(t *testing.T) {
at := assert.New(t)
buf := bytes.NewBufferString("abc")
r := NewReadWriter(buf, 1024)
b := make([]byte, 3)
n, err := r.Read(b)
at.Equal(err, nil)
at.Equal(r.ReadError(), nil)
at.Equal(n, 3)
n, err = r.Read(b)
at.Equal(err, io.EOF)
at.Equal(r.ReadError(), io.EOF)
buf.WriteString("123")
n, err = r.Read(b)
at.Equal(err, io.EOF)
at.Equal(r.ReadError(), io.EOF)
at.Equal(n, 0)
}
func TestReaderUintBE(t *testing.T) {
at := assert.New(t)
type Test struct {
i int
value uint32
bytes []byte
}
tests := []Test{
{1, 0x01, []byte{0x01}},
{2, 0x0102, []byte{0x01, 0x02}},
{3, 0x010203, []byte{0x01, 0x02, 0x03}},
{4, 0x01020304, []byte{0x01, 0x02, 0x03, 0x04}},
}
for _, test := range tests {
buf := bytes.NewBuffer(test.bytes)
r := NewReadWriter(buf, 1024)
n, err := r.ReadUintBE(test.i)
at.Equal(err, nil, "test %d", test.i)
at.Equal(n, test.value, "test %d", test.i)
}
}
func TestReaderUintLE(t *testing.T) {
at := assert.New(t)
type Test struct {
i int
value uint32
bytes []byte
}
tests := []Test{
{1, 0x01, []byte{0x01}},
{2, 0x0102, []byte{0x02, 0x01}},
{3, 0x010203, []byte{0x03, 0x02, 0x01}},
{4, 0x01020304, []byte{0x04, 0x03, 0x02, 0x01}},
}
for _, test := range tests {
buf := bytes.NewBuffer(test.bytes)
r := NewReadWriter(buf, 1024)
n, err := r.ReadUintLE(test.i)
at.Equal(err, nil, "test %d", test.i)
at.Equal(n, test.value, "test %d", test.i)
}
}
func TestWriter(t *testing.T) {
at := assert.New(t)
buf := bytes.NewBuffer(nil)
w := NewReadWriter(buf, 1024)
b := []byte{1, 2, 3}
n, err := w.Write(b)
at.Equal(err, nil)
at.Equal(w.WriteError(), nil)
at.Equal(n, 3)
w.writeError = io.EOF
n, err = w.Write(b)
at.Equal(err, io.EOF)
at.Equal(w.WriteError(), io.EOF)
at.Equal(n, 0)
}
func TestWriteUintBE(t *testing.T) {
at := assert.New(t)
type Test struct {
i int
value uint32
bytes []byte
}
tests := []Test{
{1, 0x01, []byte{0x01}},
{2, 0x0102, []byte{0x01, 0x02}},
{3, 0x010203, []byte{0x01, 0x02, 0x03}},
{4, 0x01020304, []byte{0x01, 0x02, 0x03, 0x04}},
}
for _, test := range tests {
buf := bytes.NewBuffer(nil)
r := NewReadWriter(buf, 1024)
err := r.WriteUintBE(test.value, test.i)
at.Equal(err, nil, "test %d", test.i)
err = r.Flush()
at.Equal(err, nil, "test %d", test.i)
at.Equal(buf.Bytes(), test.bytes, "test %d", test.i)
}
}
func TestWriteUintLE(t *testing.T) {
at := assert.New(t)
type Test struct {
i int
value uint32
bytes []byte
}
tests := []Test{
{1, 0x01, []byte{0x01}},
{2, 0x0102, []byte{0x02, 0x01}},
{3, 0x010203, []byte{0x03, 0x02, 0x01}},
{4, 0x01020304, []byte{0x04, 0x03, 0x02, 0x01}},
}
for _, test := range tests {
buf := bytes.NewBuffer(nil)
r := NewReadWriter(buf, 1024)
err := r.WriteUintLE(test.value, test.i)
at.Equal(err, nil, "test %d", test.i)
err = r.Flush()
at.Equal(err, nil, "test %d", test.i)
at.Equal(buf.Bytes(), test.bytes, "test %d", test.i)
}
}

341
protocol/rtmp/rtmp.go

@ -0,0 +1,341 @@ @@ -0,0 +1,341 @@
package rtmp
import (
"net"
"time"
"net/url"
"strings"
"errors"
"flag"
"github.com/gwuhaolin/livego/av"
"github.com/gwuhaolin/livego/utils/uid"
"github.com/gwuhaolin/livego/container/flv"
"github.com/golang/glog"
"github.com/gwuhaolin/livego/protocol/rtmp/core"
"log"
)
const (
maxQueueNum = 1024
)
var (
readTimeout = flag.Int("readTimeout", 10, "read time out")
writeTimeout = flag.Int("writeTimeout", 10, "write time out")
)
type Client struct {
handler av.Handler
getter av.GetWriter
}
func NewRtmpClient(h av.Handler, getter av.GetWriter) *Client {
return &Client{
handler: h,
getter: getter,
}
}
func (self *Client) Dial(url string, method string) error {
connClient := core.NewConnClient()
if err := connClient.Start(url, method); err != nil {
return err
}
if method == av.PUBLISH {
writer := NewVirWriter(connClient)
self.handler.HandleWriter(writer)
} else if method == av.PLAY {
reader := NewVirReader(connClient)
self.handler.HandleReader(reader)
if self.getter != nil {
writer := self.getter.GetWriter(reader.Info())
self.handler.HandleWriter(writer)
}
}
return nil
}
func (self *Client) GetHandle() av.Handler {
return self.handler
}
type Server struct {
handler av.Handler
getter av.GetWriter
}
func NewRtmpServer(h av.Handler, getter av.GetWriter) *Server {
return &Server{
handler: h,
getter: getter,
}
}
func (self *Server) Serve(listener net.Listener) (err error) {
defer func() {
if r := recover(); r != nil {
log.Println("rtmp serve panic: ", r)
}
}()
for {
var netconn net.Conn
netconn, err = listener.Accept()
if err != nil {
return
}
conn := core.NewConn(netconn, 4*1024)
log.Println("new client, connect remote:", conn.RemoteAddr().String(),
"local:", conn.LocalAddr().String())
go self.handleConn(conn)
}
}
func (self *Server) handleConn(conn *core.Conn) error {
if err := conn.HandshakeServer(); err != nil {
conn.Close()
log.Println("handleConn HandshakeServer err:", err)
return err
}
connServer := core.NewConnServer(conn)
if err := connServer.ReadMsg(); err != nil {
conn.Close()
log.Println("handleConn read msg err:", err)
return err
}
if connServer.IsPublisher() {
reader := NewVirReader(connServer)
self.handler.HandleReader(reader)
glog.Infof("new publisher: %+v", reader.Info())
if self.getter != nil {
writer := self.getter.GetWriter(reader.Info())
self.handler.HandleWriter(writer)
}
} else {
writer := NewVirWriter(connServer)
glog.Infof("new player: %+v", writer.Info())
self.handler.HandleWriter(writer)
}
return nil
}
type GetInFo interface {
GetInfo() (string, string, string)
}
type StreamReadWriteCloser interface {
GetInFo
Close(error)
Write(core.ChunkStream) error
Read(c *core.ChunkStream) error
}
type VirWriter struct {
Uid string
closed bool
av.RWBaser
conn StreamReadWriteCloser
packetQueue chan av.Packet
}
func NewVirWriter(conn StreamReadWriteCloser) *VirWriter {
ret := &VirWriter{
Uid: uid.NEWID(),
conn: conn,
RWBaser: av.NewRWBaser(time.Second * time.Duration(*writeTimeout)),
packetQueue: make(chan av.Packet, maxQueueNum),
}
go ret.Check()
go func() {
err := ret.SendPacket()
if err != nil {
log.Println(err)
}
}()
return ret
}
func (self *VirWriter) Check() {
var c core.ChunkStream
for {
if err := self.conn.Read(&c); err != nil {
self.Close(err)
return
}
}
}
func (self *VirWriter) DropPacket(pktQue chan av.Packet, info av.Info) {
glog.Errorf("[%v] packet queue max!!!", info)
for i := 0; i < maxQueueNum-84; i++ {
tmpPkt, ok := <-pktQue
// try to don't drop audio
if ok && tmpPkt.IsAudio {
if len(pktQue) > maxQueueNum-2 {
log.Println("drop audio pkt")
<-pktQue
} else {
pktQue <- tmpPkt
}
}
if ok && tmpPkt.IsVideo {
videoPkt, ok := tmpPkt.Header.(av.VideoPacketHeader)
// dont't drop sps config and dont't drop key frame
if ok && (videoPkt.IsSeq() || videoPkt.IsKeyFrame()) {
pktQue <- tmpPkt
}
if len(pktQue) > maxQueueNum-10 {
log.Println("drop video pkt")
<-pktQue
}
}
}
log.Println("packet queue len: ", len(pktQue))
}
//
func (self *VirWriter) Write(p av.Packet) error {
if !self.closed {
if len(self.packetQueue) >= maxQueueNum-24 {
self.DropPacket(self.packetQueue, self.Info())
} else {
self.packetQueue <- p
}
return nil
} else {
return errors.New("closed")
}
}
func (self *VirWriter) SendPacket() error {
var cs core.ChunkStream
for {
p, ok := <-self.packetQueue
if ok {
cs.Data = p.Data
cs.Length = uint32(len(p.Data))
cs.StreamID = 1
cs.Timestamp = p.TimeStamp
cs.Timestamp += self.BaseTimeStamp()
if p.IsVideo {
cs.TypeID = av.TAG_VIDEO
} else {
if p.IsMetadata {
cs.TypeID = av.TAG_SCRIPTDATAAMF0
} else {
cs.TypeID = av.TAG_AUDIO
}
}
self.SetPreTime()
self.RecTimeStamp(cs.Timestamp, cs.TypeID)
err := self.conn.Write(cs)
if err != nil {
self.closed = true
return err
}
} else {
return errors.New("closed")
}
}
return nil
}
func (self *VirWriter) Info() (ret av.Info) {
ret.UID = self.Uid
_, _, URL := self.conn.GetInfo()
ret.URL = URL
_url, err := url.Parse(URL)
if err != nil {
log.Println(err)
}
ret.Key = strings.TrimLeft(_url.Path, "/")
ret.Inter = true
return
}
func (self *VirWriter) Close(err error) {
log.Println("player ", self.Info(), "closed: "+err.Error())
if !self.closed {
close(self.packetQueue)
}
self.closed = true
self.conn.Close(err)
}
type VirReader struct {
Uid string
av.RWBaser
demuxer *flv.Demuxer
conn StreamReadWriteCloser
}
func NewVirReader(conn StreamReadWriteCloser) *VirReader {
return &VirReader{
Uid: uid.NEWID(),
conn: conn,
RWBaser: av.NewRWBaser(time.Second * time.Duration(*writeTimeout)),
demuxer: flv.NewDemuxer(),
}
}
func (self *VirReader) Read(p *av.Packet) (err error) {
defer func() {
if r := recover(); r != nil {
log.Println("rtmp read packet panic: ", r)
}
}()
self.SetPreTime()
var cs core.ChunkStream
for {
err = self.conn.Read(&cs)
if err != nil {
return err
}
if cs.TypeID == av.TAG_AUDIO ||
cs.TypeID == av.TAG_VIDEO ||
cs.TypeID == av.TAG_SCRIPTDATAAMF0 ||
cs.TypeID == av.TAG_SCRIPTDATAAMF3 {
break
}
}
p.IsAudio = cs.TypeID == av.TAG_AUDIO
p.IsVideo = cs.TypeID == av.TAG_VIDEO
p.IsMetadata = (cs.TypeID == av.TAG_SCRIPTDATAAMF0 || cs.TypeID == av.TAG_SCRIPTDATAAMF3)
p.Data = cs.Data
p.TimeStamp = cs.Timestamp
self.demuxer.DemuxH(p)
return err
}
func (self *VirReader) Info() (ret av.Info) {
ret.UID = self.Uid
_, _, URL := self.conn.GetInfo()
ret.URL = URL
_url, err := url.Parse(URL)
if err != nil {
log.Println(err)
}
ret.Key = strings.TrimLeft(_url.Path, "/")
return
}
func (self *VirReader) Close(err error) {
log.Println("publisher ", self.Info(), "closed: "+err.Error())
self.conn.Close(err)
}

228
protocol/rtmp/stream.go

@ -0,0 +1,228 @@ @@ -0,0 +1,228 @@
package rtmp
import (
"errors"
"time"
"github.com/gwuhaolin/livego/utils/cmap"
"github.com/golang/glog"
"github.com/gwuhaolin/livego/av"
"github.com/gwuhaolin/livego/protocol/rtmp/cache"
)
var (
EmptyID = ""
)
type RtmpStream struct {
streams cmap.ConcurrentMap
}
func NewRtmpStream() *RtmpStream {
ret := &RtmpStream{
streams: cmap.New(),
}
go ret.CheckAlive()
return ret
}
func (rs *RtmpStream) HandleReader(r av.ReadCloser) {
info := r.Info()
var s *Stream
ok := rs.streams.Has(info.Key)
if !ok {
s = NewStream()
rs.streams.Set(info.Key, s)
} else {
s.TransStop()
id := s.ID()
if id != EmptyID && id != info.UID {
ns := NewStream()
s.Copy(ns)
s = ns
rs.streams.Set(info.Key, ns)
}
}
s.AddReader(r)
}
func (rs *RtmpStream) HandleWriter(w av.WriteCloser) {
info := w.Info()
var s *Stream
ok := rs.streams.Has(info.Key)
if !ok {
s = NewStream()
rs.streams.Set(info.Key, s)
} else {
item, ok := rs.streams.Get(info.Key)
if ok {
s = item.(*Stream)
s.AddWriter(w)
}
}
}
func (rs *RtmpStream) GetStreams() cmap.ConcurrentMap {
return rs.streams
}
func (rs *RtmpStream) CheckAlive() {
for {
<-time.After(5 * time.Second)
for item := range rs.streams.IterBuffered() {
v := item.Val.(*Stream)
if v.CheckAlive() == 0 {
rs.streams.Remove(item.Key)
}
}
}
}
type Stream struct {
isStart bool
cache *cache.Cache
r av.ReadCloser
ws cmap.ConcurrentMap
}
type PackWriterCloser struct {
init bool
w av.WriteCloser
}
func (p *PackWriterCloser) GetWriter() av.WriteCloser {
return p.w
}
func NewStream() *Stream {
return &Stream{
cache: cache.NewCache(),
ws: cmap.New(),
}
}
func (s *Stream) ID() string {
if s.r != nil {
return s.r.Info().UID
}
return EmptyID
}
func (s *Stream) GetReader() av.ReadCloser {
return s.r
}
func (s *Stream) GetWs() cmap.ConcurrentMap {
return s.ws
}
func (s *Stream) Copy(dst *Stream) {
for item := range s.ws.IterBuffered() {
v := item.Val.(*PackWriterCloser)
s.ws.Remove(item.Key)
v.w.CalcBaseTimestamp()
dst.AddWriter(v.w)
}
}
func (s *Stream) AddReader(r av.ReadCloser) {
s.r = r
go s.TransStart()
}
func (s *Stream) AddWriter(w av.WriteCloser) {
info := w.Info()
pw := &PackWriterCloser{w: w}
s.ws.Set(info.UID, pw)
}
func (s *Stream) TransStart() {
// debug mode don't use it
// defer func() {
// if r := recover(); r != nil {
// log.Println("rtmp TransStart panic: ", r)
// }
// }()
s.isStart = true
var p av.Packet
for {
if !s.isStart {
s.closeInter()
return
}
err := s.r.Read(&p)
if err != nil {
s.closeInter()
s.isStart = false
return
}
s.cache.Write(p)
for item := range s.ws.IterBuffered() {
v := item.Val.(*PackWriterCloser)
if !v.init {
if err = s.cache.Send(v.w); err != nil {
glog.Errorf("[%s] send cache packet error: %v, remove", v.w.Info(), err)
s.ws.Remove(item.Key)
continue
}
v.init = true
} else {
if err = v.w.Write(p); err != nil {
glog.Errorf("[%s] write packet error: %v, remove", v.w.Info(), err)
s.ws.Remove(item.Key)
}
}
}
}
}
func (s *Stream) TransStop() {
if s.isStart && s.r != nil {
s.r.Close(errors.New("stop old"))
}
s.isStart = false
}
func (s *Stream) CheckAlive() (n int) {
if s.r != nil && s.isStart {
if s.r.Alive() {
n++
} else {
s.r.Close(errors.New("read timeout"))
}
}
for item := range s.ws.IterBuffered() {
v := item.Val.(*PackWriterCloser)
if v.w != nil {
if !v.w.Alive() && s.isStart {
s.ws.Remove(item.Key)
v.w.Close(errors.New("write timeout"))
continue
}
n++
}
}
return
}
func (s *Stream) closeInter() {
if s.r != nil {
glog.Infof("[%v] publisher closed", s.r.Info())
}
for item := range s.ws.IterBuffered() {
v := item.Val.(*PackWriterCloser)
if v.w != nil {
if v.w.Info().IsInterval() {
v.w.Close(errors.New("closed"))
s.ws.Remove(item.Key)
glog.Infof("[%v] player closed and remove\n", v.w.Info())
}
}
}
}

1
protocol/rtp/rtp.go

@ -0,0 +1 @@ @@ -0,0 +1 @@
package rtp

1
protocol/rtsp/protocol.go

@ -0,0 +1 @@ @@ -0,0 +1 @@
package rtsp

1
protocol/rtsp/rtsp.go

@ -0,0 +1 @@ @@ -0,0 +1 @@
package rtsp

1
protocol/webrtc/webrtc.go

@ -0,0 +1 @@ @@ -0,0 +1 @@
package webrtc

301
utils/cmap/cmap.go

@ -0,0 +1,301 @@ @@ -0,0 +1,301 @@
package cmap
import (
"encoding/json"
"sync"
)
var SHARD_COUNT = 32
// A "thread" safe map of type string:Anything.
// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
type ConcurrentMap []*ConcurrentMapShared
// A "thread" safe string to anything map.
type ConcurrentMapShared struct {
items map[string]interface{}
sync.RWMutex // Read Write mutex, guards access to internal map.
}
// Creates a new concurrent map.
func New() ConcurrentMap {
m := make(ConcurrentMap, SHARD_COUNT)
for i := 0; i < SHARD_COUNT; i++ {
m[i] = &ConcurrentMapShared{items: make(map[string]interface{})}
}
return m
}
// Returns shard under given key
func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared {
return m[uint(fnv32(key))%uint(SHARD_COUNT)]
}
func (m ConcurrentMap) MSet(data map[string]interface{}) {
for key, value := range data {
shard := m.GetShard(key)
shard.Lock()
shard.items[key] = value
shard.Unlock()
}
}
// Sets the given value under the specified key.
func (m *ConcurrentMap) Set(key string, value interface{}) {
// Get map shard.
shard := m.GetShard(key)
shard.Lock()
shard.items[key] = value
shard.Unlock()
}
// Callback to return new element to be inserted into the map
// It is called while lock is held, therefore it MUST NOT
// try to access other keys in same map, as it can lead to deadlock since
// Go sync.RWLock is not reentrant
type UpsertCb func(exist bool, valueInMap interface{}, newValue interface{}) interface{}
// Insert or Update - updates existing element or inserts a new one using UpsertCb
func (m *ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res interface{}) {
shard := m.GetShard(key)
shard.Lock()
v, ok := shard.items[key]
res = cb(ok, v, value)
shard.items[key] = res
shard.Unlock()
return res
}
// Sets the given value under the specified key if no value was associated with it.
func (m *ConcurrentMap) SetIfAbsent(key string, value interface{}) bool {
// Get map shard.
shard := m.GetShard(key)
shard.Lock()
_, ok := shard.items[key]
if !ok {
shard.items[key] = value
}
shard.Unlock()
return !ok
}
// Retrieves an element from map under given key.
func (m ConcurrentMap) Get(key string) (interface{}, bool) {
// Get shard
shard := m.GetShard(key)
shard.RLock()
// Get item from shard.
val, ok := shard.items[key]
shard.RUnlock()
return val, ok
}
// Returns the number of elements within the map.
func (m ConcurrentMap) Count() int {
count := 0
for i := 0; i < SHARD_COUNT; i++ {
shard := m[i]
shard.RLock()
count += len(shard.items)
shard.RUnlock()
}
return count
}
// Looks up an item under specified key
func (m *ConcurrentMap) Has(key string) bool {
// Get shard
shard := m.GetShard(key)
shard.RLock()
// See if element is within shard.
_, ok := shard.items[key]
shard.RUnlock()
return ok
}
// Removes an element from the map.
func (m *ConcurrentMap) Remove(key string) {
// Try to get shard.
shard := m.GetShard(key)
shard.Lock()
delete(shard.items, key)
shard.Unlock()
}
// Removes an element from the map and returns it
func (m *ConcurrentMap) Pop(key string) (v interface{}, exists bool) {
// Try to get shard.
shard := m.GetShard(key)
shard.Lock()
v, exists = shard.items[key]
delete(shard.items, key)
shard.Unlock()
return v, exists
}
// Checks if map is empty.
func (m *ConcurrentMap) IsEmpty() bool {
return m.Count() == 0
}
// Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
type Tuple struct {
Key string
Val interface{}
}
// Returns an iterator which could be used in a for range loop.
//
// Deprecated: using IterBuffered() will get a better performence
func (m ConcurrentMap) Iter() <-chan Tuple {
ch := make(chan Tuple)
go func() {
wg := sync.WaitGroup{}
wg.Add(SHARD_COUNT)
// Foreach shard.
for _, shard := range m {
go func(shard *ConcurrentMapShared) {
// Foreach key, value pair.
shard.RLock()
for key, val := range shard.items {
ch <- Tuple{key, val}
}
shard.RUnlock()
wg.Done()
}(shard)
}
wg.Wait()
close(ch)
}()
return ch
}
// Returns a buffered iterator which could be used in a for range loop.
func (m ConcurrentMap) IterBuffered() <-chan Tuple {
ch := make(chan Tuple, m.Count())
go func() {
wg := sync.WaitGroup{}
wg.Add(SHARD_COUNT)
// Foreach shard.
for _, shard := range m {
go func(shard *ConcurrentMapShared) {
// Foreach key, value pair.
shard.RLock()
for key, val := range shard.items {
ch <- Tuple{key, val}
}
shard.RUnlock()
wg.Done()
}(shard)
}
wg.Wait()
close(ch)
}()
return ch
}
// Returns all items as map[string]interface{}
func (m ConcurrentMap) Items() map[string]interface{} {
tmp := make(map[string]interface{})
// Insert items to temporary map.
for item := range m.IterBuffered() {
tmp[item.Key] = item.Val
}
return tmp
}
// Iterator callback,called for every key,value found in
// maps. RLock is held for all calls for a given shard
// therefore callback sess consistent view of a shard,
// but not across the shards
type IterCb func(key string, v interface{})
// Callback based iterator, cheapest way to read
// all elements in a map.
func (m *ConcurrentMap) IterCb(fn IterCb) {
for idx := range *m {
shard := (*m)[idx]
shard.RLock()
for key, value := range shard.items {
fn(key, value)
}
shard.RUnlock()
}
}
// Return all keys as []string
func (m ConcurrentMap) Keys() []string {
count := m.Count()
ch := make(chan string, count)
go func() {
// Foreach shard.
wg := sync.WaitGroup{}
wg.Add(SHARD_COUNT)
for _, shard := range m {
go func(shard *ConcurrentMapShared) {
// Foreach key, value pair.
shard.RLock()
for key := range shard.items {
ch <- key
}
shard.RUnlock()
wg.Done()
}(shard)
}
wg.Wait()
close(ch)
}()
// Generate keys
keys := make([]string, 0, count)
for k := range ch {
keys = append(keys, k)
}
return keys
}
//Reviles ConcurrentMap "private" variables to json marshal.
func (m ConcurrentMap) MarshalJSON() ([]byte, error) {
// Create a temporary map, which will hold all item spread across shards.
tmp := make(map[string]interface{})
// Insert items to temporary map.
for item := range m.IterBuffered() {
tmp[item.Key] = item.Val
}
return json.Marshal(tmp)
}
func fnv32(key string) uint32 {
hash := uint32(2166136261)
const prime32 = uint32(16777619)
for i := 0; i < len(key); i++ {
hash *= prime32
hash ^= uint32(key[i])
}
return hash
}
// Concurrent map uses Interface{} as its value, therefor JSON Unmarshal
// will probably won't know which to type to unmarshal into, in such case
// we'll end up with a value of type map[string]interface{}, In most cases this isn't
// out value type, this is why we've decided to remove this functionality.
// func (m *ConcurrentMap) UnmarshalJSON(b []byte) (err error) {
// // Reverse process of Marshal.
// tmp := make(map[string]interface{})
// // Unmarshal into a single map.
// if err := json.Unmarshal(b, &tmp); err != nil {
// return nil
// }
// // foreach key,value pair in temporary map insert into our concurrent map.
// for key, val := range tmp {
// m.Set(key, val)
// }
// return nil
// }

3
utils/pio/pio.go

@ -0,0 +1,3 @@ @@ -0,0 +1,3 @@
package pio
var RecommendBufioSize = 1024 * 64

121
utils/pio/reader.go

@ -0,0 +1,121 @@ @@ -0,0 +1,121 @@
package pio
func U8(b []byte) (i uint8) {
return b[0]
}
func U16BE(b []byte) (i uint16) {
i = uint16(b[0])
i <<= 8
i |= uint16(b[1])
return
}
func I16BE(b []byte) (i int16) {
i = int16(b[0])
i <<= 8
i |= int16(b[1])
return
}
func I24BE(b []byte) (i int32) {
i = int32(int8(b[0]))
i <<= 8
i |= int32(b[1])
i <<= 8
i |= int32(b[2])
return
}
func U24BE(b []byte) (i uint32) {
i = uint32(b[0])
i <<= 8
i |= uint32(b[1])
i <<= 8
i |= uint32(b[2])
return
}
func I32BE(b []byte) (i int32) {
i = int32(int8(b[0]))
i <<= 8
i |= int32(b[1])
i <<= 8
i |= int32(b[2])
i <<= 8
i |= int32(b[3])
return
}
func U32LE(b []byte) (i uint32) {
i = uint32(b[3])
i <<= 8
i |= uint32(b[2])
i <<= 8
i |= uint32(b[1])
i <<= 8
i |= uint32(b[0])
return
}
func U32BE(b []byte) (i uint32) {
i = uint32(b[0])
i <<= 8
i |= uint32(b[1])
i <<= 8
i |= uint32(b[2])
i <<= 8
i |= uint32(b[3])
return
}
func U40BE(b []byte) (i uint64) {
i = uint64(b[0])
i <<= 8
i |= uint64(b[1])
i <<= 8
i |= uint64(b[2])
i <<= 8
i |= uint64(b[3])
i <<= 8
i |= uint64(b[4])
return
}
func U64BE(b []byte) (i uint64) {
i = uint64(b[0])
i <<= 8
i |= uint64(b[1])
i <<= 8
i |= uint64(b[2])
i <<= 8
i |= uint64(b[3])
i <<= 8
i |= uint64(b[4])
i <<= 8
i |= uint64(b[5])
i <<= 8
i |= uint64(b[6])
i <<= 8
i |= uint64(b[7])
return
}
func I64BE(b []byte) (i int64) {
i = int64(int8(b[0]))
i <<= 8
i |= int64(b[1])
i <<= 8
i |= int64(b[2])
i <<= 8
i |= int64(b[3])
i <<= 8
i |= int64(b[4])
i <<= 8
i |= int64(b[5])
i <<= 8
i |= int64(b[6])
i <<= 8
i |= int64(b[7])
return
}

87
utils/pio/writer.go

@ -0,0 +1,87 @@ @@ -0,0 +1,87 @@
package pio
func PutU8(b []byte, v uint8) {
b[0] = v
}
func PutI16BE(b []byte, v int16) {
b[0] = byte(v >> 8)
b[1] = byte(v)
}
func PutU16BE(b []byte, v uint16) {
b[0] = byte(v >> 8)
b[1] = byte(v)
}
func PutI24BE(b []byte, v int32) {
b[0] = byte(v >> 16)
b[1] = byte(v >> 8)
b[2] = byte(v)
}
func PutU24BE(b []byte, v uint32) {
b[0] = byte(v >> 16)
b[1] = byte(v >> 8)
b[2] = byte(v)
}
func PutI32BE(b []byte, v int32) {
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
}
func PutU32BE(b []byte, v uint32) {
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
}
func PutU32LE(b []byte, v uint32) {
b[3] = byte(v >> 24)
b[2] = byte(v >> 16)
b[1] = byte(v >> 8)
b[0] = byte(v)
}
func PutU40BE(b []byte, v uint64) {
b[0] = byte(v >> 32)
b[1] = byte(v >> 24)
b[2] = byte(v >> 16)
b[3] = byte(v >> 8)
b[4] = byte(v)
}
func PutU48BE(b []byte, v uint64) {
b[0] = byte(v >> 40)
b[1] = byte(v >> 32)
b[2] = byte(v >> 24)
b[3] = byte(v >> 16)
b[4] = byte(v >> 8)
b[5] = byte(v)
}
func PutU64BE(b []byte, v uint64) {
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
}
func PutI64BE(b []byte, v int64) {
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
}

24
utils/pool/pool.go

@ -0,0 +1,24 @@ @@ -0,0 +1,24 @@
package pool
type Pool struct {
pos int
buf []byte
}
const maxpoolsize = 500 * 1024
func (self *Pool) Get(size int) []byte {
if maxpoolsize-self.pos < size {
self.pos = 0
self.buf = make([]byte, maxpoolsize)
}
b := self.buf[self.pos: self.pos+size]
self.pos += size
return b
}
func NewPool() *Pool {
return &Pool{
buf: make([]byte, maxpoolsize),
}
}

72
utils/queue/queue.go

@ -0,0 +1,72 @@ @@ -0,0 +1,72 @@
package queue
import (
"sync"
"github.com/gwuhaolin/livego/av"
)
// Queue is a basic FIFO queue for Messages.
type Queue struct {
maxSize int
list []*av.Packet
mutex sync.Mutex
}
// NewQueue returns a new Queue. If maxSize is greater than zero the queue will
// not grow more than the defined size.
func NewQueue(maxSize int) *Queue {
return &Queue{
maxSize: maxSize,
}
}
// Push adds a message to the queue.
func (q *Queue) Push(msg *av.Packet) {
q.mutex.Lock()
defer q.mutex.Unlock()
if len(q.list) == q.maxSize {
q.pop()
}
q.list = append(q.list, msg)
}
// Pop removes and returns a message from the queue in first to last order.
func (q *Queue) Pop() *av.Packet {
q.mutex.Lock()
defer q.mutex.Unlock()
if len(q.list) == 0 {
return nil
}
return q.pop()
}
func (q *Queue) pop() *av.Packet {
x := len(q.list) - 1
msg := q.list[x]
q.list = q.list[:x]
return msg
}
// Len returns the length of the queue.
func (q *Queue) Len() int {
q.mutex.Lock()
defer q.mutex.Unlock()
return len(q.list)
}
// All returns and removes all messages from the queue.
func (q *Queue) All() []*av.Packet {
q.mutex.Lock()
defer q.mutex.Unlock()
cache := q.list
q.list = nil
return cache
}

450
utils/uid/uuid.go

@ -0,0 +1,450 @@ @@ -0,0 +1,450 @@
package uid
// Copyright (C) 2013-2015 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// Package uuid provides implementation of Universally Unique Identifier (UUID).
// Supported versions are 1, 3, 4 and 5 (as specified in RFC 4122) and
// version 2 (as specified in DCE 1.1).
import (
"bytes"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"database/sql/driver"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"fmt"
"hash"
"net"
"os"
"sync"
"time"
)
func NEWID() string {
id := NewV4()
b64 := base64.URLEncoding.EncodeToString(id.Bytes()[:12])
return b64
}
// UUID layout variants.
const (
VariantNCS = iota
VariantRFC4122
VariantMicrosoft
VariantFuture
)
// UUID DCE domains.
const (
DomainPerson = iota
DomainGroup
DomainOrg
)
// Difference in 100-nanosecond intervals between
// UUID epoch (October 15, 1582) and Unix epoch (January 1, 1970).
const epochStart = 122192928000000000
// Used in string method conversion
const dash byte = '-'
// UUID v1/v2 storage.
var (
storageMutex sync.Mutex
storageOnce sync.Once
epochFunc = unixTimeFunc
clockSequence uint16
lastTime uint64
hardwareAddr [6]byte
posixUID = uint32(os.Getuid())
posixGID = uint32(os.Getgid())
)
// String parse helpers.
var (
urnPrefix = []byte("urn:uuid:")
byteGroups = []int{8, 4, 4, 4, 12}
)
func initClockSequence() {
buf := make([]byte, 2)
safeRandom(buf)
clockSequence = binary.BigEndian.Uint16(buf)
}
func initHardwareAddr() {
interfaces, err := net.Interfaces()
if err == nil {
for _, iface := range interfaces {
if len(iface.HardwareAddr) >= 6 {
copy(hardwareAddr[:], iface.HardwareAddr)
return
}
}
}
// Initialize hardwareAddr randomly in case
// of real network interfaces absence
safeRandom(hardwareAddr[:])
// Set multicast bit as recommended in RFC 4122
hardwareAddr[0] |= 0x01
}
func initStorage() {
initClockSequence()
initHardwareAddr()
}
func safeRandom(dest []byte) {
if _, err := rand.Read(dest); err != nil {
panic(err)
}
}
// Returns difference in 100-nanosecond intervals between
// UUID epoch (October 15, 1582) and current time.
// This is default epoch calculation function.
func unixTimeFunc() uint64 {
return epochStart + uint64(time.Now().UnixNano()/100)
}
// UUID representation compliant with specification
// described in RFC 4122.
type UUID [16]byte
// The nil UUID is special form of UUID that is specified to have all
// 128 bits set to zero.
var Nil = UUID{}
// Predefined namespace UUIDs.
var (
NamespaceDNS, _ = FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
NamespaceURL, _ = FromString("6ba7b811-9dad-11d1-80b4-00c04fd430c8")
NamespaceOID, _ = FromString("6ba7b812-9dad-11d1-80b4-00c04fd430c8")
NamespaceX500, _ = FromString("6ba7b814-9dad-11d1-80b4-00c04fd430c8")
)
// And returns result of binary AND of two UUIDs.
func And(u1 UUID, u2 UUID) UUID {
u := UUID{}
for i := 0; i < 16; i++ {
u[i] = u1[i] & u2[i]
}
return u
}
// Or returns result of binary OR of two UUIDs.
func Or(u1 UUID, u2 UUID) UUID {
u := UUID{}
for i := 0; i < 16; i++ {
u[i] = u1[i] | u2[i]
}
return u
}
// Equal returns true if u1 and u2 equals, otherwise returns false.
func Equal(u1 UUID, u2 UUID) bool {
return bytes.Equal(u1[:], u2[:])
}
// Version returns algorithm version used to generate UUID.
func (u UUID) Version() uint {
return uint(u[6] >> 4)
}
// Variant returns UUID layout variant.
func (u UUID) Variant() uint {
switch {
case (u[8] & 0x80) == 0x00:
return VariantNCS
case (u[8]&0xc0)|0x80 == 0x80:
return VariantRFC4122
case (u[8]&0xe0)|0xc0 == 0xc0:
return VariantMicrosoft
}
return VariantFuture
}
// Bytes returns bytes slice representation of UUID.
func (u UUID) Bytes() []byte {
return u[:]
}
// Returns canonical string representation of UUID:
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx.
func (u UUID) String() string {
buf := make([]byte, 36)
hex.Encode(buf[0:8], u[0:4])
buf[8] = dash
hex.Encode(buf[9:13], u[4:6])
buf[13] = dash
hex.Encode(buf[14:18], u[6:8])
buf[18] = dash
hex.Encode(buf[19:23], u[8:10])
buf[23] = dash
hex.Encode(buf[24:], u[10:])
return string(buf)
}
// SetVersion sets version bits.
func (u *UUID) SetVersion(v byte) {
u[6] = (u[6] & 0x0f) | (v << 4)
}
// SetVariant sets variant bits as described in RFC 4122.
func (u *UUID) SetVariant() {
u[8] = (u[8] & 0xbf) | 0x80
}
// MarshalText implements the encoding.TextMarshaler interface.
// The encoding is the same as returned by String.
func (u UUID) MarshalText() (text []byte, err error) {
text = []byte(u.String())
return
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// Following formats are supported:
// "6ba7b810-9dad-11d1-80b4-00c04fd430c8",
// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}",
// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8"
func (u *UUID) UnmarshalText(text []byte) (err error) {
if len(text) < 32 {
err = fmt.Errorf("uuid: UUID string too short: %s", text)
return
}
t := text[:]
if bytes.Equal(t[:9], urnPrefix) {
t = t[9:]
} else if t[0] == '{' {
t = t[1:]
}
b := u[:]
for _, byteGroup := range byteGroups {
if t[0] == '-' {
t = t[1:]
}
if len(t) < byteGroup {
err = fmt.Errorf("uuid: UUID string too short: %s", text)
return
}
_, err = hex.Decode(b[:byteGroup/2], t[:byteGroup])
if err != nil {
return
}
t = t[byteGroup:]
b = b[byteGroup/2:]
}
return
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (u UUID) MarshalBinary() (data []byte, err error) {
data = u.Bytes()
return
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
// It will return error if the slice isn't 16 bytes long.
func (u *UUID) UnmarshalBinary(data []byte) (err error) {
if len(data) != 16 {
err = fmt.Errorf("uuid: UUID must be exactly 16 bytes long, got %d bytes", len(data))
return
}
copy(u[:], data)
return
}
// Value implements the driver.Valuer interface.
func (u UUID) Value() (driver.Value, error) {
return u.String(), nil
}
// Scan implements the sql.Scanner interface.
// A 16-byte slice is handled by UnmarshalBinary, while
// a longer byte slice or a string is handled by UnmarshalText.
func (u *UUID) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
if len(src) == 16 {
return u.UnmarshalBinary(src)
}
return u.UnmarshalText(src)
case string:
return u.UnmarshalText([]byte(src))
}
return fmt.Errorf("uuid: cannot convert %T to UUID", src)
}
// FromBytes returns UUID converted from raw byte slice input.
// It will return error if the slice isn't 16 bytes long.
func FromBytes(input []byte) (u UUID, err error) {
err = u.UnmarshalBinary(input)
return
}
// FromBytesOrNil returns UUID converted from raw byte slice input.
// Same behavior as FromBytes, but returns a Nil UUID on error.
func FromBytesOrNil(input []byte) UUID {
uuid, err := FromBytes(input)
if err != nil {
return Nil
}
return uuid
}
// FromString returns UUID parsed from string input.
// Input is expected in a form accepted by UnmarshalText.
func FromString(input string) (u UUID, err error) {
err = u.UnmarshalText([]byte(input))
return
}
// FromStringOrNil returns UUID parsed from string input.
// Same behavior as FromString, but returns a Nil UUID on error.
func FromStringOrNil(input string) UUID {
uuid, err := FromString(input)
if err != nil {
return Nil
}
return uuid
}
// Returns UUID v1/v2 storage state.
// Returns epoch timestamp, clock sequence, and hardware address.
func getStorage() (uint64, uint16, []byte) {
storageOnce.Do(initStorage)
storageMutex.Lock()
defer storageMutex.Unlock()
timeNow := epochFunc()
// Clock changed backwards since last UUID generation.
// Should increase clock sequence.
if timeNow <= lastTime {
clockSequence++
}
lastTime = timeNow
return timeNow, clockSequence, hardwareAddr[:]
}
// NewV1 returns UUID based on current timestamp and MAC address.
func NewV1() UUID {
u := UUID{}
timeNow, clockSeq, hardwareAddr := getStorage()
binary.BigEndian.PutUint32(u[0:], uint32(timeNow))
binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32))
binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48))
binary.BigEndian.PutUint16(u[8:], clockSeq)
copy(u[10:], hardwareAddr)
u.SetVersion(1)
u.SetVariant()
return u
}
// NewV2 returns DCE Security UUID based on POSIX UID/GID.
func NewV2(domain byte) UUID {
u := UUID{}
timeNow, clockSeq, hardwareAddr := getStorage()
switch domain {
case DomainPerson:
binary.BigEndian.PutUint32(u[0:], posixUID)
case DomainGroup:
binary.BigEndian.PutUint32(u[0:], posixGID)
}
binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32))
binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48))
binary.BigEndian.PutUint16(u[8:], clockSeq)
u[9] = domain
copy(u[10:], hardwareAddr)
u.SetVersion(2)
u.SetVariant()
return u
}
// NewV3 returns UUID based on MD5 hash of namespace UUID and name.
func NewV3(ns UUID, name string) UUID {
u := newFromHash(md5.New(), ns, name)
u.SetVersion(3)
u.SetVariant()
return u
}
// NewV4 returns random generated UUID.
func NewV4() UUID {
u := UUID{}
safeRandom(u[:])
u.SetVersion(4)
u.SetVariant()
return u
}
// NewV5 returns UUID based on SHA-1 hash of namespace UUID and name.
func NewV5(ns UUID, name string) UUID {
u := newFromHash(sha1.New(), ns, name)
u.SetVersion(5)
u.SetVariant()
return u
}
// Returns UUID based on hashing of namespace UUID and name.
func newFromHash(h hash.Hash, ns UUID, name string) UUID {
u := UUID{}
h.Write(ns[:])
h.Write([]byte(name))
copy(u[:], h.Sum(nil))
return u
}
Loading…
Cancel
Save