From b103aca24d35b72b405c340357dc01a0ed534281 Mon Sep 17 00:00:00 2001 From: bashonly <88596187+bashonly@users.noreply.github.com> Date: Sun, 3 Nov 2024 18:19:45 +0000 Subject: [PATCH] [utils] Fix and improve `find_element` and `find_elements` (#11443) Fix d710a6ca7c622705c0c8c8a3615916f531137d5d Authored by: bashonly, Grub4K Co-authored-by: Simon Sawicki --- test/test_traversal.py | 54 +++++++++++++++++++++++++++++++++++++++ yt_dlp/utils/traversal.py | 23 +++++++++-------- 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/test/test_traversal.py b/test/test_traversal.py index 1c0cc5362c..cc0228d270 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -13,6 +13,8 @@ str_or_none, ) from yt_dlp.utils.traversal import ( + find_element, + find_elements, require, subs_list_to_dict, traverse_obj, @@ -37,6 +39,14 @@ 'dict': {}, } +_TEST_HTML = ''' +
1
+
2
+
3
+

4

+

5

+''' + class TestTraversal: def test_traversal_base(self): @@ -521,6 +531,50 @@ def test_unpack(self): with pytest.raises(TypeError): unpack() + def test_find_element(self): + for improper_kwargs in [ + dict(attr='data-id'), + dict(value='y'), + dict(attr='data-id', value='y', cls='a'), + dict(attr='data-id', value='y', id='x'), + dict(cls='a', id='x'), + dict(cls='a', tag='p'), + dict(cls='[ab]', regex=True), + ]: + with pytest.raises(AssertionError): + find_element(**improper_kwargs)(_TEST_HTML) + + assert find_element(cls='a')(_TEST_HTML) == '1' + assert find_element(cls='a', html=True)(_TEST_HTML) == '
1
' + assert find_element(id='x')(_TEST_HTML) == '2' + assert find_element(id='[ex]')(_TEST_HTML) is None + assert find_element(id='[ex]', regex=True)(_TEST_HTML) == '2' + assert find_element(id='x', html=True)(_TEST_HTML) == '
2
' + assert find_element(attr='data-id', value='y')(_TEST_HTML) == '3' + assert find_element(attr='data-id', value='y(?:es)?')(_TEST_HTML) is None + assert find_element(attr='data-id', value='y(?:es)?', regex=True)(_TEST_HTML) == '3' + assert find_element( + attr='data-id', value='y', html=True)(_TEST_HTML) == '
3
' + + def test_find_elements(self): + for improper_kwargs in [ + dict(tag='p'), + dict(attr='data-id'), + dict(value='y'), + dict(attr='data-id', value='y', cls='a'), + dict(cls='a', tag='div'), + dict(cls='[ab]', regex=True), + ]: + with pytest.raises(AssertionError): + find_elements(**improper_kwargs)(_TEST_HTML) + + assert find_elements(cls='a')(_TEST_HTML) == ['1', '2', '4'] + assert find_elements(cls='a', html=True)(_TEST_HTML) == [ + '
1
', '
2
', '

4

'] + assert find_elements(attr='custom', value='z')(_TEST_HTML) == ['2', '3'] + assert find_elements(attr='custom', value='[ez]')(_TEST_HTML) == [] + assert find_elements(attr='custom', value='[ez]', regex=True)(_TEST_HTML) == ['2', '3', '5'] + class TestDictGet: def test_dict_get(self): diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index bc313d5c42..361f239ba6 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -20,6 +20,7 @@ get_elements_html_by_class, get_elements_html_by_attribute, get_elements_by_attribute, + get_element_by_class, get_element_html_by_attribute, get_element_by_attribute, get_element_html_by_id, @@ -373,7 +374,7 @@ def subs_list_to_dict(subs: list[dict] | None = None, /, *, ext=None): @typing.overload -def find_element(*, attr: str, value: str, tag: str | None = None, html=False): ... +def find_element(*, attr: str, value: str, tag: str | None = None, html=False, regex=False): ... @typing.overload @@ -381,14 +382,14 @@ def find_element(*, cls: str, html=False): ... @typing.overload -def find_element(*, id: str, tag: str | None = None, html=False): ... +def find_element(*, id: str, tag: str | None = None, html=False, regex=False): ... @typing.overload -def find_element(*, tag: str, html=False): ... +def find_element(*, tag: str, html=False, regex=False): ... -def find_element(*, tag=None, id=None, cls=None, attr=None, value=None, html=False): +def find_element(*, tag=None, id=None, cls=None, attr=None, value=None, html=False, regex=False): # deliberately using `id=` and `cls=` for ease of readability assert tag or id or cls or (attr and value), 'One of tag, id, cls or (attr AND value) is required' ANY_TAG = r'[\w:.-]+' @@ -397,17 +398,18 @@ def find_element(*, tag=None, id=None, cls=None, attr=None, value=None, html=Fal assert not cls, 'Cannot match both attr and cls' assert not id, 'Cannot match both attr and id' func = get_element_html_by_attribute if html else get_element_by_attribute - return functools.partial(func, attr, value, tag=tag or ANY_TAG) + return functools.partial(func, attr, value, tag=tag or ANY_TAG, escape_value=not regex) elif cls: assert not id, 'Cannot match both cls and id' assert tag is None, 'Cannot match both cls and tag' - func = get_element_html_by_class if html else get_elements_by_class + assert not regex, 'Cannot use regex with cls' + func = get_element_html_by_class if html else get_element_by_class return functools.partial(func, cls) elif id: func = get_element_html_by_id if html else get_element_by_id - return functools.partial(func, id, tag=tag or ANY_TAG) + return functools.partial(func, id, tag=tag or ANY_TAG, escape_value=not regex) index = int(bool(html)) return lambda html: get_element_text_and_html_by_tag(tag, html)[index] @@ -418,19 +420,20 @@ def find_elements(*, cls: str, html=False): ... @typing.overload -def find_elements(*, attr: str, value: str, tag: str | None = None, html=False): ... +def find_elements(*, attr: str, value: str, tag: str | None = None, html=False, regex=False): ... -def find_elements(*, tag=None, cls=None, attr=None, value=None, html=False): +def find_elements(*, tag=None, cls=None, attr=None, value=None, html=False, regex=False): # deliberately using `cls=` for ease of readability assert cls or (attr and value), 'One of cls or (attr AND value) is required' if attr and value: assert not cls, 'Cannot match both attr and cls' func = get_elements_html_by_attribute if html else get_elements_by_attribute - return functools.partial(func, attr, value, tag=tag or r'[\w:.-]+') + return functools.partial(func, attr, value, tag=tag or r'[\w:.-]+', escape_value=not regex) assert not tag, 'Cannot match both cls and tag' + assert not regex, 'Cannot use regex with cls' func = get_elements_html_by_class if html else get_elements_by_class return functools.partial(func, cls)