webshell/.env/lib/python3.12/site-packages/uvloop/_testbase.py

551 lines
15 KiB
Python

"""Test utilities. Don't use outside of the uvloop project."""
import asyncio
import asyncio.events
import collections
import contextlib
import gc
import logging
import os
import pprint
import re
import select
import socket
import ssl
import sys
import tempfile
import threading
import time
import unittest
import uvloop
class MockPattern(str):
def __eq__(self, other):
return bool(re.search(str(self), other, re.S))
class TestCaseDict(collections.UserDict):
def __init__(self, name):
super().__init__()
self.name = name
def __setitem__(self, key, value):
if key in self.data:
raise RuntimeError('duplicate test {}.{}'.format(
self.name, key))
super().__setitem__(key, value)
class BaseTestCaseMeta(type):
@classmethod
def __prepare__(mcls, name, bases):
return TestCaseDict(name)
def __new__(mcls, name, bases, dct):
for test_name in dct:
if not test_name.startswith('test_'):
continue
for base in bases:
if hasattr(base, test_name):
raise RuntimeError(
'duplicate test {}.{} (also defined in {} '
'parent class)'.format(
name, test_name, base.__name__))
return super().__new__(mcls, name, bases, dict(dct))
class BaseTestCase(unittest.TestCase, metaclass=BaseTestCaseMeta):
def new_loop(self):
raise NotImplementedError
def new_policy(self):
raise NotImplementedError
def mock_pattern(self, str):
return MockPattern(str)
async def wait_closed(self, obj):
if not isinstance(obj, asyncio.StreamWriter):
return
try:
await obj.wait_closed()
except (BrokenPipeError, ConnectionError):
pass
def is_asyncio_loop(self):
return type(self.loop).__module__.startswith('asyncio.')
def run_loop_briefly(self, *, delay=0.01):
self.loop.run_until_complete(asyncio.sleep(delay))
def loop_exception_handler(self, loop, context):
self.__unhandled_exceptions.append(context)
self.loop.default_exception_handler(context)
def setUp(self):
self.loop = self.new_loop()
asyncio.set_event_loop_policy(self.new_policy())
asyncio.set_event_loop(self.loop)
self._check_unclosed_resources_in_debug = True
self.loop.set_exception_handler(self.loop_exception_handler)
self.__unhandled_exceptions = []
def tearDown(self):
self.loop.close()
if self.__unhandled_exceptions:
print('Unexpected calls to loop.call_exception_handler():')
pprint.pprint(self.__unhandled_exceptions)
self.fail('unexpected calls to loop.call_exception_handler()')
return
if not self._check_unclosed_resources_in_debug:
return
# GC to show any resource warnings as the test completes
gc.collect()
gc.collect()
gc.collect()
if getattr(self.loop, '_debug_cc', False):
gc.collect()
gc.collect()
gc.collect()
self.assertEqual(
self.loop._debug_uv_handles_total,
self.loop._debug_uv_handles_freed,
'not all uv_handle_t handles were freed')
self.assertEqual(
self.loop._debug_cb_handles_count, 0,
'not all callbacks (call_soon) are GCed')
self.assertEqual(
self.loop._debug_cb_timer_handles_count, 0,
'not all timer callbacks (call_later) are GCed')
self.assertEqual(
self.loop._debug_stream_write_ctx_cnt, 0,
'not all stream write contexts are GCed')
for h_name, h_cnt in self.loop._debug_handles_current.items():
with self.subTest('Alive handle after test',
handle_name=h_name):
self.assertEqual(
h_cnt, 0,
'alive {} after test'.format(h_name))
for h_name, h_cnt in self.loop._debug_handles_total.items():
with self.subTest('Total/closed handles',
handle_name=h_name):
self.assertEqual(
h_cnt, self.loop._debug_handles_closed[h_name],
'total != closed for {}'.format(h_name))
asyncio.set_event_loop(None)
asyncio.set_event_loop_policy(None)
self.loop = None
def skip_unclosed_handles_check(self):
self._check_unclosed_resources_in_debug = False
def tcp_server(self, server_prog, *,
family=socket.AF_INET,
addr=None,
timeout=5,
backlog=1,
max_clients=10):
if addr is None:
if family == socket.AF_UNIX:
with tempfile.NamedTemporaryFile() as tmp:
addr = tmp.name
else:
addr = ('127.0.0.1', 0)
sock = socket.socket(family, socket.SOCK_STREAM)
if timeout is None:
raise RuntimeError('timeout is required')
if timeout <= 0:
raise RuntimeError('only blocking sockets are supported')
sock.settimeout(timeout)
try:
sock.bind(addr)
sock.listen(backlog)
except OSError as ex:
sock.close()
raise ex
return TestThreadedServer(
self, sock, server_prog, timeout, max_clients)
def tcp_client(self, client_prog,
family=socket.AF_INET,
timeout=10):
sock = socket.socket(family, socket.SOCK_STREAM)
if timeout is None:
raise RuntimeError('timeout is required')
if timeout <= 0:
raise RuntimeError('only blocking sockets are supported')
sock.settimeout(timeout)
return TestThreadedClient(
self, sock, client_prog, timeout)
def unix_server(self, *args, **kwargs):
return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
def unix_client(self, *args, **kwargs):
return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
@contextlib.contextmanager
def unix_sock_name(self):
with tempfile.TemporaryDirectory() as td:
fn = os.path.join(td, 'sock')
try:
yield fn
finally:
try:
os.unlink(fn)
except OSError:
pass
def _abort_socket_test(self, ex):
try:
self.loop.stop()
finally:
self.fail(ex)
def _cert_fullname(test_file_name, cert_file_name):
fullname = os.path.abspath(os.path.join(
os.path.dirname(test_file_name), 'certs', cert_file_name))
assert os.path.isfile(fullname)
return fullname
@contextlib.contextmanager
def silence_long_exec_warning():
class Filter(logging.Filter):
def filter(self, record):
return not (record.msg.startswith('Executing') and
record.msg.endswith('seconds'))
logger = logging.getLogger('asyncio')
filter = Filter()
logger.addFilter(filter)
try:
yield
finally:
logger.removeFilter(filter)
def find_free_port(start_from=50000):
for port in range(start_from, start_from + 500):
sock = socket.socket()
with sock:
try:
sock.bind(('', port))
except socket.error:
continue
else:
return port
raise RuntimeError('could not find a free port')
class SSLTestCase:
def _create_server_ssl_context(self, certfile, keyfile=None):
if hasattr(ssl, 'PROTOCOL_TLS'):
sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS)
else:
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.load_cert_chain(certfile, keyfile)
return sslcontext
def _create_client_ssl_context(self, *, disable_verify=True):
sslcontext = ssl.create_default_context()
sslcontext.check_hostname = False
if disable_verify:
sslcontext.verify_mode = ssl.CERT_NONE
return sslcontext
@contextlib.contextmanager
def _silence_eof_received_warning(self):
# TODO This warning has to be fixed in asyncio.
logger = logging.getLogger('asyncio')
filter = logging.Filter('has no effect when using ssl')
logger.addFilter(filter)
try:
yield
finally:
logger.removeFilter(filter)
class UVTestCase(BaseTestCase):
implementation = 'uvloop'
def new_loop(self):
return uvloop.new_event_loop()
def new_policy(self):
return uvloop.EventLoopPolicy()
class AIOTestCase(BaseTestCase):
implementation = 'asyncio'
def setUp(self):
super().setUp()
if sys.version_info < (3, 12):
watcher = asyncio.SafeChildWatcher()
watcher.attach_loop(self.loop)
asyncio.set_child_watcher(watcher)
def tearDown(self):
if sys.version_info < (3, 12):
asyncio.set_child_watcher(None)
super().tearDown()
def new_loop(self):
return asyncio.new_event_loop()
def new_policy(self):
return asyncio.DefaultEventLoopPolicy()
def has_IPv6():
server_sock = socket.socket(socket.AF_INET6)
with server_sock:
try:
server_sock.bind(('::1', 0))
except OSError:
return False
else:
return True
has_IPv6 = has_IPv6()
###############################################################################
# Socket Testing Utilities
###############################################################################
class TestSocketWrapper:
def __init__(self, sock):
self.__sock = sock
def recv_all(self, n):
buf = b''
while len(buf) < n:
data = self.recv(n - len(buf))
if data == b'':
raise ConnectionAbortedError
buf += data
return buf
def starttls(self, ssl_context, *,
server_side=False,
server_hostname=None,
do_handshake_on_connect=True):
assert isinstance(ssl_context, ssl.SSLContext)
ssl_sock = ssl_context.wrap_socket(
self.__sock, server_side=server_side,
server_hostname=server_hostname,
do_handshake_on_connect=do_handshake_on_connect)
if server_side:
ssl_sock.do_handshake()
self.__sock.close()
self.__sock = ssl_sock
def __getattr__(self, name):
return getattr(self.__sock, name)
def __repr__(self):
return '<{} {!r}>'.format(type(self).__name__, self.__sock)
class SocketThread(threading.Thread):
def stop(self):
self._active = False
self.join()
def __enter__(self):
self.start()
return self
def __exit__(self, *exc):
self.stop()
class TestThreadedClient(SocketThread):
def __init__(self, test, sock, prog, timeout):
threading.Thread.__init__(self, None, None, 'test-client')
self.daemon = True
self._timeout = timeout
self._sock = sock
self._active = True
self._prog = prog
self._test = test
def run(self):
try:
self._prog(TestSocketWrapper(self._sock))
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as ex:
self._test._abort_socket_test(ex)
class TestThreadedServer(SocketThread):
def __init__(self, test, sock, prog, timeout, max_clients):
threading.Thread.__init__(self, None, None, 'test-server')
self.daemon = True
self._clients = 0
self._finished_clients = 0
self._max_clients = max_clients
self._timeout = timeout
self._sock = sock
self._active = True
self._prog = prog
self._s1, self._s2 = socket.socketpair()
self._s1.setblocking(False)
self._test = test
def stop(self):
try:
if self._s2 and self._s2.fileno() != -1:
try:
self._s2.send(b'stop')
except OSError:
pass
finally:
super().stop()
def run(self):
try:
with self._sock:
self._sock.setblocking(0)
self._run()
finally:
self._s1.close()
self._s2.close()
def _run(self):
while self._active:
if self._clients >= self._max_clients:
return
r, w, x = select.select(
[self._sock, self._s1], [], [], self._timeout)
if self._s1 in r:
return
if self._sock in r:
try:
conn, addr = self._sock.accept()
except BlockingIOError:
continue
except socket.timeout:
if not self._active:
return
else:
raise
else:
self._clients += 1
conn.settimeout(self._timeout)
try:
with conn:
self._handle_client(conn)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as ex:
self._active = False
try:
raise
finally:
self._test._abort_socket_test(ex)
def _handle_client(self, sock):
self._prog(TestSocketWrapper(sock))
@property
def addr(self):
return self._sock.getsockname()
###############################################################################
# A few helpers from asyncio/tests/testutils.py
###############################################################################
def run_briefly(loop):
async def once():
pass
gen = once()
t = loop.create_task(gen)
# Don't log a warning if the task is not done after run_until_complete().
# It occurs if the loop is stopped or if a task raises a BaseException.
t._log_destroy_pending = False
try:
loop.run_until_complete(t)
finally:
gen.close()
def run_until(loop, pred, timeout=30):
deadline = time.time() + timeout
while not pred():
if timeout is not None:
timeout = deadline - time.time()
if timeout <= 0:
raise asyncio.futures.TimeoutError()
loop.run_until_complete(asyncio.tasks.sleep(0.001))
@contextlib.contextmanager
def disable_logger():
"""Context manager to disable asyncio logger.
For example, it can be used to ignore warnings in debug mode.
"""
old_level = asyncio.log.logger.level
try:
asyncio.log.logger.setLevel(logging.CRITICAL + 1)
yield
finally:
asyncio.log.logger.setLevel(old_level)