package fmp4 import ( "bytes" "fmt" gomp4 "github.com/abema/go-mp4" ) const ( trunFlagDataOffsetPreset = 0x01 trunFlagSampleDurationPresent = 0x100 trunFlagSampleSizePresent = 0x200 trunFlagSampleFlagsPresent = 0x400 trunFlagSampleCompositionTimeOffsetPresentOrV1 = 0x800 sampleFlagIsNonSyncSample = 1 << 16 ) // Part is a FMP4 part file. type Part struct { Tracks []*PartTrack } // Parts is a sequence of FMP4 parts. type Parts []*Part // Unmarshal decodes one or more FMP4 parts. func (ps *Parts) Unmarshal(byts []byte) error { type readState int const ( waitingMoof readState = iota waitingTraf waitingTfdtTfhdTrun ) state := waitingMoof var curPart *Part var moofOffset uint64 var curTrack *PartTrack var tfdt *gomp4.Tfdt var tfhd *gomp4.Tfhd _, err := gomp4.ReadBoxStructure(bytes.NewReader(byts), func(h *gomp4.ReadHandle) (interface{}, error) { switch h.BoxInfo.Type.String() { case "moof": if state != waitingMoof { return nil, fmt.Errorf("unexpected moof") } curPart = &Part{} *ps = append(*ps, curPart) moofOffset = h.BoxInfo.Offset state = waitingTraf case "traf": if state != waitingTraf && state != waitingTfdtTfhdTrun { return nil, fmt.Errorf("unexpected traf") } if curTrack != nil { if tfdt == nil || tfhd == nil || curTrack.Samples == nil { return nil, fmt.Errorf("parse error") } } curTrack = &PartTrack{} curPart.Tracks = append(curPart.Tracks, curTrack) tfdt = nil tfhd = nil state = waitingTfdtTfhdTrun case "tfhd": if state != waitingTfdtTfhdTrun || tfhd != nil { return nil, fmt.Errorf("unexpected tfhd") } box, _, err := h.ReadPayload() if err != nil { return nil, err } tfhd = box.(*gomp4.Tfhd) curTrack.ID = int(tfhd.TrackID) case "tfdt": if state != waitingTfdtTfhdTrun || tfdt != nil { return nil, fmt.Errorf("unexpected tfdt") } box, _, err := h.ReadPayload() if err != nil { return nil, err } tfdt = box.(*gomp4.Tfdt) if tfdt.FullBox.Version != 1 { return nil, fmt.Errorf("unsupported tfdt version") } curTrack.BaseTime = tfdt.BaseMediaDecodeTimeV1 case "trun": if state != waitingTfdtTfhdTrun || tfhd == nil { return nil, fmt.Errorf("unexpected trun") } box, _, err := h.ReadPayload() if err != nil { return nil, err } trun := box.(*gomp4.Trun) trunFlags := uint16(trun.Flags[1])<<8 | uint16(trun.Flags[2]) if (trunFlags & trunFlagDataOffsetPreset) == 0 { return nil, fmt.Errorf("unsupported flags") } existing := len(curTrack.Samples) tmp := make([]*PartSample, existing+len(trun.Entries)) copy(tmp, curTrack.Samples) curTrack.Samples = tmp ptr := byts[uint64(trun.DataOffset)+moofOffset:] for i, e := range trun.Entries { s := &PartSample{} if (trunFlags & trunFlagSampleDurationPresent) != 0 { s.Duration = e.SampleDuration } else { s.Duration = tfhd.DefaultSampleDuration } s.PTSOffset = e.SampleCompositionTimeOffsetV1 var sampleFlags uint32 if (trunFlags & trunFlagSampleFlagsPresent) != 0 { sampleFlags = e.SampleFlags } else { sampleFlags = tfhd.DefaultSampleFlags } s.IsNonSyncSample = ((sampleFlags & sampleFlagIsNonSyncSample) != 0) var size uint32 if (trunFlags & trunFlagSampleSizePresent) != 0 { size = e.SampleSize } else { size = tfhd.DefaultSampleSize } s.Payload = ptr[:size] ptr = ptr[size:] curTrack.Samples[existing+i] = s } case "mdat": if state != waitingTraf && state != waitingTfdtTfhdTrun { return nil, fmt.Errorf("unexpected mdat") } if curTrack != nil { if tfdt == nil || tfhd == nil || curTrack.Samples == nil { return nil, fmt.Errorf("parse error") } } state = waitingMoof return nil, nil } return h.Expand() }) if err != nil { return err } if state != waitingMoof { return fmt.Errorf("decode error") } return nil } // Marshal encodes a FMP4 part file. func (p *Part) Marshal() ([]byte, error) { /* moof - mfhd - traf (video) - traf (audio) mdat */ w := newMP4Writer() moofOffset, err := w.writeBoxStart(&gomp4.Moof{}) // if err != nil { return nil, err } _, err = w.WriteBox(&gomp4.Mfhd{ // SequenceNumber: 0, }) if err != nil { return nil, err } trackLen := len(p.Tracks) truns := make([]*gomp4.Trun, trackLen) trunOffsets := make([]int, trackLen) dataOffsets := make([]int, trackLen) dataSize := 0 for i, track := range p.Tracks { trun, trunOffset, err := track.marshal(w) if err != nil { return nil, err } dataOffsets[i] = dataSize for _, sample := range track.Samples { dataSize += len(sample.Payload) } truns[i] = trun trunOffsets[i] = trunOffset } err = w.writeBoxEnd() // if err != nil { return nil, err } mdat := &gomp4.Mdat{} // mdat.Data = make([]byte, dataSize) pos := 0 for _, track := range p.Tracks { for _, sample := range track.Samples { pos += copy(mdat.Data[pos:], sample.Payload) } } mdatOffset, err := w.WriteBox(mdat) if err != nil { return nil, err } for i := range p.Tracks { truns[i].DataOffset = int32(dataOffsets[i] + mdatOffset - moofOffset + 8) err = w.rewriteBox(trunOffsets[i], truns[i]) if err != nil { return nil, err } } return w.bytes(), nil }