This commit is contained in:
coletdjnz 2024-05-03 05:09:35 +00:00 committed by GitHub
commit 39fe00ad0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 821 additions and 207 deletions

View File

@ -1,4 +1,3 @@
import functools
import inspect import inspect
import pytest import pytest
@ -10,7 +9,9 @@ from yt_dlp.utils._utils import _YDLLogger as FakeLogger
@pytest.fixture @pytest.fixture
def handler(request): def handler(request):
RH_KEY = request.param RH_KEY = getattr(request, 'param', None)
if not RH_KEY:
return
if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler): if inspect.isclass(RH_KEY) and issubclass(RH_KEY, RequestHandler):
handler = RH_KEY handler = RH_KEY
elif RH_KEY in _REQUEST_HANDLERS: elif RH_KEY in _REQUEST_HANDLERS:
@ -18,9 +19,46 @@ def handler(request):
else: else:
pytest.skip(f'{RH_KEY} request handler is not available') pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger) class HandlerWrapper(handler):
RH_KEY = handler.RH_KEY
def __init__(self, *args, **kwargs):
super().__init__(logger=FakeLogger, *args, **kwargs)
return HandlerWrapper
def validate_and_send(rh, req): @pytest.fixture(autouse=True)
rh.validate(req) def skip_handler(request, handler):
return rh.send(req) """usage: pytest.mark.skip_handler('my_handler', 'reason')"""
for marker in request.node.iter_markers('skip_handler'):
if marker.args[0] == handler.RH_KEY:
pytest.skip(marker.args[1] if len(marker.args) > 1 else '')
@pytest.fixture(autouse=True)
def skip_handler_if(request, handler):
"""usage: pytest.mark.skip_handler_if('my_handler', lambda request: True, 'reason')"""
for marker in request.node.iter_markers('skip_handler_if'):
if marker.args[0] == handler.RH_KEY and marker.args[1](request):
pytest.skip(marker.args[2] if len(marker.args) > 2 else '')
@pytest.fixture(autouse=True)
def skip_handlers_if(request, handler):
"""usage: pytest.mark.skip_handlers_if(lambda request, handler: True, 'reason')"""
for marker in request.node.iter_markers('skip_handlers_if'):
if handler and marker.args[0](request, handler):
pytest.skip(marker.args[1] if len(marker.args) > 1 else '')
def pytest_configure(config):
config.addinivalue_line(
"markers", "skip_handler(handler): skip test for the given handler",
)
config.addinivalue_line(
"markers", "skip_handler_if(handler): skip test for the given handler if condition is true"
)
config.addinivalue_line(
"markers", "skip_handlers_if(handler): skip test for handlers when the condition is true"
)

View File

@ -338,3 +338,8 @@ def http_server_port(httpd):
def verify_address_availability(address): def verify_address_availability(address):
if find_available_port(address) is None: if find_available_port(address) is None:
pytest.skip(f'Unable to bind to source address {address} (address may not exist)') pytest.skip(f'Unable to bind to source address {address} (address may not exist)')
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)

430
test/test_http_proxy.py Normal file
View File

@ -0,0 +1,430 @@
import abc
import base64
import contextlib
import functools
import json
import os
import random
import ssl
import threading
from http.server import BaseHTTPRequestHandler
from socketserver import BaseRequestHandler, ThreadingTCPServer
import pytest
from test.helper import http_server_port, verify_address_availability
from test.test_networking import TEST_DIR
from test.test_socks import IPv6ThreadingTCPServer
from yt_dlp.dependencies import urllib3
from yt_dlp.networking import Request
from yt_dlp.networking.exceptions import HTTPError, ProxyError, SSLError
class HTTPProxyAuthMixin:
def proxy_auth_error(self):
self.send_response(407)
self.send_header('Proxy-Authenticate', 'Basic realm="test http proxy"')
self.end_headers()
return False
def do_proxy_auth(self, username, password):
if username is None and password is None:
return True
proxy_auth_header = self.headers.get('Proxy-Authorization', None)
if proxy_auth_header is None:
return self.proxy_auth_error()
if not proxy_auth_header.startswith('Basic '):
return self.proxy_auth_error()
auth = proxy_auth_header[6:]
try:
auth_username, auth_password = base64.b64decode(auth).decode().split(':', 1)
except Exception:
return self.proxy_auth_error()
if auth_username != (username or '') or auth_password != (password or ''):
return self.proxy_auth_error()
return True
class HTTPProxyHandler(BaseHTTPRequestHandler, HTTPProxyAuthMixin):
def __init__(self, *args, proxy_info=None, username=None, password=None, request_handler=None, **kwargs):
self.username = username
self.password = password
self.proxy_info = proxy_info
super().__init__(*args, **kwargs)
def do_GET(self):
if not self.do_proxy_auth(self.username, self.password):
self.server.close_request(self.request)
return
if self.path.endswith('/proxy_info'):
payload = json.dumps(self.proxy_info or {
'client_address': self.client_address,
'connect': False,
'connect_host': None,
'connect_port': None,
'headers': dict(self.headers),
'path': self.path,
'proxy': ':'.join(str(y) for y in self.connection.getsockname()),
})
self.send_response(200)
self.send_header('Content-Type', 'application/json; charset=utf-8')
self.send_header('Content-Length', str(len(payload)))
self.end_headers()
self.wfile.write(payload.encode())
else:
self.send_response(404)
self.end_headers()
self.server.close_request(self.request)
if urllib3:
import urllib3.util.ssltransport
class SSLTransport(urllib3.util.ssltransport.SSLTransport):
"""
Modified version of urllib3 SSLTransport to support server side SSL
This allows us to chain multiple TLS connections.
"""
def __init__(self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True, server_side=False):
self.incoming = ssl.MemoryBIO()
self.outgoing = ssl.MemoryBIO()
self.suppress_ragged_eofs = suppress_ragged_eofs
self.socket = socket
self.sslobj = ssl_context.wrap_bio(
self.incoming,
self.outgoing,
server_hostname=server_hostname,
server_side=server_side
)
self._ssl_io_loop(self.sslobj.do_handshake)
@property
def _io_refs(self):
return self.socket._io_refs
@_io_refs.setter
def _io_refs(self, value):
self.socket._io_refs = value
def shutdown(self, *args, **kwargs):
self.socket.shutdown(*args, **kwargs)
else:
SSLTransport = None
class HTTPSProxyHandler(HTTPProxyHandler):
def __init__(self, request, *args, **kwargs):
certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None)
if isinstance(request, ssl.SSLSocket):
request = SSLTransport(request, ssl_context=sslctx, server_side=True)
else:
request = sslctx.wrap_socket(request, server_side=True)
super().__init__(request, *args, **kwargs)
class WebSocketProxyHandler(BaseRequestHandler):
def __init__(self, *args, proxy_info=None, **kwargs):
self.proxy_info = proxy_info
super().__init__(*args, **kwargs)
def handle(self):
import websockets.sync.server
protocol = websockets.ServerProtocol()
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
connection.handshake()
for message in connection:
if message == 'proxy_info':
connection.send(json.dumps(self.proxy_info))
connection.close()
class WebSocketSecureProxyHandler(WebSocketProxyHandler):
def __init__(self, request, *args, **kwargs):
certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None)
if SSLTransport:
request = SSLTransport(request, ssl_context=sslctx, server_side=True)
else:
request = sslctx.wrap_socket(request, server_side=True)
super().__init__(request, *args, **kwargs)
class HTTPConnectProxyHandler(BaseHTTPRequestHandler, HTTPProxyAuthMixin):
protocol_version = 'HTTP/1.1'
default_request_version = 'HTTP/1.1'
def __init__(self, *args, username=None, password=None, request_handler=None, **kwargs):
self.username = username
self.password = password
self.request_handler = request_handler
super().__init__(*args, **kwargs)
def do_CONNECT(self):
if not self.do_proxy_auth(self.username, self.password):
self.server.close_request(self.request)
return
self.send_response(200)
self.end_headers()
proxy_info = {
'client_address': self.client_address,
'connect': True,
'connect_host': self.path.split(':')[0],
'connect_port': int(self.path.split(':')[1]),
'headers': dict(self.headers),
'path': self.path,
'proxy': ':'.join(str(y) for y in self.connection.getsockname()),
}
self.request_handler(self.request, self.client_address, self.server, proxy_info=proxy_info)
self.server.close_request(self.request)
class HTTPSConnectProxyHandler(HTTPConnectProxyHandler):
def __init__(self, request, *args, **kwargs):
certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None)
request = sslctx.wrap_socket(request, server_side=True)
self._original_request = request
super().__init__(request, *args, **kwargs)
def do_CONNECT(self):
super().do_CONNECT()
self.server.close_request(self._original_request)
@contextlib.contextmanager
def proxy_server(proxy_server_class, request_handler, bind_ip=None, **proxy_server_kwargs):
server = server_thread = None
try:
bind_address = bind_ip or '127.0.0.1'
server_type = ThreadingTCPServer if '.' in bind_address else IPv6ThreadingTCPServer
server = server_type(
(bind_address, 0), functools.partial(proxy_server_class, request_handler=request_handler, **proxy_server_kwargs))
server_port = http_server_port(server)
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()
if '.' not in bind_address:
yield f'[{bind_address}]:{server_port}'
else:
yield f'{bind_address}:{server_port}'
finally:
server.shutdown()
server.server_close()
server_thread.join(2.0)
class HTTPProxyTestContext(abc.ABC):
REQUEST_HANDLER_CLASS = None
REQUEST_PROTO = None
def http_server(self, server_class, *args, **kwargs):
return proxy_server(server_class, self.REQUEST_HANDLER_CLASS, *args, **kwargs)
@abc.abstractmethod
def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs) -> dict:
"""return a dict of proxy_info"""
class HTTPProxyHTTPTestContext(HTTPProxyTestContext):
# Standard HTTP Proxy for http requests
REQUEST_HANDLER_CLASS = HTTPProxyHandler
REQUEST_PROTO = 'http'
def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
request = Request(f'http://{target_domain or "127.0.0.1"}:{target_port or "40000"}/proxy_info', **req_kwargs)
handler.validate(request)
return json.loads(handler.send(request).read().decode())
class HTTPProxyHTTPSTestContext(HTTPProxyTestContext):
# HTTP Connect proxy, for https requests
REQUEST_HANDLER_CLASS = HTTPSProxyHandler
REQUEST_PROTO = 'https'
def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
request = Request(f'https://{target_domain or "127.0.0.1"}:{target_port or "40000"}/proxy_info', **req_kwargs)
handler.validate(request)
return json.loads(handler.send(request).read().decode())
class HTTPProxyWebSocketTestContext(HTTPProxyTestContext):
REQUEST_HANDLER_CLASS = WebSocketProxyHandler
REQUEST_PROTO = 'ws'
def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
request = Request(f'{self.REQUEST_PROTO}://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
handler.validate(request)
ws = handler.send(request)
ws.send('proxy_info')
socks_info = ws.recv()
ws.close()
return json.loads(socks_info)
class HTTPProxyWebSocketSecureTestContext(HTTPProxyWebSocketTestContext):
REQUEST_HANDLER_CLASS = WebSocketSecureProxyHandler
REQUEST_PROTO = 'wss'
CTX_MAP = {
'http': HTTPProxyHTTPTestContext,
'https': HTTPProxyHTTPSTestContext,
'ws': HTTPProxyWebSocketTestContext,
'wss': HTTPProxyWebSocketSecureTestContext,
}
@pytest.fixture(scope='module')
def ctx(request):
return CTX_MAP[request.param]()
@pytest.mark.parametrize(
'handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
@pytest.mark.parametrize('ctx', ['http'], indirect=True) # pure http proxy can only support http
class TestHTTPProxy:
def test_http_no_auth(self, handler, ctx):
with ctx.http_server(HTTPProxyHandler) as server_address:
with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'}) as rh:
proxy_info = ctx.proxy_info_request(rh)
assert proxy_info['proxy'] == server_address
assert proxy_info['connect'] is False
assert 'Proxy-Authorization' not in proxy_info['headers']
def test_http_auth(self, handler, ctx):
with ctx.http_server(HTTPProxyHandler, username='test', password='test') as server_address:
with handler(proxies={ctx.REQUEST_PROTO: f'http://test:test@{server_address}'}) as rh:
proxy_info = ctx.proxy_info_request(rh)
assert proxy_info['proxy'] == server_address
assert 'Proxy-Authorization' in proxy_info['headers']
def test_http_bad_auth(self, handler, ctx):
with ctx.http_server(HTTPProxyHandler, username='test', password='test') as server_address:
with handler(proxies={ctx.REQUEST_PROTO: f'http://test:bad@{server_address}'}) as rh:
with pytest.raises(HTTPError) as exc_info:
ctx.proxy_info_request(rh)
assert exc_info.value.response.status == 407
exc_info.value.response.close()
def test_http_source_address(self, handler, ctx):
with ctx.http_server(HTTPProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'
verify_address_availability(source_address)
with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'},
source_address=source_address) as rh:
proxy_info = ctx.proxy_info_request(rh)
assert proxy_info['proxy'] == server_address
assert proxy_info['client_address'][0] == source_address
@pytest.mark.skip_handler('Urllib', 'urllib does not support https proxies')
def test_https(self, handler, ctx):
with ctx.http_server(HTTPSProxyHandler) as server_address:
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
proxy_info = ctx.proxy_info_request(rh)
assert proxy_info['proxy'] == server_address
assert proxy_info['connect'] is False
assert 'Proxy-Authorization' not in proxy_info['headers']
@pytest.mark.skip_handler('Urllib', 'urllib does not support https proxies')
def test_https_verify_failed(self, handler, ctx):
with ctx.http_server(HTTPSProxyHandler) as server_address:
with handler(verify=True, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
# Accept SSLError as may not be feasible to tell if it is proxy or request error.
# note: if request proto also does ssl verification, this may also be the error of the request.
# Until we can support passing custom cacerts to handlers, we cannot properly test this for all cases.
with pytest.raises((ProxyError, SSLError)):
ctx.proxy_info_request(rh)
def test_http_with_idn(self, handler, ctx):
with ctx.http_server(HTTPProxyHandler) as server_address:
with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'}) as rh:
proxy_info = ctx.proxy_info_request(rh, target_domain='中文.tw')
assert proxy_info['proxy'] == server_address
assert proxy_info['path'].startswith('http://xn--fiq228c.tw')
assert proxy_info['headers']['Host'].split(':', 1)[0] == 'xn--fiq228c.tw'
@pytest.mark.parametrize(
'handler,ctx', [
('Requests', 'https'),
('CurlCFFI', 'https'),
('Websockets', 'ws'),
('Websockets', 'wss')
], indirect=True)
class TestHTTPConnectProxy:
def test_http_connect_no_auth(self, handler, ctx):
with ctx.http_server(HTTPConnectProxyHandler) as server_address:
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://{server_address}'}) as rh:
proxy_info = ctx.proxy_info_request(rh)
assert proxy_info['proxy'] == server_address
assert proxy_info['connect'] is True
assert 'Proxy-Authorization' not in proxy_info['headers']
def test_http_connect_auth(self, handler, ctx):
with ctx.http_server(HTTPConnectProxyHandler, username='test', password='test') as server_address:
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://test:test@{server_address}'}) as rh:
proxy_info = ctx.proxy_info_request(rh)
assert proxy_info['proxy'] == server_address
assert 'Proxy-Authorization' in proxy_info['headers']
@pytest.mark.skip_handler(
'Requests',
'bug in urllib3 causes unclosed socket: https://github.com/urllib3/urllib3/issues/3374'
)
def test_http_connect_bad_auth(self, handler, ctx):
with ctx.http_server(HTTPConnectProxyHandler, username='test', password='test') as server_address:
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://test:bad@{server_address}'}) as rh:
with pytest.raises(ProxyError):
ctx.proxy_info_request(rh)
def test_http_connect_source_address(self, handler, ctx):
with ctx.http_server(HTTPConnectProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'
verify_address_availability(source_address)
with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'},
source_address=source_address,
verify=False) as rh:
proxy_info = ctx.proxy_info_request(rh)
assert proxy_info['proxy'] == server_address
assert proxy_info['client_address'][0] == source_address
@pytest.mark.skipif(urllib3 is None, reason='requires urllib3 to test')
def test_https_connect_proxy(self, handler, ctx):
with ctx.http_server(HTTPSConnectProxyHandler) as server_address:
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
proxy_info = ctx.proxy_info_request(rh)
assert proxy_info['proxy'] == server_address
assert proxy_info['connect'] is True
assert 'Proxy-Authorization' not in proxy_info['headers']
@pytest.mark.skipif(urllib3 is None, reason='requires urllib3 to test')
def test_https_connect_verify_failed(self, handler, ctx):
with ctx.http_server(HTTPSConnectProxyHandler) as server_address:
with handler(verify=True, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
# Accept SSLError as may not be feasible to tell if it is proxy or request error.
# note: if request proto also does ssl verification, this may also be the error of the request.
# Until we can support passing custom cacerts to handlers, we cannot properly test this for all cases.
with pytest.raises((ProxyError, SSLError)):
ctx.proxy_info_request(rh)
@pytest.mark.skipif(urllib3 is None, reason='requires urllib3 to test')
def test_https_connect_proxy_auth(self, handler, ctx):
with ctx.http_server(HTTPSConnectProxyHandler, username='test', password='test') as server_address:
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'https://test:test@{server_address}'}) as rh:
proxy_info = ctx.proxy_info_request(rh)
assert proxy_info['proxy'] == server_address
assert 'Proxy-Authorization' in proxy_info['headers']

View File

@ -6,6 +6,8 @@ import sys
import pytest import pytest
from yt_dlp.networking.common import Features
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import gzip import gzip
@ -27,8 +29,12 @@ import zlib
from email.message import Message from email.message import Message
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
from test.conftest import validate_and_send from test.helper import (
from test.helper import FakeYDL, http_server_port, verify_address_availability FakeYDL,
http_server_port,
validate_and_send,
verify_address_availability,
)
from yt_dlp.cookies import YoutubeDLCookieJar from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import brotli, curl_cffi, requests, urllib3 from yt_dlp.dependencies import brotli, curl_cffi, requests, urllib3
from yt_dlp.networking import ( from yt_dlp.networking import (
@ -62,21 +68,6 @@ from yt_dlp.utils.networking import HTTPHeaderDict, std_headers
TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_DIR = os.path.dirname(os.path.abspath(__file__))
def _build_proxy_handler(name):
class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
proxy_name = name
def log_message(self, format, *args):
pass
def do_GET(self):
self.send_response(200)
self.send_header('Content-Type', 'text/plain; charset=utf-8')
self.end_headers()
self.wfile.write(f'{self.proxy_name}: {self.path}'.encode())
return HTTPTestRequestHandler
class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler): class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1' protocol_version = 'HTTP/1.1'
default_request_version = 'HTTP/1.1' default_request_version = 'HTTP/1.1'
@ -317,8 +308,9 @@ class TestRequestHandlerBase:
cls.https_server_thread.start() cls.https_server_thread.start()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
class TestHTTPRequestHandler(TestRequestHandlerBase): class TestHTTPRequestHandler(TestRequestHandlerBase):
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_verify_cert(self, handler): def test_verify_cert(self, handler):
with handler() as rh: with handler() as rh:
with pytest.raises(CertificateVerifyError): with pytest.raises(CertificateVerifyError):
@ -329,7 +321,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert r.status == 200 assert r.status == 200
r.close() r.close()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_ssl_error(self, handler): def test_ssl_error(self, handler):
# HTTPS server with too old TLS version # HTTPS server with too old TLS version
# XXX: is there a better way to test this than to create a new server? # XXX: is there a better way to test this than to create a new server?
@ -347,7 +338,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
validate_and_send(rh, Request(f'https://127.0.0.1:{https_port}/headers')) validate_and_send(rh, Request(f'https://127.0.0.1:{https_port}/headers'))
assert not issubclass(exc_info.type, CertificateVerifyError) assert not issubclass(exc_info.type, CertificateVerifyError)
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_percent_encode(self, handler): def test_percent_encode(self, handler):
with handler() as rh: with handler() as rh:
# Unicode characters should be encoded with uppercase percent-encoding # Unicode characters should be encoded with uppercase percent-encoding
@ -359,7 +349,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.status == 200 assert res.status == 200
res.close() res.close()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
@pytest.mark.parametrize('path', [ @pytest.mark.parametrize('path', [
'/a/b/./../../headers', '/a/b/./../../headers',
'/redirect_dotsegments', '/redirect_dotsegments',
@ -375,15 +364,13 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.url == f'http://127.0.0.1:{self.http_port}/headers' assert res.url == f'http://127.0.0.1:{self.http_port}/headers'
res.close() res.close()
# Not supported by CurlCFFI (non-standard) @pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi (non-standard)')
@pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True)
def test_unicode_path_redirection(self, handler): def test_unicode_path_redirection(self, handler):
with handler() as rh: with handler() as rh:
r = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/302-non-ascii-redirect')) r = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/302-non-ascii-redirect'))
assert r.url == f'http://127.0.0.1:{self.http_port}/%E4%B8%AD%E6%96%87.html' assert r.url == f'http://127.0.0.1:{self.http_port}/%E4%B8%AD%E6%96%87.html'
r.close() r.close()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_raise_http_error(self, handler): def test_raise_http_error(self, handler):
with handler() as rh: with handler() as rh:
for bad_status in (400, 500, 599, 302): for bad_status in (400, 500, 599, 302):
@ -393,7 +380,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
# Should not raise an error # Should not raise an error
validate_and_send(rh, Request('http://127.0.0.1:%d/gen_200' % self.http_port)).close() validate_and_send(rh, Request('http://127.0.0.1:%d/gen_200' % self.http_port)).close()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_response_url(self, handler): def test_response_url(self, handler):
with handler() as rh: with handler() as rh:
# Response url should be that of the last url in redirect chain # Response url should be that of the last url in redirect chain
@ -405,7 +391,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
res2.close() res2.close()
# Covers some basic cases we expect some level of consistency between request handlers for # Covers some basic cases we expect some level of consistency between request handlers for
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
@pytest.mark.parametrize('redirect_status,method,expected', [ @pytest.mark.parametrize('redirect_status,method,expected', [
# A 303 must either use GET or HEAD for subsequent request # A 303 must either use GET or HEAD for subsequent request
(303, 'POST', ('', 'GET', False)), (303, 'POST', ('', 'GET', False)),
@ -447,7 +432,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert expected[1] == res.headers.get('method') assert expected[1] == res.headers.get('method')
assert expected[2] == ('content-length' in headers.decode().lower()) assert expected[2] == ('content-length' in headers.decode().lower())
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_request_cookie_header(self, handler): def test_request_cookie_header(self, handler):
# We should accept a Cookie header being passed as in normal headers and handle it appropriately. # We should accept a Cookie header being passed as in normal headers and handle it appropriately.
with handler() as rh: with handler() as rh:
@ -480,19 +464,16 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert b'cookie: test=ytdlp' not in data.lower() assert b'cookie: test=ytdlp' not in data.lower()
assert b'cookie: test=test3' in data.lower() assert b'cookie: test=test3' in data.lower()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_redirect_loop(self, handler): def test_redirect_loop(self, handler):
with handler() as rh: with handler() as rh:
with pytest.raises(HTTPError, match='redirect loop'): with pytest.raises(HTTPError, match='redirect loop'):
validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/redirect_loop')) validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/redirect_loop'))
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_incompleteread(self, handler): def test_incompleteread(self, handler):
with handler(timeout=2) as rh: with handler(timeout=2) as rh:
with pytest.raises(IncompleteRead, match='13 bytes read, 234221 more expected'): with pytest.raises(IncompleteRead, match='13 bytes read, 234221 more expected'):
validate_and_send(rh, Request('http://127.0.0.1:%d/incompleteread' % self.http_port)).read() validate_and_send(rh, Request('http://127.0.0.1:%d/incompleteread' % self.http_port)).read()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_cookies(self, handler): def test_cookies(self, handler):
cookiejar = YoutubeDLCookieJar() cookiejar = YoutubeDLCookieJar()
cookiejar.set_cookie(http.cookiejar.Cookie( cookiejar.set_cookie(http.cookiejar.Cookie(
@ -509,7 +490,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
rh, Request(f'http://127.0.0.1:{self.http_port}/headers', extensions={'cookiejar': cookiejar})).read() rh, Request(f'http://127.0.0.1:{self.http_port}/headers', extensions={'cookiejar': cookiejar})).read()
assert b'cookie: test=ytdlp' in data.lower() assert b'cookie: test=ytdlp' in data.lower()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_headers(self, handler): def test_headers(self, handler):
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
@ -525,7 +505,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert b'test2: test2' not in data assert b'test2: test2' not in data
assert b'test3: test3' in data assert b'test3: test3' in data
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_read_timeout(self, handler): def test_read_timeout(self, handler):
with handler() as rh: with handler() as rh:
# Default timeout is 20 seconds, so this should go through # Default timeout is 20 seconds, so this should go through
@ -541,7 +520,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
validate_and_send( validate_and_send(
rh, Request(f'http://127.0.0.1:{self.http_port}/timeout_1', extensions={'timeout': 4})) rh, Request(f'http://127.0.0.1:{self.http_port}/timeout_1', extensions={'timeout': 4}))
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_connect_timeout(self, handler): def test_connect_timeout(self, handler):
# nothing should be listening on this port # nothing should be listening on this port
connect_timeout_url = 'http://10.255.255.255' connect_timeout_url = 'http://10.255.255.255'
@ -560,7 +538,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
rh, Request(connect_timeout_url, extensions={'timeout': 0.01})) rh, Request(connect_timeout_url, extensions={'timeout': 0.01}))
assert 0.01 <= time.time() - now < 20 assert 0.01 <= time.time() - now < 20
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_source_address(self, handler): def test_source_address(self, handler):
source_address = f'127.0.0.{random.randint(5, 255)}' source_address = f'127.0.0.{random.randint(5, 255)}'
# on some systems these loopback addresses we need for testing may not be available # on some systems these loopback addresses we need for testing may not be available
@ -572,13 +549,13 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert source_address == data assert source_address == data
# Not supported by CurlCFFI # Not supported by CurlCFFI
@pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True) @pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
def test_gzip_trailing_garbage(self, handler): def test_gzip_trailing_garbage(self, handler):
with handler() as rh: with handler() as rh:
data = validate_and_send(rh, Request(f'http://localhost:{self.http_port}/trailing_garbage')).read().decode() data = validate_and_send(rh, Request(f'http://localhost:{self.http_port}/trailing_garbage')).read().decode()
assert data == '<html><video src="/vid.mp4" /></html>' assert data == '<html><video src="/vid.mp4" /></html>'
@pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True) @pytest.mark.skip_handler('CurlCFFI', 'not applicable to curl-cffi')
@pytest.mark.skipif(not brotli, reason='brotli support is not installed') @pytest.mark.skipif(not brotli, reason='brotli support is not installed')
def test_brotli(self, handler): def test_brotli(self, handler):
with handler() as rh: with handler() as rh:
@ -589,7 +566,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.headers.get('Content-Encoding') == 'br' assert res.headers.get('Content-Encoding') == 'br'
assert res.read() == b'<html><video src="/vid.mp4" /></html>' assert res.read() == b'<html><video src="/vid.mp4" /></html>'
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_deflate(self, handler): def test_deflate(self, handler):
with handler() as rh: with handler() as rh:
res = validate_and_send( res = validate_and_send(
@ -599,7 +575,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.headers.get('Content-Encoding') == 'deflate' assert res.headers.get('Content-Encoding') == 'deflate'
assert res.read() == b'<html><video src="/vid.mp4" /></html>' assert res.read() == b'<html><video src="/vid.mp4" /></html>'
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_gzip(self, handler): def test_gzip(self, handler):
with handler() as rh: with handler() as rh:
res = validate_and_send( res = validate_and_send(
@ -609,7 +584,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.headers.get('Content-Encoding') == 'gzip' assert res.headers.get('Content-Encoding') == 'gzip'
assert res.read() == b'<html><video src="/vid.mp4" /></html>' assert res.read() == b'<html><video src="/vid.mp4" /></html>'
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_multiple_encodings(self, handler): def test_multiple_encodings(self, handler):
with handler() as rh: with handler() as rh:
for pair in ('gzip,deflate', 'deflate, gzip', 'gzip, gzip', 'deflate, deflate'): for pair in ('gzip,deflate', 'deflate, gzip', 'gzip, gzip', 'deflate, deflate'):
@ -620,8 +594,7 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.headers.get('Content-Encoding') == pair assert res.headers.get('Content-Encoding') == pair
assert res.read() == b'<html><video src="/vid.mp4" /></html>' assert res.read() == b'<html><video src="/vid.mp4" /></html>'
# Not supported by curl_cffi @pytest.mark.skip_handler('CurlCFFI', 'not supported by curl-cffi')
@pytest.mark.parametrize('handler', ['Urllib', 'Requests'], indirect=True)
def test_unsupported_encoding(self, handler): def test_unsupported_encoding(self, handler):
with handler() as rh: with handler() as rh:
res = validate_and_send( res = validate_and_send(
@ -631,7 +604,6 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.headers.get('Content-Encoding') == 'unsupported' assert res.headers.get('Content-Encoding') == 'unsupported'
assert res.read() == b'raw' assert res.read() == b'raw'
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_read(self, handler): def test_read(self, handler):
with handler() as rh: with handler() as rh:
res = validate_and_send( res = validate_and_send(
@ -642,83 +614,48 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
assert res.read().decode().endswith('\n\n') assert res.read().decode().endswith('\n\n')
assert res.read() == b'' assert res.read() == b''
def test_request_disable_proxy(self, handler):
for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['http']:
# Given the handler is configured with a proxy
with handler(proxies={'http': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
# When a proxy is explicitly set to None for the request
res = validate_and_send(
rh, Request(f'http://127.0.0.1:{self.http_port}/headers', proxies={'http': None}))
# Then no proxy should be used
res.close()
assert res.status == 200
class TestHTTPProxy(TestRequestHandlerBase): @pytest.mark.skip_handlers_if(
# Note: this only tests http urls over non-CONNECT proxy lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
@classmethod
def setup_class(cls):
super().setup_class()
# HTTP Proxy server
cls.proxy = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), _build_proxy_handler('normal'))
cls.proxy_port = http_server_port(cls.proxy)
cls.proxy_thread = threading.Thread(target=cls.proxy.serve_forever)
cls.proxy_thread.daemon = True
cls.proxy_thread.start()
# Geo proxy server
cls.geo_proxy = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), _build_proxy_handler('geo'))
cls.geo_port = http_server_port(cls.geo_proxy)
cls.geo_proxy_thread = threading.Thread(target=cls.geo_proxy.serve_forever)
cls.geo_proxy_thread.daemon = True
cls.geo_proxy_thread.start()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_http_proxy(self, handler):
http_proxy = f'http://127.0.0.1:{self.proxy_port}'
geo_proxy = f'http://127.0.0.1:{self.geo_port}'
# Test global http proxy
# Test per request http proxy
# Test per request http proxy disables proxy
url = 'http://foo.com/bar'
# Global HTTP proxy
with handler(proxies={'http': http_proxy}) as rh:
res = validate_and_send(rh, Request(url)).read().decode()
assert res == f'normal: {url}'
# Per request proxy overrides global
res = validate_and_send(rh, Request(url, proxies={'http': geo_proxy})).read().decode()
assert res == f'geo: {url}'
# and setting to None disables all proxies for that request
real_url = f'http://127.0.0.1:{self.http_port}/headers'
res = validate_and_send(
rh, Request(real_url, proxies={'http': None})).read().decode()
assert res != f'normal: {real_url}'
assert 'Accept' in res
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_noproxy(self, handler): def test_noproxy(self, handler):
with handler(proxies={'proxy': f'http://127.0.0.1:{self.proxy_port}'}) as rh: for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['http']:
# NO_PROXY # Given the handler is configured with a proxy
for no_proxy in (f'127.0.0.1:{self.http_port}', '127.0.0.1', 'localhost'): with handler(proxies={'http': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
nop_response = validate_and_send( for no_proxy in (f'127.0.0.1:{self.http_port}', '127.0.0.1', 'localhost'):
rh, Request(f'http://127.0.0.1:{self.http_port}/headers', proxies={'no': no_proxy})).read().decode( # When request no proxy includes the request url host
'utf-8') nop_response = validate_and_send(
assert 'Accept' in nop_response rh, Request(f'http://127.0.0.1:{self.http_port}/headers', proxies={'no': no_proxy}))
# Then the proxy should not be used
assert nop_response.status == 200
nop_response.close()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True) @pytest.mark.skip_handlers_if(
lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
def test_allproxy(self, handler): def test_allproxy(self, handler):
url = 'http://foo.com/bar' # This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
with handler() as rh: # 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
response = validate_and_send(rh, Request(url, proxies={'all': f'http://127.0.0.1:{self.proxy_port}'})).read().decode( with handler(proxies={'all': 'http://10.255.255.255'}, timeout=0.1) as rh:
'utf-8') with pytest.raises(TransportError):
assert response == f'normal: {url}' validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/headers')).close()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True) with handler(timeout=0.1) as rh:
def test_http_proxy_with_idn(self, handler): with pytest.raises(TransportError):
with handler(proxies={ validate_and_send(
'http': f'http://127.0.0.1:{self.proxy_port}', rh, Request(
}) as rh: f'http://127.0.0.1:{self.http_port}/headers', proxies={'all': 'http://10.255.255.255'})).close()
url = 'http://中文.tw/'
response = rh.send(Request(url)).read().decode()
# b'xn--fiq228c' is '中文'.encode('idna')
assert response == 'normal: http://xn--fiq228c.tw/'
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
class TestClientCertificate: class TestClientCertificate:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
@ -745,27 +682,23 @@ class TestClientCertificate:
) as rh: ) as rh:
validate_and_send(rh, Request(f'https://127.0.0.1:{self.port}/video.html')).read().decode() validate_and_send(rh, Request(f'https://127.0.0.1:{self.port}/video.html')).read().decode()
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_certificate_combined_nopass(self, handler): def test_certificate_combined_nopass(self, handler):
self._run_test(handler, client_cert={ self._run_test(handler, client_cert={
'client_certificate': os.path.join(self.certdir, 'clientwithkey.crt'), 'client_certificate': os.path.join(self.certdir, 'clientwithkey.crt'),
}) })
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_certificate_nocombined_nopass(self, handler): def test_certificate_nocombined_nopass(self, handler):
self._run_test(handler, client_cert={ self._run_test(handler, client_cert={
'client_certificate': os.path.join(self.certdir, 'client.crt'), 'client_certificate': os.path.join(self.certdir, 'client.crt'),
'client_certificate_key': os.path.join(self.certdir, 'client.key'), 'client_certificate_key': os.path.join(self.certdir, 'client.key'),
}) })
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_certificate_combined_pass(self, handler): def test_certificate_combined_pass(self, handler):
self._run_test(handler, client_cert={ self._run_test(handler, client_cert={
'client_certificate': os.path.join(self.certdir, 'clientwithencryptedkey.crt'), 'client_certificate': os.path.join(self.certdir, 'clientwithencryptedkey.crt'),
'client_certificate_password': 'foobar', 'client_certificate_password': 'foobar',
}) })
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_certificate_nocombined_pass(self, handler): def test_certificate_nocombined_pass(self, handler):
self._run_test(handler, client_cert={ self._run_test(handler, client_cert={
'client_certificate': os.path.join(self.certdir, 'client.crt'), 'client_certificate': os.path.join(self.certdir, 'client.crt'),
@ -805,8 +738,8 @@ class TestRequestHandlerMisc:
assert len(logging_handlers) == before_count assert len(logging_handlers) == before_count
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
class TestUrllibRequestHandler(TestRequestHandlerBase): class TestUrllibRequestHandler(TestRequestHandlerBase):
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_file_urls(self, handler): def test_file_urls(self, handler):
# See https://github.com/ytdl-org/youtube-dl/issues/8227 # See https://github.com/ytdl-org/youtube-dl/issues/8227
tf = tempfile.NamedTemporaryFile(delete=False) tf = tempfile.NamedTemporaryFile(delete=False)
@ -828,7 +761,6 @@ class TestUrllibRequestHandler(TestRequestHandlerBase):
os.unlink(tf.name) os.unlink(tf.name)
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_http_error_returns_content(self, handler): def test_http_error_returns_content(self, handler):
# urllib HTTPError will try close the underlying response if reference to the HTTPError object is lost # urllib HTTPError will try close the underlying response if reference to the HTTPError object is lost
def get_response(): def get_response():
@ -841,7 +773,6 @@ class TestUrllibRequestHandler(TestRequestHandlerBase):
assert get_response().read() == b'<html></html>' assert get_response().read() == b'<html></html>'
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_verify_cert_error_text(self, handler): def test_verify_cert_error_text(self, handler):
# Check the output of the error message # Check the output of the error message
with handler() as rh: with handler() as rh:
@ -851,7 +782,6 @@ class TestUrllibRequestHandler(TestRequestHandlerBase):
): ):
validate_and_send(rh, Request(f'https://127.0.0.1:{self.https_port}/headers')) validate_and_send(rh, Request(f'https://127.0.0.1:{self.https_port}/headers'))
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
@pytest.mark.parametrize('req,match,version_check', [ @pytest.mark.parametrize('req,match,version_check', [
# https://github.com/python/cpython/blob/987b712b4aeeece336eed24fcc87a950a756c3e2/Lib/http/client.py#L1256 # https://github.com/python/cpython/blob/987b712b4aeeece336eed24fcc87a950a756c3e2/Lib/http/client.py#L1256
# bpo-39603: Check implemented in 3.7.9+, 3.8.5+ # bpo-39603: Check implemented in 3.7.9+, 3.8.5+
@ -1183,7 +1113,7 @@ class TestRequestHandlerValidation:
] ]
PROXY_SCHEME_TESTS = [ PROXY_SCHEME_TESTS = [
# scheme, expected to fail # proxy scheme, expected to fail
('Urllib', 'http', [ ('Urllib', 'http', [
('http', False), ('http', False),
('https', UnsupportedRequest), ('https', UnsupportedRequest),
@ -1209,30 +1139,41 @@ class TestRequestHandlerValidation:
('socks5', False), ('socks5', False),
('socks5h', False), ('socks5h', False),
]), ]),
('Websockets', 'ws', [
('http', False),
('https', False),
('socks4', False),
('socks4a', False),
('socks5', False),
('socks5h', False),
]),
(NoCheckRH, 'http', [('http', False)]), (NoCheckRH, 'http', [('http', False)]),
(HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]), (HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
('Websockets', 'ws', [('http', UnsupportedRequest)]),
(NoCheckRH, 'http', [('http', False)]), (NoCheckRH, 'http', [('http', False)]),
(HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]), (HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
] ]
PROXY_KEY_TESTS = [ PROXY_KEY_TESTS = [
# key, expected to fail # proxy key, proxy scheme, expected to fail
('Urllib', [ ('Urllib', 'http', [
('all', False), ('all', 'http', False),
('unrelated', False), ('unrelated', 'http', False),
]), ]),
('Requests', [ ('Requests', 'http', [
('all', False), ('all', 'http', False),
('unrelated', False), ('unrelated', 'http', False),
]), ]),
('CurlCFFI', [ ('CurlCFFI', 'http', [
('all', False), ('all', 'http', False),
('unrelated', False), ('unrelated', 'http', False),
]), ]),
(NoCheckRH, [('all', False)]), ('Websockets', 'ws', [
(HTTPSupportedRH, [('all', UnsupportedRequest)]), ('all', 'socks5', False),
(HTTPSupportedRH, [('no', UnsupportedRequest)]), ('unrelated', 'socks5', False),
]),
(NoCheckRH, 'http', [('all', 'http', False)]),
(HTTPSupportedRH, 'http', [('all', 'http', UnsupportedRequest)]),
(HTTPSupportedRH, 'http', [('no', 'http', UnsupportedRequest)]),
] ]
EXTENSION_TESTS = [ EXTENSION_TESTS = [
@ -1274,28 +1215,54 @@ class TestRequestHandlerValidation:
]), ]),
] ]
@pytest.mark.parametrize('handler,fail,scheme', [
('Urllib', False, 'http'),
('Requests', False, 'http'),
('CurlCFFI', False, 'http'),
('Websockets', False, 'ws')
], indirect=['handler'])
def test_no_proxy(self, handler, fail, scheme):
run_validation(handler, fail, Request(f'{scheme}://example.com', proxies={'no': '127.0.0.1,github.com'}))
run_validation(handler, fail, Request(f'{scheme}://example.com'), proxies={'no': '127.0.0.1,github.com'})
@pytest.mark.parametrize('handler,scheme', [
('Urllib', 'http'),
(HTTPSupportedRH, 'http'),
('Requests', 'http'),
('CurlCFFI', 'http'),
('Websockets', 'ws')
], indirect=['handler'])
def test_empty_proxy(self, handler, scheme):
run_validation(handler, False, Request(f'{scheme}://', proxies={scheme: None}))
run_validation(handler, False, Request(f'{scheme}://'), proxies={scheme: None})
@pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1', '/a/b/c'])
@pytest.mark.parametrize('handler,scheme', [
('Urllib', 'http'),
(HTTPSupportedRH, 'http'),
('Requests', 'http'),
('CurlCFFI', 'http'),
('Websockets', 'ws')
], indirect=['handler'])
def test_invalid_proxy_url(self, handler, scheme, proxy_url):
run_validation(handler, UnsupportedRequest, Request(f'{scheme}://', proxies={scheme: proxy_url}))
@pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [ @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
(handler_tests[0], scheme, fail, handler_kwargs) (handler_tests[0], scheme, fail, handler_kwargs)
for handler_tests in URL_SCHEME_TESTS for handler_tests in URL_SCHEME_TESTS
for scheme, fail, handler_kwargs in handler_tests[1] for scheme, fail, handler_kwargs in handler_tests[1]
], indirect=['handler']) ], indirect=['handler'])
def test_url_scheme(self, handler, scheme, fail, handler_kwargs): def test_url_scheme(self, handler, scheme, fail, handler_kwargs):
run_validation(handler, fail, Request(f'{scheme}://'), **(handler_kwargs or {})) run_validation(handler, fail, Request(f'{scheme}://'), **(handler_kwargs or {}))
@pytest.mark.parametrize('handler,fail', [('Urllib', False), ('Requests', False), ('CurlCFFI', False)], indirect=['handler']) @pytest.mark.parametrize('handler,scheme,proxy_key,proxy_scheme,fail', [
def test_no_proxy(self, handler, fail): (handler_tests[0], handler_tests[1], proxy_key, proxy_scheme, fail)
run_validation(handler, fail, Request('http://', proxies={'no': '127.0.0.1,github.com'}))
run_validation(handler, fail, Request('http://'), proxies={'no': '127.0.0.1,github.com'})
@pytest.mark.parametrize('handler,proxy_key,fail', [
(handler_tests[0], proxy_key, fail)
for handler_tests in PROXY_KEY_TESTS for handler_tests in PROXY_KEY_TESTS
for proxy_key, fail in handler_tests[1] for proxy_key, proxy_scheme, fail in handler_tests[2]
], indirect=['handler']) ], indirect=['handler'])
def test_proxy_key(self, handler, proxy_key, fail): def test_proxy_key(self, handler, scheme, proxy_key, proxy_scheme, fail):
run_validation(handler, fail, Request('http://', proxies={proxy_key: 'http://example.com'})) run_validation(handler, fail, Request(f'{scheme}://', proxies={proxy_key: f'{proxy_scheme}://example.com'}))
run_validation(handler, fail, Request('http://'), proxies={proxy_key: 'http://example.com'}) run_validation(handler, fail, Request(f'{scheme}://'), proxies={proxy_key: f'{proxy_scheme}://example.com'})
@pytest.mark.parametrize('handler,req_scheme,scheme,fail', [ @pytest.mark.parametrize('handler,req_scheme,scheme,fail', [
(handler_tests[0], handler_tests[1], scheme, fail) (handler_tests[0], handler_tests[1], scheme, fail)
@ -1306,16 +1273,6 @@ class TestRequestHandlerValidation:
run_validation(handler, fail, Request(f'{req_scheme}://', proxies={req_scheme: f'{scheme}://example.com'})) run_validation(handler, fail, Request(f'{req_scheme}://', proxies={req_scheme: f'{scheme}://example.com'}))
run_validation(handler, fail, Request(f'{req_scheme}://'), proxies={req_scheme: f'{scheme}://example.com'}) run_validation(handler, fail, Request(f'{req_scheme}://'), proxies={req_scheme: f'{scheme}://example.com'})
@pytest.mark.parametrize('handler', ['Urllib', HTTPSupportedRH, 'Requests', 'CurlCFFI'], indirect=True)
def test_empty_proxy(self, handler):
run_validation(handler, False, Request('http://', proxies={'http': None}))
run_validation(handler, False, Request('http://'), proxies={'http': None})
@pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1', '/a/b/c'])
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
def test_invalid_proxy_url(self, handler, proxy_url):
run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': proxy_url}))
@pytest.mark.parametrize('handler,scheme,extensions,fail', [ @pytest.mark.parametrize('handler,scheme,extensions,fail', [
(handler_tests[0], handler_tests[1], extensions, fail) (handler_tests[0], handler_tests[1], extensions, fail)
for handler_tests in EXTENSION_TESTS for handler_tests in EXTENSION_TESTS

View File

@ -216,7 +216,9 @@ class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
protocol = websockets.ServerProtocol() protocol = websockets.ServerProtocol()
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0) connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
connection.handshake() connection.handshake()
connection.send(json.dumps(self.socks_info)) for message in connection:
if message == 'socks_info':
connection.send(json.dumps(self.socks_info))
connection.close() connection.close()

View File

@ -7,6 +7,7 @@ import sys
import pytest import pytest
from test.helper import verify_address_availability from test.helper import verify_address_availability
from yt_dlp.networking.common import Features
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@ -18,7 +19,7 @@ import random
import ssl import ssl
import threading import threading
from yt_dlp import socks from yt_dlp import socks, traverse_obj
from yt_dlp.cookies import YoutubeDLCookieJar from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import websockets from yt_dlp.dependencies import websockets
from yt_dlp.networking import Request from yt_dlp.networking import Request
@ -114,6 +115,7 @@ def ws_validate_and_send(rh, req):
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
class TestWebsSocketRequestHandlerConformance: class TestWebsSocketRequestHandlerConformance:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
@ -129,7 +131,6 @@ class TestWebsSocketRequestHandlerConformance:
cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server() cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}' cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_basic_websockets(self, handler): def test_basic_websockets(self, handler):
with handler() as rh: with handler() as rh:
ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
@ -141,7 +142,6 @@ class TestWebsSocketRequestHandlerConformance:
# https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
@pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)]) @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_send_types(self, handler, msg, opcode): def test_send_types(self, handler, msg, opcode):
with handler() as rh: with handler() as rh:
ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
@ -149,7 +149,6 @@ class TestWebsSocketRequestHandlerConformance:
assert int(ws.recv()) == opcode assert int(ws.recv()) == opcode
ws.close() ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_verify_cert(self, handler): def test_verify_cert(self, handler):
with handler() as rh: with handler() as rh:
with pytest.raises(CertificateVerifyError): with pytest.raises(CertificateVerifyError):
@ -160,14 +159,12 @@ class TestWebsSocketRequestHandlerConformance:
assert ws.status == 101 assert ws.status == 101
ws.close() ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_ssl_error(self, handler): def test_ssl_error(self, handler):
with handler(verify=False) as rh: with handler(verify=False) as rh:
with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info: with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
ws_validate_and_send(rh, Request(self.bad_wss_host)) ws_validate_and_send(rh, Request(self.bad_wss_host))
assert not issubclass(exc_info.type, CertificateVerifyError) assert not issubclass(exc_info.type, CertificateVerifyError)
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('path,expected', [ @pytest.mark.parametrize('path,expected', [
# Unicode characters should be encoded with uppercase percent-encoding # Unicode characters should be encoded with uppercase percent-encoding
('/中文', '/%E4%B8%AD%E6%96%87'), ('/中文', '/%E4%B8%AD%E6%96%87'),
@ -182,7 +179,6 @@ class TestWebsSocketRequestHandlerConformance:
assert ws.status == 101 assert ws.status == 101
ws.close() ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_remove_dot_segments(self, handler): def test_remove_dot_segments(self, handler):
with handler() as rh: with handler() as rh:
# This isn't a comprehensive test, # This isn't a comprehensive test,
@ -195,7 +191,6 @@ class TestWebsSocketRequestHandlerConformance:
# We are restricted to known HTTP status codes in http.HTTPStatus # We are restricted to known HTTP status codes in http.HTTPStatus
# Redirects are not supported for websockets # Redirects are not supported for websockets
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511)) @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
def test_raise_http_error(self, handler, status): def test_raise_http_error(self, handler, status):
with handler() as rh: with handler() as rh:
@ -203,7 +198,6 @@ class TestWebsSocketRequestHandlerConformance:
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
assert exc_info.value.status == status assert exc_info.value.status == status
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('params,extensions', [ @pytest.mark.parametrize('params,extensions', [
({'timeout': sys.float_info.min}, {}), ({'timeout': sys.float_info.min}, {}),
({}, {'timeout': sys.float_info.min}), ({}, {'timeout': sys.float_info.min}),
@ -213,7 +207,6 @@ class TestWebsSocketRequestHandlerConformance:
with pytest.raises(TransportError): with pytest.raises(TransportError):
ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_cookies(self, handler): def test_cookies(self, handler):
cookiejar = YoutubeDLCookieJar() cookiejar = YoutubeDLCookieJar()
cookiejar.set_cookie(http.cookiejar.Cookie( cookiejar.set_cookie(http.cookiejar.Cookie(
@ -239,7 +232,6 @@ class TestWebsSocketRequestHandlerConformance:
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close() ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_source_address(self, handler): def test_source_address(self, handler):
source_address = f'127.0.0.{random.randint(5, 255)}' source_address = f'127.0.0.{random.randint(5, 255)}'
verify_address_availability(source_address) verify_address_availability(source_address)
@ -249,7 +241,6 @@ class TestWebsSocketRequestHandlerConformance:
assert source_address == ws.recv() assert source_address == ws.recv()
ws.close() ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_response_url(self, handler): def test_response_url(self, handler):
with handler() as rh: with handler() as rh:
url = f'{self.ws_base_url}/something' url = f'{self.ws_base_url}/something'
@ -257,7 +248,6 @@ class TestWebsSocketRequestHandlerConformance:
assert ws.url == url assert ws.url == url
ws.close() ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_request_headers(self, handler): def test_request_headers(self, handler):
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
# Global Headers # Global Headers
@ -293,7 +283,6 @@ class TestWebsSocketRequestHandlerConformance:
'client_certificate_password': 'foobar', 'client_certificate_password': 'foobar',
} }
)) ))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_mtls(self, handler, client_cert): def test_mtls(self, handler, client_cert):
with handler( with handler(
# Disable client-side validation of unacceptable self-signed testcert.pem # Disable client-side validation of unacceptable self-signed testcert.pem
@ -303,6 +292,44 @@ class TestWebsSocketRequestHandlerConformance:
) as rh: ) as rh:
ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close() ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
def test_request_disable_proxy(self, handler):
for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
# Given handler is configured with a proxy
with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
# When a proxy is explicitly set to None for the request
ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None}))
# Then no proxy should be used
assert ws.status == 101
ws.close()
@pytest.mark.skip_handlers_if(
lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
def test_noproxy(self, handler):
for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
# Given the handler is configured with a proxy
with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
for no_proxy in (f'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
# When request no proxy includes the request url host
ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'no': no_proxy}))
# Then the proxy should not be used
assert ws.status == 101
ws.close()
@pytest.mark.skip_handlers_if(
lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
def test_allproxy(self, handler):
supported_proto = traverse_obj(handler._SUPPORTED_PROXY_SCHEMES, 0, default='ws')
# This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
# 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
with handler(proxies={'all': f'{supported_proto}://10.255.255.255'}, timeout=0.1) as rh:
with pytest.raises(TransportError):
ws_validate_and_send(rh, Request(self.ws_base_url)).close()
with handler(timeout=0.1) as rh:
with pytest.raises(TransportError):
ws_validate_and_send(
rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close()
def create_fake_ws_connection(raised): def create_fake_ws_connection(raised):
import websockets.sync.client import websockets.sync.client

View File

@ -4140,15 +4140,15 @@ class YoutubeDL:
'Use --enable-file-urls to enable at your own risk.', cause=ue) from ue 'Use --enable-file-urls to enable at your own risk.', cause=ue) from ue
if ( if (
'unsupported proxy type: "https"' in ue.msg.lower() 'unsupported proxy type: "https"' in ue.msg.lower()
and 'requests' not in self._request_director.handlers and 'Requests' not in self._request_director.handlers
and 'curl_cffi' not in self._request_director.handlers and 'CurlCFFI' not in self._request_director.handlers
): ):
raise RequestError( raise RequestError(
'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests, curl_cffi') 'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests, curl_cffi')
elif ( elif (
re.match(r'unsupported url scheme: "wss?"', ue.msg.lower()) re.match(r'unsupported url scheme: "wss?"', ue.msg.lower())
and 'websockets' not in self._request_director.handlers and 'Websockets' not in self._request_director.handlers
): ):
raise RequestError( raise RequestError(
'This request requires WebSocket support. ' 'This request requires WebSocket support. '

View File

@ -21,7 +21,7 @@ from .exceptions import (
TransportError, TransportError,
) )
from .impersonate import ImpersonateRequestHandler, ImpersonateTarget from .impersonate import ImpersonateRequestHandler, ImpersonateTarget
from ..dependencies import curl_cffi from ..dependencies import curl_cffi, certifi
from ..utils import int_or_none from ..utils import int_or_none
if curl_cffi is None: if curl_cffi is None:
@ -156,6 +156,13 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
# See: https://curl.se/libcurl/c/CURLOPT_HTTPPROXYTUNNEL.html # See: https://curl.se/libcurl/c/CURLOPT_HTTPPROXYTUNNEL.html
session.curl.setopt(CurlOpt.HTTPPROXYTUNNEL, 1) session.curl.setopt(CurlOpt.HTTPPROXYTUNNEL, 1)
# curl_cffi does not currently set these for proxies
session.curl.setopt(CurlOpt.PROXY_CAINFO, certifi.where())
if not self.verify:
session.curl.setopt(CurlOpt.PROXY_SSL_VERIFYPEER, 0)
session.curl.setopt(CurlOpt.PROXY_SSL_VERIFYHOST, 0)
headers = self._get_impersonate_headers(request) headers = self._get_impersonate_headers(request)
if self._client_cert: if self._client_cert:
@ -203,7 +210,10 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
max_redirects_exceeded = True max_redirects_exceeded = True
curl_response = e.response curl_response = e.response
elif e.code == CurlECode.PROXY: elif (
e.code == CurlECode.PROXY
or (e.code == CurlECode.RECV_ERROR and 'Received HTTP code 407 from proxy after CONNECT' in str(e))
):
raise ProxyError(cause=e) from e raise ProxyError(cause=e) from e
else: else:
raise TransportError(cause=e) from e raise TransportError(cause=e) from e

View File

@ -1,10 +1,13 @@
from __future__ import annotations from __future__ import annotations
import base64
import contextlib import contextlib
import io import io
import logging import logging
import ssl import ssl
import sys import sys
import urllib.parse
from http.client import HTTPConnection, HTTPResponse
from ._helper import ( from ._helper import (
create_connection, create_connection,
@ -20,12 +23,14 @@ from .exceptions import (
RequestError, RequestError,
SSLError, SSLError,
TransportError, TransportError,
UnsupportedRequest,
) )
from .websocket import WebSocketRequestHandler, WebSocketResponse from .websocket import WebSocketRequestHandler, WebSocketResponse
from ..compat import functools from ..compat import functools
from ..dependencies import websockets from ..dependencies import urllib3, websockets
from ..socks import ProxyError as SocksProxyError from ..socks import ProxyError as SocksProxyError
from ..utils import int_or_none from ..utils import int_or_none
from ..utils.networking import HTTPHeaderDict
if not websockets: if not websockets:
raise ImportError('websockets is not installed') raise ImportError('websockets is not installed')
@ -36,6 +41,11 @@ websockets_version = tuple(map(int_or_none, websockets.version.version.split('.'
if websockets_version < (12, 0): if websockets_version < (12, 0):
raise ImportError('Only websockets>=12.0 is supported') raise ImportError('Only websockets>=12.0 is supported')
urllib3_supported = False
urllib3_version = tuple(int_or_none(x, default=0) for x in urllib3.__version__.split('.')) if urllib3 else None
if urllib3_version and urllib3_version >= (1, 26, 17):
urllib3_supported = True
import websockets.sync.client import websockets.sync.client
from websockets.uri import parse_uri from websockets.uri import parse_uri
@ -98,7 +108,7 @@ class WebsocketsRH(WebSocketRequestHandler):
https://github.com/python-websockets/websockets https://github.com/python-websockets/websockets
""" """
_SUPPORTED_URL_SCHEMES = ('wss', 'ws') _SUPPORTED_URL_SCHEMES = ('wss', 'ws')
_SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h') _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h', 'http', 'https')
_SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY) _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
RH_NAME = 'websockets' RH_NAME = 'websockets'
@ -108,12 +118,23 @@ class WebsocketsRH(WebSocketRequestHandler):
for name in ('websockets.client', 'websockets.server'): for name in ('websockets.client', 'websockets.server'):
logger = logging.getLogger(name) logger = logging.getLogger(name)
handler = logging.StreamHandler(stream=sys.stdout) handler = logging.StreamHandler(stream=sys.stdout)
handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s')) handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: [{name}] %(message)s'))
self.__logging_handlers[name] = handler self.__logging_handlers[name] = handler
logger.addHandler(handler) logger.addHandler(handler)
if self.verbose: if self.verbose:
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
def _validate(self, request):
super()._validate(request)
proxy = select_proxy(request.url, self._get_proxies(request))
if (
proxy
and urllib.parse.urlparse(proxy).scheme.lower() == 'https'
and urllib.parse.urlparse(request.url).scheme.lower() == 'wss'
and not urllib3_supported
):
raise UnsupportedRequest('WSS over HTTPS proxies requires a supported version of urllib3')
def _check_extensions(self, extensions): def _check_extensions(self, extensions):
super()._check_extensions(extensions) super()._check_extensions(extensions)
extensions.pop('timeout', None) extensions.pop('timeout', None)
@ -125,6 +146,38 @@ class WebsocketsRH(WebSocketRequestHandler):
for name, handler in self.__logging_handlers.items(): for name, handler in self.__logging_handlers.items():
logging.getLogger(name).removeHandler(handler) logging.getLogger(name).removeHandler(handler)
def _make_sock(self, proxy, url, timeout):
create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None,
'timeout': timeout
}
parsed_url = parse_uri(url)
parsed_proxy_url = urllib.parse.urlparse(proxy)
if proxy:
if parsed_proxy_url.scheme.startswith('socks'):
socks_proxy_options = make_socks_proxy_opts(proxy)
return create_connection(
address=(socks_proxy_options['addr'], socks_proxy_options['port']),
_create_socket_func=functools.partial(
create_socks_proxy_socket, (parsed_url.host, parsed_url.port), socks_proxy_options),
**create_conn_kwargs
)
elif parsed_proxy_url.scheme in ('http', 'https'):
return create_http_connect_conn(
proxy_url=proxy,
url=url,
timeout=timeout,
ssl_context=self._make_sslcontext() if parsed_proxy_url.scheme == 'https' else None,
source_address=self.source_address,
username=parsed_proxy_url.username,
password=parsed_proxy_url.password,
)
return create_connection(
address=(parsed_url.host, parsed_url.port),
**create_conn_kwargs
)
def _send(self, request): def _send(self, request):
timeout = self._calculate_timeout(request) timeout = self._calculate_timeout(request)
headers = self._merge_headers(request.headers) headers = self._merge_headers(request.headers)
@ -134,33 +187,22 @@ class WebsocketsRH(WebSocketRequestHandler):
if cookie_header: if cookie_header:
headers['cookie'] = cookie_header headers['cookie'] = cookie_header
wsuri = parse_uri(request.url)
create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None,
'timeout': timeout
}
proxy = select_proxy(request.url, self._get_proxies(request)) proxy = select_proxy(request.url, self._get_proxies(request))
try:
if proxy: ssl_context = None
socks_proxy_options = make_socks_proxy_opts(proxy) if parse_uri(request.url).secure:
sock = create_connection( if WebsocketsSSLContext is not None:
address=(socks_proxy_options['addr'], socks_proxy_options['port']), ssl_context = WebsocketsSSLContext(self._make_sslcontext())
_create_socket_func=functools.partial(
create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
**create_conn_kwargs
)
else: else:
sock = create_connection( ssl_context = self._make_sslcontext()
address=(wsuri.host, wsuri.port), try:
**create_conn_kwargs
)
conn = websockets.sync.client.connect( conn = websockets.sync.client.connect(
sock=sock, sock=self._make_sock(proxy, request.url, timeout),
uri=request.url, uri=request.url,
additional_headers=headers, additional_headers=headers,
open_timeout=timeout, open_timeout=timeout,
user_agent_header=None, user_agent_header=None,
ssl_context=self._make_sslcontext() if wsuri.secure else None, ssl_context=ssl_context,
close_timeout=0, # not ideal, but prevents yt-dlp hanging close_timeout=0, # not ideal, but prevents yt-dlp hanging
) )
return WebsocketsResponseAdapter(conn, url=request.url) return WebsocketsResponseAdapter(conn, url=request.url)
@ -185,3 +227,98 @@ class WebsocketsRH(WebSocketRequestHandler):
) from e ) from e
except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e: except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
raise TransportError(cause=e) from e raise TransportError(cause=e) from e
class NoCloseHTTPResponse(HTTPResponse):
def begin(self):
super().begin()
# Revert the default behavior of closing the connection after reading the response
if not self._check_close() and not self.chunked and self.length is None:
self.will_close = False
if urllib3_supported:
from urllib3.util.ssltransport import SSLTransport
class WebsocketsSSLTransport(SSLTransport):
"""
Modified version of urllib3 SSLTransport to support additional operations used by websockets
"""
def setsockopt(self, *args, **kwargs):
self.socket.setsockopt(*args, **kwargs)
def shutdown(self, *args, **kwargs):
self.unwrap()
self.socket.shutdown(*args, **kwargs)
else:
WebsocketsSSLTransport = None
class WebsocketsSSLContext:
"""
Dummy SSL Context for websockets which returns a WebsocketsSSLTransport instance
for wrap socket when using TLS-in-TLS.
"""
def __init__(self, ssl_context: ssl.SSLContext):
self.ssl_context = ssl_context
def wrap_socket(self, sock, server_hostname=None):
if isinstance(sock, ssl.SSLSocket):
return WebsocketsSSLTransport(sock, self.ssl_context, server_hostname=server_hostname)
return self.ssl_context.wrap_socket(sock, server_hostname=server_hostname)
def create_http_connect_conn(
proxy_url,
url,
timeout=None,
ssl_context=None,
source_address=None,
username=None,
password=None,
):
proxy_headers = HTTPHeaderDict()
if username is not None or password is not None:
proxy_headers['Proxy-Authorization'] = 'Basic ' + base64.b64encode(
f'{username or ""}:{password or ""}'.encode('utf-8')).decode('utf-8')
proxy_url_parsed = urllib.parse.urlparse(proxy_url)
request_url_parsed = parse_uri(url)
conn = HTTPConnection(proxy_url_parsed.hostname, port=proxy_url_parsed.port, timeout=timeout)
conn.response_class = NoCloseHTTPResponse
if hasattr(conn, '_create_connection'):
conn._create_connection = create_connection
if source_address is not None:
conn.source_address = (source_address, 0)
try:
conn.connect()
if ssl_context:
conn.sock = ssl_context.wrap_socket(conn.sock, server_hostname=proxy_url_parsed.hostname)
conn.request(
method='CONNECT',
url=f'{request_url_parsed.host}:{request_url_parsed.port}',
headers=proxy_headers)
response = conn.getresponse()
except OSError as e:
conn.close()
raise ProxyError('Unable to connect to proxy', cause=e) from e
if response.status == 200:
return conn.sock
elif response.status == 407:
conn.close()
raise ProxyError('Got HTTP Error 407 with CONNECT: Proxy Authentication Required')
else:
conn.close()
res_adapter = Response(
fp=io.BytesIO(b''),
url=proxy_url, headers=response.headers,
status=response.status,
reason=response.reason)
raise HTTPError(response=res_adapter)

View File

@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import urllib.parse
from .common import RequestHandler, Response from .common import RequestHandler, Response, register_preference
class WebSocketResponse(Response): class WebSocketResponse(Response):
@ -21,3 +22,10 @@ class WebSocketResponse(Response):
class WebSocketRequestHandler(RequestHandler, abc.ABC): class WebSocketRequestHandler(RequestHandler, abc.ABC):
pass pass
@register_preference(WebSocketRequestHandler)
def websocket_preference(_, request):
if urllib.parse.urlparse(request.url).scheme in ('ws', 'wss'):
return 200
return 0