[utils] Multiple changes to base_n()

1. Renamed to encode_base_n()
2. Allow tables longer than 62 characters
3. Raise ValueError instead of AssertionError for invalid input data
4. Return the first character in the table instead of '0' for number 0
5. Add tests
This commit is contained in:
Yen Chi Hsuan 2016-02-27 03:19:50 +08:00
parent 5633b4d39d
commit 5eb6bdced4
2 changed files with 20 additions and 6 deletions

View file

@ -18,6 +18,7 @@
from youtube_dl.utils import ( from youtube_dl.utils import (
age_restricted, age_restricted,
args_to_str, args_to_str,
encode_base_n,
clean_html, clean_html,
DateRange, DateRange,
detect_exe_version, detect_exe_version,
@ -802,5 +803,16 @@ def test_ohdave_rsa_encrypt(self):
ohdave_rsa_encrypt(b'aa111222', e, N), ohdave_rsa_encrypt(b'aa111222', e, N),
'726664bd9a23fd0c70f9f1b84aab5e3905ce1e45a584e9cbcf9bcc7510338fc1986d6c599ff990d923aa43c51c0d9013cd572e13bc58f4ae48f2ed8c0b0ba881') '726664bd9a23fd0c70f9f1b84aab5e3905ce1e45a584e9cbcf9bcc7510338fc1986d6c599ff990d923aa43c51c0d9013cd572e13bc58f4ae48f2ed8c0b0ba881')
def test_encode_base_n(self):
self.assertEqual(encode_base_n(0, 30), '0')
self.assertEqual(encode_base_n(80, 30), '2k')
custom_table = '9876543210ZYXWVUTSRQPONMLKJIHGFEDCBA'
self.assertEqual(encode_base_n(0, 30, custom_table), '9')
self.assertEqual(encode_base_n(80, 30, custom_table), '7P')
self.assertRaises(ValueError, encode_base_n, 0, 70)
self.assertRaises(ValueError, encode_base_n, 0, 60, custom_table)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -2621,15 +2621,17 @@ def ohdave_rsa_encrypt(data, exponent, modulus):
return '%x' % encrypted return '%x' % encrypted
def base_n(num, n, table=None): def encode_base_n(num, n, table=None):
if num == 0:
return '0'
FULL_TABLE = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' FULL_TABLE = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
assert n <= len(FULL_TABLE)
if not table: if not table:
table = FULL_TABLE[:n] table = FULL_TABLE[:n]
if n > len(table):
raise ValueError('base %d exceeds table length %d' % (n, len(table)))
if num == 0:
return table[0]
ret = '' ret = ''
while num: while num:
ret = table[num % n] + ret ret = table[num % n] + ret
@ -2649,7 +2651,7 @@ def decode_packed_codes(code):
while count: while count:
count -= 1 count -= 1
base_n_count = base_n(count, base) base_n_count = encode_base_n(count, base)
symbol_table[base_n_count] = symbols[count] or base_n_count symbol_table[base_n_count] = symbols[count] or base_n_count
return re.sub( return re.sub(