[fragment,aria2c] Generalize and refactor some code

This commit is contained in:
pukkandan 2021-09-22 05:27:07 +05:30
parent bd6f722de8
commit 1009f67c2a
No known key found for this signature in database
GPG key ID: 0F00D95A001F4698
2 changed files with 30 additions and 42 deletions

View file

@ -6,7 +6,7 @@
import sys import sys
import time import time
from .common import FileDownloader from .fragment import FragmentFD
from ..aes import aes_cbc_decrypt_bytes from ..aes import aes_cbc_decrypt_bytes
from ..compat import ( from ..compat import (
compat_setenv, compat_setenv,
@ -30,7 +30,7 @@
) )
class ExternalFD(FileDownloader): class ExternalFD(FragmentFD):
SUPPORTED_PROTOCOLS = ('http', 'https', 'ftp', 'ftps') SUPPORTED_PROTOCOLS = ('http', 'https', 'ftp', 'ftps')
can_download_to_stdout = False can_download_to_stdout = False
@ -142,6 +142,7 @@ def _call_downloader(self, tmpfilename, info_dict):
self.report_error('Giving up after %s fragment retries' % fragment_retries) self.report_error('Giving up after %s fragment retries' % fragment_retries)
return -1 return -1
decrypt_fragment = self.decrypter(info_dict)
dest, _ = sanitize_open(tmpfilename, 'wb') dest, _ = sanitize_open(tmpfilename, 'wb')
for frag_index, fragment in enumerate(info_dict['fragments']): for frag_index, fragment in enumerate(info_dict['fragments']):
fragment_filename = '%s-Frag%d' % (tmpfilename, frag_index) fragment_filename = '%s-Frag%d' % (tmpfilename, frag_index)
@ -153,21 +154,7 @@ def _call_downloader(self, tmpfilename, info_dict):
continue continue
self.report_error('Unable to open fragment %d' % frag_index) self.report_error('Unable to open fragment %d' % frag_index)
return -1 return -1
decrypt_info = fragment.get('decrypt_info') dest.write(decrypt_fragment(fragment, src.read()))
if decrypt_info:
if decrypt_info['METHOD'] == 'AES-128':
iv = decrypt_info.get('IV') or compat_struct_pack('>8xq', fragment['media_sequence'])
decrypt_info['KEY'] = decrypt_info.get('KEY') or self.ydl.urlopen(
self._prepare_url(info_dict, info_dict.get('_decryption_key_url') or decrypt_info['URI'])).read()
encrypted_data = src.read()
decrypted_data = aes_cbc_decrypt_bytes(encrypted_data, decrypt_info['KEY'], iv)
dest.write(decrypted_data)
else:
fragment_data = src.read()
dest.write(fragment_data)
else:
fragment_data = src.read()
dest.write(fragment_data)
src.close() src.close()
if not self.params.get('keep_fragments', False): if not self.params.get('keep_fragments', False):
os.remove(encodeFilename(fragment_filename)) os.remove(encodeFilename(fragment_filename))
@ -181,10 +168,6 @@ def _call_downloader(self, tmpfilename, info_dict):
self.to_stderr(stderr.decode('utf-8', 'replace')) self.to_stderr(stderr.decode('utf-8', 'replace'))
return p.returncode return p.returncode
def _prepare_url(self, info_dict, url):
headers = info_dict.get('http_headers')
return sanitized_Request(url, None, headers) if headers else url
class CurlFD(ExternalFD): class CurlFD(ExternalFD):
AVAILABLE_OPT = '-V' AVAILABLE_OPT = '-V'
@ -518,7 +501,7 @@ class AVconvFD(FFmpegFD):
_BY_NAME = dict( _BY_NAME = dict(
(klass.get_basename(), klass) (klass.get_basename(), klass)
for name, klass in globals().items() for name, klass in globals().items()
if name.endswith('FD') and name != 'ExternalFD' if name.endswith('FD') and name not in ('ExternalFD', 'FragmentFD')
) )

View file

@ -324,6 +324,29 @@ def _prepare_external_frag_download(self, ctx):
'fragment_index': 0, 'fragment_index': 0,
}) })
def decrypter(self, info_dict):
_key_cache = {}
def _get_key(url):
if url not in _key_cache:
_key_cache[url] = self.ydl.urlopen(self._prepare_url(info_dict, url)).read()
return _key_cache[url]
def decrypt_fragment(fragment, frag_content):
decrypt_info = fragment.get('decrypt_info')
if not decrypt_info or decrypt_info['METHOD'] != 'AES-128':
return frag_content
iv = decrypt_info.get('IV') or compat_struct_pack('>8xq', fragment['media_sequence'])
decrypt_info['KEY'] = decrypt_info.get('KEY') or _get_key(info_dict.get('_decryption_key_url') or decrypt_info['URI'])
# Don't decrypt the content in tests since the data is explicitly truncated and it's not to a valid block
# size (see https://github.com/ytdl-org/youtube-dl/pull/27660). Tests only care that the correct data downloaded,
# not what it decrypts to.
if self.params.get('test', False):
return frag_content
return aes_cbc_decrypt_bytes(frag_content, decrypt_info['KEY'], iv)
return decrypt_fragment
def download_and_append_fragments(self, ctx, fragments, info_dict, *, pack_func=None, finish_func=None): def download_and_append_fragments(self, ctx, fragments, info_dict, *, pack_func=None, finish_func=None):
fragment_retries = self.params.get('fragment_retries', 0) fragment_retries = self.params.get('fragment_retries', 0)
is_fatal = (lambda idx: idx == 0) if self.params.get('skip_unavailable_fragments', True) else (lambda _: True) is_fatal = (lambda idx: idx == 0) if self.params.get('skip_unavailable_fragments', True) else (lambda _: True)
@ -369,26 +392,6 @@ def download_fragment(fragment, ctx):
return False, frag_index return False, frag_index
return frag_content, frag_index return frag_content, frag_index
_key_cache = {}
def _get_key(url):
if url not in _key_cache:
_key_cache[url] = self.ydl.urlopen(self._prepare_url(info_dict, url)).read()
return _key_cache[url]
def decrypt_fragment(fragment, frag_content):
decrypt_info = fragment.get('decrypt_info')
if not decrypt_info or decrypt_info['METHOD'] != 'AES-128':
return frag_content
iv = decrypt_info.get('IV') or compat_struct_pack('>8xq', fragment['media_sequence'])
decrypt_info['KEY'] = decrypt_info.get('KEY') or _get_key(info_dict.get('_decryption_key_url') or decrypt_info['URI'])
# Don't decrypt the content in tests since the data is explicitly truncated and it's not to a valid block
# size (see https://github.com/ytdl-org/youtube-dl/pull/27660). Tests only care that the correct data downloaded,
# not what it decrypts to.
if self.params.get('test', False):
return frag_content
return aes_cbc_decrypt_bytes(frag_content, decrypt_info['KEY'], iv)
def append_fragment(frag_content, frag_index, ctx): def append_fragment(frag_content, frag_index, ctx):
if not frag_content: if not frag_content:
if not is_fatal(frag_index - 1): if not is_fatal(frag_index - 1):
@ -402,6 +405,8 @@ def append_fragment(frag_content, frag_index, ctx):
self._append_fragment(ctx, pack_func(frag_content, frag_index)) self._append_fragment(ctx, pack_func(frag_content, frag_index))
return True return True
decrypt_fragment = self.decrypter(info_dict)
max_workers = self.params.get('concurrent_fragment_downloads', 1) max_workers = self.params.get('concurrent_fragment_downloads', 1)
if can_threaded_download and max_workers > 1: if can_threaded_download and max_workers > 1: