Fixup frame timestamp in MP4 file without ffmpeg

This commit is contained in:
Lesmiscore 2022-06-11 16:08:51 +09:00
parent 56ba69e4c9
commit d4c52a28af
No known key found for this signature in database
GPG key ID: 0EC2B52CF86236FF
8 changed files with 395 additions and 20 deletions

51
test/test_mp4parser.py Normal file
View file

@ -0,0 +1,51 @@
#!/usr/bin/env python
# Allow direct execution
import os
import sys
import unittest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import io
from yt_dlp.mp4_parser import (
parse_mp4_boxes,
write_mp4_boxes,
)
TEST_SEQUENCE = [
('test', b'123456'),
('trak', b''),
('helo', b'abcdef'),
('1984', b'1q84'),
('moov', b''),
('keys', b'2022'),
(None, 'moov'),
('topp', b'1991'),
(None, 'trak'),
]
# on-file reprensetation of the above sequence
TEST_BYTES = b'\x00\x00\x00\x0etest123456\x00\x00\x00Btrak\x00\x00\x00\x0eheloabcdef\x00\x00\x00\x0c19841q84\x00\x00\x00\x14moov\x00\x00\x00\x0ckeys2022\x00\x00\x00\x0ctopp1991'
class TestMP4Parser(unittest.TestCase):
def test_write_sequence(self):
with io.BytesIO() as w:
write_mp4_boxes(w, TEST_SEQUENCE)
bs = w.getvalue()
self.assertEqual(TEST_BYTES, bs)
def test_read_bytes(self):
with io.BytesIO(TEST_BYTES) as r:
result = list(parse_mp4_boxes(r))
self.assertListEqual(TEST_SEQUENCE, result)
def test_mismatched_box_end(self):
with io.BytesIO() as w, self.assertRaises(AssertionError):
write_mp4_boxes(w, [
('moov', b''),
('trak', b''),
(None, 'moov'),
(None, 'trak'),
])

View file

@ -55,6 +55,7 @@
FFmpegMergerPP, FFmpegMergerPP,
FFmpegPostProcessor, FFmpegPostProcessor,
MoveFilesAfterDownloadPP, MoveFilesAfterDownloadPP,
MP4FixupTimestampPP,
get_postprocessor, get_postprocessor,
) )
from .update import detect_variant from .update import detect_variant
@ -3256,8 +3257,11 @@ def ffmpeg_fixup(cndn, msg, cls):
ffmpeg_fixup(info_dict.get('is_live') and downloader == 'DashSegmentsFD', ffmpeg_fixup(info_dict.get('is_live') and downloader == 'DashSegmentsFD',
'Possible duplicate MOOV atoms', FFmpegFixupDuplicateMoovPP) 'Possible duplicate MOOV atoms', FFmpegFixupDuplicateMoovPP)
is_fmp4 = info_dict.get('protocol') == 'websocket_frag' and info_dict.get('container') == 'fmp4'
ffmpeg_fixup(downloader == 'web_socket_fragment', 'Malformed timestamps detected', FFmpegFixupTimestampPP) ffmpeg_fixup(downloader == 'web_socket_fragment', 'Malformed timestamps detected', FFmpegFixupTimestampPP)
ffmpeg_fixup(downloader == 'web_socket_fragment', 'Malformed duration detected', FFmpegFixupDurationPP) ffmpeg_fixup(downloader == 'web_socket_fragment', 'Malformed duration detected', FFmpegFixupDurationPP)
ffmpeg_fixup(downloader == 'web_socket_to_file' and is_fmp4, 'Malformed timestamps detected', MP4FixupTimestampPP)
ffmpeg_fixup(downloader == 'web_socket_to_file' and is_fmp4, 'Possible duplicate MOOV atoms', FFmpegFixupDuplicateMoovPP)
fixup() fixup()
try: try:

View file

@ -33,7 +33,7 @@ def get_suitable_downloader(info_dict, params={}, default=NO_DEFAULT, protocol=N
from .niconico import NiconicoDmcFD from .niconico import NiconicoDmcFD
from .rtmp import RtmpFD from .rtmp import RtmpFD
from .rtsp import RtspFD from .rtsp import RtspFD
from .websocket import WebSocketFragmentFD from .websocket import WebSocketFragmentFD, WebSocketToFileFD
from .youtube_live_chat import YoutubeLiveChatFD from .youtube_live_chat import YoutubeLiveChatFD
PROTOCOL_MAP = { PROTOCOL_MAP = {
@ -118,6 +118,9 @@ def _get_suitable_downloader(info_dict, protocol, params, default):
elif params.get('hls_prefer_native') is False: elif params.get('hls_prefer_native') is False:
return FFmpegFD return FFmpegFD
if protocol == 'websocket_frag' and info_dict.get('container') == 'fmp4' and external_downloader != 'ffmpeg':
return WebSocketToFileFD
return PROTOCOL_MAP.get(protocol, default) return PROTOCOL_MAP.get(protocol, default)

View file

@ -1,7 +1,6 @@
import contextlib import contextlib
import os
import signal
import threading import threading
import time
from .common import FileDownloader from .common import FileDownloader
from .external import FFmpegFD from .external import FFmpegFD
@ -9,14 +8,8 @@
from ..dependencies import websockets from ..dependencies import websockets
class FFmpegSinkFD(FileDownloader): class AsyncSinkFD(FileDownloader):
""" A sink to ffmpeg for downloading fragments in any form """ async def connect(self, stdin, info_dict):
def real_download(self, filename, info_dict):
info_copy = info_dict.copy()
info_copy['url'] = '-'
async def call_conn(proc, stdin):
try: try:
await self.real_connection(stdin, info_dict) await self.real_connection(stdin, info_dict)
except OSError: except OSError:
@ -25,7 +18,19 @@ async def call_conn(proc, stdin):
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
stdin.flush() stdin.flush()
stdin.close() stdin.close()
os.kill(os.getpid(), signal.SIGINT)
async def real_connection(self, sink, info_dict):
""" Override this in subclasses """
raise NotImplementedError('This method must be implemented by subclasses')
class FFmpegSinkFD(AsyncSinkFD):
""" A sink to ffmpeg for downloading fragments in any form """
def real_download(self, filename, info_dict):
info_copy = info_dict.copy()
info_copy['url'] = '-'
connect = self.connect
class FFmpegStdinFD(FFmpegFD): class FFmpegStdinFD(FFmpegFD):
@classmethod @classmethod
@ -33,17 +38,57 @@ def get_basename(cls):
return FFmpegFD.get_basename() return FFmpegFD.get_basename()
def on_process_started(self, proc, stdin): def on_process_started(self, proc, stdin):
thread = threading.Thread(target=asyncio.run, daemon=True, args=(call_conn(proc, stdin), )) thread = threading.Thread(target=asyncio.run, daemon=True, args=(connect(stdin, info_dict), ))
thread.start() thread.start()
return FFmpegStdinFD(self.ydl, self.params or {}).download(filename, info_copy) return FFmpegStdinFD(self.ydl, self.params or {}).download(filename, info_copy)
async def real_connection(self, sink, info_dict):
""" Override this in subclasses """ class FileSinkFD(AsyncSinkFD):
raise NotImplementedError('This method must be implemented by subclasses') """ A sink to a file for downloading fragments in any form """
def real_download(self, filename, info_dict):
tempname = self.temp_name(filename)
try:
with open(tempname, 'wb') as w:
started = time.time()
status = {
'filename': info_dict.get('_filename'),
'status': 'downloading',
'elapsed': 0,
'downloaded_bytes': 0,
}
self._hook_progress(status, info_dict)
thread = threading.Thread(target=asyncio.run, daemon=True, args=(self.connect(w, info_dict), ))
thread.start()
time_and_size, avg_len = [], 10
while thread.is_alive():
time.sleep(0.1)
downloaded, curr = w.tell(), time.time()
# taken from ffmpeg attachment
time_and_size.append((downloaded, curr))
time_and_size = time_and_size[-avg_len:]
if len(time_and_size) > 1:
last, early = time_and_size[0], time_and_size[-1]
average_speed = (early[0] - last[0]) / (early[1] - last[1])
else:
average_speed = None
status.update({
'downloaded_bytes': downloaded,
'speed': average_speed,
'elapsed': curr - started,
})
self._hook_progress(status, info_dict)
except KeyboardInterrupt:
pass
finally:
self.ydl.replace(tempname, filename)
return True
class WebSocketFragmentFD(FFmpegSinkFD): class _WebSocketFD(AsyncSinkFD):
async def real_connection(self, sink, info_dict): async def real_connection(self, sink, info_dict):
async with websockets.connect(info_dict['url'], extra_headers=info_dict.get('http_headers', {})) as ws: async with websockets.connect(info_dict['url'], extra_headers=info_dict.get('http_headers', {})) as ws:
while True: while True:
@ -51,3 +96,11 @@ async def real_connection(self, sink, info_dict):
if isinstance(recv, str): if isinstance(recv, str):
recv = recv.encode('utf8') recv = recv.encode('utf8')
sink.write(recv) sink.write(recv)
class WebSocketFragmentFD(_WebSocketFD, FFmpegSinkFD):
pass
class WebSocketToFileFD(_WebSocketFD, FileSinkFD):
pass

View file

@ -173,6 +173,7 @@ def find_dmu(x):
'source_preference': -10, 'source_preference': -10,
# TwitCasting simply sends moof atom directly over WS # TwitCasting simply sends moof atom directly over WS
'protocol': 'websocket_frag', 'protocol': 'websocket_frag',
'container': 'fmp4',
}) })
self._sort_formats(formats, ('source',)) self._sort_formats(formats, ('source',))

136
yt_dlp/mp4_parser.py Normal file
View file

@ -0,0 +1,136 @@
import struct
from typing import Tuple
from io import BytesIO, RawIOBase
class LengthLimiter(RawIOBase):
def __init__(self, r: RawIOBase, size: int):
self.r = r
self.remaining = size
def read(self, sz: int = None) -> bytes:
if self.remaining == 0:
return b''
if sz in (-1, None):
sz = self.remaining
sz = min(sz, self.remaining)
ret = self.r.read(sz)
if ret:
self.remaining -= len(ret)
return ret
def readall(self) -> bytes:
if self.remaining == 0:
return b''
ret = self.read(self.remaining)
if ret:
self.remaining -= len(ret)
return ret
def readable(self) -> bool:
return bool(self.remaining)
def read_harder(r, size):
retry = 0
buf = b''
while len(buf) < size and retry < 3:
ret = r.read(size - len(buf))
if not ret:
retry += 1
continue
retry = 0
buf += ret
return buf
def pack_be32(value: int) -> bytes:
return struct.pack('>I', value)
def pack_be64(value: int) -> bytes:
return struct.pack('>L', value)
def unpack_be32(value: bytes) -> int:
return struct.unpack('>I', value)[0]
def unpack_ver_flags(value: bytes) -> Tuple[int, int]:
ver, up_flag, down_flag = struct.unpack('>BBH', value)
return ver, (up_flag << 16 | down_flag)
def unpack_be64(value: bytes) -> int:
return struct.unpack('>L', value)[0]
# https://github.com/gpac/mp4box.js/blob/4e1bc23724d2603754971abc00c2bd5aede7be60/src/box.js#L13-L40
MP4_CONTAINER_BOXES = ('moov', 'trak', 'edts', 'mdia', 'minf', 'dinf', 'stbl', 'mvex', 'moof', 'traf', 'vttc', 'tref', 'iref', 'mfra', 'meco', 'hnti', 'hinf', 'strk', 'strd', 'sinf', 'rinf', 'schi', 'trgr', 'udta', 'iprp', 'ipco')
def parse_mp4_boxes(r: RawIOBase):
"""
Parses an ISO BMFF (which MP4 follows) and yields its boxes as a sequence.
This does not interpret content of these boxes.
Sequence details:
('atom', b'blablabla'): A box, with content (not container boxes)
('atom', b''): Possibly container box (must check MP4_CONTAINER_BOXES) or really an empty box
(None, 'atom'): End of a container box
Example: Path:
('test', b'123456') /test
('box1', b'') /box1 (start of container box)
('helo', b'abcdef') /box1/helo
('1984', b'1q84') /box1/1984
('http', b'') /box1/http (start of container box)
('keys', b'2022') /box1/http/keys
(None , 'http') /box1/http (end of container box)
('topp', b'1991') /box1/topp
(None , 'box1') /box1 (end of container box)
"""
while True:
size_b = read_harder(r, 4)
if not size_b:
break
type_b = r.read(4)
# 00 00 00 20 is big-endian
box_size = unpack_be32(size_b)
type_s = type_b.decode()
if type_s in MP4_CONTAINER_BOXES:
yield (type_s, b'')
yield from parse_mp4_boxes(LengthLimiter(r, box_size - 8))
yield (None, type_s)
continue
# subtract by 8
full_body = read_harder(r, box_size - 8)
yield (type_s, full_body)
def write_mp4_boxes(w: RawIOBase, box_iter):
"""
Writes an ISO BMFF file from a given sequence to a given writer.
The iterator to be passed must follow parse_mp4_boxes's protocol.
"""
stack = [
(None, w), # parent box, IO
]
for btype, content in box_iter:
if btype in MP4_CONTAINER_BOXES:
bio = BytesIO()
stack.append((btype, bio))
continue
elif btype is None:
assert stack[-1][0] == content
btype, bio = stack.pop()
content = bio.getvalue()
wt = stack[-1][1]
wt.write(pack_be32(len(content) + 8))
wt.write(btype.encode()[:4])
wt.write(content)

View file

@ -30,6 +30,7 @@
) )
from .modify_chapters import ModifyChaptersPP from .modify_chapters import ModifyChaptersPP
from .movefilesafterdownload import MoveFilesAfterDownloadPP from .movefilesafterdownload import MoveFilesAfterDownloadPP
from .mp4direct import MP4FixupTimestampPP
from .sponskrub import SponSkrubPP from .sponskrub import SponSkrubPP
from .sponsorblock import SponsorBlockPP from .sponsorblock import SponsorBlockPP
from .xattrpp import XAttrMetadataPP from .xattrpp import XAttrMetadataPP

View file

@ -0,0 +1,126 @@
from .common import PostProcessor
from ..utils import prepend_extension
from ..mp4_parser import (
write_mp4_boxes,
parse_mp4_boxes,
pack_be32,
pack_be64,
unpack_ver_flags,
unpack_be32,
unpack_be64,
)
class MP4FixupTimestampPP(PostProcessor):
@property
def available(self):
return True
def analyze_mp4(self, filepath):
""" returns (baseMediaDecodeTime offset, sample duration cutoff) """
smallest_bmdt, known_sdur = float('inf'), set()
with open(filepath, 'rb') as r:
for btype, content in parse_mp4_boxes(r):
if btype == 'tfdt':
version, _ = unpack_ver_flags(content[0:4])
# baseMediaDecodeTime always comes to the first
if version == 0:
bmdt = unpack_be32(content[4:8])
else:
bmdt = unpack_be64(content[4:12])
if bmdt == 0:
continue
smallest_bmdt = min(bmdt, smallest_bmdt)
elif btype == 'tfhd':
version, flags = unpack_ver_flags(content[0:4])
if not flags & 0x08:
# this box does not contain "sample duration"
continue
# https://github.com/gpac/mp4box.js/blob/4e1bc23724d2603754971abc00c2bd5aede7be60/src/box.js#L203-L209
# https://github.com/gpac/mp4box.js/blob/4e1bc23724d2603754971abc00c2bd5aede7be60/src/parsing/tfhd.js
sdur_start = 8 # header + track id
if flags & 0x01:
sdur_start += 8
if flags & 0x02:
sdur_start += 4
# the next 4 bytes are "sample duration"
sample_dur = unpack_be32(content[sdur_start:sdur_start + 4])
known_sdur.add(sample_dur)
maximum_sdur = max(known_sdur)
for multiplier in (0.7, 0.8, 0.9, 0.95):
sdur_cutoff = maximum_sdur * multiplier
if len(set(x for x in known_sdur if x > sdur_cutoff)) < 3:
break
else:
sdur_cutoff = float('inf')
return smallest_bmdt, sdur_cutoff
def modify_mp4(self, src, dst, bmdt_offset, sdur_cutoff):
with open(src, 'rb') as r, open(dst, 'wb') as w:
def converter():
for btype, content in parse_mp4_boxes(r):
if btype == 'tfdt':
version, _ = unpack_ver_flags(content[0:4])
if version == 0:
bmdt = unpack_be32(content[4:8])
else:
bmdt = unpack_be64(content[4:12])
if bmdt == 0:
yield (btype, content)
continue
# calculate new baseMediaDecodeTime
bmdt = max(0, bmdt - bmdt_offset)
# pack everything again and insert as a new box
if version == 0:
bmdt_b = pack_be32(bmdt)
else:
bmdt_b = pack_be64(bmdt)
yield ('tfdt', content[0:4] + bmdt_b + content[8 + version * 4:])
continue
elif btype == 'tfhd':
version, flags = unpack_ver_flags(content[0:4])
if not flags & 0x08:
yield (btype, content)
continue
sdur_start = 8
if flags & 0x01:
sdur_start += 8
if flags & 0x02:
sdur_start += 4
sample_dur = unpack_be32(content[sdur_start:sdur_start + 4])
if sample_dur > sdur_cutoff:
sample_dur = 0
sd_b = pack_be32(sample_dur)
yield ('tfhd', content[:sdur_start] + sd_b + content[sdur_start + 4:])
continue
yield (btype, content)
write_mp4_boxes(w, converter())
def run(self, information):
filename = information['filepath']
temp_filename = prepend_extension(filename, 'temp')
self.write_debug('Analyzing MP4')
bmdt_offset, sdur_cutoff = self.analyze_mp4(filename)
working = float('inf') not in (bmdt_offset, sdur_cutoff)
# if any of them are Infinity, there's something wrong
# baseMediaDecodeTime = to shift PTS
# sample duration = to define duration in each segment
self.write_debug(f'baseMediaDecodeTime offset = {bmdt_offset}, sample duration cutoff = {sdur_cutoff}')
if bmdt_offset == float('inf'):
# safeguard
bmdt_offset = 0
self.modify_mp4(filename, temp_filename, bmdt_offset, sdur_cutoff)
if working:
self.to_screen('Duration of the file has been fixed')
else:
self.report_warning(f'Failed to fix duration of the file. (baseMediaDecodeTime offset = {bmdt_offset}, sample duration cutoff = {sdur_cutoff})')
self._downloader.replace(temp_filename, filename)
return [], information