from __future__ import annotations import io import logging import ssl import sys from ._helper import ( create_connection, create_socks_proxy_socket, make_socks_proxy_opts, select_proxy, ) from .common import Features, Response, register_rh from .exceptions import ( CertificateVerifyError, HTTPError, ProxyError, RequestError, SSLError, TransportError, ) from .websocket import WebSocketRequestHandler, WebSocketResponse from ..compat import functools from ..dependencies import websockets from ..socks import ProxyError as SocksProxyError from ..utils import int_or_none if not websockets: raise ImportError('websockets is not installed') import websockets.version websockets_version = tuple(map(int_or_none, websockets.version.version.split('.'))) if websockets_version < (12, 0): raise ImportError('Only websockets>=12.0 is supported') import websockets.sync.client from websockets.uri import parse_uri class WebsocketsResponseAdapter(WebSocketResponse): def __init__(self, wsw: websockets.sync.client.ClientConnection, url): super().__init__( fp=io.BytesIO(wsw.response.body or b''), url=url, headers=wsw.response.headers, status=wsw.response.status_code, reason=wsw.response.reason_phrase, ) self.wsw = wsw def close(self): self.wsw.close() super().close() def send(self, message): # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send try: return self.wsw.send(message) except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: raise TransportError(cause=e) from e except SocksProxyError as e: raise ProxyError(cause=e) from e except TypeError as e: raise RequestError(cause=e) from e def recv(self): # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv try: return self.wsw.recv() except SocksProxyError as e: raise ProxyError(cause=e) from e except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: raise TransportError(cause=e) from e @register_rh class WebsocketsRH(WebSocketRequestHandler): """ Websockets request handler https://websockets.readthedocs.io https://github.com/python-websockets/websockets """ _SUPPORTED_URL_SCHEMES = ('wss', 'ws') _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h') _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY) RH_NAME = 'websockets' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__logging_handlers = {} for name in ('websockets.client', 'websockets.server'): logger = logging.getLogger(name) handler = logging.StreamHandler(stream=sys.stdout) handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s')) self.__logging_handlers[name] = handler logger.addHandler(handler) if self.verbose: logger.setLevel(logging.DEBUG) def _check_extensions(self, extensions): super()._check_extensions(extensions) extensions.pop('timeout', None) extensions.pop('cookiejar', None) def close(self): # Remove the logging handler that contains a reference to our logger # See: https://github.com/yt-dlp/yt-dlp/issues/8922 for name, handler in self.__logging_handlers.items(): logging.getLogger(name).removeHandler(handler) def _send(self, request): timeout = self._calculate_timeout(request) headers = self._merge_headers(request.headers) if 'cookie' not in headers: cookiejar = self._get_cookiejar(request) cookie_header = cookiejar.get_cookie_header(request.url) if cookie_header: headers['cookie'] = cookie_header wsuri = parse_uri(request.url) create_conn_kwargs = { 'source_address': (self.source_address, 0) if self.source_address else None, 'timeout': timeout } proxy = select_proxy(request.url, self._get_proxies(request)) try: if proxy: socks_proxy_options = make_socks_proxy_opts(proxy) sock = create_connection( address=(socks_proxy_options['addr'], socks_proxy_options['port']), _create_socket_func=functools.partial( create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options), **create_conn_kwargs ) else: sock = create_connection( address=(wsuri.host, wsuri.port), **create_conn_kwargs ) conn = websockets.sync.client.connect( sock=sock, uri=request.url, additional_headers=headers, open_timeout=timeout, user_agent_header=None, ssl_context=self._make_sslcontext() if wsuri.secure else None, close_timeout=0, # not ideal, but prevents yt-dlp hanging ) return WebsocketsResponseAdapter(conn, url=request.url) # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html except SocksProxyError as e: raise ProxyError(cause=e) from e except websockets.exceptions.InvalidURI as e: raise RequestError(cause=e) from e except ssl.SSLCertVerificationError as e: raise CertificateVerifyError(cause=e) from e except ssl.SSLError as e: raise SSLError(cause=e) from e except websockets.exceptions.InvalidStatus as e: raise HTTPError( Response( fp=io.BytesIO(e.response.body), url=request.url, headers=e.response.headers, status=e.response.status_code, reason=e.response.reason_phrase), ) from e except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e: raise TransportError(cause=e) from e