Updated script that can be controled by Nodejs web app
This commit is contained in:
@@ -0,0 +1,138 @@
|
||||
"""Trio - A friendly Python library for async concurrency and I/O
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# General layout:
|
||||
#
|
||||
# trio/_core/... is the self-contained core library. It does various
|
||||
# shenanigans to export a consistent "core API", but parts of the core API are
|
||||
# too low-level to be recommended for regular use.
|
||||
#
|
||||
# trio/*.py define a set of more usable tools on top of this. They import from
|
||||
# trio._core and from each other.
|
||||
#
|
||||
# This file pulls together the friendly public API, by re-exporting the more
|
||||
# innocuous bits of the _core API + the higher-level tools from trio/*.py.
|
||||
#
|
||||
# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625)
|
||||
#
|
||||
# must be imported early to avoid circular import
|
||||
from ._core import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED # isort: split
|
||||
|
||||
# Submodules imported by default
|
||||
from . import abc, from_thread, lowlevel, socket, to_thread
|
||||
from ._channel import (
|
||||
MemoryChannelStatistics as MemoryChannelStatistics,
|
||||
MemoryReceiveChannel as MemoryReceiveChannel,
|
||||
MemorySendChannel as MemorySendChannel,
|
||||
open_memory_channel as open_memory_channel,
|
||||
)
|
||||
from ._core import (
|
||||
BrokenResourceError as BrokenResourceError,
|
||||
BusyResourceError as BusyResourceError,
|
||||
Cancelled as Cancelled,
|
||||
CancelScope as CancelScope,
|
||||
ClosedResourceError as ClosedResourceError,
|
||||
EndOfChannel as EndOfChannel,
|
||||
Nursery as Nursery,
|
||||
RunFinishedError as RunFinishedError,
|
||||
TaskStatus as TaskStatus,
|
||||
TrioInternalError as TrioInternalError,
|
||||
WouldBlock as WouldBlock,
|
||||
current_effective_deadline as current_effective_deadline,
|
||||
current_time as current_time,
|
||||
open_nursery as open_nursery,
|
||||
run as run,
|
||||
)
|
||||
from ._deprecate import TrioDeprecationWarning as TrioDeprecationWarning
|
||||
from ._dtls import (
|
||||
DTLSChannel as DTLSChannel,
|
||||
DTLSChannelStatistics as DTLSChannelStatistics,
|
||||
DTLSEndpoint as DTLSEndpoint,
|
||||
)
|
||||
from ._file_io import open_file as open_file, wrap_file as wrap_file
|
||||
from ._highlevel_generic import (
|
||||
StapledStream as StapledStream,
|
||||
aclose_forcefully as aclose_forcefully,
|
||||
)
|
||||
from ._highlevel_open_tcp_listeners import (
|
||||
open_tcp_listeners as open_tcp_listeners,
|
||||
serve_tcp as serve_tcp,
|
||||
)
|
||||
from ._highlevel_open_tcp_stream import open_tcp_stream as open_tcp_stream
|
||||
from ._highlevel_open_unix_stream import open_unix_socket as open_unix_socket
|
||||
from ._highlevel_serve_listeners import serve_listeners as serve_listeners
|
||||
from ._highlevel_socket import (
|
||||
SocketListener as SocketListener,
|
||||
SocketStream as SocketStream,
|
||||
)
|
||||
from ._highlevel_ssl_helpers import (
|
||||
open_ssl_over_tcp_listeners as open_ssl_over_tcp_listeners,
|
||||
open_ssl_over_tcp_stream as open_ssl_over_tcp_stream,
|
||||
serve_ssl_over_tcp as serve_ssl_over_tcp,
|
||||
)
|
||||
from ._path import Path as Path, PosixPath as PosixPath, WindowsPath as WindowsPath
|
||||
from ._signals import open_signal_receiver as open_signal_receiver
|
||||
from ._ssl import (
|
||||
NeedHandshakeError as NeedHandshakeError,
|
||||
SSLListener as SSLListener,
|
||||
SSLStream as SSLStream,
|
||||
)
|
||||
from ._subprocess import Process as Process, run_process as run_process
|
||||
from ._sync import (
|
||||
CapacityLimiter as CapacityLimiter,
|
||||
CapacityLimiterStatistics as CapacityLimiterStatistics,
|
||||
Condition as Condition,
|
||||
ConditionStatistics as ConditionStatistics,
|
||||
Event as Event,
|
||||
EventStatistics as EventStatistics,
|
||||
Lock as Lock,
|
||||
LockStatistics as LockStatistics,
|
||||
Semaphore as Semaphore,
|
||||
StrictFIFOLock as StrictFIFOLock,
|
||||
)
|
||||
from ._timeouts import (
|
||||
TooSlowError as TooSlowError,
|
||||
fail_after as fail_after,
|
||||
fail_at as fail_at,
|
||||
move_on_after as move_on_after,
|
||||
move_on_at as move_on_at,
|
||||
sleep as sleep,
|
||||
sleep_forever as sleep_forever,
|
||||
sleep_until as sleep_until,
|
||||
)
|
||||
|
||||
# pyright explicitly does not care about `__version__`
|
||||
# see https://github.com/microsoft/pyright/blob/main/docs/typed-libraries.md#type-completeness
|
||||
from ._version import __version__
|
||||
|
||||
# Not imported by default, but mentioned here so static analysis tools like
|
||||
# pylint will know that it exists.
|
||||
if TYPE_CHECKING:
|
||||
from . import testing
|
||||
|
||||
from . import _deprecate as _deprecate
|
||||
|
||||
_deprecate.enable_attribute_deprecations(__name__)
|
||||
|
||||
__deprecated_attributes__: dict[str, _deprecate.DeprecatedAttribute] = {}
|
||||
|
||||
# Having the public path in .__module__ attributes is important for:
|
||||
# - exception names in printed tracebacks
|
||||
# - sphinx :show-inheritance:
|
||||
# - deprecation warnings
|
||||
# - pickle
|
||||
# - probably other stuff
|
||||
from ._util import fixup_module_metadata
|
||||
|
||||
fixup_module_metadata(__name__, globals())
|
||||
fixup_module_metadata(lowlevel.__name__, lowlevel.__dict__)
|
||||
fixup_module_metadata(socket.__name__, socket.__dict__)
|
||||
fixup_module_metadata(abc.__name__, abc.__dict__)
|
||||
fixup_module_metadata(from_thread.__name__, from_thread.__dict__)
|
||||
fixup_module_metadata(to_thread.__name__, to_thread.__dict__)
|
||||
del fixup_module_metadata
|
||||
del TYPE_CHECKING
|
||||
@@ -0,0 +1,3 @@
|
||||
from trio._repl import main
|
||||
|
||||
main(locals())
|
||||
@@ -0,0 +1,716 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
import trio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
# both of these introduce circular imports if outside a TYPE_CHECKING guard
|
||||
from ._socket import SocketType
|
||||
from .lowlevel import Task
|
||||
|
||||
|
||||
class Clock(ABC):
|
||||
"""The interface for custom run loop clocks."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
def start_clock(self) -> None:
|
||||
"""Do any setup this clock might need.
|
||||
|
||||
Called at the beginning of the run.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def current_time(self) -> float:
|
||||
"""Return the current time, according to this clock.
|
||||
|
||||
This is used to implement functions like :func:`trio.current_time` and
|
||||
:func:`trio.move_on_after`.
|
||||
|
||||
Returns:
|
||||
float: The current time.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def deadline_to_sleep_time(self, deadline: float) -> float:
|
||||
"""Compute the real time until the given deadline.
|
||||
|
||||
This is called before we enter a system-specific wait function like
|
||||
:func:`select.select`, to get the timeout to pass.
|
||||
|
||||
For a clock using wall-time, this should be something like::
|
||||
|
||||
return deadline - self.current_time()
|
||||
|
||||
but of course it may be different if you're implementing some kind of
|
||||
virtual clock.
|
||||
|
||||
Args:
|
||||
deadline (float): The absolute time of the next deadline,
|
||||
according to this clock.
|
||||
|
||||
Returns:
|
||||
float: The number of real seconds to sleep until the given
|
||||
deadline. May be :data:`math.inf`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class Instrument(ABC): # noqa: B024 # conceptually is ABC
|
||||
"""The interface for run loop instrumentation.
|
||||
|
||||
Instruments don't have to inherit from this abstract base class, and all
|
||||
of these methods are optional. This class serves mostly as documentation.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def before_run(self) -> None:
|
||||
"""Called at the beginning of :func:`trio.run`."""
|
||||
return
|
||||
|
||||
def after_run(self) -> None:
|
||||
"""Called just before :func:`trio.run` returns."""
|
||||
return
|
||||
|
||||
def task_spawned(self, task: Task) -> None:
|
||||
"""Called when the given task is created.
|
||||
|
||||
Args:
|
||||
task (trio.lowlevel.Task): The new task.
|
||||
|
||||
"""
|
||||
return
|
||||
|
||||
def task_scheduled(self, task: Task) -> None:
|
||||
"""Called when the given task becomes runnable.
|
||||
|
||||
It may still be some time before it actually runs, if there are other
|
||||
runnable tasks ahead of it.
|
||||
|
||||
Args:
|
||||
task (trio.lowlevel.Task): The task that became runnable.
|
||||
|
||||
"""
|
||||
return
|
||||
|
||||
def before_task_step(self, task: Task) -> None:
|
||||
"""Called immediately before we resume running the given task.
|
||||
|
||||
Args:
|
||||
task (trio.lowlevel.Task): The task that is about to run.
|
||||
|
||||
"""
|
||||
return
|
||||
|
||||
def after_task_step(self, task: Task) -> None:
|
||||
"""Called when we return to the main run loop after a task has yielded.
|
||||
|
||||
Args:
|
||||
task (trio.lowlevel.Task): The task that just ran.
|
||||
|
||||
"""
|
||||
return
|
||||
|
||||
def task_exited(self, task: Task) -> None:
|
||||
"""Called when the given task exits.
|
||||
|
||||
Args:
|
||||
task (trio.lowlevel.Task): The finished task.
|
||||
|
||||
"""
|
||||
return
|
||||
|
||||
def before_io_wait(self, timeout: float) -> None:
|
||||
"""Called before blocking to wait for I/O readiness.
|
||||
|
||||
Args:
|
||||
timeout (float): The number of seconds we are willing to wait.
|
||||
|
||||
"""
|
||||
return
|
||||
|
||||
def after_io_wait(self, timeout: float) -> None:
|
||||
"""Called after handling pending I/O.
|
||||
|
||||
Args:
|
||||
timeout (float): The number of seconds we were willing to
|
||||
wait. This much time may or may not have elapsed, depending on
|
||||
whether any I/O was ready.
|
||||
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class HostnameResolver(ABC):
|
||||
"""If you have a custom hostname resolver, then implementing
|
||||
:class:`HostnameResolver` allows you to register this to be used by Trio.
|
||||
|
||||
See :func:`trio.socket.set_custom_hostname_resolver`.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = 0,
|
||||
type: int = 0,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
"""A custom implementation of :func:`~trio.socket.getaddrinfo`.
|
||||
|
||||
Called by :func:`trio.socket.getaddrinfo`.
|
||||
|
||||
If ``host`` is given as a numeric IP address, then
|
||||
:func:`~trio.socket.getaddrinfo` may handle the request itself rather
|
||||
than calling this method.
|
||||
|
||||
Any required IDNA encoding is handled before calling this function;
|
||||
your implementation can assume that it will never see U-labels like
|
||||
``"café.com"``, and only needs to handle A-labels like
|
||||
``b"xn--caf-dma.com"``.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def getnameinfo(
|
||||
self,
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
flags: int,
|
||||
) -> tuple[str, str]:
|
||||
"""A custom implementation of :func:`~trio.socket.getnameinfo`.
|
||||
|
||||
Called by :func:`trio.socket.getnameinfo`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class SocketFactory(ABC):
|
||||
"""If you write a custom class implementing the Trio socket interface,
|
||||
then you can use a :class:`SocketFactory` to get Trio to use it.
|
||||
|
||||
See :func:`trio.socket.set_custom_socket_factory`.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
def socket(
|
||||
self,
|
||||
family: socket.AddressFamily | int = socket.AF_INET,
|
||||
type: socket.SocketKind | int = socket.SOCK_STREAM,
|
||||
proto: int = 0,
|
||||
) -> SocketType:
|
||||
"""Create and return a socket object.
|
||||
|
||||
Your socket object must inherit from :class:`trio.socket.SocketType`,
|
||||
which is an empty class whose only purpose is to "mark" which classes
|
||||
should be considered valid Trio sockets.
|
||||
|
||||
Called by :func:`trio.socket.socket`.
|
||||
|
||||
Note that unlike :func:`trio.socket.socket`, this does not take a
|
||||
``fileno=`` argument. If a ``fileno=`` is specified, then
|
||||
:func:`trio.socket.socket` returns a regular Trio socket object
|
||||
instead of calling this method.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class AsyncResource(ABC):
|
||||
"""A standard interface for resources that needs to be cleaned up, and
|
||||
where that cleanup may require blocking operations.
|
||||
|
||||
This class distinguishes between "graceful" closes, which may perform I/O
|
||||
and thus block, and a "forceful" close, which cannot. For example, cleanly
|
||||
shutting down a TLS-encrypted connection requires sending a "goodbye"
|
||||
message; but if a peer has become non-responsive, then sending this
|
||||
message might block forever, so we may want to just drop the connection
|
||||
instead. Therefore the :meth:`aclose` method is unusual in that it
|
||||
should always close the connection (or at least make its best attempt)
|
||||
*even if it fails*; failure indicates a failure to achieve grace, not a
|
||||
failure to close the connection.
|
||||
|
||||
Objects that implement this interface can be used as async context
|
||||
managers, i.e., you can write::
|
||||
|
||||
async with create_resource() as some_async_resource:
|
||||
...
|
||||
|
||||
Entering the context manager is synchronous (not a checkpoint); exiting it
|
||||
calls :meth:`aclose`. The default implementations of
|
||||
``__aenter__`` and ``__aexit__`` should be adequate for all subclasses.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
async def aclose(self) -> None:
|
||||
"""Close this resource, possibly blocking.
|
||||
|
||||
IMPORTANT: This method may block in order to perform a "graceful"
|
||||
shutdown. But, if this fails, then it still *must* close any
|
||||
underlying resources before returning. An error from this method
|
||||
indicates a failure to achieve grace, *not* a failure to close the
|
||||
connection.
|
||||
|
||||
For example, suppose we call :meth:`aclose` on a TLS-encrypted
|
||||
connection. This requires sending a "goodbye" message; but if the peer
|
||||
has become non-responsive, then our attempt to send this message might
|
||||
block forever, and eventually time out and be cancelled. In this case
|
||||
the :meth:`aclose` method on :class:`~trio.SSLStream` will
|
||||
immediately close the underlying transport stream using
|
||||
:func:`trio.aclose_forcefully` before raising :exc:`~trio.Cancelled`.
|
||||
|
||||
If the resource is already closed, then this method should silently
|
||||
succeed.
|
||||
|
||||
Once this method completes, any other pending or future operations on
|
||||
this resource should generally raise :exc:`~trio.ClosedResourceError`,
|
||||
unless there's a good reason to do otherwise.
|
||||
|
||||
See also: :func:`trio.aclose_forcefully`.
|
||||
|
||||
"""
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
|
||||
class SendStream(AsyncResource):
|
||||
"""A standard interface for sending data on a byte stream.
|
||||
|
||||
The underlying stream may be unidirectional, or bidirectional. If it's
|
||||
bidirectional, then you probably want to also implement
|
||||
:class:`ReceiveStream`, which makes your object a :class:`Stream`.
|
||||
|
||||
:class:`SendStream` objects also implement the :class:`AsyncResource`
|
||||
interface, so they can be closed by calling :meth:`~AsyncResource.aclose`
|
||||
or using an ``async with`` block.
|
||||
|
||||
If you want to send Python objects rather than raw bytes, see
|
||||
:class:`SendChannel`.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
|
||||
"""Sends the given data through the stream, blocking if necessary.
|
||||
|
||||
Args:
|
||||
data (bytes, bytearray, or memoryview): The data to send.
|
||||
|
||||
Raises:
|
||||
trio.BusyResourceError: if another task is already executing a
|
||||
:meth:`send_all`, :meth:`wait_send_all_might_not_block`, or
|
||||
:meth:`HalfCloseableStream.send_eof` on this stream.
|
||||
trio.BrokenResourceError: if something has gone wrong, and the stream
|
||||
is broken.
|
||||
trio.ClosedResourceError: if you previously closed this stream
|
||||
object, or if another task closes this stream object while
|
||||
:meth:`send_all` is running.
|
||||
|
||||
Most low-level operations in Trio provide a guarantee: if they raise
|
||||
:exc:`trio.Cancelled`, this means that they had no effect, so the
|
||||
system remains in a known state. This is **not true** for
|
||||
:meth:`send_all`. If this operation raises :exc:`trio.Cancelled` (or
|
||||
any other exception for that matter), then it may have sent some, all,
|
||||
or none of the requested data, and there is no way to know which.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
"""Block until it's possible that :meth:`send_all` might not block.
|
||||
|
||||
This method may return early: it's possible that after it returns,
|
||||
:meth:`send_all` will still block. (In the worst case, if no better
|
||||
implementation is available, then it might always return immediately
|
||||
without blocking. It's nice to do better than that when possible,
|
||||
though.)
|
||||
|
||||
This method **must not** return *late*: if it's possible for
|
||||
:meth:`send_all` to complete without blocking, then it must
|
||||
return. When implementing it, err on the side of returning early.
|
||||
|
||||
Raises:
|
||||
trio.BusyResourceError: if another task is already executing a
|
||||
:meth:`send_all`, :meth:`wait_send_all_might_not_block`, or
|
||||
:meth:`HalfCloseableStream.send_eof` on this stream.
|
||||
trio.BrokenResourceError: if something has gone wrong, and the stream
|
||||
is broken.
|
||||
trio.ClosedResourceError: if you previously closed this stream
|
||||
object, or if another task closes this stream object while
|
||||
:meth:`wait_send_all_might_not_block` is running.
|
||||
|
||||
Note:
|
||||
|
||||
This method is intended to aid in implementing protocols that want
|
||||
to delay choosing which data to send until the last moment. E.g.,
|
||||
suppose you're working on an implementation of a remote display server
|
||||
like `VNC
|
||||
<https://en.wikipedia.org/wiki/Virtual_Network_Computing>`__, and
|
||||
the network connection is currently backed up so that if you call
|
||||
:meth:`send_all` now then it will sit for 0.5 seconds before actually
|
||||
sending anything. In this case it doesn't make sense to take a
|
||||
screenshot, then wait 0.5 seconds, and then send it, because the
|
||||
screen will keep changing while you wait; it's better to wait 0.5
|
||||
seconds, then take the screenshot, and then send it, because this
|
||||
way the data you deliver will be more
|
||||
up-to-date. Using :meth:`wait_send_all_might_not_block` makes it
|
||||
possible to implement the better strategy.
|
||||
|
||||
If you use this method, you might also want to read up on
|
||||
``TCP_NOTSENT_LOWAT``.
|
||||
|
||||
Further reading:
|
||||
|
||||
* `Prioritization Only Works When There's Pending Data to Prioritize
|
||||
<https://insouciant.org/tech/prioritization-only-works-when-theres-pending-data-to-prioritize/>`__
|
||||
|
||||
* WWDC 2015: Your App and Next Generation Networks: `slides
|
||||
<http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1>`__,
|
||||
`video and transcript
|
||||
<https://developer.apple.com/videos/play/wwdc2015/719/>`__
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ReceiveStream(AsyncResource):
|
||||
"""A standard interface for receiving data on a byte stream.
|
||||
|
||||
The underlying stream may be unidirectional, or bidirectional. If it's
|
||||
bidirectional, then you probably want to also implement
|
||||
:class:`SendStream`, which makes your object a :class:`Stream`.
|
||||
|
||||
:class:`ReceiveStream` objects also implement the :class:`AsyncResource`
|
||||
interface, so they can be closed by calling :meth:`~AsyncResource.aclose`
|
||||
or using an ``async with`` block.
|
||||
|
||||
If you want to receive Python objects rather than raw bytes, see
|
||||
:class:`ReceiveChannel`.
|
||||
|
||||
`ReceiveStream` objects can be used in ``async for`` loops. Each iteration
|
||||
will produce an arbitrary sized chunk of bytes, like calling
|
||||
`receive_some` with no arguments. Every chunk will contain at least one
|
||||
byte, and the loop automatically exits when reaching end-of-file.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
|
||||
"""Wait until there is data available on this stream, and then return
|
||||
some of it.
|
||||
|
||||
A return value of ``b""`` (an empty bytestring) indicates that the
|
||||
stream has reached end-of-file. Implementations should be careful that
|
||||
they return ``b""`` if, and only if, the stream has reached
|
||||
end-of-file!
|
||||
|
||||
Args:
|
||||
max_bytes (int): The maximum number of bytes to return. Must be
|
||||
greater than zero. Optional; if omitted, then the stream object
|
||||
is free to pick a reasonable default.
|
||||
|
||||
Returns:
|
||||
bytes or bytearray: The data received.
|
||||
|
||||
Raises:
|
||||
trio.BusyResourceError: if two tasks attempt to call
|
||||
:meth:`receive_some` on the same stream at the same time.
|
||||
trio.BrokenResourceError: if something has gone wrong, and the stream
|
||||
is broken.
|
||||
trio.ClosedResourceError: if you previously closed this stream
|
||||
object, or if another task closes this stream object while
|
||||
:meth:`receive_some` is running.
|
||||
|
||||
"""
|
||||
|
||||
def __aiter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> bytes | bytearray:
|
||||
data = await self.receive_some()
|
||||
if not data:
|
||||
raise StopAsyncIteration
|
||||
return data
|
||||
|
||||
|
||||
class Stream(SendStream, ReceiveStream):
|
||||
"""A standard interface for interacting with bidirectional byte streams.
|
||||
|
||||
A :class:`Stream` is an object that implements both the
|
||||
:class:`SendStream` and :class:`ReceiveStream` interfaces.
|
||||
|
||||
If implementing this interface, you should consider whether you can go one
|
||||
step further and implement :class:`HalfCloseableStream`.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class HalfCloseableStream(Stream):
|
||||
"""This interface extends :class:`Stream` to also allow closing the send
|
||||
part of the stream without closing the receive part.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
async def send_eof(self) -> None:
|
||||
"""Send an end-of-file indication on this stream, if possible.
|
||||
|
||||
The difference between :meth:`send_eof` and
|
||||
:meth:`~AsyncResource.aclose` is that :meth:`send_eof` is a
|
||||
*unidirectional* end-of-file indication. After you call this method,
|
||||
you shouldn't try sending any more data on this stream, and your
|
||||
remote peer should receive an end-of-file indication (eventually,
|
||||
after receiving all the data you sent before that). But, they may
|
||||
continue to send data to you, and you can continue to receive it by
|
||||
calling :meth:`~ReceiveStream.receive_some`. You can think of it as
|
||||
calling :meth:`~AsyncResource.aclose` on just the
|
||||
:class:`SendStream` "half" of the stream object (and in fact that's
|
||||
literally how :class:`trio.StapledStream` implements it).
|
||||
|
||||
Examples:
|
||||
|
||||
* On a socket, this corresponds to ``shutdown(..., SHUT_WR)`` (`man
|
||||
page <https://linux.die.net/man/2/shutdown>`__).
|
||||
|
||||
* The SSH protocol provides the ability to multiplex bidirectional
|
||||
"channels" on top of a single encrypted connection. A Trio
|
||||
implementation of SSH could expose these channels as
|
||||
:class:`HalfCloseableStream` objects, and calling :meth:`send_eof`
|
||||
would send an ``SSH_MSG_CHANNEL_EOF`` request (see `RFC 4254 §5.3
|
||||
<https://tools.ietf.org/html/rfc4254#section-5.3>`__).
|
||||
|
||||
* On an SSL/TLS-encrypted connection, the protocol doesn't provide any
|
||||
way to do a unidirectional shutdown without closing the connection
|
||||
entirely, so :class:`~trio.SSLStream` implements
|
||||
:class:`Stream`, not :class:`HalfCloseableStream`.
|
||||
|
||||
If an EOF has already been sent, then this method should silently
|
||||
succeed.
|
||||
|
||||
Raises:
|
||||
trio.BusyResourceError: if another task is already executing a
|
||||
:meth:`~SendStream.send_all`,
|
||||
:meth:`~SendStream.wait_send_all_might_not_block`, or
|
||||
:meth:`send_eof` on this stream.
|
||||
trio.BrokenResourceError: if something has gone wrong, and the stream
|
||||
is broken.
|
||||
trio.ClosedResourceError: if you previously closed this stream
|
||||
object, or if another task closes this stream object while
|
||||
:meth:`send_eof` is running.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# A regular invariant generic type
|
||||
T = TypeVar("T")
|
||||
|
||||
# The type of object produced by a ReceiveChannel (covariant because
|
||||
# ReceiveChannel[Derived] can be passed to someone expecting
|
||||
# ReceiveChannel[Base])
|
||||
ReceiveType = TypeVar("ReceiveType", covariant=True)
|
||||
|
||||
# The type of object accepted by a SendChannel (contravariant because
|
||||
# SendChannel[Base] can be passed to someone expecting
|
||||
# SendChannel[Derived])
|
||||
SendType = TypeVar("SendType", contravariant=True)
|
||||
|
||||
# The type of object produced by a Listener (covariant plus must be
|
||||
# an AsyncResource)
|
||||
T_resource = TypeVar("T_resource", bound=AsyncResource, covariant=True)
|
||||
|
||||
|
||||
class Listener(AsyncResource, Generic[T_resource]):
|
||||
"""A standard interface for listening for incoming connections.
|
||||
|
||||
:class:`Listener` objects also implement the :class:`AsyncResource`
|
||||
interface, so they can be closed by calling :meth:`~AsyncResource.aclose`
|
||||
or using an ``async with`` block.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
async def accept(self) -> T_resource:
|
||||
"""Wait until an incoming connection arrives, and then return it.
|
||||
|
||||
Returns:
|
||||
AsyncResource: An object representing the incoming connection. In
|
||||
practice this is generally some kind of :class:`Stream`,
|
||||
but in principle you could also define a :class:`Listener` that
|
||||
returned, say, channel objects.
|
||||
|
||||
Raises:
|
||||
trio.BusyResourceError: if two tasks attempt to call
|
||||
:meth:`accept` on the same listener at the same time.
|
||||
trio.ClosedResourceError: if you previously closed this listener
|
||||
object, or if another task closes this listener object while
|
||||
:meth:`accept` is running.
|
||||
|
||||
Listeners don't generally raise :exc:`~trio.BrokenResourceError`,
|
||||
because for listeners there is no general condition of "the
|
||||
network/remote peer broke the connection" that can be handled in a
|
||||
generic way, like there is for streams. Other errors *can* occur and
|
||||
be raised from :meth:`accept` – for example, if you run out of file
|
||||
descriptors then you might get an :class:`OSError` with its errno set
|
||||
to ``EMFILE``.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class SendChannel(AsyncResource, Generic[SendType]):
|
||||
"""A standard interface for sending Python objects to some receiver.
|
||||
|
||||
`SendChannel` objects also implement the `AsyncResource` interface, so
|
||||
they can be closed by calling `~AsyncResource.aclose` or using an ``async
|
||||
with`` block.
|
||||
|
||||
If you want to send raw bytes rather than Python objects, see
|
||||
`SendStream`.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, value: SendType) -> None:
|
||||
"""Attempt to send an object through the channel, blocking if necessary.
|
||||
|
||||
Args:
|
||||
value (object): The object to send.
|
||||
|
||||
Raises:
|
||||
trio.BrokenResourceError: if something has gone wrong, and the
|
||||
channel is broken. For example, you may get this if the receiver
|
||||
has already been closed.
|
||||
trio.ClosedResourceError: if you previously closed this
|
||||
:class:`SendChannel` object, or if another task closes it while
|
||||
:meth:`send` is running.
|
||||
trio.BusyResourceError: some channels allow multiple tasks to call
|
||||
`send` at the same time, but others don't. If you try to call
|
||||
`send` simultaneously from multiple tasks on a channel that
|
||||
doesn't support it, then you can get `~trio.BusyResourceError`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ReceiveChannel(AsyncResource, Generic[ReceiveType]):
|
||||
"""A standard interface for receiving Python objects from some sender.
|
||||
|
||||
You can iterate over a :class:`ReceiveChannel` using an ``async for``
|
||||
loop::
|
||||
|
||||
async for value in receive_channel:
|
||||
...
|
||||
|
||||
This is equivalent to calling :meth:`receive` repeatedly. The loop exits
|
||||
without error when `receive` raises `~trio.EndOfChannel`.
|
||||
|
||||
`ReceiveChannel` objects also implement the `AsyncResource` interface, so
|
||||
they can be closed by calling `~AsyncResource.aclose` or using an ``async
|
||||
with`` block.
|
||||
|
||||
If you want to receive raw bytes rather than Python objects, see
|
||||
`ReceiveStream`.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> ReceiveType:
|
||||
"""Attempt to receive an incoming object, blocking if necessary.
|
||||
|
||||
Returns:
|
||||
object: Whatever object was received.
|
||||
|
||||
Raises:
|
||||
trio.EndOfChannel: if the sender has been closed cleanly, and no
|
||||
more objects are coming. This is not an error condition.
|
||||
trio.ClosedResourceError: if you previously closed this
|
||||
:class:`ReceiveChannel` object.
|
||||
trio.BrokenResourceError: if something has gone wrong, and the
|
||||
channel is broken.
|
||||
trio.BusyResourceError: some channels allow multiple tasks to call
|
||||
`receive` at the same time, but others don't. If you try to call
|
||||
`receive` simultaneously from multiple tasks on a channel that
|
||||
doesn't support it, then you can get `~trio.BusyResourceError`.
|
||||
|
||||
"""
|
||||
|
||||
def __aiter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> ReceiveType:
|
||||
try:
|
||||
return await self.receive()
|
||||
except trio.EndOfChannel:
|
||||
raise StopAsyncIteration from None
|
||||
|
||||
|
||||
# these are necessary for Sphinx's :show-inheritance: with type args.
|
||||
# (this should be removed if possible)
|
||||
# see: https://github.com/python/cpython/issues/123250
|
||||
SendChannel.__module__ = SendChannel.__module__.replace("_abc", "abc")
|
||||
ReceiveChannel.__module__ = ReceiveChannel.__module__.replace("_abc", "abc")
|
||||
Listener.__module__ = Listener.__module__.replace("_abc", "abc")
|
||||
|
||||
|
||||
class Channel(SendChannel[T], ReceiveChannel[T]):
|
||||
"""A standard interface for interacting with bidirectional channels.
|
||||
|
||||
A `Channel` is an object that implements both the `SendChannel` and
|
||||
`ReceiveChannel` interfaces, so you can both send and receive objects.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
# see above
|
||||
Channel.__module__ = Channel.__module__.replace("_abc", "abc")
|
||||
@@ -0,0 +1,444 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict, deque
|
||||
from math import inf
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Generic,
|
||||
Tuple, # only needed for typechecking on <3.9
|
||||
)
|
||||
|
||||
import attrs
|
||||
from outcome import Error, Value
|
||||
|
||||
import trio
|
||||
|
||||
from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T
|
||||
from ._core import Abort, RaiseCancelT, Task, enable_ki_protection
|
||||
from ._util import NoPublicConstructor, final, generic_function
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
def _open_memory_channel(
|
||||
max_buffer_size: int | float, # noqa: PYI041
|
||||
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
|
||||
"""Open a channel for passing objects between tasks within a process.
|
||||
|
||||
Memory channels are lightweight, cheap to allocate, and entirely
|
||||
in-memory. They don't involve any operating-system resources, or any kind
|
||||
of serialization. They just pass Python objects directly between tasks
|
||||
(with a possible stop in an internal buffer along the way).
|
||||
|
||||
Channel objects can be closed by calling `~trio.abc.AsyncResource.aclose`
|
||||
or using ``async with``. They are *not* automatically closed when garbage
|
||||
collected. Closing memory channels isn't mandatory, but it is generally a
|
||||
good idea, because it helps avoid situations where tasks get stuck waiting
|
||||
on a channel when there's no-one on the other side. See
|
||||
:ref:`channel-shutdown` for details.
|
||||
|
||||
Memory channel operations are all atomic with respect to
|
||||
cancellation, either `~trio.abc.ReceiveChannel.receive` will
|
||||
successfully return an object, or it will raise :exc:`Cancelled`
|
||||
while leaving the channel unchanged.
|
||||
|
||||
Args:
|
||||
max_buffer_size (int or math.inf): The maximum number of items that can
|
||||
be buffered in the channel before :meth:`~trio.abc.SendChannel.send`
|
||||
blocks. Choosing a sensible value here is important to ensure that
|
||||
backpressure is communicated promptly and avoid unnecessary latency;
|
||||
see :ref:`channel-buffering` for more details. If in doubt, use 0.
|
||||
|
||||
Returns:
|
||||
A pair ``(send_channel, receive_channel)``. If you have
|
||||
trouble remembering which order these go in, remember: data
|
||||
flows from left → right.
|
||||
|
||||
In addition to the standard channel methods, all memory channel objects
|
||||
provide a ``statistics()`` method, which returns an object with the
|
||||
following fields:
|
||||
|
||||
* ``current_buffer_used``: The number of items currently stored in the
|
||||
channel buffer.
|
||||
* ``max_buffer_size``: The maximum number of items allowed in the buffer,
|
||||
as passed to :func:`open_memory_channel`.
|
||||
* ``open_send_channels``: The number of open
|
||||
:class:`MemorySendChannel` endpoints pointing to this channel.
|
||||
Initially 1, but can be increased by
|
||||
:meth:`MemorySendChannel.clone`.
|
||||
* ``open_receive_channels``: Likewise, but for open
|
||||
:class:`MemoryReceiveChannel` endpoints.
|
||||
* ``tasks_waiting_send``: The number of tasks blocked in ``send`` on this
|
||||
channel (summing over all clones).
|
||||
* ``tasks_waiting_receive``: The number of tasks blocked in ``receive`` on
|
||||
this channel (summing over all clones).
|
||||
|
||||
"""
|
||||
if max_buffer_size != inf and not isinstance(max_buffer_size, int):
|
||||
raise TypeError("max_buffer_size must be an integer or math.inf")
|
||||
if max_buffer_size < 0:
|
||||
raise ValueError("max_buffer_size must be >= 0")
|
||||
state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size)
|
||||
return (
|
||||
MemorySendChannel[T]._create(state),
|
||||
MemoryReceiveChannel[T]._create(state),
|
||||
)
|
||||
|
||||
|
||||
# This workaround requires python3.9+, once older python versions are not supported
|
||||
# or there's a better way of achieving type-checking on a generic factory function,
|
||||
# it could replace the normal function header
|
||||
if TYPE_CHECKING:
|
||||
# written as a class so you can say open_memory_channel[int](5)
|
||||
# Need to use Tuple instead of tuple due to CI check running on 3.8
|
||||
class open_memory_channel(Tuple["MemorySendChannel[T]", "MemoryReceiveChannel[T]"]):
|
||||
def __new__( # type: ignore[misc] # "must return a subtype"
|
||||
cls,
|
||||
max_buffer_size: int | float, # noqa: PYI041
|
||||
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
|
||||
return _open_memory_channel(max_buffer_size)
|
||||
|
||||
def __init__(self, max_buffer_size: int | float): # noqa: PYI041
|
||||
...
|
||||
|
||||
else:
|
||||
# apply the generic_function decorator to make open_memory_channel indexable
|
||||
# so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime
|
||||
open_memory_channel = generic_function(_open_memory_channel)
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class MemoryChannelStatistics:
|
||||
current_buffer_used: int
|
||||
max_buffer_size: int | float
|
||||
open_send_channels: int
|
||||
open_receive_channels: int
|
||||
tasks_waiting_send: int
|
||||
tasks_waiting_receive: int
|
||||
|
||||
|
||||
@attrs.define
|
||||
class MemoryChannelState(Generic[T]):
|
||||
max_buffer_size: int | float
|
||||
data: deque[T] = attrs.Factory(deque)
|
||||
# Counts of open endpoints using this state
|
||||
open_send_channels: int = 0
|
||||
open_receive_channels: int = 0
|
||||
# {task: value}
|
||||
send_tasks: OrderedDict[Task, T] = attrs.Factory(OrderedDict)
|
||||
# {task: None}
|
||||
receive_tasks: OrderedDict[Task, None] = attrs.Factory(OrderedDict)
|
||||
|
||||
def statistics(self) -> MemoryChannelStatistics:
|
||||
return MemoryChannelStatistics(
|
||||
current_buffer_used=len(self.data),
|
||||
max_buffer_size=self.max_buffer_size,
|
||||
open_send_channels=self.open_send_channels,
|
||||
open_receive_channels=self.open_receive_channels,
|
||||
tasks_waiting_send=len(self.send_tasks),
|
||||
tasks_waiting_receive=len(self.receive_tasks),
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@attrs.define(eq=False, repr=False, slots=False)
|
||||
class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor):
|
||||
_state: MemoryChannelState[SendType]
|
||||
_closed: bool = False
|
||||
# This is just the tasks waiting on *this* object. As compared to
|
||||
# self._state.send_tasks, which includes tasks from this object and
|
||||
# all clones.
|
||||
_tasks: set[Task] = attrs.Factory(set)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
self._state.open_send_channels += 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<send channel at {id(self):#x}, using buffer at {id(self._state):#x}>"
|
||||
|
||||
def statistics(self) -> MemoryChannelStatistics:
|
||||
"""Returns a `MemoryChannelStatistics` for the memory channel this is
|
||||
associated with."""
|
||||
# XX should we also report statistics specific to this object?
|
||||
return self._state.statistics()
|
||||
|
||||
@enable_ki_protection
|
||||
def send_nowait(self, value: SendType) -> None:
|
||||
"""Like `~trio.abc.SendChannel.send`, but if the channel's buffer is
|
||||
full, raises `WouldBlock` instead of blocking.
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise trio.ClosedResourceError
|
||||
if self._state.open_receive_channels == 0:
|
||||
raise trio.BrokenResourceError
|
||||
if self._state.receive_tasks:
|
||||
assert not self._state.data
|
||||
task, _ = self._state.receive_tasks.popitem(last=False)
|
||||
task.custom_sleep_data._tasks.remove(task)
|
||||
trio.lowlevel.reschedule(task, Value(value))
|
||||
elif len(self._state.data) < self._state.max_buffer_size:
|
||||
self._state.data.append(value)
|
||||
else:
|
||||
raise trio.WouldBlock
|
||||
|
||||
@enable_ki_protection
|
||||
async def send(self, value: SendType) -> None:
|
||||
"""See `SendChannel.send <trio.abc.SendChannel.send>`.
|
||||
|
||||
Memory channels allow multiple tasks to call `send` at the same time.
|
||||
|
||||
"""
|
||||
await trio.lowlevel.checkpoint_if_cancelled()
|
||||
try:
|
||||
self.send_nowait(value)
|
||||
except trio.WouldBlock:
|
||||
pass
|
||||
else:
|
||||
await trio.lowlevel.cancel_shielded_checkpoint()
|
||||
return
|
||||
|
||||
task = trio.lowlevel.current_task()
|
||||
self._tasks.add(task)
|
||||
self._state.send_tasks[task] = value
|
||||
task.custom_sleep_data = self
|
||||
|
||||
def abort_fn(_: RaiseCancelT) -> Abort:
|
||||
self._tasks.remove(task)
|
||||
del self._state.send_tasks[task]
|
||||
return trio.lowlevel.Abort.SUCCEEDED
|
||||
|
||||
await trio.lowlevel.wait_task_rescheduled(abort_fn)
|
||||
|
||||
# Return type must be stringified or use a TypeVar
|
||||
@enable_ki_protection
|
||||
def clone(self) -> MemorySendChannel[SendType]:
|
||||
"""Clone this send channel object.
|
||||
|
||||
This returns a new `MemorySendChannel` object, which acts as a
|
||||
duplicate of the original: sending on the new object does exactly the
|
||||
same thing as sending on the old object. (If you're familiar with
|
||||
`os.dup`, then this is a similar idea.)
|
||||
|
||||
However, closing one of the objects does not close the other, and
|
||||
receivers don't get `EndOfChannel` until *all* clones have been
|
||||
closed.
|
||||
|
||||
This is useful for communication patterns that involve multiple
|
||||
producers all sending objects to the same destination. If you give
|
||||
each producer its own clone of the `MemorySendChannel`, and then make
|
||||
sure to close each `MemorySendChannel` when it's finished, receivers
|
||||
will automatically get notified when all producers are finished. See
|
||||
:ref:`channel-mpmc` for examples.
|
||||
|
||||
Raises:
|
||||
trio.ClosedResourceError: if you already closed this
|
||||
`MemorySendChannel` object.
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise trio.ClosedResourceError
|
||||
return MemorySendChannel._create(self._state)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
@enable_ki_protection
|
||||
def close(self) -> None:
|
||||
"""Close this send channel object synchronously.
|
||||
|
||||
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
|
||||
Memory channels can also be closed synchronously. This has the same
|
||||
effect on the channel and other tasks using it, but `close` is not a
|
||||
trio checkpoint. This simplifies cleaning up in cancelled tasks.
|
||||
|
||||
Using ``with send_channel:`` will close the channel object on leaving
|
||||
the with block.
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
for task in self._tasks:
|
||||
trio.lowlevel.reschedule(task, Error(trio.ClosedResourceError()))
|
||||
del self._state.send_tasks[task]
|
||||
self._tasks.clear()
|
||||
self._state.open_send_channels -= 1
|
||||
if self._state.open_send_channels == 0:
|
||||
assert not self._state.send_tasks
|
||||
for task in self._state.receive_tasks:
|
||||
task.custom_sleep_data._tasks.remove(task)
|
||||
trio.lowlevel.reschedule(task, Error(trio.EndOfChannel()))
|
||||
self._state.receive_tasks.clear()
|
||||
|
||||
@enable_ki_protection
|
||||
async def aclose(self) -> None:
|
||||
"""Close this send channel object asynchronously.
|
||||
|
||||
See `MemorySendChannel.close`."""
|
||||
self.close()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
@final
|
||||
@attrs.define(eq=False, repr=False, slots=False)
|
||||
class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor):
|
||||
_state: MemoryChannelState[ReceiveType]
|
||||
_closed: bool = False
|
||||
_tasks: set[trio._core._run.Task] = attrs.Factory(set)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
self._state.open_receive_channels += 1
|
||||
|
||||
def statistics(self) -> MemoryChannelStatistics:
|
||||
"""Returns a `MemoryChannelStatistics` for the memory channel this is
|
||||
associated with."""
|
||||
return self._state.statistics()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<receive channel at {id(self):#x}, using buffer at {id(self._state):#x}>"
|
||||
)
|
||||
|
||||
@enable_ki_protection
|
||||
def receive_nowait(self) -> ReceiveType:
|
||||
"""Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing
|
||||
ready to receive, raises `WouldBlock` instead of blocking.
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise trio.ClosedResourceError
|
||||
if self._state.send_tasks:
|
||||
task, value = self._state.send_tasks.popitem(last=False)
|
||||
task.custom_sleep_data._tasks.remove(task)
|
||||
trio.lowlevel.reschedule(task)
|
||||
self._state.data.append(value)
|
||||
# Fall through
|
||||
if self._state.data:
|
||||
return self._state.data.popleft()
|
||||
if not self._state.open_send_channels:
|
||||
raise trio.EndOfChannel
|
||||
raise trio.WouldBlock
|
||||
|
||||
@enable_ki_protection
|
||||
async def receive(self) -> ReceiveType:
|
||||
"""See `ReceiveChannel.receive <trio.abc.ReceiveChannel.receive>`.
|
||||
|
||||
Memory channels allow multiple tasks to call `receive` at the same
|
||||
time. The first task will get the first item sent, the second task
|
||||
will get the second item sent, and so on.
|
||||
|
||||
"""
|
||||
await trio.lowlevel.checkpoint_if_cancelled()
|
||||
try:
|
||||
value = self.receive_nowait()
|
||||
except trio.WouldBlock:
|
||||
pass
|
||||
else:
|
||||
await trio.lowlevel.cancel_shielded_checkpoint()
|
||||
return value
|
||||
|
||||
task = trio.lowlevel.current_task()
|
||||
self._tasks.add(task)
|
||||
self._state.receive_tasks[task] = None
|
||||
task.custom_sleep_data = self
|
||||
|
||||
def abort_fn(_: RaiseCancelT) -> Abort:
|
||||
self._tasks.remove(task)
|
||||
del self._state.receive_tasks[task]
|
||||
return trio.lowlevel.Abort.SUCCEEDED
|
||||
|
||||
# Not strictly guaranteed to return ReceiveType, but will do so unless
|
||||
# you intentionally reschedule with a bad value.
|
||||
return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore[no-any-return]
|
||||
|
||||
@enable_ki_protection
|
||||
def clone(self) -> MemoryReceiveChannel[ReceiveType]:
|
||||
"""Clone this receive channel object.
|
||||
|
||||
This returns a new `MemoryReceiveChannel` object, which acts as a
|
||||
duplicate of the original: receiving on the new object does exactly
|
||||
the same thing as receiving on the old object.
|
||||
|
||||
However, closing one of the objects does not close the other, and the
|
||||
underlying channel is not closed until all clones are closed. (If
|
||||
you're familiar with `os.dup`, then this is a similar idea.)
|
||||
|
||||
This is useful for communication patterns that involve multiple
|
||||
consumers all receiving objects from the same underlying channel. See
|
||||
:ref:`channel-mpmc` for examples.
|
||||
|
||||
.. warning:: The clones all share the same underlying channel.
|
||||
Whenever a clone :meth:`receive`\\s a value, it is removed from the
|
||||
channel and the other clones do *not* receive that value. If you
|
||||
want to send multiple copies of the same stream of values to
|
||||
multiple destinations, like :func:`itertools.tee`, then you need to
|
||||
find some other solution; this method does *not* do that.
|
||||
|
||||
Raises:
|
||||
trio.ClosedResourceError: if you already closed this
|
||||
`MemoryReceiveChannel` object.
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise trio.ClosedResourceError
|
||||
return MemoryReceiveChannel._create(self._state)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
@enable_ki_protection
|
||||
def close(self) -> None:
|
||||
"""Close this receive channel object synchronously.
|
||||
|
||||
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
|
||||
Memory channels can also be closed synchronously. This has the same
|
||||
effect on the channel and other tasks using it, but `close` is not a
|
||||
trio checkpoint. This simplifies cleaning up in cancelled tasks.
|
||||
|
||||
Using ``with receive_channel:`` will close the channel object on
|
||||
leaving the with block.
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
for task in self._tasks:
|
||||
trio.lowlevel.reschedule(task, Error(trio.ClosedResourceError()))
|
||||
del self._state.receive_tasks[task]
|
||||
self._tasks.clear()
|
||||
self._state.open_receive_channels -= 1
|
||||
if self._state.open_receive_channels == 0:
|
||||
assert not self._state.receive_tasks
|
||||
for task in self._state.send_tasks:
|
||||
task.custom_sleep_data._tasks.remove(task)
|
||||
trio.lowlevel.reschedule(task, Error(trio.BrokenResourceError()))
|
||||
self._state.send_tasks.clear()
|
||||
self._state.data.clear()
|
||||
|
||||
@enable_ki_protection
|
||||
async def aclose(self) -> None:
|
||||
"""Close this receive channel object asynchronously.
|
||||
|
||||
See `MemoryReceiveChannel.close`."""
|
||||
self.close()
|
||||
await trio.lowlevel.checkpoint()
|
||||
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
This namespace represents the core functionality that has to be built-in
|
||||
and deal with private internal data structures. Things in this namespace
|
||||
are publicly available in either trio, trio.lowlevel, or trio.testing.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
from ._entry_queue import TrioToken
|
||||
from ._exceptions import (
|
||||
BrokenResourceError,
|
||||
BusyResourceError,
|
||||
Cancelled,
|
||||
ClosedResourceError,
|
||||
EndOfChannel,
|
||||
RunFinishedError,
|
||||
TrioInternalError,
|
||||
WouldBlock,
|
||||
)
|
||||
from ._ki import currently_ki_protected, disable_ki_protection, enable_ki_protection
|
||||
from ._local import RunVar, RunVarToken
|
||||
from ._mock_clock import MockClock
|
||||
from ._parking_lot import (
|
||||
ParkingLot,
|
||||
ParkingLotStatistics,
|
||||
add_parking_lot_breaker,
|
||||
remove_parking_lot_breaker,
|
||||
)
|
||||
|
||||
# Imports that always exist
|
||||
from ._run import (
|
||||
TASK_STATUS_IGNORED,
|
||||
CancelScope,
|
||||
Nursery,
|
||||
RunStatistics,
|
||||
Task,
|
||||
TaskStatus,
|
||||
add_instrument,
|
||||
checkpoint,
|
||||
checkpoint_if_cancelled,
|
||||
current_clock,
|
||||
current_effective_deadline,
|
||||
current_root_task,
|
||||
current_statistics,
|
||||
current_task,
|
||||
current_time,
|
||||
current_trio_token,
|
||||
notify_closing,
|
||||
open_nursery,
|
||||
remove_instrument,
|
||||
reschedule,
|
||||
run,
|
||||
spawn_system_task,
|
||||
start_guest_run,
|
||||
wait_all_tasks_blocked,
|
||||
wait_readable,
|
||||
wait_writable,
|
||||
)
|
||||
from ._thread_cache import start_thread_soon
|
||||
|
||||
# Has to come after _run to resolve a circular import
|
||||
from ._traps import (
|
||||
Abort,
|
||||
RaiseCancelT,
|
||||
cancel_shielded_checkpoint,
|
||||
permanently_detach_coroutine_object,
|
||||
reattach_detached_coroutine_object,
|
||||
temporarily_detach_coroutine_object,
|
||||
wait_task_rescheduled,
|
||||
)
|
||||
from ._unbounded_queue import UnboundedQueue, UnboundedQueueStatistics
|
||||
|
||||
# Windows imports
|
||||
if sys.platform == "win32":
|
||||
from ._run import (
|
||||
current_iocp,
|
||||
monitor_completion_key,
|
||||
readinto_overlapped,
|
||||
register_with_iocp,
|
||||
wait_overlapped,
|
||||
write_overlapped,
|
||||
)
|
||||
# Kqueue imports
|
||||
elif sys.platform != "linux" and sys.platform != "win32":
|
||||
from ._run import current_kqueue, monitor_kevent, wait_kevent
|
||||
|
||||
del sys # It would be better to import sys as _sys, but mypy does not understand it
|
||||
@@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, NoReturn
|
||||
|
||||
import attrs
|
||||
|
||||
from .. import _core
|
||||
from .._util import name_asyncgen
|
||||
from . import _run
|
||||
|
||||
# Used to log exceptions in async generator finalizers
|
||||
ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import AsyncGeneratorType
|
||||
from typing import Set
|
||||
|
||||
_WEAK_ASYNC_GEN_SET = weakref.WeakSet[AsyncGeneratorType[object, NoReturn]]
|
||||
_ASYNC_GEN_SET = Set[AsyncGeneratorType[object, NoReturn]]
|
||||
else:
|
||||
_WEAK_ASYNC_GEN_SET = weakref.WeakSet
|
||||
_ASYNC_GEN_SET = set
|
||||
|
||||
|
||||
@attrs.define(eq=False)
|
||||
class AsyncGenerators:
|
||||
# Async generators are added to this set when first iterated. Any
|
||||
# left after the main task exits will be closed before trio.run()
|
||||
# returns. During most of the run, this is a WeakSet so GC works.
|
||||
# During shutdown, when we're finalizing all the remaining
|
||||
# asyncgens after the system nursery has been closed, it's a
|
||||
# regular set so we don't have to deal with GC firing at
|
||||
# unexpected times.
|
||||
alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attrs.Factory(_WEAK_ASYNC_GEN_SET)
|
||||
|
||||
# This collects async generators that get garbage collected during
|
||||
# the one-tick window between the system nursery closing and the
|
||||
# init task starting end-of-run asyncgen finalization.
|
||||
trailing_needs_finalize: _ASYNC_GEN_SET = attrs.Factory(_ASYNC_GEN_SET)
|
||||
|
||||
prev_hooks: sys._asyncgen_hooks = attrs.field(init=False)
|
||||
|
||||
def install_hooks(self, runner: _run.Runner) -> None:
|
||||
def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None:
|
||||
if hasattr(_run.GLOBAL_RUN_CONTEXT, "task"):
|
||||
self.alive.add(agen)
|
||||
else:
|
||||
# An async generator first iterated outside of a Trio
|
||||
# task doesn't belong to Trio. Probably we're in guest
|
||||
# mode and the async generator belongs to our host.
|
||||
# The locals dictionary is the only good place to
|
||||
# remember this fact, at least until
|
||||
# https://bugs.python.org/issue40916 is implemented.
|
||||
agen.ag_frame.f_locals["@trio_foreign_asyncgen"] = True
|
||||
if self.prev_hooks.firstiter is not None:
|
||||
self.prev_hooks.firstiter(agen)
|
||||
|
||||
def finalize_in_trio_context(
|
||||
agen: AsyncGeneratorType[object, NoReturn],
|
||||
agen_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
runner.spawn_system_task(
|
||||
self._finalize_one,
|
||||
agen,
|
||||
agen_name,
|
||||
name=f"close asyncgen {agen_name} (abandoned)",
|
||||
)
|
||||
except RuntimeError:
|
||||
# There is a one-tick window where the system nursery
|
||||
# is closed but the init task hasn't yet made
|
||||
# self.asyncgens a strong set to disable GC. We seem to
|
||||
# have hit it.
|
||||
self.trailing_needs_finalize.add(agen)
|
||||
|
||||
def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None:
|
||||
agen_name = name_asyncgen(agen)
|
||||
try:
|
||||
is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen")
|
||||
except AttributeError: # pragma: no cover
|
||||
is_ours = True
|
||||
|
||||
if is_ours:
|
||||
runner.entry_queue.run_sync_soon(
|
||||
finalize_in_trio_context,
|
||||
agen,
|
||||
agen_name,
|
||||
)
|
||||
|
||||
# Do this last, because it might raise an exception
|
||||
# depending on the user's warnings filter. (That
|
||||
# exception will be printed to the terminal and
|
||||
# ignored, since we're running in GC context.)
|
||||
warnings.warn(
|
||||
f"Async generator {agen_name!r} was garbage collected before it "
|
||||
"had been exhausted. Surround its use in 'async with "
|
||||
"aclosing(...):' to ensure that it gets cleaned up as soon as "
|
||||
"you're done using it.",
|
||||
ResourceWarning,
|
||||
stacklevel=2,
|
||||
source=agen,
|
||||
)
|
||||
else:
|
||||
# Not ours -> forward to the host loop's async generator finalizer
|
||||
if self.prev_hooks.finalizer is not None:
|
||||
self.prev_hooks.finalizer(agen)
|
||||
else:
|
||||
# Host has no finalizer. Reimplement the default
|
||||
# Python behavior with no hooks installed: throw in
|
||||
# GeneratorExit, step once, raise RuntimeError if
|
||||
# it doesn't exit.
|
||||
closer = agen.aclose()
|
||||
try:
|
||||
# If the next thing is a yield, this will raise RuntimeError
|
||||
# which we allow to propagate
|
||||
closer.send(None)
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
# If the next thing is an await, we get here. Give a nicer
|
||||
# error than the default "async generator ignored GeneratorExit"
|
||||
raise RuntimeError(
|
||||
f"Non-Trio async generator {agen_name!r} awaited something "
|
||||
"during finalization; install a finalization hook to "
|
||||
"support this, or wrap it in 'async with aclosing(...):'",
|
||||
)
|
||||
|
||||
self.prev_hooks = sys.get_asyncgen_hooks()
|
||||
sys.set_asyncgen_hooks(firstiter=firstiter, finalizer=finalizer) # type: ignore[arg-type] # Finalizer doesn't use AsyncGeneratorType
|
||||
|
||||
async def finalize_remaining(self, runner: _run.Runner) -> None:
|
||||
# This is called from init after shutting down the system nursery.
|
||||
# The only tasks running at this point are init and
|
||||
# the run_sync_soon task, and since the system nursery is closed,
|
||||
# there's no way for user code to spawn more.
|
||||
assert _core.current_task() is runner.init_task
|
||||
assert len(runner.tasks) == 2
|
||||
|
||||
# To make async generator finalization easier to reason
|
||||
# about, we'll shut down asyncgen garbage collection by turning
|
||||
# the alive WeakSet into a regular set.
|
||||
self.alive = set(self.alive)
|
||||
|
||||
# Process all pending run_sync_soon callbacks, in case one of
|
||||
# them was an asyncgen finalizer that snuck in under the wire.
|
||||
runner.entry_queue.run_sync_soon(runner.reschedule, runner.init_task)
|
||||
await _core.wait_task_rescheduled(
|
||||
lambda _: _core.Abort.FAILED, # pragma: no cover
|
||||
)
|
||||
self.alive.update(self.trailing_needs_finalize)
|
||||
self.trailing_needs_finalize.clear()
|
||||
|
||||
# None of the still-living tasks use async generators, so
|
||||
# every async generator must be suspended at a yield point --
|
||||
# there's no one to be doing the iteration. That's good,
|
||||
# because aclose() only works on an asyncgen that's suspended
|
||||
# at a yield point. (If it's suspended at an event loop trap,
|
||||
# because someone is in the middle of iterating it, then you
|
||||
# get a RuntimeError on 3.8+, and a nasty surprise on earlier
|
||||
# versions due to https://bugs.python.org/issue32526.)
|
||||
#
|
||||
# However, once we start aclose() of one async generator, it
|
||||
# might start fetching the next value from another, thus
|
||||
# preventing us from closing that other (at least until
|
||||
# aclose() of the first one is complete). This constraint
|
||||
# effectively requires us to finalize the remaining asyncgens
|
||||
# in arbitrary order, rather than doing all of them at the
|
||||
# same time. On 3.8+ we could defer any generator with
|
||||
# ag_running=True to a later batch, but that only catches
|
||||
# the case where our aclose() starts after the user's
|
||||
# asend()/etc. If our aclose() starts first, then the
|
||||
# user's asend()/etc will raise RuntimeError, since they're
|
||||
# probably not checking ag_running.
|
||||
#
|
||||
# It might be possible to allow some parallelized cleanup if
|
||||
# we can determine that a certain set of asyncgens have no
|
||||
# interdependencies, using gc.get_referents() and such.
|
||||
# But just doing one at a time will typically work well enough
|
||||
# (since each aclose() executes in a cancelled scope) and
|
||||
# is much easier to reason about.
|
||||
|
||||
# It's possible that that cleanup code will itself create
|
||||
# more async generators, so we iterate repeatedly until
|
||||
# all are gone.
|
||||
while self.alive:
|
||||
batch = self.alive
|
||||
self.alive = _ASYNC_GEN_SET()
|
||||
for agen in batch:
|
||||
await self._finalize_one(agen, name_asyncgen(agen))
|
||||
|
||||
def close(self) -> None:
|
||||
sys.set_asyncgen_hooks(*self.prev_hooks)
|
||||
|
||||
async def _finalize_one(
|
||||
self,
|
||||
agen: AsyncGeneratorType[object, NoReturn],
|
||||
name: object,
|
||||
) -> None:
|
||||
try:
|
||||
# This shield ensures that finalize_asyncgen never exits
|
||||
# with an exception, not even a Cancelled. The inside
|
||||
# is cancelled so there's no deadlock risk.
|
||||
with _core.CancelScope(shield=True) as cancel_scope:
|
||||
cancel_scope.cancel()
|
||||
await agen.aclose()
|
||||
except BaseException:
|
||||
ASYNCGEN_LOGGER.exception(
|
||||
"Exception ignored during finalization of async generator %r -- "
|
||||
"surround your use of the generator in 'async with aclosing(...):' "
|
||||
"to raise exceptions like this in the context where they're generated",
|
||||
name,
|
||||
)
|
||||
@@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import TracebackType
|
||||
from typing import Any, ClassVar, cast
|
||||
|
||||
################################################################
|
||||
# concat_tb
|
||||
################################################################
|
||||
|
||||
# We need to compute a new traceback that is the concatenation of two existing
|
||||
# tracebacks. This requires copying the entries in 'head' and then pointing
|
||||
# the final tb_next to 'tail'.
|
||||
#
|
||||
# NB: 'tail' might be None, which requires some special handling in the ctypes
|
||||
# version.
|
||||
#
|
||||
# The complication here is that Python doesn't actually support copying or
|
||||
# modifying traceback objects, so we have to get creative...
|
||||
#
|
||||
# On CPython, we use ctypes. On PyPy, we use "transparent proxies".
|
||||
#
|
||||
# Jinja2 is a useful source of inspiration:
|
||||
# https://github.com/pallets/jinja/blob/main/src/jinja2/debug.py
|
||||
|
||||
try:
|
||||
import tputil
|
||||
except ImportError:
|
||||
# ctypes it is
|
||||
# How to handle refcounting? I don't want to use ctypes.py_object because
|
||||
# I don't understand or trust it, and I don't want to use
|
||||
# ctypes.pythonapi.Py_{Inc,Dec}Ref because we might clash with user code
|
||||
# that also tries to use them but with different types. So private _ctypes
|
||||
# APIs it is!
|
||||
import _ctypes
|
||||
import ctypes
|
||||
|
||||
class CTraceback(ctypes.Structure):
|
||||
_fields_: ClassVar = [
|
||||
("PyObject_HEAD", ctypes.c_byte * object().__sizeof__()),
|
||||
("tb_next", ctypes.c_void_p),
|
||||
("tb_frame", ctypes.c_void_p),
|
||||
("tb_lasti", ctypes.c_int),
|
||||
("tb_lineno", ctypes.c_int),
|
||||
]
|
||||
|
||||
def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType:
|
||||
# TracebackType has no public constructor, so allocate one the hard way
|
||||
try:
|
||||
raise ValueError
|
||||
except ValueError as exc:
|
||||
new_tb = exc.__traceback__
|
||||
assert new_tb is not None
|
||||
c_new_tb = CTraceback.from_address(id(new_tb))
|
||||
|
||||
# At the C level, tb_next either points to the next traceback or is
|
||||
# NULL. c_void_p and the .tb_next accessor both convert NULL to None,
|
||||
# but we shouldn't DECREF None just because we assigned to a NULL
|
||||
# pointer! Here we know that our new traceback has only 1 frame in it,
|
||||
# so we can assume the tb_next field is NULL.
|
||||
assert c_new_tb.tb_next is None
|
||||
# If tb_next is None, then we want to set c_new_tb.tb_next to NULL,
|
||||
# which it already is, so we're done. Otherwise, we have to actually
|
||||
# do some work:
|
||||
if tb_next is not None:
|
||||
_ctypes.Py_INCREF(tb_next) # type: ignore[attr-defined]
|
||||
c_new_tb.tb_next = id(tb_next)
|
||||
|
||||
assert c_new_tb.tb_frame is not None
|
||||
_ctypes.Py_INCREF(base_tb.tb_frame) # type: ignore[attr-defined]
|
||||
old_tb_frame = new_tb.tb_frame
|
||||
c_new_tb.tb_frame = id(base_tb.tb_frame)
|
||||
_ctypes.Py_DECREF(old_tb_frame) # type: ignore[attr-defined]
|
||||
|
||||
c_new_tb.tb_lasti = base_tb.tb_lasti
|
||||
c_new_tb.tb_lineno = base_tb.tb_lineno
|
||||
|
||||
try:
|
||||
return new_tb
|
||||
finally:
|
||||
# delete references from locals to avoid creating cycles
|
||||
# see test_cancel_scope_exit_doesnt_create_cyclic_garbage
|
||||
del new_tb, old_tb_frame
|
||||
|
||||
else:
|
||||
# http://doc.pypy.org/en/latest/objspace-proxies.html
|
||||
def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType:
|
||||
# tputil.ProxyOperation is PyPy-only, and there's no way to specify
|
||||
# cpython/pypy in current type checkers.
|
||||
def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[no-any-unimported]
|
||||
# Rationale for pragma: I looked fairly carefully and tried a few
|
||||
# things, and AFAICT it's not actually possible to get any
|
||||
# 'opname' that isn't __getattr__ or __getattribute__. So there's
|
||||
# no missing test we could add, and no value in coverage nagging
|
||||
# us about adding one.
|
||||
if (
|
||||
operation.opname
|
||||
in {
|
||||
"__getattribute__",
|
||||
"__getattr__",
|
||||
}
|
||||
and operation.args[0] == "tb_next"
|
||||
): # pragma: no cover
|
||||
return tb_next
|
||||
return operation.delegate() # Delegate is reverting to original behaviour
|
||||
|
||||
return cast(
|
||||
TracebackType,
|
||||
tputil.make_proxy(controller, type(base_tb), base_tb),
|
||||
) # Returns proxy to traceback
|
||||
|
||||
|
||||
# this is used for collapsing single-exception ExceptionGroups when using
|
||||
# `strict_exception_groups=False`. Once that is retired this function and its helper can
|
||||
# be removed as well.
|
||||
def concat_tb(
|
||||
head: TracebackType | None,
|
||||
tail: TracebackType | None,
|
||||
) -> TracebackType | None:
|
||||
# We have to use an iterative algorithm here, because in the worst case
|
||||
# this might be a RecursionError stack that is by definition too deep to
|
||||
# process by recursion!
|
||||
head_tbs = []
|
||||
pointer = head
|
||||
while pointer is not None:
|
||||
head_tbs.append(pointer)
|
||||
pointer = pointer.tb_next
|
||||
current_head = tail
|
||||
for head_tb in reversed(head_tbs):
|
||||
current_head = copy_tb(head_tb, tb_next=current_head)
|
||||
return current_head
|
||||
@@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Callable, NoReturn, Tuple
|
||||
|
||||
import attrs
|
||||
|
||||
from .. import _core
|
||||
from .._util import NoPublicConstructor, final
|
||||
from ._wakeup_socketpair import WakeupSocketpair
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeVarTuple, Unpack
|
||||
|
||||
PosArgsT = TypeVarTuple("PosArgsT")
|
||||
|
||||
Function = Callable[..., object]
|
||||
Job = Tuple[Function, Tuple[object, ...]]
|
||||
|
||||
|
||||
@attrs.define
|
||||
class EntryQueue:
|
||||
# This used to use a queue.Queue. but that was broken, because Queues are
|
||||
# implemented in Python, and not reentrant -- so it was thread-safe, but
|
||||
# not signal-safe. deque is implemented in C, so each operation is atomic
|
||||
# WRT threads (and this is guaranteed in the docs), AND each operation is
|
||||
# atomic WRT signal delivery (signal handlers can run on either side, but
|
||||
# not *during* a deque operation). dict makes similar guarantees - and
|
||||
# it's even ordered!
|
||||
queue: deque[Job] = attrs.Factory(deque)
|
||||
idempotent_queue: dict[Job, None] = attrs.Factory(dict)
|
||||
|
||||
wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair)
|
||||
done: bool = False
|
||||
# Must be a reentrant lock, because it's acquired from signal handlers.
|
||||
# RLock is signal-safe as of cpython 3.2. NB that this does mean that the
|
||||
# lock is effectively *disabled* when we enter from signal context. The
|
||||
# way we use the lock this is OK though, because when
|
||||
# run_sync_soon is called from a signal it's atomic WRT the
|
||||
# main thread -- it just might happen at some inconvenient place. But if
|
||||
# you look at the one place where the main thread holds the lock, it's
|
||||
# just to make 1 assignment, so that's atomic WRT a signal anyway.
|
||||
lock: threading.RLock = attrs.Factory(threading.RLock)
|
||||
|
||||
async def task(self) -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
# RLock has two implementations: a signal-safe version in _thread, and
|
||||
# and signal-UNsafe version in threading. We need the signal safe
|
||||
# version. Python 3.2 and later should always use this anyway, but,
|
||||
# since the symptoms if this goes wrong are just "weird rare
|
||||
# deadlocks", then let's make a little check.
|
||||
# See:
|
||||
# https://bugs.python.org/issue13697#msg237140
|
||||
assert self.lock.__class__.__module__ == "_thread"
|
||||
|
||||
def run_cb(job: Job) -> None:
|
||||
# We run this with KI protection enabled; it's the callback's
|
||||
# job to disable it if it wants it disabled. Exceptions are
|
||||
# treated like system task exceptions (i.e., converted into
|
||||
# TrioInternalError and cause everything to shut down).
|
||||
sync_fn, args = job
|
||||
try:
|
||||
sync_fn(*args)
|
||||
except BaseException as exc:
|
||||
|
||||
async def kill_everything(exc: BaseException) -> NoReturn:
|
||||
raise exc
|
||||
|
||||
try:
|
||||
_core.spawn_system_task(kill_everything, exc)
|
||||
except RuntimeError:
|
||||
# We're quite late in the shutdown process and the
|
||||
# system nursery is already closed.
|
||||
# TODO(2020-06): this is a gross hack and should
|
||||
# be fixed soon when we address #1607.
|
||||
parent_nursery = _core.current_task().parent_nursery
|
||||
if parent_nursery is None:
|
||||
raise AssertionError(
|
||||
"Internal error: `parent_nursery` should never be `None`",
|
||||
) from exc # pragma: no cover
|
||||
parent_nursery.start_soon(kill_everything, exc)
|
||||
|
||||
# This has to be carefully written to be safe in the face of new items
|
||||
# being queued while we iterate, and to do a bounded amount of work on
|
||||
# each pass:
|
||||
def run_all_bounded() -> None:
|
||||
for _ in range(len(self.queue)):
|
||||
run_cb(self.queue.popleft())
|
||||
for job in list(self.idempotent_queue):
|
||||
del self.idempotent_queue[job]
|
||||
run_cb(job)
|
||||
|
||||
try:
|
||||
while True:
|
||||
run_all_bounded()
|
||||
if not self.queue and not self.idempotent_queue:
|
||||
await self.wakeup.wait_woken()
|
||||
else:
|
||||
await _core.checkpoint()
|
||||
except _core.Cancelled:
|
||||
# Keep the work done with this lock held as minimal as possible,
|
||||
# because it doesn't protect us against concurrent signal delivery
|
||||
# (see the comment above). Notice that this code would still be
|
||||
# correct if written like:
|
||||
# self.done = True
|
||||
# with self.lock:
|
||||
# pass
|
||||
# because all we want is to force run_sync_soon
|
||||
# to either be completely before or completely after the write to
|
||||
# done. That's why we don't need the lock to protect
|
||||
# against signal handlers.
|
||||
with self.lock:
|
||||
self.done = True
|
||||
# No more jobs will be submitted, so just clear out any residual
|
||||
# ones:
|
||||
run_all_bounded()
|
||||
assert not self.queue
|
||||
assert not self.idempotent_queue
|
||||
|
||||
def close(self) -> None:
|
||||
self.wakeup.close()
|
||||
|
||||
def size(self) -> int:
|
||||
return len(self.queue) + len(self.idempotent_queue)
|
||||
|
||||
def run_sync_soon(
|
||||
self,
|
||||
sync_fn: Callable[[Unpack[PosArgsT]], object],
|
||||
*args: Unpack[PosArgsT],
|
||||
idempotent: bool = False,
|
||||
) -> None:
|
||||
with self.lock:
|
||||
if self.done:
|
||||
raise _core.RunFinishedError("run() has exited")
|
||||
# We have to hold the lock all the way through here, because
|
||||
# otherwise the main thread might exit *while* we're doing these
|
||||
# calls, and then our queue item might not be processed, or the
|
||||
# wakeup call might trigger an OSError b/c the IO manager has
|
||||
# already been shut down.
|
||||
if idempotent:
|
||||
self.idempotent_queue[(sync_fn, args)] = None
|
||||
else:
|
||||
self.queue.append((sync_fn, args))
|
||||
self.wakeup.wakeup_thread_and_signal_safe()
|
||||
|
||||
|
||||
@final
|
||||
@attrs.define(eq=False)
|
||||
class TrioToken(metaclass=NoPublicConstructor):
|
||||
"""An opaque object representing a single call to :func:`trio.run`.
|
||||
|
||||
It has no public constructor; instead, see :func:`current_trio_token`.
|
||||
|
||||
This object has two uses:
|
||||
|
||||
1. It lets you re-enter the Trio run loop from external threads or signal
|
||||
handlers. This is the low-level primitive that :func:`trio.to_thread`
|
||||
and `trio.from_thread` use to communicate with worker threads, that
|
||||
`trio.open_signal_receiver` uses to receive notifications about
|
||||
signals, and so forth.
|
||||
|
||||
2. Each call to :func:`trio.run` has exactly one associated
|
||||
:class:`TrioToken` object, so you can use it to identify a particular
|
||||
call.
|
||||
|
||||
"""
|
||||
|
||||
_reentry_queue: EntryQueue
|
||||
|
||||
def run_sync_soon(
|
||||
self,
|
||||
sync_fn: Callable[[Unpack[PosArgsT]], object],
|
||||
*args: Unpack[PosArgsT],
|
||||
idempotent: bool = False,
|
||||
) -> None:
|
||||
"""Schedule a call to ``sync_fn(*args)`` to occur in the context of a
|
||||
Trio task.
|
||||
|
||||
This is safe to call from the main thread, from other threads, and
|
||||
from signal handlers. This is the fundamental primitive used to
|
||||
re-enter the Trio run loop from outside of it.
|
||||
|
||||
The call will happen "soon", but there's no guarantee about exactly
|
||||
when, and no mechanism provided for finding out when it's happened.
|
||||
If you need this, you'll have to build your own.
|
||||
|
||||
The call is effectively run as part of a system task (see
|
||||
:func:`~trio.lowlevel.spawn_system_task`). In particular this means
|
||||
that:
|
||||
|
||||
* :exc:`KeyboardInterrupt` protection is *enabled* by default; if
|
||||
you want ``sync_fn`` to be interruptible by control-C, then you
|
||||
need to use :func:`~trio.lowlevel.disable_ki_protection`
|
||||
explicitly.
|
||||
|
||||
* If ``sync_fn`` raises an exception, then it's converted into a
|
||||
:exc:`~trio.TrioInternalError` and *all* tasks are cancelled. You
|
||||
should be careful that ``sync_fn`` doesn't crash.
|
||||
|
||||
All calls with ``idempotent=False`` are processed in strict
|
||||
first-in first-out order.
|
||||
|
||||
If ``idempotent=True``, then ``sync_fn`` and ``args`` must be
|
||||
hashable, and Trio will make a best-effort attempt to discard any
|
||||
call submission which is equal to an already-pending call. Trio
|
||||
will process these in first-in first-out order.
|
||||
|
||||
Any ordering guarantees apply separately to ``idempotent=False``
|
||||
and ``idempotent=True`` calls; there's no rule for how calls in the
|
||||
different categories are ordered with respect to each other.
|
||||
|
||||
:raises trio.RunFinishedError:
|
||||
if the associated call to :func:`trio.run`
|
||||
has already exited. (Any call that *doesn't* raise this error
|
||||
is guaranteed to be fully processed before :func:`trio.run`
|
||||
exits.)
|
||||
|
||||
"""
|
||||
self._reentry_queue.run_sync_soon(sync_fn, *args, idempotent=idempotent)
|
||||
@@ -0,0 +1,113 @@
|
||||
from trio._util import NoPublicConstructor, final
|
||||
|
||||
|
||||
class TrioInternalError(Exception):
|
||||
"""Raised by :func:`run` if we encounter a bug in Trio, or (possibly) a
|
||||
misuse of one of the low-level :mod:`trio.lowlevel` APIs.
|
||||
|
||||
This should never happen! If you get this error, please file a bug.
|
||||
|
||||
Unfortunately, if you get this error it also means that all bets are off –
|
||||
Trio doesn't know what is going on and its normal invariants may be void.
|
||||
(For example, we might have "lost track" of a task. Or lost track of all
|
||||
tasks.) Again, though, this shouldn't happen.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class RunFinishedError(RuntimeError):
|
||||
"""Raised by `trio.from_thread.run` and similar functions if the
|
||||
corresponding call to :func:`trio.run` has already finished.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class WouldBlock(Exception):
|
||||
"""Raised by ``X_nowait`` functions if ``X`` would block."""
|
||||
|
||||
|
||||
@final
|
||||
class Cancelled(BaseException, metaclass=NoPublicConstructor):
|
||||
"""Raised by blocking calls if the surrounding scope has been cancelled.
|
||||
|
||||
You should let this exception propagate, to be caught by the relevant
|
||||
cancel scope. To remind you of this, it inherits from :exc:`BaseException`
|
||||
instead of :exc:`Exception`, just like :exc:`KeyboardInterrupt` and
|
||||
:exc:`SystemExit` do. This means that if you write something like::
|
||||
|
||||
try:
|
||||
...
|
||||
except Exception:
|
||||
...
|
||||
|
||||
then this *won't* catch a :exc:`Cancelled` exception.
|
||||
|
||||
You cannot raise :exc:`Cancelled` yourself. Attempting to do so
|
||||
will produce a :exc:`TypeError`. Use :meth:`cancel_scope.cancel()
|
||||
<trio.CancelScope.cancel>` instead.
|
||||
|
||||
.. note::
|
||||
|
||||
In the US it's also common to see this word spelled "canceled", with
|
||||
only one "l". This is a `recent
|
||||
<https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=5&smoothing=3&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__
|
||||
and `US-specific
|
||||
<https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=18&smoothing=3&share=&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__
|
||||
innovation, and even in the US both forms are still commonly used. So
|
||||
for consistency with the rest of the world and with "cancellation"
|
||||
(which always has two "l"s), Trio uses the two "l" spelling
|
||||
everywhere.
|
||||
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "Cancelled"
|
||||
|
||||
|
||||
class BusyResourceError(Exception):
|
||||
"""Raised when a task attempts to use a resource that some other task is
|
||||
already using, and this would lead to bugs and nonsense.
|
||||
|
||||
For example, if two tasks try to send data through the same socket at the
|
||||
same time, Trio will raise :class:`BusyResourceError` instead of letting
|
||||
the data get scrambled.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ClosedResourceError(Exception):
|
||||
"""Raised when attempting to use a resource after it has been closed.
|
||||
|
||||
Note that "closed" here means that *your* code closed the resource,
|
||||
generally by calling a method with a name like ``close`` or ``aclose``, or
|
||||
by exiting a context manager. If a problem arises elsewhere – for example,
|
||||
because of a network failure, or because a remote peer closed their end of
|
||||
a connection – then that should be indicated by a different exception
|
||||
class, like :exc:`BrokenResourceError` or an :exc:`OSError` subclass.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class BrokenResourceError(Exception):
|
||||
"""Raised when an attempt to use a resource fails due to external
|
||||
circumstances.
|
||||
|
||||
For example, you might get this if you try to send data on a stream where
|
||||
the remote side has already closed the connection.
|
||||
|
||||
You *don't* get this error if *you* closed the resource – in that case you
|
||||
get :class:`ClosedResourceError`.
|
||||
|
||||
This exception's ``__cause__`` attribute will often contain more
|
||||
information about the underlying error.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class EndOfChannel(Exception):
|
||||
"""Raised when trying to receive from a :class:`trio.abc.ReceiveChannel`
|
||||
that has no more data to receive.
|
||||
|
||||
This is analogous to an "end-of-file" condition, but for channels.
|
||||
|
||||
"""
|
||||
@@ -0,0 +1,51 @@
|
||||
# ***********************************************************
|
||||
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
|
||||
# *************************************************************
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
|
||||
from ._run import GLOBAL_RUN_CONTEXT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._instrumentation import Instrument
|
||||
|
||||
__all__ = ["add_instrument", "remove_instrument"]
|
||||
|
||||
|
||||
def add_instrument(instrument: Instrument) -> None:
|
||||
"""Start instrumenting the current run loop with the given instrument.
|
||||
|
||||
Args:
|
||||
instrument (trio.abc.Instrument): The instrument to activate.
|
||||
|
||||
If ``instrument`` is already active, does nothing.
|
||||
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def remove_instrument(instrument: Instrument) -> None:
|
||||
"""Stop instrumenting the current run loop with the given instrument.
|
||||
|
||||
Args:
|
||||
instrument (trio.abc.Instrument): The instrument to de-activate.
|
||||
|
||||
Raises:
|
||||
KeyError: if the instrument is not currently active. This could
|
||||
occur either because you never added it, or because you added it
|
||||
and then it raised an unhandled exception and was automatically
|
||||
deactivated.
|
||||
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
@@ -0,0 +1,98 @@
|
||||
# ***********************************************************
|
||||
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
|
||||
# *************************************************************
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
|
||||
from ._run import GLOBAL_RUN_CONTEXT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._file_io import _HasFileNo
|
||||
|
||||
assert not TYPE_CHECKING or sys.platform == "linux"
|
||||
|
||||
|
||||
__all__ = ["notify_closing", "wait_readable", "wait_writable"]
|
||||
|
||||
|
||||
async def wait_readable(fd: int | _HasFileNo) -> None:
|
||||
"""Block until the kernel reports that the given object is readable.
|
||||
|
||||
On Unix systems, ``fd`` must either be an integer file descriptor,
|
||||
or else an object with a ``.fileno()`` method which returns an
|
||||
integer file descriptor. Any kind of file descriptor can be passed,
|
||||
though the exact semantics will depend on your kernel. For example,
|
||||
this probably won't do anything useful for on-disk files.
|
||||
|
||||
On Windows systems, ``fd`` must either be an integer ``SOCKET``
|
||||
handle, or else an object with a ``.fileno()`` method which returns
|
||||
an integer ``SOCKET`` handle. File descriptors aren't supported,
|
||||
and neither are handles that refer to anything besides a
|
||||
``SOCKET``.
|
||||
|
||||
:raises trio.BusyResourceError:
|
||||
if another task is already waiting for the given socket to
|
||||
become readable.
|
||||
:raises trio.ClosedResourceError:
|
||||
if another task calls :func:`notify_closing` while this
|
||||
function is still working.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
async def wait_writable(fd: int | _HasFileNo) -> None:
|
||||
"""Block until the kernel reports that the given object is writable.
|
||||
|
||||
See `wait_readable` for the definition of ``fd``.
|
||||
|
||||
:raises trio.BusyResourceError:
|
||||
if another task is already waiting for the given socket to
|
||||
become writable.
|
||||
:raises trio.ClosedResourceError:
|
||||
if another task calls :func:`notify_closing` while this
|
||||
function is still working.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def notify_closing(fd: int | _HasFileNo) -> None:
|
||||
"""Notify waiters of the given object that it will be closed.
|
||||
|
||||
Call this before closing a file descriptor (on Unix) or socket (on
|
||||
Windows). This will cause any `wait_readable` or `wait_writable`
|
||||
calls on the given object to immediately wake up and raise
|
||||
`~trio.ClosedResourceError`.
|
||||
|
||||
This doesn't actually close the object – you still have to do that
|
||||
yourself afterwards. Also, you want to be careful to make sure no
|
||||
new tasks start waiting on the object in between when you call this
|
||||
and when it's actually closed. So to close something properly, you
|
||||
usually want to do these steps in order:
|
||||
|
||||
1. Explicitly mark the object as closed, so that any new attempts
|
||||
to use it will abort before they start.
|
||||
2. Call `notify_closing` to wake up any already-existing users.
|
||||
3. Actually close the object.
|
||||
|
||||
It's also possible to do them in a different order if that's more
|
||||
convenient, *but only if* you make sure not to have any checkpoints in
|
||||
between the steps. This way they all happen in a single atomic
|
||||
step, so other tasks won't be able to tell what order they happened
|
||||
in anyway.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
@@ -0,0 +1,156 @@
|
||||
# ***********************************************************
|
||||
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
|
||||
# *************************************************************
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Callable, ContextManager
|
||||
|
||||
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
|
||||
from ._run import GLOBAL_RUN_CONTEXT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import select
|
||||
|
||||
from .. import _core
|
||||
from .._file_io import _HasFileNo
|
||||
from ._traps import Abort, RaiseCancelT
|
||||
|
||||
assert not TYPE_CHECKING or sys.platform == "darwin"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"current_kqueue",
|
||||
"monitor_kevent",
|
||||
"notify_closing",
|
||||
"wait_kevent",
|
||||
"wait_readable",
|
||||
"wait_writable",
|
||||
]
|
||||
|
||||
|
||||
def current_kqueue() -> select.kqueue:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue()
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def monitor_kevent(
|
||||
ident: int,
|
||||
filter: int,
|
||||
) -> ContextManager[_core.UnboundedQueue[select.kevent]]:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
async def wait_kevent(
|
||||
ident: int,
|
||||
filter: int,
|
||||
abort_func: Callable[[RaiseCancelT], Abort],
|
||||
) -> Abort:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(
|
||||
ident,
|
||||
filter,
|
||||
abort_func,
|
||||
)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
async def wait_readable(fd: int | _HasFileNo) -> None:
|
||||
"""Block until the kernel reports that the given object is readable.
|
||||
|
||||
On Unix systems, ``fd`` must either be an integer file descriptor,
|
||||
or else an object with a ``.fileno()`` method which returns an
|
||||
integer file descriptor. Any kind of file descriptor can be passed,
|
||||
though the exact semantics will depend on your kernel. For example,
|
||||
this probably won't do anything useful for on-disk files.
|
||||
|
||||
On Windows systems, ``fd`` must either be an integer ``SOCKET``
|
||||
handle, or else an object with a ``.fileno()`` method which returns
|
||||
an integer ``SOCKET`` handle. File descriptors aren't supported,
|
||||
and neither are handles that refer to anything besides a
|
||||
``SOCKET``.
|
||||
|
||||
:raises trio.BusyResourceError:
|
||||
if another task is already waiting for the given socket to
|
||||
become readable.
|
||||
:raises trio.ClosedResourceError:
|
||||
if another task calls :func:`notify_closing` while this
|
||||
function is still working.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
async def wait_writable(fd: int | _HasFileNo) -> None:
|
||||
"""Block until the kernel reports that the given object is writable.
|
||||
|
||||
See `wait_readable` for the definition of ``fd``.
|
||||
|
||||
:raises trio.BusyResourceError:
|
||||
if another task is already waiting for the given socket to
|
||||
become writable.
|
||||
:raises trio.ClosedResourceError:
|
||||
if another task calls :func:`notify_closing` while this
|
||||
function is still working.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def notify_closing(fd: int | _HasFileNo) -> None:
|
||||
"""Notify waiters of the given object that it will be closed.
|
||||
|
||||
Call this before closing a file descriptor (on Unix) or socket (on
|
||||
Windows). This will cause any `wait_readable` or `wait_writable`
|
||||
calls on the given object to immediately wake up and raise
|
||||
`~trio.ClosedResourceError`.
|
||||
|
||||
This doesn't actually close the object – you still have to do that
|
||||
yourself afterwards. Also, you want to be careful to make sure no
|
||||
new tasks start waiting on the object in between when you call this
|
||||
and when it's actually closed. So to close something properly, you
|
||||
usually want to do these steps in order:
|
||||
|
||||
1. Explicitly mark the object as closed, so that any new attempts
|
||||
to use it will abort before they start.
|
||||
2. Call `notify_closing` to wake up any already-existing users.
|
||||
3. Actually close the object.
|
||||
|
||||
It's also possible to do them in a different order if that's more
|
||||
convenient, *but only if* you make sure not to have any checkpoints in
|
||||
between the steps. This way they all happen in a single atomic
|
||||
step, so other tasks won't be able to tell what order they happened
|
||||
in anyway.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
@@ -0,0 +1,209 @@
|
||||
# ***********************************************************
|
||||
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
|
||||
# *************************************************************
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, ContextManager
|
||||
|
||||
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
|
||||
from ._run import GLOBAL_RUN_CONTEXT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Buffer
|
||||
|
||||
from .._file_io import _HasFileNo
|
||||
from ._unbounded_queue import UnboundedQueue
|
||||
from ._windows_cffi import CData, Handle
|
||||
|
||||
assert not TYPE_CHECKING or sys.platform == "win32"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"current_iocp",
|
||||
"monitor_completion_key",
|
||||
"notify_closing",
|
||||
"readinto_overlapped",
|
||||
"register_with_iocp",
|
||||
"wait_overlapped",
|
||||
"wait_readable",
|
||||
"wait_writable",
|
||||
"write_overlapped",
|
||||
]
|
||||
|
||||
|
||||
async def wait_readable(sock: _HasFileNo | int) -> None:
|
||||
"""Block until the kernel reports that the given object is readable.
|
||||
|
||||
On Unix systems, ``sock`` must either be an integer file descriptor,
|
||||
or else an object with a ``.fileno()`` method which returns an
|
||||
integer file descriptor. Any kind of file descriptor can be passed,
|
||||
though the exact semantics will depend on your kernel. For example,
|
||||
this probably won't do anything useful for on-disk files.
|
||||
|
||||
On Windows systems, ``sock`` must either be an integer ``SOCKET``
|
||||
handle, or else an object with a ``.fileno()`` method which returns
|
||||
an integer ``SOCKET`` handle. File descriptors aren't supported,
|
||||
and neither are handles that refer to anything besides a
|
||||
``SOCKET``.
|
||||
|
||||
:raises trio.BusyResourceError:
|
||||
if another task is already waiting for the given socket to
|
||||
become readable.
|
||||
:raises trio.ClosedResourceError:
|
||||
if another task calls :func:`notify_closing` while this
|
||||
function is still working.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
async def wait_writable(sock: _HasFileNo | int) -> None:
|
||||
"""Block until the kernel reports that the given object is writable.
|
||||
|
||||
See `wait_readable` for the definition of ``sock``.
|
||||
|
||||
:raises trio.BusyResourceError:
|
||||
if another task is already waiting for the given socket to
|
||||
become writable.
|
||||
:raises trio.ClosedResourceError:
|
||||
if another task calls :func:`notify_closing` while this
|
||||
function is still working.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def notify_closing(handle: Handle | int | _HasFileNo) -> None:
|
||||
"""Notify waiters of the given object that it will be closed.
|
||||
|
||||
Call this before closing a file descriptor (on Unix) or socket (on
|
||||
Windows). This will cause any `wait_readable` or `wait_writable`
|
||||
calls on the given object to immediately wake up and raise
|
||||
`~trio.ClosedResourceError`.
|
||||
|
||||
This doesn't actually close the object – you still have to do that
|
||||
yourself afterwards. Also, you want to be careful to make sure no
|
||||
new tasks start waiting on the object in between when you call this
|
||||
and when it's actually closed. So to close something properly, you
|
||||
usually want to do these steps in order:
|
||||
|
||||
1. Explicitly mark the object as closed, so that any new attempts
|
||||
to use it will abort before they start.
|
||||
2. Call `notify_closing` to wake up any already-existing users.
|
||||
3. Actually close the object.
|
||||
|
||||
It's also possible to do them in a different order if that's more
|
||||
convenient, *but only if* you make sure not to have any checkpoints in
|
||||
between the steps. This way they all happen in a single atomic
|
||||
step, so other tasks won't be able to tell what order they happened
|
||||
in anyway.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def register_with_iocp(handle: int | CData) -> None:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__ and `#52
|
||||
<https://github.com/python-trio/trio/issues/52>`__.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
async def wait_overlapped(handle_: int | CData, lpOverlapped: CData | int) -> object:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__ and `#52
|
||||
<https://github.com/python-trio/trio/issues/52>`__.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped(
|
||||
handle_,
|
||||
lpOverlapped,
|
||||
)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
async def write_overlapped(
|
||||
handle: int | CData,
|
||||
data: Buffer,
|
||||
file_offset: int = 0,
|
||||
) -> int:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__ and `#52
|
||||
<https://github.com/python-trio/trio/issues/52>`__.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped(
|
||||
handle,
|
||||
data,
|
||||
file_offset,
|
||||
)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
async def readinto_overlapped(
|
||||
handle: int | CData,
|
||||
buffer: Buffer,
|
||||
file_offset: int = 0,
|
||||
) -> int:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__ and `#52
|
||||
<https://github.com/python-trio/trio/issues/52>`__.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped(
|
||||
handle,
|
||||
buffer,
|
||||
file_offset,
|
||||
)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def current_iocp() -> int:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__ and `#52
|
||||
<https://github.com/python-trio/trio/issues/52>`__.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp()
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def monitor_completion_key() -> ContextManager[tuple[int, UnboundedQueue[object]]]:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__ and `#52
|
||||
<https://github.com/python-trio/trio/issues/52>`__.
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key()
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
@@ -0,0 +1,273 @@
|
||||
# ***********************************************************
|
||||
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
|
||||
# *************************************************************
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
|
||||
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT, RunStatistics, Task
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import contextvars
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from outcome import Outcome
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from .._abc import Clock
|
||||
from ._entry_queue import TrioToken
|
||||
from ._run import PosArgT
|
||||
|
||||
|
||||
__all__ = [
|
||||
"current_clock",
|
||||
"current_root_task",
|
||||
"current_statistics",
|
||||
"current_time",
|
||||
"current_trio_token",
|
||||
"reschedule",
|
||||
"spawn_system_task",
|
||||
"wait_all_tasks_blocked",
|
||||
]
|
||||
|
||||
|
||||
def current_statistics() -> RunStatistics:
|
||||
"""Returns ``RunStatistics``, which contains run-loop-level debugging information.
|
||||
|
||||
Currently, the following fields are defined:
|
||||
|
||||
* ``tasks_living`` (int): The number of tasks that have been spawned
|
||||
and not yet exited.
|
||||
* ``tasks_runnable`` (int): The number of tasks that are currently
|
||||
queued on the run queue (as opposed to blocked waiting for something
|
||||
to happen).
|
||||
* ``seconds_to_next_deadline`` (float): The time until the next
|
||||
pending cancel scope deadline. May be negative if the deadline has
|
||||
expired but we haven't yet processed cancellations. May be
|
||||
:data:`~math.inf` if there are no pending deadlines.
|
||||
* ``run_sync_soon_queue_size`` (int): The number of
|
||||
unprocessed callbacks queued via
|
||||
:meth:`trio.lowlevel.TrioToken.run_sync_soon`.
|
||||
* ``io_statistics`` (object): Some statistics from Trio's I/O
|
||||
backend. This always has an attribute ``backend`` which is a string
|
||||
naming which operating-system-specific I/O backend is in use; the
|
||||
other attributes vary between backends.
|
||||
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.current_statistics()
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def current_time() -> float:
|
||||
"""Returns the current time according to Trio's internal clock.
|
||||
|
||||
Returns:
|
||||
float: The current time.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if not inside a call to :func:`trio.run`.
|
||||
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.current_time()
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def current_clock() -> Clock:
|
||||
"""Returns the current :class:`~trio.abc.Clock`."""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.current_clock()
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def current_root_task() -> Task | None:
|
||||
"""Returns the current root :class:`Task`.
|
||||
|
||||
This is the task that is the ultimate parent of all other tasks.
|
||||
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.current_root_task()
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None:
|
||||
"""Reschedule the given task with the given
|
||||
:class:`outcome.Outcome`.
|
||||
|
||||
See :func:`wait_task_rescheduled` for the gory details.
|
||||
|
||||
There must be exactly one call to :func:`reschedule` for every call to
|
||||
:func:`wait_task_rescheduled`. (And when counting, keep in mind that
|
||||
returning :data:`Abort.SUCCEEDED` from an abort callback is equivalent
|
||||
to calling :func:`reschedule` once.)
|
||||
|
||||
Args:
|
||||
task (trio.lowlevel.Task): the task to be rescheduled. Must be blocked
|
||||
in a call to :func:`wait_task_rescheduled`.
|
||||
next_send (outcome.Outcome): the value (or error) to return (or
|
||||
raise) from :func:`wait_task_rescheduled`.
|
||||
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def spawn_system_task(
|
||||
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
|
||||
*args: Unpack[PosArgT],
|
||||
name: object = None,
|
||||
context: contextvars.Context | None = None,
|
||||
) -> Task:
|
||||
"""Spawn a "system" task.
|
||||
|
||||
System tasks have a few differences from regular tasks:
|
||||
|
||||
* They don't need an explicit nursery; instead they go into the
|
||||
internal "system nursery".
|
||||
|
||||
* If a system task raises an exception, then it's converted into a
|
||||
:exc:`~trio.TrioInternalError` and *all* tasks are cancelled. If you
|
||||
write a system task, you should be careful to make sure it doesn't
|
||||
crash.
|
||||
|
||||
* System tasks are automatically cancelled when the main task exits.
|
||||
|
||||
* By default, system tasks have :exc:`KeyboardInterrupt` protection
|
||||
*enabled*. If you want your task to be interruptible by control-C,
|
||||
then you need to use :func:`disable_ki_protection` explicitly (and
|
||||
come up with some plan for what to do with a
|
||||
:exc:`KeyboardInterrupt`, given that system tasks aren't allowed to
|
||||
raise exceptions).
|
||||
|
||||
* System tasks do not inherit context variables from their creator.
|
||||
|
||||
Towards the end of a call to :meth:`trio.run`, after the main
|
||||
task and all system tasks have exited, the system nursery
|
||||
becomes closed. At this point, new calls to
|
||||
:func:`spawn_system_task` will raise ``RuntimeError("Nursery
|
||||
is closed to new arrivals")`` instead of creating a system
|
||||
task. It's possible to encounter this state either in
|
||||
a ``finally`` block in an async generator, or in a callback
|
||||
passed to :meth:`TrioToken.run_sync_soon` at the right moment.
|
||||
|
||||
Args:
|
||||
async_fn: An async callable.
|
||||
args: Positional arguments for ``async_fn``. If you want to pass
|
||||
keyword arguments, use :func:`functools.partial`.
|
||||
name: The name for this task. Only used for debugging/introspection
|
||||
(e.g. ``repr(task_obj)``). If this isn't a string,
|
||||
:func:`spawn_system_task` will try to make it one. A common use
|
||||
case is if you're wrapping a function before spawning a new
|
||||
task, you might pass the original function as the ``name=`` to
|
||||
make debugging easier.
|
||||
context: An optional ``contextvars.Context`` object with context variables
|
||||
to use for this task. You would normally get a copy of the current
|
||||
context with ``context = contextvars.copy_context()`` and then you would
|
||||
pass that ``context`` object here.
|
||||
|
||||
Returns:
|
||||
Task: the newly spawned task
|
||||
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.spawn_system_task(
|
||||
async_fn,
|
||||
*args,
|
||||
name=name,
|
||||
context=context,
|
||||
)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
def current_trio_token() -> TrioToken:
|
||||
"""Retrieve the :class:`TrioToken` for the current call to
|
||||
:func:`trio.run`.
|
||||
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return GLOBAL_RUN_CONTEXT.runner.current_trio_token()
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
|
||||
|
||||
async def wait_all_tasks_blocked(cushion: float = 0.0) -> None:
|
||||
"""Block until there are no runnable tasks.
|
||||
|
||||
This is useful in testing code when you want to give other tasks a
|
||||
chance to "settle down". The calling task is blocked, and doesn't wake
|
||||
up until all other tasks are also blocked for at least ``cushion``
|
||||
seconds. (Setting a non-zero ``cushion`` is intended to handle cases
|
||||
like two tasks talking to each other over a local socket, where we
|
||||
want to ignore the potential brief moment between a send and receive
|
||||
when all tasks are blocked.)
|
||||
|
||||
Note that ``cushion`` is measured in *real* time, not the Trio clock
|
||||
time.
|
||||
|
||||
If there are multiple tasks blocked in :func:`wait_all_tasks_blocked`,
|
||||
then the one with the shortest ``cushion`` is the one woken (and
|
||||
this task becoming unblocked resets the timers for the remaining
|
||||
tasks). If there are multiple tasks that have exactly the same
|
||||
``cushion``, then all are woken.
|
||||
|
||||
You should also consider :class:`trio.testing.Sequencer`, which
|
||||
provides a more explicit way to control execution ordering within a
|
||||
test, and will often produce more readable tests.
|
||||
|
||||
Example:
|
||||
Here's an example of one way to test that Trio's locks are fair: we
|
||||
take the lock in the parent, start a child, wait for the child to be
|
||||
blocked waiting for the lock (!), and then check that we can't
|
||||
release and immediately re-acquire the lock::
|
||||
|
||||
async def lock_taker(lock):
|
||||
await lock.acquire()
|
||||
lock.release()
|
||||
|
||||
async def test_lock_fairness():
|
||||
lock = trio.Lock()
|
||||
await lock.acquire()
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(lock_taker, lock)
|
||||
# child hasn't run yet, we have the lock
|
||||
assert lock.locked()
|
||||
assert lock._owner is trio.lowlevel.current_task()
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
# now the child has run and is blocked on lock.acquire(), we
|
||||
# still have the lock
|
||||
assert lock.locked()
|
||||
assert lock._owner is trio.lowlevel.current_task()
|
||||
lock.release()
|
||||
try:
|
||||
# The child has a prior claim, so we can't have it
|
||||
lock.acquire_nowait()
|
||||
except trio.WouldBlock:
|
||||
assert lock._owner is not trio.lowlevel.current_task()
|
||||
print("PASS")
|
||||
else:
|
||||
print("FAIL")
|
||||
|
||||
"""
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
||||
try:
|
||||
return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion)
|
||||
except AttributeError:
|
||||
raise RuntimeError("must be called from async context") from None
|
||||
@@ -0,0 +1,108 @@
|
||||
import logging
|
||||
import types
|
||||
from typing import Any, Callable, Dict, Sequence, TypeVar
|
||||
|
||||
from .._abc import Instrument
|
||||
|
||||
# Used to log exceptions in instruments
|
||||
INSTRUMENT_LOGGER = logging.getLogger("trio.abc.Instrument")
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
# Decorator to mark methods public. This does nothing by itself, but
|
||||
# trio/_tools/gen_exports.py looks for it.
|
||||
def _public(fn: F) -> F:
|
||||
return fn
|
||||
|
||||
|
||||
class Instruments(Dict[str, Dict[Instrument, None]]):
|
||||
"""A collection of `trio.abc.Instrument` organized by hook.
|
||||
|
||||
Instrumentation calls are rather expensive, and we don't want a
|
||||
rarely-used instrument (like before_run()) to slow down hot
|
||||
operations (like before_task_step()). Thus, we cache the set of
|
||||
instruments to be called for each hook, and skip the instrumentation
|
||||
call if there's nothing currently installed for that hook.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, incoming: Sequence[Instrument]):
|
||||
self["_all"] = {}
|
||||
for instrument in incoming:
|
||||
self.add_instrument(instrument)
|
||||
|
||||
@_public
|
||||
def add_instrument(self, instrument: Instrument) -> None:
|
||||
"""Start instrumenting the current run loop with the given instrument.
|
||||
|
||||
Args:
|
||||
instrument (trio.abc.Instrument): The instrument to activate.
|
||||
|
||||
If ``instrument`` is already active, does nothing.
|
||||
|
||||
"""
|
||||
if instrument in self["_all"]:
|
||||
return
|
||||
self["_all"][instrument] = None
|
||||
try:
|
||||
for name in dir(instrument):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
try:
|
||||
prototype = getattr(Instrument, name)
|
||||
except AttributeError:
|
||||
continue
|
||||
impl = getattr(instrument, name)
|
||||
if isinstance(impl, types.MethodType) and impl.__func__ is prototype:
|
||||
# Inherited unchanged from _abc.Instrument
|
||||
continue
|
||||
self.setdefault(name, {})[instrument] = None
|
||||
except:
|
||||
self.remove_instrument(instrument)
|
||||
raise
|
||||
|
||||
@_public
|
||||
def remove_instrument(self, instrument: Instrument) -> None:
|
||||
"""Stop instrumenting the current run loop with the given instrument.
|
||||
|
||||
Args:
|
||||
instrument (trio.abc.Instrument): The instrument to de-activate.
|
||||
|
||||
Raises:
|
||||
KeyError: if the instrument is not currently active. This could
|
||||
occur either because you never added it, or because you added it
|
||||
and then it raised an unhandled exception and was automatically
|
||||
deactivated.
|
||||
|
||||
"""
|
||||
# If instrument isn't present, the KeyError propagates out
|
||||
self["_all"].pop(instrument)
|
||||
for hookname, instruments in list(self.items()):
|
||||
if instrument in instruments:
|
||||
del instruments[instrument]
|
||||
if not instruments:
|
||||
del self[hookname]
|
||||
|
||||
def call(self, hookname: str, *args: Any) -> None:
|
||||
"""Call hookname(*args) on each applicable instrument.
|
||||
|
||||
You must first check whether there are any instruments installed for
|
||||
that hook, e.g.::
|
||||
|
||||
if "before_task_step" in instruments:
|
||||
instruments.call("before_task_step", task)
|
||||
"""
|
||||
for instrument in list(self[hookname]):
|
||||
try:
|
||||
getattr(instrument, hookname)(*args)
|
||||
except BaseException:
|
||||
self.remove_instrument(instrument)
|
||||
INSTRUMENT_LOGGER.exception(
|
||||
"Exception raised when calling %r on instrument %r. "
|
||||
"Instrument has been disabled.",
|
||||
hookname,
|
||||
instrument,
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import outcome
|
||||
|
||||
from .. import _core
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._io_epoll import EpollWaiters
|
||||
from ._io_windows import AFDWaiters
|
||||
|
||||
|
||||
# Utility function shared between _io_epoll and _io_windows
|
||||
def wake_all(waiters: EpollWaiters | AFDWaiters, exc: BaseException) -> None:
|
||||
try:
|
||||
current_task = _core.current_task()
|
||||
except RuntimeError:
|
||||
current_task = None
|
||||
raise_at_end = False
|
||||
for attr_name in ["read_task", "write_task"]:
|
||||
task = getattr(waiters, attr_name)
|
||||
if task is not None:
|
||||
if task is current_task:
|
||||
raise_at_end = True
|
||||
else:
|
||||
_core.reschedule(task, outcome.Error(copy.copy(exc)))
|
||||
setattr(waiters, attr_name, None)
|
||||
if raise_at_end:
|
||||
raise exc
|
||||
@@ -0,0 +1,387 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import select
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import attrs
|
||||
|
||||
from .. import _core
|
||||
from ._io_common import wake_all
|
||||
from ._run import Task, _public
|
||||
from ._wakeup_socketpair import WakeupSocketpair
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .._core import Abort, RaiseCancelT
|
||||
from .._file_io import _HasFileNo
|
||||
|
||||
|
||||
@attrs.define(eq=False)
|
||||
class EpollWaiters:
|
||||
read_task: Task | None = None
|
||||
write_task: Task | None = None
|
||||
current_flags: int = 0
|
||||
|
||||
|
||||
assert not TYPE_CHECKING or sys.platform == "linux"
|
||||
|
||||
|
||||
EventResult: TypeAlias = "list[tuple[int, int]]"
|
||||
|
||||
|
||||
@attrs.frozen(eq=False)
|
||||
class _EpollStatistics:
|
||||
tasks_waiting_read: int
|
||||
tasks_waiting_write: int
|
||||
backend: Literal["epoll"] = attrs.field(init=False, default="epoll")
|
||||
|
||||
|
||||
# Some facts about epoll
|
||||
# ----------------------
|
||||
#
|
||||
# Internally, an epoll object is sort of like a WeakKeyDictionary where the
|
||||
# keys are tuples of (fd number, file object). When you call epoll_ctl, you
|
||||
# pass in an fd; that gets converted to an (fd number, file object) tuple by
|
||||
# looking up the fd in the process's fd table at the time of the call. When an
|
||||
# event happens on the file object, epoll_wait drops the file object part, and
|
||||
# just returns the fd number in its event. So from the outside it looks like
|
||||
# it's keeping a table of fds, but really it's a bit more complicated. This
|
||||
# has some subtle consequences.
|
||||
#
|
||||
# In general, file objects inside the kernel are reference counted. Each entry
|
||||
# in a process's fd table holds a strong reference to the corresponding file
|
||||
# object, and most operations that use file objects take a temporary strong
|
||||
# reference while they're working. So when you call close() on an fd, that
|
||||
# might or might not cause the file object to be deallocated -- it depends on
|
||||
# whether there are any other references to that file object. Some common ways
|
||||
# this can happen:
|
||||
#
|
||||
# - after calling dup(), you have two fds in the same process referring to the
|
||||
# same file object. Even if you close one fd (= remove that entry from the
|
||||
# fd table), the file object will be kept alive by the other fd.
|
||||
# - when calling fork(), the child inherits a copy of the parent's fd table,
|
||||
# so all the file objects get another reference. (But if the fork() is
|
||||
# followed by exec(), then all of the child's fds that have the CLOEXEC flag
|
||||
# set will be closed at that point.)
|
||||
# - most syscalls that work on fds take a strong reference to the underlying
|
||||
# file object while they're using it. So there's one thread blocked in
|
||||
# read(fd), and then another thread calls close() on the last fd referring
|
||||
# to that object, the underlying file won't actually be closed until
|
||||
# after read() returns.
|
||||
#
|
||||
# However, epoll does *not* take a reference to any of the file objects in its
|
||||
# interest set (that's what makes it similar to a WeakKeyDictionary). File
|
||||
# objects inside an epoll interest set will be deallocated if all *other*
|
||||
# references to them are closed. And when that happens, the epoll object will
|
||||
# automatically deregister that file object and stop reporting events on it.
|
||||
# So that's quite handy.
|
||||
#
|
||||
# But, what happens if we do this?
|
||||
#
|
||||
# fd1 = open(...)
|
||||
# epoll_ctl(EPOLL_CTL_ADD, fd1, ...)
|
||||
# fd2 = dup(fd1)
|
||||
# close(fd1)
|
||||
#
|
||||
# In this case, the dup() keeps the underlying file object alive, so it
|
||||
# remains registered in the epoll object's interest set, as the tuple (fd1,
|
||||
# file object). But, fd1 no longer refers to this file object! You might think
|
||||
# there was some magic to handle this, but unfortunately no; the consequences
|
||||
# are totally predictable from what I said above:
|
||||
#
|
||||
# If any events occur on the file object, then epoll will report them as
|
||||
# happening on fd1, even though that doesn't make sense.
|
||||
#
|
||||
# Perhaps we would like to deregister fd1 to stop getting nonsensical events.
|
||||
# But how? When we call epoll_ctl, we have to pass an fd number, which will
|
||||
# get expanded to an (fd number, file object) tuple. We can't pass fd1,
|
||||
# because when epoll_ctl tries to look it up, it won't find our file object.
|
||||
# And we can't pass fd2, because that will get expanded to (fd2, file object),
|
||||
# which is a different lookup key. In fact, it's *impossible* to de-register
|
||||
# this fd!
|
||||
#
|
||||
# We could even have fd1 get assigned to another file object, and then we can
|
||||
# have multiple keys registered simultaneously using the same fd number, like:
|
||||
# (fd1, file object 1), (fd1, file object 2). And if events happen on either
|
||||
# file object, then epoll will happily report that something happened to
|
||||
# "fd1".
|
||||
#
|
||||
# Now here's what makes this especially nasty: suppose the old file object
|
||||
# becomes, say, readable. That means that every time we call epoll_wait, it
|
||||
# will return immediately to tell us that "fd1" is readable. Normally, we
|
||||
# would handle this by de-registering fd1, waking up the corresponding call to
|
||||
# wait_readable, then the user will call read() or recv() or something, and
|
||||
# we're fine. But if this happens on a stale fd where we can't remove the
|
||||
# registration, then we might get stuck in a state where epoll_wait *always*
|
||||
# returns immediately, so our event loop becomes unable to sleep, and now our
|
||||
# program is burning 100% of the CPU doing nothing, with no way out.
|
||||
#
|
||||
#
|
||||
# What does this mean for Trio?
|
||||
# -----------------------------
|
||||
#
|
||||
# Since we don't control the user's code, we have no way to guarantee that we
|
||||
# don't get stuck with stale fd's in our epoll interest set. For example, a
|
||||
# user could call wait_readable(fd) in one task, and then while that's
|
||||
# running, they might close(fd) from another task. In this situation, they're
|
||||
# *supposed* to call notify_closing(fd) to let us know what's happening, so we
|
||||
# can interrupt the wait_readable() call and avoid getting into this mess. And
|
||||
# that's the only thing that can possibly work correctly in all cases. But
|
||||
# sometimes user code has bugs. So if this does happen, we'd like to degrade
|
||||
# gracefully, and survive without corrupting Trio's internal state or
|
||||
# otherwise causing the whole program to explode messily.
|
||||
#
|
||||
# Our solution: we always use EPOLLONESHOT. This way, we might get *one*
|
||||
# spurious event on a stale fd, but then epoll will automatically silence it
|
||||
# until we explicitly say that we want more events... and if we have a stale
|
||||
# fd, then we actually can't re-enable it! So we can't get stuck in an
|
||||
# infinite busy-loop. If there's a stale fd hanging around, then it might
|
||||
# cause a spurious `BusyResourceError`, or cause one wait_* call to return
|
||||
# before it should have... but in general, the wait_* functions are allowed to
|
||||
# have some spurious wakeups; the user code will just attempt the operation,
|
||||
# get EWOULDBLOCK, and call wait_* again. And the program as a whole will
|
||||
# survive, any exceptions will propagate, etc.
|
||||
#
|
||||
# As a bonus, EPOLLONESHOT also saves us having to explicitly deregister fds
|
||||
# on the normal wakeup path, so it's a bit more efficient in general.
|
||||
#
|
||||
# However, EPOLLONESHOT has a few trade-offs to consider:
|
||||
#
|
||||
# First, you can't combine EPOLLONESHOT with EPOLLEXCLUSIVE. This is a bit sad
|
||||
# in one somewhat rare case: if you have a multi-process server where a group
|
||||
# of processes all share the same listening socket, then EPOLLEXCLUSIVE can be
|
||||
# used to avoid "thundering herd" problems when a new connection comes in. But
|
||||
# this isn't too bad. It's not clear if EPOLLEXCLUSIVE even works for us
|
||||
# anyway:
|
||||
#
|
||||
# https://stackoverflow.com/questions/41582560/how-does-epolls-epollexclusive-mode-interact-with-level-triggering
|
||||
#
|
||||
# And it's not clear that EPOLLEXCLUSIVE is a great approach either:
|
||||
#
|
||||
# https://blog.cloudflare.com/the-sad-state-of-linux-socket-balancing/
|
||||
#
|
||||
# And if we do need to support this, we could always add support through some
|
||||
# more-specialized API in the future. So this isn't a blocker to using
|
||||
# EPOLLONESHOT.
|
||||
#
|
||||
# Second, EPOLLONESHOT does not actually *deregister* the fd after delivering
|
||||
# an event (EPOLL_CTL_DEL). Instead, it keeps the fd registered, but
|
||||
# effectively does an EPOLL_CTL_MOD to set the fd's interest flags to
|
||||
# all-zeros. So we could still end up with an fd hanging around in the
|
||||
# interest set for a long time, even if we're not using it.
|
||||
#
|
||||
# Fortunately, this isn't a problem, because it's only a weak reference – if
|
||||
# we have a stale fd that's been silenced by EPOLLONESHOT, then it wastes a
|
||||
# tiny bit of kernel memory remembering this fd that can never be revived, but
|
||||
# when the underlying file object is eventually closed, that memory will be
|
||||
# reclaimed. So that's OK.
|
||||
#
|
||||
# The other issue is that when someone calls wait_*, using EPOLLONESHOT means
|
||||
# that if we have ever waited for this fd before, we have to use EPOLL_CTL_MOD
|
||||
# to re-enable it; but if it's a new fd, we have to use EPOLL_CTL_ADD. How do
|
||||
# we know which one to use? There's no reasonable way to track which fds are
|
||||
# currently registered -- remember, we're assuming the user might have gone
|
||||
# and rearranged their fds without telling us!
|
||||
#
|
||||
# Fortunately, this also has a simple solution: if we wait on a socket or
|
||||
# other fd once, then we'll probably wait on it lots of times. And the epoll
|
||||
# object itself knows which fds it already has registered. So when an fd comes
|
||||
# in, we optimistically assume that it's been waited on before, and try doing
|
||||
# EPOLL_CTL_MOD. And if that fails with an ENOENT error, then we try again
|
||||
# with EPOLL_CTL_ADD.
|
||||
#
|
||||
# So that's why this code is the way it is. And now you know more than you
|
||||
# wanted to about how epoll works.
|
||||
|
||||
|
||||
@attrs.define(eq=False)
|
||||
class EpollIOManager:
|
||||
# Using lambda here because otherwise crash on import with gevent monkey patching
|
||||
# See https://github.com/python-trio/trio/issues/2848
|
||||
_epoll: select.epoll = attrs.Factory(lambda: select.epoll())
|
||||
# {fd: EpollWaiters}
|
||||
_registered: defaultdict[int, EpollWaiters] = attrs.Factory(
|
||||
lambda: defaultdict(EpollWaiters),
|
||||
)
|
||||
_force_wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair)
|
||||
_force_wakeup_fd: int | None = None
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN)
|
||||
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
|
||||
|
||||
def statistics(self) -> _EpollStatistics:
|
||||
tasks_waiting_read = 0
|
||||
tasks_waiting_write = 0
|
||||
for waiter in self._registered.values():
|
||||
if waiter.read_task is not None:
|
||||
tasks_waiting_read += 1
|
||||
if waiter.write_task is not None:
|
||||
tasks_waiting_write += 1
|
||||
return _EpollStatistics(
|
||||
tasks_waiting_read=tasks_waiting_read,
|
||||
tasks_waiting_write=tasks_waiting_write,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
self._epoll.close()
|
||||
self._force_wakeup.close()
|
||||
|
||||
def force_wakeup(self) -> None:
|
||||
self._force_wakeup.wakeup_thread_and_signal_safe()
|
||||
|
||||
# Return value must be False-y IFF the timeout expired, NOT if any I/O
|
||||
# happened or force_wakeup was called. Otherwise it can be anything; gets
|
||||
# passed straight through to process_events.
|
||||
def get_events(self, timeout: float) -> EventResult:
|
||||
# max_events must be > 0 or epoll gets cranky
|
||||
# accessing self._registered from a thread looks dangerous, but it's
|
||||
# OK because it doesn't matter if our value is a little bit off.
|
||||
max_events = max(1, len(self._registered))
|
||||
return self._epoll.poll(timeout, max_events)
|
||||
|
||||
def process_events(self, events: EventResult) -> None:
|
||||
for fd, flags in events:
|
||||
if fd == self._force_wakeup_fd:
|
||||
self._force_wakeup.drain()
|
||||
continue
|
||||
waiters = self._registered[fd]
|
||||
# EPOLLONESHOT always clears the flags when an event is delivered
|
||||
waiters.current_flags = 0
|
||||
# Clever hack stolen from selectors.EpollSelector: an event
|
||||
# with EPOLLHUP or EPOLLERR flags wakes both readers and
|
||||
# writers.
|
||||
if flags & ~select.EPOLLIN and waiters.write_task is not None:
|
||||
_core.reschedule(waiters.write_task)
|
||||
waiters.write_task = None
|
||||
if flags & ~select.EPOLLOUT and waiters.read_task is not None:
|
||||
_core.reschedule(waiters.read_task)
|
||||
waiters.read_task = None
|
||||
self._update_registrations(fd)
|
||||
|
||||
def _update_registrations(self, fd: int) -> None:
|
||||
waiters = self._registered[fd]
|
||||
wanted_flags = 0
|
||||
if waiters.read_task is not None:
|
||||
wanted_flags |= select.EPOLLIN
|
||||
if waiters.write_task is not None:
|
||||
wanted_flags |= select.EPOLLOUT
|
||||
if wanted_flags != waiters.current_flags:
|
||||
try:
|
||||
try:
|
||||
# First try EPOLL_CTL_MOD
|
||||
self._epoll.modify(fd, wanted_flags | select.EPOLLONESHOT)
|
||||
except OSError:
|
||||
# If that fails, it might be a new fd; try EPOLL_CTL_ADD
|
||||
self._epoll.register(fd, wanted_flags | select.EPOLLONESHOT)
|
||||
waiters.current_flags = wanted_flags
|
||||
except OSError as exc:
|
||||
# If everything fails, probably it's a bad fd, e.g. because
|
||||
# the fd was closed behind our back. In this case we don't
|
||||
# want to try to unregister the fd, because that will probably
|
||||
# fail too. Just clear our state and wake everyone up.
|
||||
del self._registered[fd]
|
||||
# This could raise (in case we're calling this inside one of
|
||||
# the to-be-woken tasks), so we have to do it last.
|
||||
wake_all(waiters, exc)
|
||||
return
|
||||
if not wanted_flags:
|
||||
del self._registered[fd]
|
||||
|
||||
async def _epoll_wait(self, fd: int | _HasFileNo, attr_name: str) -> None:
|
||||
if not isinstance(fd, int):
|
||||
fd = fd.fileno()
|
||||
waiters = self._registered[fd]
|
||||
if getattr(waiters, attr_name) is not None:
|
||||
raise _core.BusyResourceError(
|
||||
"another task is already reading / writing this fd",
|
||||
)
|
||||
setattr(waiters, attr_name, _core.current_task())
|
||||
self._update_registrations(fd)
|
||||
|
||||
def abort(_: RaiseCancelT) -> Abort:
|
||||
setattr(waiters, attr_name, None)
|
||||
self._update_registrations(fd)
|
||||
return _core.Abort.SUCCEEDED
|
||||
|
||||
await _core.wait_task_rescheduled(abort)
|
||||
|
||||
@_public
|
||||
async def wait_readable(self, fd: int | _HasFileNo) -> None:
|
||||
"""Block until the kernel reports that the given object is readable.
|
||||
|
||||
On Unix systems, ``fd`` must either be an integer file descriptor,
|
||||
or else an object with a ``.fileno()`` method which returns an
|
||||
integer file descriptor. Any kind of file descriptor can be passed,
|
||||
though the exact semantics will depend on your kernel. For example,
|
||||
this probably won't do anything useful for on-disk files.
|
||||
|
||||
On Windows systems, ``fd`` must either be an integer ``SOCKET``
|
||||
handle, or else an object with a ``.fileno()`` method which returns
|
||||
an integer ``SOCKET`` handle. File descriptors aren't supported,
|
||||
and neither are handles that refer to anything besides a
|
||||
``SOCKET``.
|
||||
|
||||
:raises trio.BusyResourceError:
|
||||
if another task is already waiting for the given socket to
|
||||
become readable.
|
||||
:raises trio.ClosedResourceError:
|
||||
if another task calls :func:`notify_closing` while this
|
||||
function is still working.
|
||||
"""
|
||||
await self._epoll_wait(fd, "read_task")
|
||||
|
||||
@_public
|
||||
async def wait_writable(self, fd: int | _HasFileNo) -> None:
|
||||
"""Block until the kernel reports that the given object is writable.
|
||||
|
||||
See `wait_readable` for the definition of ``fd``.
|
||||
|
||||
:raises trio.BusyResourceError:
|
||||
if another task is already waiting for the given socket to
|
||||
become writable.
|
||||
:raises trio.ClosedResourceError:
|
||||
if another task calls :func:`notify_closing` while this
|
||||
function is still working.
|
||||
"""
|
||||
await self._epoll_wait(fd, "write_task")
|
||||
|
||||
@_public
|
||||
def notify_closing(self, fd: int | _HasFileNo) -> None:
|
||||
"""Notify waiters of the given object that it will be closed.
|
||||
|
||||
Call this before closing a file descriptor (on Unix) or socket (on
|
||||
Windows). This will cause any `wait_readable` or `wait_writable`
|
||||
calls on the given object to immediately wake up and raise
|
||||
`~trio.ClosedResourceError`.
|
||||
|
||||
This doesn't actually close the object – you still have to do that
|
||||
yourself afterwards. Also, you want to be careful to make sure no
|
||||
new tasks start waiting on the object in between when you call this
|
||||
and when it's actually closed. So to close something properly, you
|
||||
usually want to do these steps in order:
|
||||
|
||||
1. Explicitly mark the object as closed, so that any new attempts
|
||||
to use it will abort before they start.
|
||||
2. Call `notify_closing` to wake up any already-existing users.
|
||||
3. Actually close the object.
|
||||
|
||||
It's also possible to do them in a different order if that's more
|
||||
convenient, *but only if* you make sure not to have any checkpoints in
|
||||
between the steps. This way they all happen in a single atomic
|
||||
step, so other tasks won't be able to tell what order they happened
|
||||
in anyway.
|
||||
"""
|
||||
if not isinstance(fd, int):
|
||||
fd = fd.fileno()
|
||||
wake_all(
|
||||
self._registered[fd],
|
||||
_core.ClosedResourceError("another task closed this fd"),
|
||||
)
|
||||
del self._registered[fd]
|
||||
with contextlib.suppress(OSError, ValueError):
|
||||
self._epoll.unregister(fd)
|
||||
@@ -0,0 +1,292 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import select
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Literal
|
||||
|
||||
import attrs
|
||||
import outcome
|
||||
|
||||
from .. import _core
|
||||
from ._run import _public
|
||||
from ._wakeup_socketpair import WakeupSocketpair
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .._core import Abort, RaiseCancelT, Task, UnboundedQueue
|
||||
from .._file_io import _HasFileNo
|
||||
|
||||
assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")
|
||||
|
||||
EventResult: TypeAlias = "list[select.kevent]"
|
||||
|
||||
|
||||
@attrs.frozen(eq=False)
|
||||
class _KqueueStatistics:
|
||||
tasks_waiting: int
|
||||
monitors: int
|
||||
backend: Literal["kqueue"] = attrs.field(init=False, default="kqueue")
|
||||
|
||||
|
||||
@attrs.define(eq=False)
|
||||
class KqueueIOManager:
|
||||
_kqueue: select.kqueue = attrs.Factory(select.kqueue)
|
||||
# {(ident, filter): Task or UnboundedQueue}
|
||||
_registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = (
|
||||
attrs.Factory(dict)
|
||||
)
|
||||
_force_wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair)
|
||||
_force_wakeup_fd: int | None = None
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
force_wakeup_event = select.kevent(
|
||||
self._force_wakeup.wakeup_sock,
|
||||
select.KQ_FILTER_READ,
|
||||
select.KQ_EV_ADD,
|
||||
)
|
||||
self._kqueue.control([force_wakeup_event], 0)
|
||||
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
|
||||
|
||||
def statistics(self) -> _KqueueStatistics:
|
||||
tasks_waiting = 0
|
||||
monitors = 0
|
||||
for receiver in self._registered.values():
|
||||
if type(receiver) is _core.Task:
|
||||
tasks_waiting += 1
|
||||
else:
|
||||
monitors += 1
|
||||
return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors)
|
||||
|
||||
def close(self) -> None:
|
||||
self._kqueue.close()
|
||||
self._force_wakeup.close()
|
||||
|
||||
def force_wakeup(self) -> None:
|
||||
self._force_wakeup.wakeup_thread_and_signal_safe()
|
||||
|
||||
def get_events(self, timeout: float) -> EventResult:
|
||||
# max_events must be > 0 or kqueue gets cranky
|
||||
# and we generally want this to be strictly larger than the actual
|
||||
# number of events we get, so that we can tell that we've gotten
|
||||
# all the events in just 1 call.
|
||||
max_events = len(self._registered) + 1
|
||||
events = []
|
||||
while True:
|
||||
batch = self._kqueue.control([], max_events, timeout)
|
||||
events += batch
|
||||
if len(batch) < max_events:
|
||||
break
|
||||
else:
|
||||
timeout = 0
|
||||
# and loop back to the start
|
||||
return events
|
||||
|
||||
def process_events(self, events: EventResult) -> None:
|
||||
for event in events:
|
||||
key = (event.ident, event.filter)
|
||||
if event.ident == self._force_wakeup_fd:
|
||||
self._force_wakeup.drain()
|
||||
continue
|
||||
receiver = self._registered[key]
|
||||
if event.flags & select.KQ_EV_ONESHOT:
|
||||
del self._registered[key]
|
||||
if isinstance(receiver, _core.Task):
|
||||
_core.reschedule(receiver, outcome.Value(event))
|
||||
else:
|
||||
receiver.put_nowait(event)
|
||||
|
||||
# kevent registration is complicated -- e.g. aio submission can
|
||||
# implicitly perform a EV_ADD, and EVFILT_PROC with NOTE_TRACK will
|
||||
# automatically register filters for child processes. So our lowlevel
|
||||
# API is *very* low-level: we expose the kqueue itself for adding
|
||||
# events or sticking into AIO submission structs, and split waiting
|
||||
# off into separate methods. It's your responsibility to make sure
|
||||
# that handle_io never receives an event without a corresponding
|
||||
# registration! This may be challenging if you want to be careful
|
||||
# about e.g. KeyboardInterrupt. Possibly this API could be improved to
|
||||
# be more ergonomic...
|
||||
|
||||
@_public
|
||||
def current_kqueue(self) -> select.kqueue:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__.
|
||||
"""
|
||||
return self._kqueue
|
||||
|
||||
@contextmanager
|
||||
@_public
|
||||
def monitor_kevent(
|
||||
self,
|
||||
ident: int,
|
||||
filter: int,
|
||||
) -> Iterator[_core.UnboundedQueue[select.kevent]]:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__.
|
||||
"""
|
||||
key = (ident, filter)
|
||||
if key in self._registered:
|
||||
raise _core.BusyResourceError(
|
||||
"attempt to register multiple listeners for same ident/filter pair",
|
||||
)
|
||||
q = _core.UnboundedQueue[select.kevent]()
|
||||
self._registered[key] = q
|
||||
try:
|
||||
yield q
|
||||
finally:
|
||||
del self._registered[key]
|
||||
|
||||
@_public
|
||||
async def wait_kevent(
|
||||
self,
|
||||
ident: int,
|
||||
filter: int,
|
||||
abort_func: Callable[[RaiseCancelT], Abort],
|
||||
) -> Abort:
|
||||
"""TODO: these are implemented, but are currently more of a sketch than
|
||||
anything real. See `#26
|
||||
<https://github.com/python-trio/trio/issues/26>`__.
|
||||
"""
|
||||
key = (ident, filter)
|
||||
if key in self._registered:
|
||||
raise _core.BusyResourceError(
|
||||
"attempt to register multiple listeners for same ident/filter pair",
|
||||
)
|
||||
self._registered[key] = _core.current_task()
|
||||
|
||||
def abort(raise_cancel: RaiseCancelT) -> Abort:
|
||||
r = abort_func(raise_cancel)
|
||||
if r is _core.Abort.SUCCEEDED:
|
||||
del self._registered[key]
|
||||
return r
|
||||
|
||||
# wait_task_rescheduled does not have its return type typed
|
||||
return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return]
|
||||
|
||||
async def _wait_common(
|
||||
self,
|
||||
fd: int | _HasFileNo,
|
||||
filter: int,
|
||||
) -> None:
|
||||
if not isinstance(fd, int):
|
||||
fd = fd.fileno()
|
||||
flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT
|
||||
event = select.kevent(fd, filter, flags)
|
||||
self._kqueue.control([event], 0)
|
||||
|
||||
def abort(_: RaiseCancelT) -> Abort:
|
||||
event = select.kevent(fd, filter, select.KQ_EV_DELETE)
|
||||
try:
|
||||
self._kqueue.control([event], 0)
|
||||
except OSError as exc:
|
||||
# kqueue tracks individual fds (*not* the underlying file
|
||||
# object, see _io_epoll.py for a long discussion of why this
|
||||
# distinction matters), and automatically deregisters an event
|
||||
# if the fd is closed. So if kqueue.control says that it
|
||||
# doesn't know about this event, then probably it's because
|
||||
# the fd was closed behind our backs. (Too bad we can't ask it
|
||||
# to wake us up when this happens, versus discovering it after
|
||||
# the fact... oh well, you can't have everything.)
|
||||
#
|
||||
# FreeBSD reports this using EBADF. macOS uses ENOENT.
|
||||
if exc.errno in (errno.EBADF, errno.ENOENT): # pragma: no branch
|
||||
pass
|
||||
else: # pragma: no cover
|
||||
# As far as we know, this branch can't happen.
|
||||
raise
|
||||
return _core.Abort.SUCCEEDED
|
||||
|
||||
await self.wait_kevent(fd, filter, abort)
|
||||
|
||||
@_public
|
||||
async def wait_readable(self, fd: int | _HasFileNo) -> None:
|
||||
"""Block until the kernel reports that the given object is readable.
|
||||
|
||||
On Unix systems, ``fd`` must either be an integer file descriptor,
|
||||
or else an object with a ``.fileno()`` method which returns an
|
||||
integer file descriptor. Any kind of file descriptor can be passed,
|
||||
though the exact semantics will depend on your kernel. For example,
|
||||
this probably won't do anything useful for on-disk files.
|
||||
|
||||
On Windows systems, ``fd`` must either be an integer ``SOCKET``
|
||||
handle, or else an object with a ``.fileno()`` method which returns
|
||||
an integer ``SOCKET`` handle. File descriptors aren't supported,
|
||||
and neither are handles that refer to anything besides a
|
||||
``SOCKET``.
|
||||
|
||||
:raises trio.BusyResourceError:
|
||||
if another task is already waiting for the given socket to
|
||||
become readable.
|
||||
:raises trio.ClosedResourceError:
|
||||
if another task calls :func:`notify_closing` while this
|
||||
function is still working.
|
||||
"""
|
||||
await self._wait_common(fd, select.KQ_FILTER_READ)
|
||||
|
||||
@_public
|
||||
async def wait_writable(self, fd: int | _HasFileNo) -> None:
|
||||
"""Block until the kernel reports that the given object is writable.
|
||||
|
||||
See `wait_readable` for the definition of ``fd``.
|
||||
|
||||
:raises trio.BusyResourceError:
|
||||
if another task is already waiting for the given socket to
|
||||
become writable.
|
||||
:raises trio.ClosedResourceError:
|
||||
if another task calls :func:`notify_closing` while this
|
||||
function is still working.
|
||||
"""
|
||||
await self._wait_common(fd, select.KQ_FILTER_WRITE)
|
||||
|
||||
@_public
|
||||
def notify_closing(self, fd: int | _HasFileNo) -> None:
|
||||
"""Notify waiters of the given object that it will be closed.
|
||||
|
||||
Call this before closing a file descriptor (on Unix) or socket (on
|
||||
Windows). This will cause any `wait_readable` or `wait_writable`
|
||||
calls on the given object to immediately wake up and raise
|
||||
`~trio.ClosedResourceError`.
|
||||
|
||||
This doesn't actually close the object – you still have to do that
|
||||
yourself afterwards. Also, you want to be careful to make sure no
|
||||
new tasks start waiting on the object in between when you call this
|
||||
and when it's actually closed. So to close something properly, you
|
||||
usually want to do these steps in order:
|
||||
|
||||
1. Explicitly mark the object as closed, so that any new attempts
|
||||
to use it will abort before they start.
|
||||
2. Call `notify_closing` to wake up any already-existing users.
|
||||
3. Actually close the object.
|
||||
|
||||
It's also possible to do them in a different order if that's more
|
||||
convenient, *but only if* you make sure not to have any checkpoints in
|
||||
between the steps. This way they all happen in a single atomic
|
||||
step, so other tasks won't be able to tell what order they happened
|
||||
in anyway.
|
||||
"""
|
||||
if not isinstance(fd, int):
|
||||
fd = fd.fileno()
|
||||
|
||||
for filter_ in [select.KQ_FILTER_READ, select.KQ_FILTER_WRITE]:
|
||||
key = (fd, filter_)
|
||||
receiver = self._registered.get(key)
|
||||
|
||||
if receiver is None:
|
||||
continue
|
||||
|
||||
if type(receiver) is _core.Task:
|
||||
event = select.kevent(fd, filter_, select.KQ_EV_DELETE)
|
||||
self._kqueue.control([event], 0)
|
||||
exc = _core.ClosedResourceError("another task closed this fd")
|
||||
_core.reschedule(receiver, outcome.Error(exc))
|
||||
del self._registered[key]
|
||||
else:
|
||||
# XX this is an interesting example of a case where being able
|
||||
# to close a queue would be useful...
|
||||
raise NotImplementedError(
|
||||
"can't close an fd that monitor_kevent is using",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,237 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import signal
|
||||
import sys
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Final, Protocol, TypeVar
|
||||
|
||||
import attrs
|
||||
|
||||
from .._util import is_main_thread
|
||||
|
||||
CallableT = TypeVar("CallableT", bound="Callable[..., object]")
|
||||
RetT = TypeVar("RetT")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import types
|
||||
from collections.abc import Callable
|
||||
|
||||
from typing_extensions import ParamSpec, TypeGuard
|
||||
|
||||
ArgsT = ParamSpec("ArgsT")
|
||||
|
||||
# In ordinary single-threaded Python code, when you hit control-C, it raises
|
||||
# an exception and automatically does all the regular unwinding stuff.
|
||||
#
|
||||
# In Trio code, we would like hitting control-C to raise an exception and
|
||||
# automatically do all the regular unwinding stuff. In particular, we would
|
||||
# like to maintain our invariant that all tasks always run to completion (one
|
||||
# way or another), by unwinding all of them.
|
||||
#
|
||||
# But it's basically impossible to write the core task running code in such a
|
||||
# way that it can maintain this invariant in the face of KeyboardInterrupt
|
||||
# exceptions arising at arbitrary bytecode positions. Similarly, if a
|
||||
# KeyboardInterrupt happened at the wrong moment inside pretty much any of our
|
||||
# inter-task synchronization or I/O primitives, then the system state could
|
||||
# get corrupted and prevent our being able to clean up properly.
|
||||
#
|
||||
# So, we need a way to defer KeyboardInterrupt processing from these critical
|
||||
# sections.
|
||||
#
|
||||
# Things that don't work:
|
||||
#
|
||||
# - Listen for SIGINT and process it in a system task: works fine for
|
||||
# well-behaved programs that regularly pass through the event loop, but if
|
||||
# user-code goes into an infinite loop then it can't be interrupted. Which
|
||||
# is unfortunate, since dealing with infinite loops is what
|
||||
# KeyboardInterrupt is for!
|
||||
#
|
||||
# - Use pthread_sigmask to disable signal delivery during critical section:
|
||||
# (a) windows has no pthread_sigmask, (b) python threads start with all
|
||||
# signals unblocked, so if there are any threads around they'll receive the
|
||||
# signal and then tell the main thread to run the handler, even if the main
|
||||
# thread has that signal blocked.
|
||||
#
|
||||
# - Install a signal handler which checks a global variable to decide whether
|
||||
# to raise the exception immediately (if we're in a non-critical section),
|
||||
# or to schedule it on the event loop (if we're in a critical section). The
|
||||
# problem here is that it's impossible to transition safely out of user code:
|
||||
#
|
||||
# with keyboard_interrupt_enabled:
|
||||
# msg = coro.send(value)
|
||||
#
|
||||
# If this raises a KeyboardInterrupt, it might be because the coroutine got
|
||||
# interrupted and has unwound... or it might be the KeyboardInterrupt
|
||||
# arrived just *after* 'send' returned, so the coroutine is still running,
|
||||
# but we just lost the message it sent. (And worse, in our actual task
|
||||
# runner, the send is hidden inside a utility function etc.)
|
||||
#
|
||||
# Solution:
|
||||
#
|
||||
# Mark *stack frames* as being interrupt-safe or interrupt-unsafe, and from
|
||||
# the signal handler check which kind of frame we're currently in when
|
||||
# deciding whether to raise or schedule the exception.
|
||||
#
|
||||
# There are still some cases where this can fail, like if someone hits
|
||||
# control-C while the process is in the event loop, and then it immediately
|
||||
# enters an infinite loop in user code. In this case the user has to hit
|
||||
# control-C a second time. And of course if the user code is written so that
|
||||
# it doesn't actually exit after a task crashes and everything gets cancelled,
|
||||
# then there's not much to be done. (Hitting control-C repeatedly might help,
|
||||
# but in general the solution is to kill the process some other way, just like
|
||||
# for any Python program that's written to catch and ignore
|
||||
# KeyboardInterrupt.)
|
||||
|
||||
# We use this special string as a unique key into the frame locals dictionary.
|
||||
# The @ ensures it is not a valid identifier and can't clash with any possible
|
||||
# real local name. See: https://github.com/python-trio/trio/issues/469
|
||||
LOCALS_KEY_KI_PROTECTION_ENABLED: Final = "@TRIO_KI_PROTECTION_ENABLED"
|
||||
|
||||
|
||||
# NB: according to the signal.signal docs, 'frame' can be None on entry to
|
||||
# this function:
|
||||
def ki_protection_enabled(frame: types.FrameType | None) -> bool:
|
||||
while frame is not None:
|
||||
if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals:
|
||||
return bool(frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED])
|
||||
if frame.f_code.co_name == "__del__":
|
||||
return True
|
||||
frame = frame.f_back
|
||||
return True
|
||||
|
||||
|
||||
def currently_ki_protected() -> bool:
|
||||
r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection
|
||||
enabled.
|
||||
|
||||
It's surprisingly easy to think that one's :exc:`KeyboardInterrupt`
|
||||
protection is enabled when it isn't, or vice-versa. This function tells
|
||||
you what Trio thinks of the matter, which makes it useful for ``assert``\s
|
||||
and unit tests.
|
||||
|
||||
Returns:
|
||||
bool: True if protection is enabled, and False otherwise.
|
||||
|
||||
"""
|
||||
return ki_protection_enabled(sys._getframe())
|
||||
|
||||
|
||||
# This is to support the async_generator package necessary for aclosing on <3.10
|
||||
# functions decorated @async_generator are given this magic property that's a
|
||||
# reference to the object itself
|
||||
# see python-trio/async_generator/async_generator/_impl.py
|
||||
def legacy_isasyncgenfunction(
|
||||
obj: object,
|
||||
) -> TypeGuard[Callable[..., types.AsyncGeneratorType[object, object]]]:
|
||||
return getattr(obj, "_async_gen_function", None) == id(obj)
|
||||
|
||||
|
||||
def _ki_protection_decorator(
|
||||
enabled: bool,
|
||||
) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]:
|
||||
# The "ignore[return-value]" below is because the inspect functions cast away the
|
||||
# original return type of fn, making it just CoroutineType[Any, Any, Any] etc.
|
||||
# ignore[misc] is because @wraps() is passed a callable with Any in the return type.
|
||||
def decorator(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]:
|
||||
# In some version of Python, isgeneratorfunction returns true for
|
||||
# coroutine functions, so we have to check for coroutine functions
|
||||
# first.
|
||||
if inspect.iscoroutinefunction(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc]
|
||||
# See the comment for regular generators below
|
||||
coro = fn(*args, **kwargs)
|
||||
assert coro.cr_frame is not None, "Coroutine frame should exist"
|
||||
coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
|
||||
return coro # type: ignore[return-value]
|
||||
|
||||
return wrapper
|
||||
elif inspect.isgeneratorfunction(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc]
|
||||
# It's important that we inject this directly into the
|
||||
# generator's locals, as opposed to setting it here and then
|
||||
# doing 'yield from'. The reason is, if a generator is
|
||||
# throw()n into, then it may magically pop to the top of the
|
||||
# stack. And @contextmanager generators in particular are a
|
||||
# case where we often want KI protection, and which are often
|
||||
# thrown into! See:
|
||||
# https://bugs.python.org/issue29590
|
||||
gen = fn(*args, **kwargs)
|
||||
gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
|
||||
return gen # type: ignore[return-value]
|
||||
|
||||
return wrapper
|
||||
elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn):
|
||||
|
||||
@wraps(fn) # type: ignore[arg-type]
|
||||
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc]
|
||||
# See the comment for regular generators above
|
||||
agen = fn(*args, **kwargs)
|
||||
agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
|
||||
return agen # type: ignore[return-value]
|
||||
|
||||
return wrapper
|
||||
else:
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
|
||||
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# pyright workaround: https://github.com/microsoft/pyright/issues/5866
|
||||
class KIProtectionSignature(Protocol):
|
||||
__name__: str
|
||||
|
||||
def __call__(self, f: CallableT, /) -> CallableT:
|
||||
pass
|
||||
|
||||
|
||||
# the following `type: ignore`s are because we use ParamSpec internally, but want to allow overloads
|
||||
enable_ki_protection: KIProtectionSignature = _ki_protection_decorator(True) # type: ignore[assignment]
|
||||
enable_ki_protection.__name__ = "enable_ki_protection"
|
||||
|
||||
disable_ki_protection: KIProtectionSignature = _ki_protection_decorator(False) # type: ignore[assignment]
|
||||
disable_ki_protection.__name__ = "disable_ki_protection"
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class KIManager:
|
||||
handler: Callable[[int, types.FrameType | None], None] | None = None
|
||||
|
||||
def install(
|
||||
self,
|
||||
deliver_cb: Callable[[], object],
|
||||
restrict_keyboard_interrupt_to_checkpoints: bool,
|
||||
) -> None:
|
||||
assert self.handler is None
|
||||
if (
|
||||
not is_main_thread()
|
||||
or signal.getsignal(signal.SIGINT) != signal.default_int_handler
|
||||
):
|
||||
return
|
||||
|
||||
def handler(signum: int, frame: types.FrameType | None) -> None:
|
||||
assert signum == signal.SIGINT
|
||||
protection_enabled = ki_protection_enabled(frame)
|
||||
if protection_enabled or restrict_keyboard_interrupt_to_checkpoints:
|
||||
deliver_cb()
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
self.handler = handler
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.handler is not None:
|
||||
if signal.getsignal(signal.SIGINT) is self.handler:
|
||||
signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||
self.handler = None
|
||||
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, TypeVar, cast
|
||||
|
||||
# Runvar implementations
|
||||
import attrs
|
||||
|
||||
from .._util import NoPublicConstructor, final
|
||||
from . import _run
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@final
|
||||
class _NoValue: ...
|
||||
|
||||
|
||||
@final
|
||||
@attrs.define(eq=False)
|
||||
class RunVarToken(Generic[T], metaclass=NoPublicConstructor):
|
||||
_var: RunVar[T]
|
||||
previous_value: T | type[_NoValue] = _NoValue
|
||||
redeemed: bool = attrs.field(default=False, init=False)
|
||||
|
||||
@classmethod
|
||||
def _empty(cls, var: RunVar[T]) -> RunVarToken[T]:
|
||||
return cls._create(var)
|
||||
|
||||
|
||||
@final
|
||||
@attrs.define(eq=False, repr=False)
|
||||
class RunVar(Generic[T]):
|
||||
"""The run-local variant of a context variable.
|
||||
|
||||
:class:`RunVar` objects are similar to context variable objects,
|
||||
except that they are shared across a single call to :func:`trio.run`
|
||||
rather than a single task.
|
||||
|
||||
"""
|
||||
|
||||
_name: str
|
||||
_default: T | type[_NoValue] = _NoValue
|
||||
|
||||
def get(self, default: T | type[_NoValue] = _NoValue) -> T:
|
||||
"""Gets the value of this :class:`RunVar` for the current run call."""
|
||||
try:
|
||||
return cast(T, _run.GLOBAL_RUN_CONTEXT.runner._locals[self])
|
||||
except AttributeError:
|
||||
raise RuntimeError("Cannot be used outside of a run context") from None
|
||||
except KeyError:
|
||||
# contextvars consistency
|
||||
# `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released
|
||||
if default is not _NoValue:
|
||||
return default # type: ignore[return-value]
|
||||
|
||||
if self._default is not _NoValue:
|
||||
return self._default # type: ignore[return-value]
|
||||
|
||||
raise LookupError(self) from None
|
||||
|
||||
def set(self, value: T) -> RunVarToken[T]:
|
||||
"""Sets the value of this :class:`RunVar` for this current run
|
||||
call.
|
||||
|
||||
"""
|
||||
try:
|
||||
old_value = self.get()
|
||||
except LookupError:
|
||||
token = RunVarToken._empty(self)
|
||||
else:
|
||||
token = RunVarToken[T]._create(self, old_value)
|
||||
|
||||
# This can't fail, because if we weren't in Trio context then the
|
||||
# get() above would have failed.
|
||||
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value
|
||||
return token
|
||||
|
||||
def reset(self, token: RunVarToken[T]) -> None:
|
||||
"""Resets the value of this :class:`RunVar` to what it was
|
||||
previously specified by the token.
|
||||
|
||||
"""
|
||||
if token is None:
|
||||
raise TypeError("token must not be none")
|
||||
|
||||
if token.redeemed:
|
||||
raise ValueError("token has already been used")
|
||||
|
||||
if token._var is not self:
|
||||
raise ValueError("token is not for us")
|
||||
|
||||
previous = token.previous_value
|
||||
try:
|
||||
if previous is _NoValue:
|
||||
_run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self)
|
||||
else:
|
||||
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous
|
||||
except AttributeError:
|
||||
raise RuntimeError("Cannot be used outside of a run context") from None
|
||||
|
||||
token.redeemed = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<RunVar name={self._name!r}>"
|
||||
@@ -0,0 +1,164 @@
|
||||
import time
|
||||
from math import inf
|
||||
|
||||
from .. import _core
|
||||
from .._abc import Clock
|
||||
from .._util import final
|
||||
from ._run import GLOBAL_RUN_CONTEXT
|
||||
|
||||
################################################################
|
||||
# The glorious MockClock
|
||||
################################################################
|
||||
|
||||
|
||||
# Prior art:
|
||||
# https://twistedmatrix.com/documents/current/api/twisted.internet.task.Clock.html
|
||||
# https://github.com/ztellman/manifold/issues/57
|
||||
@final
|
||||
class MockClock(Clock):
|
||||
"""A user-controllable clock suitable for writing tests.
|
||||
|
||||
Args:
|
||||
rate (float): the initial :attr:`rate`.
|
||||
autojump_threshold (float): the initial :attr:`autojump_threshold`.
|
||||
|
||||
.. attribute:: rate
|
||||
|
||||
How many seconds of clock time pass per second of real time. Default is
|
||||
0.0, i.e. the clock only advances through manuals calls to :meth:`jump`
|
||||
or when the :attr:`autojump_threshold` is triggered. You can assign to
|
||||
this attribute to change it.
|
||||
|
||||
.. attribute:: autojump_threshold
|
||||
|
||||
The clock keeps an eye on the run loop, and if at any point it detects
|
||||
that all tasks have been blocked for this many real seconds (i.e.,
|
||||
according to the actual clock, not this clock), then the clock
|
||||
automatically jumps ahead to the run loop's next scheduled
|
||||
timeout. Default is :data:`math.inf`, i.e., to never autojump. You can
|
||||
assign to this attribute to change it.
|
||||
|
||||
Basically the idea is that if you have code or tests that use sleeps
|
||||
and timeouts, you can use this to make it run much faster, totally
|
||||
automatically. (At least, as long as those sleeps/timeouts are
|
||||
happening inside Trio; if your test involves talking to external
|
||||
service and waiting for it to timeout then obviously we can't help you
|
||||
there.)
|
||||
|
||||
You should set this to the smallest value that lets you reliably avoid
|
||||
"false alarms" where some I/O is in flight (e.g. between two halves of
|
||||
a socketpair) but the threshold gets triggered and time gets advanced
|
||||
anyway. This will depend on the details of your tests and test
|
||||
environment. If you aren't doing any I/O (like in our sleeping example
|
||||
above) then just set it to zero, and the clock will jump whenever all
|
||||
tasks are blocked.
|
||||
|
||||
.. note:: If you use ``autojump_threshold`` and
|
||||
`wait_all_tasks_blocked` at the same time, then you might wonder how
|
||||
they interact, since they both cause things to happen after the run
|
||||
loop goes idle for some time. The answer is:
|
||||
`wait_all_tasks_blocked` takes priority. If there's a task blocked
|
||||
in `wait_all_tasks_blocked`, then the autojump feature treats that
|
||||
as active task and does *not* jump the clock.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, rate: float = 0.0, autojump_threshold: float = inf):
|
||||
# when the real clock said 'real_base', the virtual time was
|
||||
# 'virtual_base', and since then it's advanced at 'rate' virtual
|
||||
# seconds per real second.
|
||||
self._real_base = 0.0
|
||||
self._virtual_base = 0.0
|
||||
self._rate = 0.0
|
||||
self._autojump_threshold = 0.0
|
||||
# kept as an attribute so that our tests can monkeypatch it
|
||||
self._real_clock = time.perf_counter
|
||||
|
||||
# use the property update logic to set initial values
|
||||
self.rate = rate
|
||||
self.autojump_threshold = autojump_threshold
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<MockClock, time={self.current_time():.7f}, rate={self._rate} @ {id(self):#x}>"
|
||||
|
||||
@property
|
||||
def rate(self) -> float:
|
||||
return self._rate
|
||||
|
||||
@rate.setter
|
||||
def rate(self, new_rate: float) -> None:
|
||||
if new_rate < 0:
|
||||
raise ValueError("rate must be >= 0")
|
||||
else:
|
||||
real = self._real_clock()
|
||||
virtual = self._real_to_virtual(real)
|
||||
self._virtual_base = virtual
|
||||
self._real_base = real
|
||||
self._rate = float(new_rate)
|
||||
|
||||
@property
|
||||
def autojump_threshold(self) -> float:
|
||||
return self._autojump_threshold
|
||||
|
||||
@autojump_threshold.setter
|
||||
def autojump_threshold(self, new_autojump_threshold: float) -> None:
|
||||
self._autojump_threshold = float(new_autojump_threshold)
|
||||
self._try_resync_autojump_threshold()
|
||||
|
||||
# runner.clock_autojump_threshold is an internal API that isn't easily
|
||||
# usable by custom third-party Clock objects. If you need access to this
|
||||
# functionality, let us know, and we'll figure out how to make a public
|
||||
# API. Discussion:
|
||||
#
|
||||
# https://github.com/python-trio/trio/issues/1587
|
||||
def _try_resync_autojump_threshold(self) -> None:
|
||||
try:
|
||||
runner = GLOBAL_RUN_CONTEXT.runner
|
||||
if runner.is_guest:
|
||||
runner.force_guest_tick_asap()
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
runner.clock_autojump_threshold = self._autojump_threshold
|
||||
|
||||
# Invoked by the run loop when runner.clock_autojump_threshold is
|
||||
# exceeded.
|
||||
def _autojump(self) -> None:
|
||||
statistics = _core.current_statistics()
|
||||
jump = statistics.seconds_to_next_deadline
|
||||
if 0 < jump < inf:
|
||||
self.jump(jump)
|
||||
|
||||
def _real_to_virtual(self, real: float) -> float:
|
||||
real_offset = real - self._real_base
|
||||
virtual_offset = self._rate * real_offset
|
||||
return self._virtual_base + virtual_offset
|
||||
|
||||
def start_clock(self) -> None:
|
||||
self._try_resync_autojump_threshold()
|
||||
|
||||
def current_time(self) -> float:
|
||||
return self._real_to_virtual(self._real_clock())
|
||||
|
||||
def deadline_to_sleep_time(self, deadline: float) -> float:
|
||||
virtual_timeout = deadline - self.current_time()
|
||||
if virtual_timeout <= 0:
|
||||
return 0
|
||||
elif self._rate > 0:
|
||||
return virtual_timeout / self._rate
|
||||
else:
|
||||
return 999999999
|
||||
|
||||
def jump(self, seconds: float) -> None:
|
||||
"""Manually advance the clock by the given number of seconds.
|
||||
|
||||
Args:
|
||||
seconds (float): the number of seconds to jump the clock forward.
|
||||
|
||||
Raises:
|
||||
ValueError: if you try to pass a negative value for ``seconds``.
|
||||
|
||||
"""
|
||||
if seconds < 0:
|
||||
raise ValueError("time can't go backwards")
|
||||
self._virtual_base += seconds
|
||||
@@ -0,0 +1,317 @@
|
||||
# ParkingLot provides an abstraction for a fair waitqueue with cancellation
|
||||
# and requeueing support. Inspiration:
|
||||
#
|
||||
# https://webkit.org/blog/6161/locking-in-webkit/
|
||||
# https://amanieu.github.io/parking_lot/
|
||||
#
|
||||
# which were in turn heavily influenced by
|
||||
#
|
||||
# http://gee.cs.oswego.edu/dl/papers/aqs.pdf
|
||||
#
|
||||
# Compared to these, our use of cooperative scheduling allows some
|
||||
# simplifications (no need for internal locking). On the other hand, the need
|
||||
# to support Trio's strong cancellation semantics adds some complications
|
||||
# (tasks need to know where they're queued so they can cancel). Also, in the
|
||||
# above work, the ParkingLot is a global structure that holds a collection of
|
||||
# waitqueues keyed by lock address, and which are opportunistically allocated
|
||||
# and destroyed as contention arises; this allows the worst-case memory usage
|
||||
# for all waitqueues to be O(#tasks). Here we allocate a separate wait queue
|
||||
# for each synchronization object, so we're O(#objects + #tasks). This isn't
|
||||
# *so* bad since compared to our synchronization objects are heavier than
|
||||
# theirs and our tasks are lighter, so for us #objects is smaller and #tasks
|
||||
# is larger.
|
||||
#
|
||||
# This is in the core because for two reasons. First, it's used by
|
||||
# UnboundedQueue, and UnboundedQueue is used for a number of things in the
|
||||
# core. And second, it's responsible for providing fairness to all of our
|
||||
# high-level synchronization primitives (locks, queues, etc.). For now with
|
||||
# our FIFO scheduler this is relatively trivial (it's just a FIFO waitqueue),
|
||||
# but in the future we ever start support task priorities or fair scheduling
|
||||
#
|
||||
# https://github.com/python-trio/trio/issues/32
|
||||
#
|
||||
# then all we'll have to do is update this. (Well, full-fledged task
|
||||
# priorities might also require priority inheritance, which would require more
|
||||
# work.)
|
||||
#
|
||||
# For discussion of data structures to use here, see:
|
||||
#
|
||||
# https://github.com/dabeaz/curio/issues/136
|
||||
#
|
||||
# (and also the articles above). Currently we use a SortedDict ordered by a
|
||||
# global monotonic counter that ensures FIFO ordering. The main advantage of
|
||||
# this is that it's easy to implement :-). An intrusive doubly-linked list
|
||||
# would also be a natural approach, so long as we only handle FIFO ordering.
|
||||
#
|
||||
# XX: should we switch to the shared global ParkingLot approach?
|
||||
#
|
||||
# XX: we should probably add support for "parking tokens" to allow for
|
||||
# task-fair RWlock (basically: when parking a task needs to be able to mark
|
||||
# itself as a reader or a writer, and then a task-fair wakeup policy is, wake
|
||||
# the next task, and if it's a reader than keep waking tasks so long as they
|
||||
# are readers). Without this I think you can implement write-biased or
|
||||
# read-biased RWlocks (by using two parking lots and drawing from whichever is
|
||||
# preferred), but not task-fair -- and task-fair plays much more nicely with
|
||||
# WFQ. (Consider what happens in the two-lot implementation if you're
|
||||
# write-biased but all the pending writers are blocked at the scheduler level
|
||||
# by the WFQ logic...)
|
||||
# ...alternatively, "phase-fair" RWlocks are pretty interesting:
|
||||
# http://www.cs.unc.edu/~anderson/papers/ecrts09b.pdf
|
||||
# Useful summary:
|
||||
# https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/locks/ReadWriteLock.html
|
||||
#
|
||||
# XX: if we do add WFQ, then we might have to drop the current feature where
|
||||
# unpark returns the tasks that were unparked. Rationale: suppose that at the
|
||||
# time we call unpark, the next task is deprioritized... and then, before it
|
||||
# becomes runnable, a new task parks which *is* runnable. Ideally we should
|
||||
# immediately wake the new task, and leave the old task on the queue for
|
||||
# later. But this means we can't commit to which task we are unparking when
|
||||
# unpark is called.
|
||||
#
|
||||
# See: https://github.com/python-trio/trio/issues/53
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import attrs
|
||||
import outcome
|
||||
|
||||
from .. import _core
|
||||
from .._util import final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
from ._run import Task
|
||||
|
||||
|
||||
GLOBAL_PARKING_LOT_BREAKER: dict[Task, list[ParkingLot]] = {}
|
||||
|
||||
|
||||
def add_parking_lot_breaker(task: Task, lot: ParkingLot) -> None:
|
||||
"""Register a task as a breaker for a lot. See :meth:`ParkingLot.break_lot`.
|
||||
|
||||
raises:
|
||||
trio.BrokenResourceError: if the task has already exited.
|
||||
"""
|
||||
if inspect.getcoroutinestate(task.coro) == inspect.CORO_CLOSED:
|
||||
raise _core._exceptions.BrokenResourceError(
|
||||
"Attempted to add already exited task as lot breaker.",
|
||||
)
|
||||
if task not in GLOBAL_PARKING_LOT_BREAKER:
|
||||
GLOBAL_PARKING_LOT_BREAKER[task] = [lot]
|
||||
else:
|
||||
GLOBAL_PARKING_LOT_BREAKER[task].append(lot)
|
||||
|
||||
|
||||
def remove_parking_lot_breaker(task: Task, lot: ParkingLot) -> None:
|
||||
"""Deregister a task as a breaker for a lot. See :meth:`ParkingLot.break_lot`"""
|
||||
try:
|
||||
GLOBAL_PARKING_LOT_BREAKER[task].remove(lot)
|
||||
except (KeyError, ValueError):
|
||||
raise RuntimeError(
|
||||
"Attempted to remove task as breaker for a lot it is not registered for",
|
||||
) from None
|
||||
if not GLOBAL_PARKING_LOT_BREAKER[task]:
|
||||
del GLOBAL_PARKING_LOT_BREAKER[task]
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class ParkingLotStatistics:
|
||||
"""An object containing debugging information for a ParkingLot.
|
||||
|
||||
Currently, the following fields are defined:
|
||||
|
||||
* ``tasks_waiting`` (int): The number of tasks blocked on this lot's
|
||||
:meth:`trio.lowlevel.ParkingLot.park` method.
|
||||
|
||||
"""
|
||||
|
||||
tasks_waiting: int
|
||||
|
||||
|
||||
@final
|
||||
@attrs.define(eq=False)
|
||||
class ParkingLot:
|
||||
"""A fair wait queue with cancellation and requeueing.
|
||||
|
||||
This class encapsulates the tricky parts of implementing a wait
|
||||
queue. It's useful for implementing higher-level synchronization
|
||||
primitives like queues and locks.
|
||||
|
||||
In addition to the methods below, you can use ``len(parking_lot)`` to get
|
||||
the number of parked tasks, and ``if parking_lot: ...`` to check whether
|
||||
there are any parked tasks.
|
||||
|
||||
"""
|
||||
|
||||
# {task: None}, we just want a deque where we can quickly delete random
|
||||
# items
|
||||
_parked: OrderedDict[Task, None] = attrs.field(factory=OrderedDict, init=False)
|
||||
broken_by: list[Task] = attrs.field(factory=list, init=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of parked tasks."""
|
||||
return len(self._parked)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""True if there are parked tasks, False otherwise."""
|
||||
return bool(self._parked)
|
||||
|
||||
# XX this currently returns None
|
||||
# if we ever add the ability to repark while one's resuming place in
|
||||
# line (for false wakeups), then we could have it return a ticket that
|
||||
# abstracts the "place in line" concept.
|
||||
@_core.enable_ki_protection
|
||||
async def park(self) -> None:
|
||||
"""Park the current task until woken by a call to :meth:`unpark` or
|
||||
:meth:`unpark_all`.
|
||||
|
||||
Raises:
|
||||
BrokenResourceError: if attempting to park in a broken lot, or the lot
|
||||
breaks before we get to unpark.
|
||||
|
||||
"""
|
||||
if self.broken_by:
|
||||
raise _core.BrokenResourceError(
|
||||
f"Attempted to park in parking lot broken by {self.broken_by}",
|
||||
)
|
||||
task = _core.current_task()
|
||||
self._parked[task] = None
|
||||
task.custom_sleep_data = self
|
||||
|
||||
def abort_fn(_: _core.RaiseCancelT) -> _core.Abort:
|
||||
del task.custom_sleep_data._parked[task]
|
||||
return _core.Abort.SUCCEEDED
|
||||
|
||||
await _core.wait_task_rescheduled(abort_fn)
|
||||
|
||||
def _pop_several(self, count: int | float) -> Iterator[Task]: # noqa: PYI041
|
||||
if isinstance(count, float):
|
||||
if math.isinf(count):
|
||||
count = len(self._parked)
|
||||
else:
|
||||
raise ValueError("Cannot pop a non-integer number of tasks.")
|
||||
else:
|
||||
count = min(count, len(self._parked))
|
||||
for _ in range(count):
|
||||
task, _ = self._parked.popitem(last=False)
|
||||
yield task
|
||||
|
||||
@_core.enable_ki_protection
|
||||
def unpark(self, *, count: int | float = 1) -> list[Task]: # noqa: PYI041
|
||||
"""Unpark one or more tasks.
|
||||
|
||||
This wakes up ``count`` tasks that are blocked in :meth:`park`. If
|
||||
there are fewer than ``count`` tasks parked, then wakes as many tasks
|
||||
are available and then returns successfully.
|
||||
|
||||
Args:
|
||||
count (int | math.inf): the number of tasks to unpark.
|
||||
|
||||
"""
|
||||
tasks = list(self._pop_several(count))
|
||||
for task in tasks:
|
||||
_core.reschedule(task)
|
||||
return tasks
|
||||
|
||||
def unpark_all(self) -> list[Task]:
|
||||
"""Unpark all parked tasks."""
|
||||
return self.unpark(count=len(self))
|
||||
|
||||
@_core.enable_ki_protection
|
||||
def repark(
|
||||
self,
|
||||
new_lot: ParkingLot,
|
||||
*,
|
||||
count: int | float = 1, # noqa: PYI041
|
||||
) -> None:
|
||||
"""Move parked tasks from one :class:`ParkingLot` object to another.
|
||||
|
||||
This dequeues ``count`` tasks from one lot, and requeues them on
|
||||
another, preserving order. For example::
|
||||
|
||||
async def parker(lot):
|
||||
print("sleeping")
|
||||
await lot.park()
|
||||
print("woken")
|
||||
|
||||
async def main():
|
||||
lot1 = trio.lowlevel.ParkingLot()
|
||||
lot2 = trio.lowlevel.ParkingLot()
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(parker, lot1)
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
assert len(lot1) == 1
|
||||
assert len(lot2) == 0
|
||||
lot1.repark(lot2)
|
||||
assert len(lot1) == 0
|
||||
assert len(lot2) == 1
|
||||
# This wakes up the task that was originally parked in lot1
|
||||
lot2.unpark()
|
||||
|
||||
If there are fewer than ``count`` tasks parked, then reparks as many
|
||||
tasks as are available and then returns successfully.
|
||||
|
||||
Args:
|
||||
new_lot (ParkingLot): the parking lot to move tasks to.
|
||||
count (int|math.inf): the number of tasks to move.
|
||||
|
||||
"""
|
||||
if not isinstance(new_lot, ParkingLot):
|
||||
raise TypeError("new_lot must be a ParkingLot")
|
||||
for task in self._pop_several(count):
|
||||
new_lot._parked[task] = None
|
||||
task.custom_sleep_data = new_lot
|
||||
|
||||
def repark_all(self, new_lot: ParkingLot) -> None:
|
||||
"""Move all parked tasks from one :class:`ParkingLot` object to
|
||||
another.
|
||||
|
||||
See :meth:`repark` for details.
|
||||
|
||||
"""
|
||||
return self.repark(new_lot, count=len(self))
|
||||
|
||||
def break_lot(self, task: Task | None = None) -> None:
|
||||
"""Break this lot, with ``task`` noted as the task that broke it.
|
||||
|
||||
This causes all parked tasks to raise an error, and any
|
||||
future tasks attempting to park to error. Unpark & repark become no-ops as the
|
||||
parking lot is empty.
|
||||
|
||||
The error raised contains a reference to the task sent as a parameter. The task
|
||||
is also saved in the parking lot in the ``broken_by`` attribute.
|
||||
"""
|
||||
if task is None:
|
||||
task = _core.current_task()
|
||||
|
||||
# if lot is already broken, just mark this as another breaker and return
|
||||
if self.broken_by:
|
||||
self.broken_by.append(task)
|
||||
return
|
||||
|
||||
self.broken_by.append(task)
|
||||
|
||||
for parked_task in self._parked:
|
||||
_core.reschedule(
|
||||
parked_task,
|
||||
outcome.Error(
|
||||
_core.BrokenResourceError(f"Parking lot broken by {task}"),
|
||||
),
|
||||
)
|
||||
self._parked.clear()
|
||||
|
||||
def statistics(self) -> ParkingLotStatistics:
|
||||
"""Return an object containing debugging information.
|
||||
|
||||
Currently the following fields are defined:
|
||||
|
||||
* ``tasks_waiting``: The number of tasks blocked on this lot's
|
||||
:meth:`park` method.
|
||||
|
||||
"""
|
||||
return ParkingLotStatistics(tasks_waiting=len(self._parked))
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,333 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import weakref
|
||||
from math import inf
|
||||
from typing import TYPE_CHECKING, NoReturn
|
||||
|
||||
import pytest
|
||||
|
||||
from ... import _core
|
||||
from .tutil import gc_collect_harder, restore_unraisablehook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
|
||||
def test_asyncgen_basics() -> None:
|
||||
collected = []
|
||||
|
||||
async def example(cause: str) -> AsyncGenerator[int, None]:
|
||||
try:
|
||||
with contextlib.suppress(GeneratorExit):
|
||||
yield 42
|
||||
await _core.checkpoint()
|
||||
except _core.Cancelled:
|
||||
assert "exhausted" not in cause
|
||||
task_name = _core.current_task().name
|
||||
assert cause in task_name or task_name == "<init>"
|
||||
assert _core.current_effective_deadline() == -inf
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await _core.checkpoint()
|
||||
collected.append(cause)
|
||||
else:
|
||||
assert "async_main" in _core.current_task().name
|
||||
assert "exhausted" in cause
|
||||
assert _core.current_effective_deadline() == inf
|
||||
await _core.checkpoint()
|
||||
collected.append(cause)
|
||||
|
||||
saved = []
|
||||
|
||||
async def async_main() -> None:
|
||||
# GC'ed before exhausted
|
||||
with pytest.warns(
|
||||
ResourceWarning,
|
||||
match="Async generator.*collected before.*exhausted",
|
||||
):
|
||||
assert await example("abandoned").asend(None) == 42
|
||||
gc_collect_harder()
|
||||
await _core.wait_all_tasks_blocked()
|
||||
assert collected.pop() == "abandoned"
|
||||
|
||||
aiter_ = example("exhausted 1")
|
||||
try:
|
||||
assert await aiter_.asend(None) == 42
|
||||
finally:
|
||||
await aiter_.aclose()
|
||||
assert collected.pop() == "exhausted 1"
|
||||
|
||||
# Also fine if you exhaust it at point of use
|
||||
async for val in example("exhausted 2"):
|
||||
assert val == 42
|
||||
assert collected.pop() == "exhausted 2"
|
||||
|
||||
gc_collect_harder()
|
||||
|
||||
# No problems saving the geniter when using either of these patterns
|
||||
aiter_ = example("exhausted 3")
|
||||
try:
|
||||
saved.append(aiter_)
|
||||
assert await aiter_.asend(None) == 42
|
||||
finally:
|
||||
await aiter_.aclose()
|
||||
assert collected.pop() == "exhausted 3"
|
||||
|
||||
# Also fine if you exhaust it at point of use
|
||||
saved.append(example("exhausted 4"))
|
||||
async for val in saved[-1]:
|
||||
assert val == 42
|
||||
assert collected.pop() == "exhausted 4"
|
||||
|
||||
# Leave one referenced-but-unexhausted and make sure it gets cleaned up
|
||||
saved.append(example("outlived run"))
|
||||
assert await saved[-1].asend(None) == 42
|
||||
assert collected == []
|
||||
|
||||
_core.run(async_main)
|
||||
assert collected.pop() == "outlived run"
|
||||
for agen in saved:
|
||||
assert agen.ag_frame is None # all should now be exhausted
|
||||
|
||||
|
||||
async def test_asyncgen_throws_during_finalization(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
record = []
|
||||
|
||||
async def agen() -> AsyncGenerator[int, None]:
|
||||
try:
|
||||
yield 1
|
||||
finally:
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
record.append("crashing")
|
||||
raise ValueError("oops")
|
||||
|
||||
with restore_unraisablehook():
|
||||
await agen().asend(None)
|
||||
gc_collect_harder()
|
||||
await _core.wait_all_tasks_blocked()
|
||||
assert record == ["crashing"]
|
||||
# Following type ignore is because typing for LogCaptureFixture is wrong
|
||||
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info # type: ignore[misc]
|
||||
assert exc_type is ValueError
|
||||
assert str(exc_value) == "oops"
|
||||
assert "during finalization of async generator" in caplog.records[0].message
|
||||
|
||||
|
||||
def test_firstiter_after_closing() -> None:
|
||||
saved = []
|
||||
record = []
|
||||
|
||||
async def funky_agen() -> AsyncGenerator[int, None]:
|
||||
try:
|
||||
yield 1
|
||||
except GeneratorExit:
|
||||
record.append("cleanup 1")
|
||||
raise
|
||||
try:
|
||||
yield 2
|
||||
finally:
|
||||
record.append("cleanup 2")
|
||||
await funky_agen().asend(None)
|
||||
|
||||
async def async_main() -> None:
|
||||
aiter_ = funky_agen()
|
||||
saved.append(aiter_)
|
||||
assert await aiter_.asend(None) == 1
|
||||
assert await aiter_.asend(None) == 2
|
||||
|
||||
_core.run(async_main)
|
||||
assert record == ["cleanup 2", "cleanup 1"]
|
||||
|
||||
|
||||
def test_interdependent_asyncgen_cleanup_order() -> None:
|
||||
saved: list[AsyncGenerator[int, None]] = []
|
||||
record: list[int | str] = []
|
||||
|
||||
async def innermost() -> AsyncGenerator[int, None]:
|
||||
try:
|
||||
yield 1
|
||||
finally:
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
record.append("innermost")
|
||||
|
||||
async def agen(
|
||||
label: int,
|
||||
inner: AsyncGenerator[int, None],
|
||||
) -> AsyncGenerator[int, None]:
|
||||
try:
|
||||
yield await inner.asend(None)
|
||||
finally:
|
||||
# Either `inner` has already been cleaned up, or
|
||||
# we're about to exhaust it. Either way, we wind
|
||||
# up with `record` containing the labels in
|
||||
# innermost-to-outermost order.
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await inner.asend(None)
|
||||
record.append(label)
|
||||
|
||||
async def async_main() -> None:
|
||||
# This makes a chain of 101 interdependent asyncgens:
|
||||
# agen(99)'s cleanup will iterate agen(98)'s will iterate
|
||||
# ... agen(0)'s will iterate innermost()'s
|
||||
ag_chain = innermost()
|
||||
for idx in range(100):
|
||||
ag_chain = agen(idx, ag_chain)
|
||||
saved.append(ag_chain)
|
||||
assert await ag_chain.asend(None) == 1
|
||||
assert record == []
|
||||
|
||||
_core.run(async_main)
|
||||
assert record == ["innermost", *range(100)]
|
||||
|
||||
|
||||
@restore_unraisablehook()
|
||||
def test_last_minute_gc_edge_case() -> None:
|
||||
saved: list[AsyncGenerator[int, None]] = []
|
||||
record = []
|
||||
needs_retry = True
|
||||
|
||||
async def agen() -> AsyncGenerator[int, None]:
|
||||
try:
|
||||
yield 1
|
||||
finally:
|
||||
record.append("cleaned up")
|
||||
|
||||
def collect_at_opportune_moment(token: _core._entry_queue.TrioToken) -> None:
|
||||
runner = _core._run.GLOBAL_RUN_CONTEXT.runner
|
||||
assert runner.system_nursery is not None
|
||||
if runner.system_nursery._closed and isinstance(
|
||||
runner.asyncgens.alive,
|
||||
weakref.WeakSet,
|
||||
):
|
||||
saved.clear()
|
||||
record.append("final collection")
|
||||
gc_collect_harder()
|
||||
record.append("done")
|
||||
else:
|
||||
try:
|
||||
token.run_sync_soon(collect_at_opportune_moment, token)
|
||||
except _core.RunFinishedError: # pragma: no cover
|
||||
nonlocal needs_retry
|
||||
needs_retry = True
|
||||
|
||||
async def async_main() -> None:
|
||||
token = _core.current_trio_token()
|
||||
token.run_sync_soon(collect_at_opportune_moment, token)
|
||||
saved.append(agen())
|
||||
await saved[-1].asend(None)
|
||||
|
||||
# Actually running into the edge case requires that the run_sync_soon task
|
||||
# execute in between the system nursery's closure and the strong-ification
|
||||
# of runner.asyncgens. There's about a 25% chance that it doesn't
|
||||
# (if the run_sync_soon task runs before init on one tick and after init
|
||||
# on the next tick); if we try enough times, we can make the chance of
|
||||
# failure as small as we want.
|
||||
for _attempt in range(50):
|
||||
needs_retry = False
|
||||
del record[:]
|
||||
del saved[:]
|
||||
_core.run(async_main)
|
||||
if needs_retry: # pragma: no cover
|
||||
assert record == ["cleaned up"]
|
||||
else:
|
||||
assert record == ["final collection", "done", "cleaned up"]
|
||||
break
|
||||
else: # pragma: no cover
|
||||
pytest.fail(
|
||||
"Didn't manage to hit the trailing_finalizer_asyncgens case "
|
||||
f"despite trying {_attempt} times",
|
||||
)
|
||||
|
||||
|
||||
async def step_outside_async_context(aiter_: AsyncGenerator[int, None]) -> None:
|
||||
# abort_fns run outside of task context, at least if they're
|
||||
# triggered by a deadline expiry rather than a direct
|
||||
# cancellation. Thus, an asyncgen first iterated inside one
|
||||
# will appear non-Trio, and since no other hooks were installed,
|
||||
# will use the last-ditch fallback handling (that tries to mimic
|
||||
# CPython's behavior with no hooks).
|
||||
#
|
||||
# NB: the strangeness with aiter being an attribute of abort_fn is
|
||||
# to make it as easy as possible to ensure we don't hang onto a
|
||||
# reference to aiter inside the guts of the run loop.
|
||||
def abort_fn(_: _core.RaiseCancelT) -> _core.Abort:
|
||||
with pytest.raises(StopIteration, match="42"):
|
||||
abort_fn.aiter.asend(None).send(None) # type: ignore[attr-defined] # Callables don't have attribute "aiter"
|
||||
del abort_fn.aiter # type: ignore[attr-defined]
|
||||
return _core.Abort.SUCCEEDED
|
||||
|
||||
abort_fn.aiter = aiter_ # type: ignore[attr-defined]
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(_core.wait_task_rescheduled, abort_fn)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.deadline = _core.current_time()
|
||||
|
||||
|
||||
async def test_fallback_when_no_hook_claims_it(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
async def well_behaved() -> AsyncGenerator[int, None]:
|
||||
yield 42
|
||||
|
||||
async def yields_after_yield() -> AsyncGenerator[int, None]:
|
||||
with pytest.raises(GeneratorExit):
|
||||
yield 42
|
||||
yield 100
|
||||
|
||||
async def awaits_after_yield() -> AsyncGenerator[int, None]:
|
||||
with pytest.raises(GeneratorExit):
|
||||
yield 42
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
|
||||
with restore_unraisablehook():
|
||||
await step_outside_async_context(well_behaved())
|
||||
gc_collect_harder()
|
||||
assert capsys.readouterr().err == ""
|
||||
|
||||
await step_outside_async_context(yields_after_yield())
|
||||
gc_collect_harder()
|
||||
assert "ignored GeneratorExit" in capsys.readouterr().err
|
||||
|
||||
await step_outside_async_context(awaits_after_yield())
|
||||
gc_collect_harder()
|
||||
assert "awaited something during finalization" in capsys.readouterr().err
|
||||
|
||||
|
||||
def test_delegation_to_existing_hooks() -> None:
|
||||
record = []
|
||||
|
||||
def my_firstiter(agen: AsyncGenerator[object, NoReturn]) -> None:
|
||||
record.append("firstiter " + agen.ag_frame.f_locals["arg"])
|
||||
|
||||
def my_finalizer(agen: AsyncGenerator[object, NoReturn]) -> None:
|
||||
record.append("finalizer " + agen.ag_frame.f_locals["arg"])
|
||||
|
||||
async def example(arg: str) -> AsyncGenerator[int, None]:
|
||||
try:
|
||||
yield 42
|
||||
finally:
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await _core.checkpoint()
|
||||
record.append("trio collected " + arg)
|
||||
|
||||
async def async_main() -> None:
|
||||
await step_outside_async_context(example("theirs"))
|
||||
assert await example("ours").asend(None) == 42
|
||||
gc_collect_harder()
|
||||
assert record == ["firstiter theirs", "finalizer theirs"]
|
||||
record[:] = []
|
||||
await _core.wait_all_tasks_blocked()
|
||||
assert record == ["trio collected ours"]
|
||||
|
||||
with restore_unraisablehook():
|
||||
old_hooks = sys.get_asyncgen_hooks()
|
||||
sys.set_asyncgen_hooks(my_firstiter, my_finalizer)
|
||||
try:
|
||||
_core.run(async_main)
|
||||
finally:
|
||||
assert sys.get_asyncgen_hooks() == (my_firstiter, my_finalizer)
|
||||
sys.set_asyncgen_hooks(*old_hooks)
|
||||
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import sys
|
||||
from traceback import extract_tb
|
||||
from typing import TYPE_CHECKING, Callable, NoReturn
|
||||
|
||||
import pytest
|
||||
|
||||
from .._concat_tb import concat_tb
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import ExceptionGroup
|
||||
|
||||
|
||||
def raiser1() -> NoReturn:
|
||||
raiser1_2()
|
||||
|
||||
|
||||
def raiser1_2() -> NoReturn:
|
||||
raiser1_3()
|
||||
|
||||
|
||||
def raiser1_3() -> NoReturn:
|
||||
raise ValueError("raiser1_string")
|
||||
|
||||
|
||||
def raiser2() -> NoReturn:
|
||||
raiser2_2()
|
||||
|
||||
|
||||
def raiser2_2() -> NoReturn:
|
||||
raise KeyError("raiser2_string")
|
||||
|
||||
|
||||
def get_exc(raiser: Callable[[], NoReturn]) -> Exception:
|
||||
try:
|
||||
raiser()
|
||||
except Exception as exc:
|
||||
return exc
|
||||
raise AssertionError("raiser should always raise") # pragma: no cover
|
||||
|
||||
|
||||
def get_tb(raiser: Callable[[], NoReturn]) -> TracebackType | None:
|
||||
return get_exc(raiser).__traceback__
|
||||
|
||||
|
||||
def test_concat_tb() -> None:
|
||||
tb1 = get_tb(raiser1)
|
||||
tb2 = get_tb(raiser2)
|
||||
|
||||
# These return a list of (filename, lineno, fn name, text) tuples
|
||||
# https://docs.python.org/3/library/traceback.html#traceback.extract_tb
|
||||
entries1 = extract_tb(tb1)
|
||||
entries2 = extract_tb(tb2)
|
||||
|
||||
tb12 = concat_tb(tb1, tb2)
|
||||
assert extract_tb(tb12) == entries1 + entries2
|
||||
|
||||
tb21 = concat_tb(tb2, tb1)
|
||||
assert extract_tb(tb21) == entries2 + entries1
|
||||
|
||||
# Check degenerate cases
|
||||
assert extract_tb(concat_tb(None, tb1)) == entries1
|
||||
assert extract_tb(concat_tb(tb1, None)) == entries1
|
||||
assert concat_tb(None, None) is None
|
||||
|
||||
# Make sure the original tracebacks didn't get mutated by mistake
|
||||
assert extract_tb(get_tb(raiser1)) == entries1
|
||||
assert extract_tb(get_tb(raiser2)) == entries2
|
||||
|
||||
|
||||
# Unclear if this can still fail, removing the `del` from _concat_tb.copy_tb does not seem
|
||||
# to trigger it (on a platform where the `del` is executed)
|
||||
@pytest.mark.skipif(
|
||||
sys.implementation.name != "cpython",
|
||||
reason="Only makes sense with refcounting GC",
|
||||
)
|
||||
def test_ExceptionGroup_catch_doesnt_create_cyclic_garbage() -> None:
|
||||
# https://github.com/python-trio/trio/pull/2063
|
||||
gc.collect()
|
||||
old_flags = gc.get_debug()
|
||||
|
||||
def make_multi() -> NoReturn:
|
||||
raise ExceptionGroup("", [get_exc(raiser1), get_exc(raiser2)])
|
||||
|
||||
try:
|
||||
gc.set_debug(gc.DEBUG_SAVEALL)
|
||||
with pytest.raises(ExceptionGroup) as excinfo:
|
||||
# covers ~~MultiErrorCatcher.__exit__ and~~ _concat_tb.copy_tb
|
||||
# TODO: is the above comment true anymore? as this no longer uses MultiError.catch
|
||||
raise make_multi()
|
||||
for exc in excinfo.value.exceptions:
|
||||
assert isinstance(exc, (ValueError, KeyError))
|
||||
gc.collect()
|
||||
assert not gc.garbage
|
||||
finally:
|
||||
gc.set_debug(old_flags)
|
||||
gc.garbage.clear()
|
||||
@@ -0,0 +1,666 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import contextvars
|
||||
import queue
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from functools import partial
|
||||
from math import inf
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
NoReturn,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from outcome import Outcome
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
from trio.abc import Instrument
|
||||
|
||||
from ..._util import signal_raise
|
||||
from .tutil import gc_collect_harder, restore_unraisablehook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from trio._channel import MemorySendChannel
|
||||
|
||||
T = TypeVar("T")
|
||||
InHost: TypeAlias = Callable[[object], None]
|
||||
|
||||
|
||||
# The simplest possible "host" loop.
|
||||
# Nice features:
|
||||
# - we can run code "outside" of trio using the schedule function passed to
|
||||
# our main
|
||||
# - final result is returned
|
||||
# - any unhandled exceptions cause an immediate crash
|
||||
def trivial_guest_run(
|
||||
trio_fn: Callable[..., Awaitable[T]],
|
||||
*,
|
||||
in_host_after_start: Callable[[], None] | None = None,
|
||||
**start_guest_run_kwargs: Any,
|
||||
) -> T:
|
||||
todo: queue.Queue[tuple[str, Outcome[T] | Callable[..., object]]] = queue.Queue()
|
||||
|
||||
host_thread = threading.current_thread()
|
||||
|
||||
def run_sync_soon_threadsafe(fn: Callable[[], object]) -> None:
|
||||
nonlocal todo
|
||||
if host_thread is threading.current_thread(): # pragma: no cover
|
||||
crash = partial(
|
||||
pytest.fail,
|
||||
"run_sync_soon_threadsafe called from host thread",
|
||||
)
|
||||
todo.put(("run", crash))
|
||||
todo.put(("run", fn))
|
||||
|
||||
def run_sync_soon_not_threadsafe(fn: Callable[[], object]) -> None:
|
||||
nonlocal todo
|
||||
if host_thread is not threading.current_thread(): # pragma: no cover
|
||||
crash = partial(
|
||||
pytest.fail,
|
||||
"run_sync_soon_not_threadsafe called from worker thread",
|
||||
)
|
||||
todo.put(("run", crash))
|
||||
todo.put(("run", fn))
|
||||
|
||||
def done_callback(outcome: Outcome[T]) -> None:
|
||||
nonlocal todo
|
||||
todo.put(("unwrap", outcome))
|
||||
|
||||
trio.lowlevel.start_guest_run(
|
||||
trio_fn,
|
||||
run_sync_soon_not_threadsafe,
|
||||
run_sync_soon_threadsafe=run_sync_soon_threadsafe,
|
||||
run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe,
|
||||
done_callback=done_callback,
|
||||
**start_guest_run_kwargs,
|
||||
)
|
||||
if in_host_after_start is not None:
|
||||
in_host_after_start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
op, obj = todo.get()
|
||||
if op == "run":
|
||||
assert not isinstance(obj, Outcome)
|
||||
obj()
|
||||
elif op == "unwrap":
|
||||
assert isinstance(obj, Outcome)
|
||||
return obj.unwrap()
|
||||
else: # pragma: no cover
|
||||
raise NotImplementedError(f"{op!r} not handled")
|
||||
finally:
|
||||
# Make sure that exceptions raised here don't capture these, so that
|
||||
# if an exception does cause us to abandon a run then the Trio state
|
||||
# has a chance to be GC'ed and warn about it.
|
||||
del todo, run_sync_soon_threadsafe, done_callback
|
||||
|
||||
|
||||
def test_guest_trivial() -> None:
|
||||
async def trio_return(in_host: InHost) -> str:
|
||||
await trio.lowlevel.checkpoint()
|
||||
return "ok"
|
||||
|
||||
assert trivial_guest_run(trio_return) == "ok"
|
||||
|
||||
async def trio_fail(in_host: InHost) -> NoReturn:
|
||||
raise KeyError("whoopsiedaisy")
|
||||
|
||||
with pytest.raises(KeyError, match="whoopsiedaisy"):
|
||||
trivial_guest_run(trio_fail)
|
||||
|
||||
|
||||
def test_guest_can_do_io() -> None:
|
||||
async def trio_main(in_host: InHost) -> None:
|
||||
record = []
|
||||
a, b = trio.socket.socketpair()
|
||||
with a, b:
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def do_receive() -> None:
|
||||
record.append(await a.recv(1))
|
||||
|
||||
nursery.start_soon(do_receive)
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
await b.send(b"x")
|
||||
|
||||
assert record == [b"x"]
|
||||
|
||||
trivial_guest_run(trio_main)
|
||||
|
||||
|
||||
def test_guest_is_initialized_when_start_returns() -> None:
|
||||
trio_token = None
|
||||
record = []
|
||||
|
||||
async def trio_main(in_host: InHost) -> str:
|
||||
record.append("main task ran")
|
||||
await trio.lowlevel.checkpoint()
|
||||
assert trio.lowlevel.current_trio_token() is trio_token
|
||||
return "ok"
|
||||
|
||||
def after_start() -> None:
|
||||
# We should get control back before the main task executes any code
|
||||
assert record == []
|
||||
|
||||
nonlocal trio_token
|
||||
trio_token = trio.lowlevel.current_trio_token()
|
||||
trio_token.run_sync_soon(record.append, "run_sync_soon cb ran")
|
||||
|
||||
@trio.lowlevel.spawn_system_task
|
||||
async def early_task() -> None:
|
||||
record.append("system task ran")
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
res = trivial_guest_run(trio_main, in_host_after_start=after_start)
|
||||
assert res == "ok"
|
||||
assert set(record) == {"system task ran", "main task ran", "run_sync_soon cb ran"}
|
||||
|
||||
class BadClock:
|
||||
def start_clock(self) -> NoReturn:
|
||||
raise ValueError("whoops")
|
||||
|
||||
def after_start_never_runs() -> None: # pragma: no cover
|
||||
pytest.fail("shouldn't get here")
|
||||
|
||||
# Errors during initialization (which can only be TrioInternalErrors)
|
||||
# are raised out of start_guest_run, not out of the done_callback
|
||||
with pytest.raises(trio.TrioInternalError):
|
||||
trivial_guest_run(
|
||||
trio_main,
|
||||
clock=BadClock(),
|
||||
in_host_after_start=after_start_never_runs,
|
||||
)
|
||||
|
||||
|
||||
def test_host_can_directly_wake_trio_task() -> None:
|
||||
async def trio_main(in_host: InHost) -> str:
|
||||
ev = trio.Event()
|
||||
in_host(ev.set)
|
||||
await ev.wait()
|
||||
return "ok"
|
||||
|
||||
assert trivial_guest_run(trio_main) == "ok"
|
||||
|
||||
|
||||
def test_host_altering_deadlines_wakes_trio_up() -> None:
|
||||
def set_deadline(cscope: trio.CancelScope, new_deadline: float) -> None:
|
||||
cscope.deadline = new_deadline
|
||||
|
||||
async def trio_main(in_host: InHost) -> str:
|
||||
with trio.CancelScope() as cscope:
|
||||
in_host(lambda: set_deadline(cscope, -inf))
|
||||
await trio.sleep_forever()
|
||||
assert cscope.cancelled_caught
|
||||
|
||||
with trio.CancelScope() as cscope:
|
||||
# also do a change that doesn't affect the next deadline, just to
|
||||
# exercise that path
|
||||
in_host(lambda: set_deadline(cscope, 1e6))
|
||||
in_host(lambda: set_deadline(cscope, -inf))
|
||||
await trio.sleep(999)
|
||||
assert cscope.cancelled_caught
|
||||
|
||||
return "ok"
|
||||
|
||||
assert trivial_guest_run(trio_main) == "ok"
|
||||
|
||||
|
||||
def test_guest_mode_sniffio_integration() -> None:
|
||||
from sniffio import current_async_library, thread_local as sniffio_library
|
||||
|
||||
async def trio_main(in_host: InHost) -> str:
|
||||
async def synchronize() -> None:
|
||||
"""Wait for all in_host() calls issued so far to complete."""
|
||||
evt = trio.Event()
|
||||
in_host(evt.set)
|
||||
await evt.wait()
|
||||
|
||||
# Host and guest have separate sniffio_library contexts
|
||||
in_host(partial(setattr, sniffio_library, "name", "nullio"))
|
||||
await synchronize()
|
||||
assert current_async_library() == "trio"
|
||||
|
||||
record = []
|
||||
in_host(lambda: record.append(current_async_library()))
|
||||
await synchronize()
|
||||
assert record == ["nullio"]
|
||||
assert current_async_library() == "trio"
|
||||
|
||||
return "ok"
|
||||
|
||||
try:
|
||||
assert trivial_guest_run(trio_main) == "ok"
|
||||
finally:
|
||||
sniffio_library.name = None
|
||||
|
||||
|
||||
def test_warn_set_wakeup_fd_overwrite() -> None:
|
||||
assert signal.set_wakeup_fd(-1) == -1
|
||||
|
||||
async def trio_main(in_host: InHost) -> str:
|
||||
return "ok"
|
||||
|
||||
a, b = socket.socketpair()
|
||||
with a, b:
|
||||
a.setblocking(False)
|
||||
|
||||
# Warn if there's already a wakeup fd
|
||||
signal.set_wakeup_fd(a.fileno())
|
||||
try:
|
||||
with pytest.warns(RuntimeWarning, match="signal handling code.*collided"):
|
||||
assert trivial_guest_run(trio_main) == "ok"
|
||||
finally:
|
||||
assert signal.set_wakeup_fd(-1) == a.fileno()
|
||||
|
||||
signal.set_wakeup_fd(a.fileno())
|
||||
try:
|
||||
with pytest.warns(RuntimeWarning, match="signal handling code.*collided"):
|
||||
assert (
|
||||
trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=False)
|
||||
== "ok"
|
||||
)
|
||||
finally:
|
||||
assert signal.set_wakeup_fd(-1) == a.fileno()
|
||||
|
||||
# Don't warn if there isn't already a wakeup fd
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
assert trivial_guest_run(trio_main) == "ok"
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
assert (
|
||||
trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=True)
|
||||
== "ok"
|
||||
)
|
||||
|
||||
# If there's already a wakeup fd, but we've been told to trust it,
|
||||
# then it's left alone and there's no warning
|
||||
signal.set_wakeup_fd(a.fileno())
|
||||
try:
|
||||
|
||||
async def trio_check_wakeup_fd_unaltered(in_host: InHost) -> str:
|
||||
fd = signal.set_wakeup_fd(-1)
|
||||
assert fd == a.fileno()
|
||||
signal.set_wakeup_fd(fd)
|
||||
return "ok"
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
assert (
|
||||
trivial_guest_run(
|
||||
trio_check_wakeup_fd_unaltered,
|
||||
host_uses_signal_set_wakeup_fd=True,
|
||||
)
|
||||
== "ok"
|
||||
)
|
||||
finally:
|
||||
assert signal.set_wakeup_fd(-1) == a.fileno()
|
||||
|
||||
|
||||
def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked() -> None:
|
||||
# This is designed to hit the branch in unrolled_run where:
|
||||
# idle_primed=True
|
||||
# runner.runq is empty
|
||||
# events is Truth-y
|
||||
# ...and confirm that in this case, wait_all_tasks_blocked does not get
|
||||
# triggered.
|
||||
def set_deadline(cscope: trio.CancelScope, new_deadline: float) -> None:
|
||||
print(f"setting deadline {new_deadline}")
|
||||
cscope.deadline = new_deadline
|
||||
|
||||
async def trio_main(in_host: InHost) -> str:
|
||||
async def sit_in_wait_all_tasks_blocked(watb_cscope: trio.CancelScope) -> None:
|
||||
with watb_cscope:
|
||||
# Overall point of this test is that this
|
||||
# wait_all_tasks_blocked should *not* return normally, but
|
||||
# only by cancellation.
|
||||
await trio.testing.wait_all_tasks_blocked(cushion=9999)
|
||||
raise AssertionError( # pragma: no cover
|
||||
"wait_all_tasks_blocked should *not* return normally, "
|
||||
"only by cancellation.",
|
||||
)
|
||||
assert watb_cscope.cancelled_caught
|
||||
|
||||
async def get_woken_by_host_deadline(watb_cscope: trio.CancelScope) -> None:
|
||||
with trio.CancelScope() as cscope:
|
||||
print("scheduling stuff to happen")
|
||||
|
||||
# Altering the deadline from the host, to something in the
|
||||
# future, will cause the run loop to wake up, but then
|
||||
# discover that there is nothing to do and go back to sleep.
|
||||
# This should *not* trigger wait_all_tasks_blocked.
|
||||
#
|
||||
# So the 'before_io_wait' here will wait until we're blocking
|
||||
# with the wait_all_tasks_blocked primed, and then schedule a
|
||||
# deadline change. The critical test is that this should *not*
|
||||
# wake up 'sit_in_wait_all_tasks_blocked'.
|
||||
#
|
||||
# The after we've had a chance to wake up
|
||||
# 'sit_in_wait_all_tasks_blocked', we want the test to
|
||||
# actually end. So in after_io_wait we schedule a second host
|
||||
# call to tear things down.
|
||||
class InstrumentHelper(Instrument):
|
||||
def __init__(self) -> None:
|
||||
self.primed = False
|
||||
|
||||
def before_io_wait(self, timeout: float) -> None:
|
||||
print(f"before_io_wait({timeout})")
|
||||
if timeout == 9999: # pragma: no branch
|
||||
assert not self.primed
|
||||
in_host(lambda: set_deadline(cscope, 1e9))
|
||||
self.primed = True
|
||||
|
||||
def after_io_wait(self, timeout: float) -> None:
|
||||
if self.primed: # pragma: no branch
|
||||
print("instrument triggered")
|
||||
in_host(lambda: cscope.cancel())
|
||||
trio.lowlevel.remove_instrument(self)
|
||||
|
||||
trio.lowlevel.add_instrument(InstrumentHelper())
|
||||
await trio.sleep_forever()
|
||||
assert cscope.cancelled_caught
|
||||
watb_cscope.cancel()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
watb_cscope = trio.CancelScope()
|
||||
nursery.start_soon(sit_in_wait_all_tasks_blocked, watb_cscope)
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
nursery.start_soon(get_woken_by_host_deadline, watb_cscope)
|
||||
|
||||
return "ok"
|
||||
|
||||
assert trivial_guest_run(trio_main) == "ok"
|
||||
|
||||
|
||||
@restore_unraisablehook()
|
||||
def test_guest_warns_if_abandoned() -> None:
|
||||
# This warning is emitted from the garbage collector. So we have to make
|
||||
# sure that our abandoned run is garbage. The easiest way to do this is to
|
||||
# put it into a function, so that we're sure all the local state,
|
||||
# traceback frames, etc. are garbage once it returns.
|
||||
def do_abandoned_guest_run() -> None:
|
||||
async def abandoned_main(in_host: InHost) -> None:
|
||||
in_host(lambda: 1 / 0)
|
||||
while True:
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
trivial_guest_run(abandoned_main)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="Trio guest run got abandoned"):
|
||||
do_abandoned_guest_run()
|
||||
gc_collect_harder()
|
||||
|
||||
# If you have problems some day figuring out what's holding onto a
|
||||
# reference to the unrolled_run generator and making this test fail,
|
||||
# then this might be useful to help track it down. (It assumes you
|
||||
# also hack start_guest_run so that it does 'global W; W =
|
||||
# weakref(unrolled_run_gen)'.)
|
||||
#
|
||||
# import gc
|
||||
# print(trio._core._run.W)
|
||||
# targets = [trio._core._run.W()]
|
||||
# for i in range(15):
|
||||
# new_targets = []
|
||||
# for target in targets:
|
||||
# new_targets += gc.get_referrers(target)
|
||||
# new_targets.remove(targets)
|
||||
# print("#####################")
|
||||
# print(f"depth {i}: {len(new_targets)}")
|
||||
# print(new_targets)
|
||||
# targets = new_targets
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
trio.current_time()
|
||||
|
||||
|
||||
def aiotrio_run(
|
||||
trio_fn: Callable[..., Awaitable[T]],
|
||||
*,
|
||||
pass_not_threadsafe: bool = True,
|
||||
**start_guest_run_kwargs: Any,
|
||||
) -> T:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
async def aio_main() -> T:
|
||||
trio_done_fut = loop.create_future()
|
||||
|
||||
def trio_done_callback(main_outcome: Outcome[object]) -> None:
|
||||
print(f"trio_fn finished: {main_outcome!r}")
|
||||
trio_done_fut.set_result(main_outcome)
|
||||
|
||||
if pass_not_threadsafe:
|
||||
start_guest_run_kwargs["run_sync_soon_not_threadsafe"] = loop.call_soon
|
||||
|
||||
trio.lowlevel.start_guest_run(
|
||||
trio_fn,
|
||||
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
|
||||
done_callback=trio_done_callback,
|
||||
**start_guest_run_kwargs,
|
||||
)
|
||||
|
||||
return (await trio_done_fut).unwrap() # type: ignore[no-any-return]
|
||||
|
||||
try:
|
||||
return loop.run_until_complete(aio_main())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def test_guest_mode_on_asyncio() -> None:
|
||||
async def trio_main() -> str:
|
||||
print("trio_main!")
|
||||
|
||||
to_trio, from_aio = trio.open_memory_channel[int](float("inf"))
|
||||
from_trio: asyncio.Queue[int] = asyncio.Queue()
|
||||
|
||||
aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio))
|
||||
|
||||
# Make sure we have at least one tick where we don't need to go into
|
||||
# the thread
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
from_trio.put_nowait(0)
|
||||
|
||||
async for n in from_aio:
|
||||
print(f"trio got: {n}")
|
||||
from_trio.put_nowait(n + 1)
|
||||
if n >= 10:
|
||||
aio_task.cancel()
|
||||
return "trio-main-done"
|
||||
|
||||
raise AssertionError("should never be reached") # pragma: no cover
|
||||
|
||||
async def aio_pingpong(
|
||||
from_trio: asyncio.Queue[int],
|
||||
to_trio: MemorySendChannel[int],
|
||||
) -> None:
|
||||
print("aio_pingpong!")
|
||||
|
||||
try:
|
||||
while True:
|
||||
n = await from_trio.get()
|
||||
print(f"aio got: {n}")
|
||||
to_trio.send_nowait(n + 1)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except: # pragma: no cover
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
assert (
|
||||
aiotrio_run(
|
||||
trio_main,
|
||||
# Not all versions of asyncio we test on can actually be trusted,
|
||||
# but this test doesn't care about signal handling, and it's
|
||||
# easier to just avoid the warnings.
|
||||
host_uses_signal_set_wakeup_fd=True,
|
||||
)
|
||||
== "trio-main-done"
|
||||
)
|
||||
|
||||
assert (
|
||||
aiotrio_run(
|
||||
trio_main,
|
||||
# Also check that passing only call_soon_threadsafe works, via the
|
||||
# fallback path where we use it for everything.
|
||||
pass_not_threadsafe=False,
|
||||
host_uses_signal_set_wakeup_fd=True,
|
||||
)
|
||||
== "trio-main-done"
|
||||
)
|
||||
|
||||
|
||||
def test_guest_mode_internal_errors(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
recwarn: pytest.WarningsRecorder,
|
||||
) -> None:
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
async def crash_in_run_loop(in_host: InHost) -> None:
|
||||
m.setattr("trio._core._run.GLOBAL_RUN_CONTEXT.runner.runq", "HI")
|
||||
await trio.sleep(1)
|
||||
|
||||
with pytest.raises(trio.TrioInternalError):
|
||||
trivial_guest_run(crash_in_run_loop)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
async def crash_in_io(in_host: InHost) -> None:
|
||||
m.setattr("trio._core._run.TheIOManager.get_events", None)
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
with pytest.raises(trio.TrioInternalError):
|
||||
trivial_guest_run(crash_in_io)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
async def crash_in_worker_thread_io(in_host: InHost) -> None:
|
||||
t = threading.current_thread()
|
||||
old_get_events = trio._core._run.TheIOManager.get_events
|
||||
|
||||
def bad_get_events(*args: Any) -> object:
|
||||
if threading.current_thread() is not t:
|
||||
raise ValueError("oh no!")
|
||||
else:
|
||||
return old_get_events(*args)
|
||||
|
||||
m.setattr("trio._core._run.TheIOManager.get_events", bad_get_events)
|
||||
|
||||
await trio.sleep(1)
|
||||
|
||||
with pytest.raises(trio.TrioInternalError):
|
||||
trivial_guest_run(crash_in_worker_thread_io)
|
||||
|
||||
gc_collect_harder()
|
||||
|
||||
|
||||
def test_guest_mode_ki() -> None:
|
||||
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
|
||||
|
||||
# Check SIGINT in Trio func and in host func
|
||||
async def trio_main(in_host: InHost) -> None:
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
signal_raise(signal.SIGINT)
|
||||
|
||||
# Host SIGINT should get injected into Trio
|
||||
in_host(partial(signal_raise, signal.SIGINT))
|
||||
await trio.sleep(10)
|
||||
|
||||
with pytest.raises(KeyboardInterrupt) as excinfo:
|
||||
trivial_guest_run(trio_main)
|
||||
assert excinfo.value.__context__ is None
|
||||
# Signal handler should be restored properly on exit
|
||||
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
|
||||
|
||||
# Also check chaining in the case where KI is injected after main exits
|
||||
final_exc = KeyError("whoa")
|
||||
|
||||
async def trio_main_raising(in_host: InHost) -> NoReturn:
|
||||
in_host(partial(signal_raise, signal.SIGINT))
|
||||
raise final_exc
|
||||
|
||||
with pytest.raises(KeyboardInterrupt) as excinfo:
|
||||
trivial_guest_run(trio_main_raising)
|
||||
assert excinfo.value.__context__ is final_exc
|
||||
|
||||
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
|
||||
|
||||
|
||||
def test_guest_mode_autojump_clock_threshold_changing() -> None:
|
||||
# This is super obscure and probably no-one will ever notice, but
|
||||
# technically mutating the MockClock.autojump_threshold from the host
|
||||
# should wake up the guest, so let's test it.
|
||||
|
||||
clock = trio.testing.MockClock()
|
||||
|
||||
DURATION = 120
|
||||
|
||||
async def trio_main(in_host: InHost) -> None:
|
||||
assert trio.current_time() == 0
|
||||
in_host(lambda: setattr(clock, "autojump_threshold", 0))
|
||||
await trio.sleep(DURATION)
|
||||
assert trio.current_time() == DURATION
|
||||
|
||||
start = time.monotonic()
|
||||
trivial_guest_run(trio_main, clock=clock)
|
||||
end = time.monotonic()
|
||||
# Should be basically instantaneous, but we'll leave a generous buffer to
|
||||
# account for any CI weirdness
|
||||
assert end - start < DURATION / 2
|
||||
|
||||
|
||||
@restore_unraisablehook()
|
||||
def test_guest_mode_asyncgens() -> None:
|
||||
import sniffio
|
||||
|
||||
record = set()
|
||||
|
||||
async def agen(label: str) -> AsyncGenerator[int, None]:
|
||||
assert sniffio.current_async_library() == label
|
||||
try:
|
||||
yield 1
|
||||
finally:
|
||||
library = sniffio.current_async_library()
|
||||
with contextlib.suppress(trio.Cancelled):
|
||||
await sys.modules[library].sleep(0)
|
||||
record.add((label, library))
|
||||
|
||||
async def iterate_in_aio() -> None:
|
||||
await agen("asyncio").asend(None)
|
||||
|
||||
async def trio_main() -> None:
|
||||
task = asyncio.ensure_future(iterate_in_aio())
|
||||
done_evt = trio.Event()
|
||||
task.add_done_callback(lambda _: done_evt.set())
|
||||
with trio.fail_after(1):
|
||||
await done_evt.wait()
|
||||
|
||||
await agen("trio").asend(None)
|
||||
|
||||
gc_collect_harder()
|
||||
|
||||
# Ensure we don't pollute the thread-level context if run under
|
||||
# an asyncio without contextvars support (3.6)
|
||||
context = contextvars.copy_context()
|
||||
context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True)
|
||||
|
||||
assert record == {("asyncio", "asyncio"), ("trio", "trio")}
|
||||
@@ -0,0 +1,266 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Container, Iterable, NoReturn
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from ... import _abc, _core
|
||||
from .tutil import check_sequence_matches
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...lowlevel import Task
|
||||
|
||||
|
||||
@attrs.define(eq=False, slots=False)
|
||||
class TaskRecorder(_abc.Instrument):
|
||||
record: list[tuple[str, Task | None]] = attrs.Factory(list)
|
||||
|
||||
def before_run(self) -> None:
|
||||
self.record.append(("before_run", None))
|
||||
|
||||
def task_scheduled(self, task: Task) -> None:
|
||||
self.record.append(("schedule", task))
|
||||
|
||||
def before_task_step(self, task: Task) -> None:
|
||||
assert task is _core.current_task()
|
||||
self.record.append(("before", task))
|
||||
|
||||
def after_task_step(self, task: Task) -> None:
|
||||
assert task is _core.current_task()
|
||||
self.record.append(("after", task))
|
||||
|
||||
def after_run(self) -> None:
|
||||
self.record.append(("after_run", None))
|
||||
|
||||
def filter_tasks(self, tasks: Container[Task]) -> Iterable[tuple[str, Task | None]]:
|
||||
for item in self.record:
|
||||
if item[0] in ("schedule", "before", "after") and item[1] in tasks:
|
||||
yield item
|
||||
if item[0] in ("before_run", "after_run"):
|
||||
yield item
|
||||
|
||||
|
||||
def test_instruments(recwarn: object) -> None:
|
||||
r1 = TaskRecorder()
|
||||
r2 = TaskRecorder()
|
||||
r3 = TaskRecorder()
|
||||
|
||||
task = None
|
||||
|
||||
# We use a child task for this, because the main task does some extra
|
||||
# bookkeeping stuff that can leak into the instrument results, and we
|
||||
# don't want to deal with it.
|
||||
async def task_fn() -> None:
|
||||
nonlocal task
|
||||
task = _core.current_task()
|
||||
|
||||
for _ in range(4):
|
||||
await _core.checkpoint()
|
||||
# replace r2 with r3, to test that we can manipulate them as we go
|
||||
_core.remove_instrument(r2)
|
||||
with pytest.raises(KeyError):
|
||||
_core.remove_instrument(r2)
|
||||
# add is idempotent
|
||||
_core.add_instrument(r3)
|
||||
_core.add_instrument(r3)
|
||||
for _ in range(1):
|
||||
await _core.checkpoint()
|
||||
|
||||
async def main() -> None:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(task_fn)
|
||||
|
||||
_core.run(main, instruments=[r1, r2])
|
||||
|
||||
# It sleeps 5 times, so it runs 6 times. Note that checkpoint()
|
||||
# reschedules the task immediately upon yielding, before the
|
||||
# after_task_step event fires.
|
||||
expected = (
|
||||
[("before_run", None), ("schedule", task)]
|
||||
+ [("before", task), ("schedule", task), ("after", task)] * 5
|
||||
+ [("before", task), ("after", task), ("after_run", None)]
|
||||
)
|
||||
assert r1.record == r2.record + r3.record
|
||||
assert task is not None
|
||||
assert list(r1.filter_tasks([task])) == expected
|
||||
|
||||
|
||||
def test_instruments_interleave() -> None:
|
||||
tasks = {}
|
||||
|
||||
async def two_step1() -> None:
|
||||
tasks["t1"] = _core.current_task()
|
||||
await _core.checkpoint()
|
||||
|
||||
async def two_step2() -> None:
|
||||
tasks["t2"] = _core.current_task()
|
||||
await _core.checkpoint()
|
||||
|
||||
async def main() -> None:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(two_step1)
|
||||
nursery.start_soon(two_step2)
|
||||
|
||||
r = TaskRecorder()
|
||||
_core.run(main, instruments=[r])
|
||||
|
||||
expected = [
|
||||
("before_run", None),
|
||||
("schedule", tasks["t1"]),
|
||||
("schedule", tasks["t2"]),
|
||||
{
|
||||
("before", tasks["t1"]),
|
||||
("schedule", tasks["t1"]),
|
||||
("after", tasks["t1"]),
|
||||
("before", tasks["t2"]),
|
||||
("schedule", tasks["t2"]),
|
||||
("after", tasks["t2"]),
|
||||
},
|
||||
{
|
||||
("before", tasks["t1"]),
|
||||
("after", tasks["t1"]),
|
||||
("before", tasks["t2"]),
|
||||
("after", tasks["t2"]),
|
||||
},
|
||||
("after_run", None),
|
||||
]
|
||||
print(list(r.filter_tasks(tasks.values())))
|
||||
check_sequence_matches(list(r.filter_tasks(tasks.values())), expected)
|
||||
|
||||
|
||||
def test_null_instrument() -> None:
|
||||
# undefined instrument methods are skipped
|
||||
class NullInstrument(_abc.Instrument):
|
||||
def something_unrelated(self) -> None:
|
||||
pass # pragma: no cover
|
||||
|
||||
async def main() -> None:
|
||||
await _core.checkpoint()
|
||||
|
||||
_core.run(main, instruments=[NullInstrument()])
|
||||
|
||||
|
||||
def test_instrument_before_after_run() -> None:
|
||||
record = []
|
||||
|
||||
class BeforeAfterRun(_abc.Instrument):
|
||||
def before_run(self) -> None:
|
||||
record.append("before_run")
|
||||
|
||||
def after_run(self) -> None:
|
||||
record.append("after_run")
|
||||
|
||||
async def main() -> None:
|
||||
pass
|
||||
|
||||
_core.run(main, instruments=[BeforeAfterRun()])
|
||||
assert record == ["before_run", "after_run"]
|
||||
|
||||
|
||||
def test_instrument_task_spawn_exit() -> None:
|
||||
record = []
|
||||
|
||||
class SpawnExitRecorder(_abc.Instrument):
|
||||
def task_spawned(self, task: Task) -> None:
|
||||
record.append(("spawned", task))
|
||||
|
||||
def task_exited(self, task: Task) -> None:
|
||||
record.append(("exited", task))
|
||||
|
||||
async def main() -> Task:
|
||||
return _core.current_task()
|
||||
|
||||
main_task = _core.run(main, instruments=[SpawnExitRecorder()])
|
||||
assert ("spawned", main_task) in record
|
||||
assert ("exited", main_task) in record
|
||||
|
||||
|
||||
# This test also tests having a crash before the initial task is even spawned,
|
||||
# which is very difficult to handle.
|
||||
def test_instruments_crash(caplog: pytest.LogCaptureFixture) -> None:
|
||||
record = []
|
||||
|
||||
class BrokenInstrument(_abc.Instrument):
|
||||
def task_scheduled(self, task: Task) -> NoReturn:
|
||||
record.append("scheduled")
|
||||
raise ValueError("oops")
|
||||
|
||||
def close(self) -> None:
|
||||
# Shouldn't be called -- tests that the instrument disabling logic
|
||||
# works right.
|
||||
record.append("closed") # pragma: no cover
|
||||
|
||||
async def main() -> Task:
|
||||
record.append("main ran")
|
||||
return _core.current_task()
|
||||
|
||||
r = TaskRecorder()
|
||||
main_task = _core.run(main, instruments=[r, BrokenInstrument()])
|
||||
assert record == ["scheduled", "main ran"]
|
||||
# the TaskRecorder kept going throughout, even though the BrokenInstrument
|
||||
# was disabled
|
||||
assert ("after", main_task) in r.record
|
||||
assert ("after_run", None) in r.record
|
||||
# And we got a log message
|
||||
assert caplog.records[0].exc_info is not None
|
||||
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info
|
||||
assert exc_type is ValueError
|
||||
assert str(exc_value) == "oops"
|
||||
assert "Instrument has been disabled" in caplog.records[0].message
|
||||
|
||||
|
||||
def test_instruments_monkeypatch() -> None:
|
||||
class NullInstrument(_abc.Instrument):
|
||||
pass
|
||||
|
||||
instrument = NullInstrument()
|
||||
|
||||
async def main() -> None:
|
||||
record: list[Task] = []
|
||||
|
||||
# Changing the set of hooks implemented by an instrument after
|
||||
# it's installed doesn't make them start being called right away
|
||||
instrument.before_task_step = ( # type: ignore[method-assign]
|
||||
record.append # type: ignore[assignment] # append is pos-only
|
||||
)
|
||||
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
assert len(record) == 0
|
||||
|
||||
# But if we remove and re-add the instrument, the new hooks are
|
||||
# picked up
|
||||
_core.remove_instrument(instrument)
|
||||
_core.add_instrument(instrument)
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
assert record.count(_core.current_task()) == 2
|
||||
|
||||
_core.remove_instrument(instrument)
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
assert record.count(_core.current_task()) == 2
|
||||
|
||||
_core.run(main, instruments=[instrument])
|
||||
|
||||
|
||||
def test_instrument_that_raises_on_getattr() -> None:
|
||||
class EvilInstrument(_abc.Instrument):
|
||||
def task_exited(self, task: Task) -> NoReturn:
|
||||
raise AssertionError("this should never happen") # pragma: no cover
|
||||
|
||||
@property
|
||||
def after_run(self) -> NoReturn:
|
||||
raise ValueError("oops")
|
||||
|
||||
async def main() -> None:
|
||||
with pytest.raises(ValueError, match="^oops$"):
|
||||
_core.add_instrument(EvilInstrument())
|
||||
|
||||
# Make sure the instrument is fully removed from the per-method lists
|
||||
runner = _core.current_task()._runner
|
||||
assert "after_run" not in runner.instruments
|
||||
assert "task_exited" not in runner.instruments
|
||||
|
||||
_core.run(main)
|
||||
@@ -0,0 +1,480 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import socket as stdlib_socket
|
||||
from contextlib import suppress
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
|
||||
from ... import _core
|
||||
from ...testing import assert_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
# Cross-platform tests for IO handling
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
ArgsT = ParamSpec("ArgsT")
|
||||
|
||||
|
||||
def fill_socket(sock: stdlib_socket.socket) -> None:
|
||||
try:
|
||||
while True:
|
||||
sock.send(b"x" * 65536)
|
||||
except BlockingIOError:
|
||||
pass
|
||||
|
||||
|
||||
def drain_socket(sock: stdlib_socket.socket) -> None:
|
||||
try:
|
||||
while True:
|
||||
sock.recv(65536)
|
||||
except BlockingIOError:
|
||||
pass
|
||||
|
||||
|
||||
WaitSocket = Callable[[stdlib_socket.socket], Awaitable[object]]
|
||||
SocketPair = Tuple[stdlib_socket.socket, stdlib_socket.socket]
|
||||
RetT = TypeVar("RetT")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socketpair() -> Generator[SocketPair, None, None]:
|
||||
pair = stdlib_socket.socketpair()
|
||||
for sock in pair:
|
||||
sock.setblocking(False)
|
||||
yield pair
|
||||
for sock in pair:
|
||||
sock.close()
|
||||
|
||||
|
||||
def also_using_fileno(
|
||||
fn: Callable[[stdlib_socket.socket | int], RetT],
|
||||
) -> list[Callable[[stdlib_socket.socket], RetT]]:
|
||||
def fileno_wrapper(fileobj: stdlib_socket.socket) -> RetT:
|
||||
return fn(fileobj.fileno())
|
||||
|
||||
name = f"<{fn.__name__} on fileno>"
|
||||
fileno_wrapper.__name__ = fileno_wrapper.__qualname__ = name
|
||||
return [fn, fileno_wrapper]
|
||||
|
||||
|
||||
# Decorators that feed in different settings for wait_readable / wait_writable
|
||||
# / notify_closing.
|
||||
# Note that if you use all three decorators on the same test, it will run all
|
||||
# N**3 *combinations*
|
||||
read_socket_test = pytest.mark.parametrize(
|
||||
"wait_readable",
|
||||
also_using_fileno(trio.lowlevel.wait_readable),
|
||||
ids=lambda fn: fn.__name__,
|
||||
)
|
||||
write_socket_test = pytest.mark.parametrize(
|
||||
"wait_writable",
|
||||
also_using_fileno(trio.lowlevel.wait_writable),
|
||||
ids=lambda fn: fn.__name__,
|
||||
)
|
||||
notify_closing_test = pytest.mark.parametrize(
|
||||
"notify_closing",
|
||||
also_using_fileno(trio.lowlevel.notify_closing),
|
||||
ids=lambda fn: fn.__name__,
|
||||
)
|
||||
|
||||
|
||||
# XX These tests are all a bit dicey because they can't distinguish between
|
||||
# wait_on_{read,writ}able blocking the way it should, versus blocking
|
||||
# momentarily and then immediately resuming.
|
||||
@read_socket_test
|
||||
@write_socket_test
|
||||
async def test_wait_basic(
|
||||
socketpair: SocketPair,
|
||||
wait_readable: WaitSocket,
|
||||
wait_writable: WaitSocket,
|
||||
) -> None:
|
||||
a, b = socketpair
|
||||
|
||||
# They start out writable()
|
||||
with assert_checkpoints():
|
||||
await wait_writable(a)
|
||||
|
||||
# But readable() blocks until data arrives
|
||||
record = []
|
||||
|
||||
async def block_on_read() -> None:
|
||||
try:
|
||||
with assert_checkpoints():
|
||||
await wait_readable(a)
|
||||
except _core.Cancelled:
|
||||
record.append("cancelled")
|
||||
else:
|
||||
record.append("readable")
|
||||
assert a.recv(10) == b"x"
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(block_on_read)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == []
|
||||
b.send(b"x")
|
||||
|
||||
fill_socket(a)
|
||||
|
||||
# Now writable will block, but readable won't
|
||||
with assert_checkpoints():
|
||||
await wait_readable(b)
|
||||
record = []
|
||||
|
||||
async def block_on_write() -> None:
|
||||
try:
|
||||
with assert_checkpoints():
|
||||
await wait_writable(a)
|
||||
except _core.Cancelled:
|
||||
record.append("cancelled")
|
||||
else:
|
||||
record.append("writable")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(block_on_write)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == []
|
||||
drain_socket(b)
|
||||
|
||||
# check cancellation
|
||||
record = []
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(block_on_read)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
assert record == ["cancelled"]
|
||||
|
||||
fill_socket(a)
|
||||
record = []
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(block_on_write)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
assert record == ["cancelled"]
|
||||
|
||||
|
||||
@read_socket_test
|
||||
async def test_double_read(socketpair: SocketPair, wait_readable: WaitSocket) -> None:
|
||||
a, b = socketpair
|
||||
|
||||
# You can't have two tasks trying to read from a socket at the same time
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(wait_readable, a)
|
||||
await wait_all_tasks_blocked()
|
||||
with pytest.raises(_core.BusyResourceError):
|
||||
await wait_readable(a)
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@write_socket_test
|
||||
async def test_double_write(socketpair: SocketPair, wait_writable: WaitSocket) -> None:
|
||||
a, b = socketpair
|
||||
|
||||
# You can't have two tasks trying to write to a socket at the same time
|
||||
fill_socket(a)
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(wait_writable, a)
|
||||
await wait_all_tasks_blocked()
|
||||
with pytest.raises(_core.BusyResourceError):
|
||||
await wait_writable(a)
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@read_socket_test
|
||||
@write_socket_test
|
||||
@notify_closing_test
|
||||
async def test_interrupted_by_close(
|
||||
socketpair: SocketPair,
|
||||
wait_readable: WaitSocket,
|
||||
wait_writable: WaitSocket,
|
||||
notify_closing: Callable[[stdlib_socket.socket], object],
|
||||
) -> None:
|
||||
a, b = socketpair
|
||||
|
||||
async def reader() -> None:
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await wait_readable(a)
|
||||
|
||||
async def writer() -> None:
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await wait_writable(a)
|
||||
|
||||
fill_socket(a)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(reader)
|
||||
nursery.start_soon(writer)
|
||||
await wait_all_tasks_blocked()
|
||||
notify_closing(a)
|
||||
|
||||
|
||||
@read_socket_test
|
||||
@write_socket_test
|
||||
async def test_socket_simultaneous_read_write(
|
||||
socketpair: SocketPair,
|
||||
wait_readable: WaitSocket,
|
||||
wait_writable: WaitSocket,
|
||||
) -> None:
|
||||
record: list[str] = []
|
||||
|
||||
async def r_task(sock: stdlib_socket.socket) -> None:
|
||||
await wait_readable(sock)
|
||||
record.append("r_task")
|
||||
|
||||
async def w_task(sock: stdlib_socket.socket) -> None:
|
||||
await wait_writable(sock)
|
||||
record.append("w_task")
|
||||
|
||||
a, b = socketpair
|
||||
fill_socket(a)
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(r_task, a)
|
||||
nursery.start_soon(w_task, a)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == []
|
||||
b.send(b"x")
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["r_task"]
|
||||
drain_socket(b)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["r_task", "w_task"]
|
||||
|
||||
|
||||
@read_socket_test
|
||||
@write_socket_test
|
||||
async def test_socket_actual_streaming(
|
||||
socketpair: SocketPair,
|
||||
wait_readable: WaitSocket,
|
||||
wait_writable: WaitSocket,
|
||||
) -> None:
|
||||
a, b = socketpair
|
||||
|
||||
# Use a small send buffer on one of the sockets to increase the chance of
|
||||
# getting partial writes
|
||||
a.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_SNDBUF, 10000)
|
||||
|
||||
N = 1000000 # 1 megabyte
|
||||
MAX_CHUNK = 65536
|
||||
|
||||
results: dict[str, int] = {}
|
||||
|
||||
async def sender(sock: stdlib_socket.socket, seed: int, key: str) -> None:
|
||||
r = random.Random(seed)
|
||||
sent = 0
|
||||
while sent < N:
|
||||
print("sent", sent)
|
||||
chunk = bytearray(r.randrange(MAX_CHUNK))
|
||||
while chunk:
|
||||
with assert_checkpoints():
|
||||
await wait_writable(sock)
|
||||
this_chunk_size = sock.send(chunk)
|
||||
sent += this_chunk_size
|
||||
del chunk[:this_chunk_size]
|
||||
sock.shutdown(stdlib_socket.SHUT_WR)
|
||||
results[key] = sent
|
||||
|
||||
async def receiver(sock: stdlib_socket.socket, key: str) -> None:
|
||||
received = 0
|
||||
while True:
|
||||
print("received", received)
|
||||
with assert_checkpoints():
|
||||
await wait_readable(sock)
|
||||
this_chunk_size = len(sock.recv(MAX_CHUNK))
|
||||
if not this_chunk_size:
|
||||
break
|
||||
received += this_chunk_size
|
||||
results[key] = received
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender, a, 0, "send_a")
|
||||
nursery.start_soon(sender, b, 1, "send_b")
|
||||
nursery.start_soon(receiver, a, "recv_a")
|
||||
nursery.start_soon(receiver, b, "recv_b")
|
||||
|
||||
assert results["send_a"] == results["recv_b"]
|
||||
assert results["send_b"] == results["recv_a"]
|
||||
|
||||
|
||||
async def test_notify_closing_on_invalid_object() -> None:
|
||||
# It should either be a no-op (generally on Unix, where we don't know
|
||||
# which fds are valid), or an OSError (on Windows, where we currently only
|
||||
# support sockets, so we have to do some validation to figure out whether
|
||||
# it's a socket or a regular handle).
|
||||
got_oserror = False
|
||||
got_no_error = False
|
||||
try:
|
||||
trio.lowlevel.notify_closing(-1)
|
||||
except OSError:
|
||||
got_oserror = True
|
||||
else:
|
||||
got_no_error = True
|
||||
assert got_oserror or got_no_error
|
||||
|
||||
|
||||
async def test_wait_on_invalid_object() -> None:
|
||||
# We definitely want to raise an error everywhere if you pass in an
|
||||
# invalid fd to wait_*
|
||||
for wait in [trio.lowlevel.wait_readable, trio.lowlevel.wait_writable]:
|
||||
with stdlib_socket.socket() as s:
|
||||
fileno = s.fileno()
|
||||
# We just closed the socket and don't do anything else in between, so
|
||||
# we can be confident that the fileno hasn't be reassigned.
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^\[\w+ \d+] (Bad file descriptor|An operation was attempted on something that is not a socket)$",
|
||||
):
|
||||
await wait(fileno)
|
||||
|
||||
|
||||
async def test_io_manager_statistics() -> None:
|
||||
def check(*, expected_readers: int, expected_writers: int) -> None:
|
||||
statistics = _core.current_statistics()
|
||||
print(statistics)
|
||||
iostats = statistics.io_statistics
|
||||
if iostats.backend == "epoll" or iostats.backend == "windows":
|
||||
assert iostats.tasks_waiting_read == expected_readers
|
||||
assert iostats.tasks_waiting_write == expected_writers
|
||||
else:
|
||||
assert iostats.backend == "kqueue"
|
||||
assert iostats.tasks_waiting == expected_readers + expected_writers
|
||||
|
||||
a1, b1 = stdlib_socket.socketpair()
|
||||
a2, b2 = stdlib_socket.socketpair()
|
||||
a3, b3 = stdlib_socket.socketpair()
|
||||
for sock in [a1, b1, a2, b2, a3, b3]:
|
||||
sock.setblocking(False)
|
||||
with a1, b1, a2, b2, a3, b3:
|
||||
# let the call_soon_task settle down
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# 1 for call_soon_task
|
||||
check(expected_readers=1, expected_writers=0)
|
||||
|
||||
# We want:
|
||||
# - one socket with a writer blocked
|
||||
# - two sockets with a reader blocked
|
||||
# - a socket with both blocked
|
||||
fill_socket(a1)
|
||||
fill_socket(a3)
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(_core.wait_writable, a1)
|
||||
nursery.start_soon(_core.wait_readable, a2)
|
||||
nursery.start_soon(_core.wait_readable, b2)
|
||||
nursery.start_soon(_core.wait_writable, a3)
|
||||
nursery.start_soon(_core.wait_readable, a3)
|
||||
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# +1 for call_soon_task
|
||||
check(expected_readers=3 + 1, expected_writers=2)
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# 1 for call_soon_task
|
||||
check(expected_readers=1, expected_writers=0)
|
||||
|
||||
|
||||
async def test_can_survive_unnotified_close() -> None:
|
||||
# An "unnotified" close is when the user closes an fd/socket/handle
|
||||
# directly, without calling notify_closing first. This should never happen
|
||||
# -- users should call notify_closing before closing things. But, just in
|
||||
# case they don't, we would still like to avoid exploding.
|
||||
#
|
||||
# Acceptable behaviors:
|
||||
# - wait_* never return, but can be cancelled cleanly
|
||||
# - wait_* exit cleanly
|
||||
# - wait_* raise an OSError
|
||||
#
|
||||
# Not acceptable:
|
||||
# - getting stuck in an uncancellable state
|
||||
# - TrioInternalError blowing up the whole run
|
||||
#
|
||||
# This test exercises some tricky "unnotified close" scenarios, to make
|
||||
# sure we get the "acceptable" behaviors.
|
||||
|
||||
async def allow_OSError(
|
||||
async_func: Callable[ArgsT, Awaitable[object]],
|
||||
*args: ArgsT.args,
|
||||
**kwargs: ArgsT.kwargs,
|
||||
) -> None:
|
||||
with suppress(OSError):
|
||||
await async_func(*args, **kwargs)
|
||||
|
||||
with stdlib_socket.socket() as s:
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
|
||||
await wait_all_tasks_blocked()
|
||||
s.close()
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# We hit different paths on Windows depending on whether we close the last
|
||||
# handle to the object (which produces a LOCAL_CLOSE notification and
|
||||
# wakes up wait_readable), or only close one of the handles (which leaves
|
||||
# wait_readable pending until cancelled).
|
||||
with stdlib_socket.socket() as s, s.dup() as s2: # noqa: F841
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
|
||||
await wait_all_tasks_blocked()
|
||||
s.close()
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# A more elaborate case, with two tasks waiting. On windows and epoll,
|
||||
# the two tasks get muxed together onto a single underlying wait
|
||||
# operation. So when they're cancelled, there's a brief moment where one
|
||||
# of the tasks is cancelled but the other isn't, so we try to re-issue the
|
||||
# underlying wait operation. But here, the handle we were going to use to
|
||||
# do that has been pulled out from under our feet... so test that we can
|
||||
# survive this.
|
||||
a, b = stdlib_socket.socketpair()
|
||||
with a, b, a.dup() as a2:
|
||||
a.setblocking(False)
|
||||
b.setblocking(False)
|
||||
fill_socket(a)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
|
||||
nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
|
||||
await wait_all_tasks_blocked()
|
||||
a.close()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# A similar case, but now the single-task-wakeup happens due to I/O
|
||||
# arriving, not a cancellation, so the operation gets re-issued from
|
||||
# handle_io context rather than abort context.
|
||||
a, b = stdlib_socket.socketpair()
|
||||
with a, b, a.dup() as a2:
|
||||
print(f"a={a.fileno()}, b={b.fileno()}, a2={a2.fileno()}")
|
||||
a.setblocking(False)
|
||||
b.setblocking(False)
|
||||
fill_socket(a)
|
||||
e = trio.Event()
|
||||
|
||||
# We want to wait for the kernel to process the wakeup on 'a', if any.
|
||||
# But depending on the platform, we might not get a wakeup on 'a'. So
|
||||
# we put one task to sleep waiting on 'a', and we put a second task to
|
||||
# sleep waiting on 'a2', with the idea that the 'a2' notification will
|
||||
# definitely arrive, and when it does then we can assume that whatever
|
||||
# notification was going to arrive for 'a' has also arrived.
|
||||
async def wait_readable_a2_then_set() -> None:
|
||||
await trio.lowlevel.wait_readable(a2)
|
||||
e.set()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
|
||||
nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
|
||||
nursery.start_soon(wait_readable_a2_then_set)
|
||||
await wait_all_tasks_blocked()
|
||||
a.close()
|
||||
b.send(b"x")
|
||||
# Make sure that the wakeup has been received and everything has
|
||||
# settled before cancelling the wait_writable.
|
||||
await e.wait()
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
@@ -0,0 +1,517 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import signal
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator
|
||||
|
||||
import outcome
|
||||
import pytest
|
||||
|
||||
from trio.testing import RaisesGroup
|
||||
|
||||
try:
|
||||
from async_generator import async_generator, yield_
|
||||
except ImportError: # pragma: no cover
|
||||
async_generator = yield_ = None
|
||||
|
||||
from ... import _core
|
||||
from ..._abc import Instrument
|
||||
from ..._timeouts import sleep
|
||||
from ..._util import signal_raise
|
||||
from ...testing import wait_all_tasks_blocked
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._core import Abort, RaiseCancelT
|
||||
|
||||
|
||||
def ki_self() -> None:
|
||||
signal_raise(signal.SIGINT)
|
||||
|
||||
|
||||
def test_ki_self() -> None:
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
ki_self()
|
||||
|
||||
|
||||
async def test_ki_enabled() -> None:
|
||||
# Regular tasks aren't KI-protected
|
||||
assert not _core.currently_ki_protected()
|
||||
|
||||
# Low-level call-soon callbacks are KI-protected
|
||||
token = _core.current_trio_token()
|
||||
record = []
|
||||
|
||||
def check() -> None:
|
||||
record.append(_core.currently_ki_protected())
|
||||
|
||||
token.run_sync_soon(check)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == [True]
|
||||
|
||||
@_core.enable_ki_protection
|
||||
def protected() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
unprotected()
|
||||
|
||||
@_core.disable_ki_protection
|
||||
def unprotected() -> None:
|
||||
assert not _core.currently_ki_protected()
|
||||
|
||||
protected()
|
||||
|
||||
@_core.enable_ki_protection
|
||||
async def aprotected() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
await aunprotected()
|
||||
|
||||
@_core.disable_ki_protection
|
||||
async def aunprotected() -> None:
|
||||
assert not _core.currently_ki_protected()
|
||||
|
||||
await aprotected()
|
||||
|
||||
# make sure that the decorator here overrides the automatic manipulation
|
||||
# that start_soon() does:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(aprotected)
|
||||
nursery.start_soon(aunprotected)
|
||||
|
||||
@_core.enable_ki_protection
|
||||
def gen_protected() -> Iterator[None]:
|
||||
assert _core.currently_ki_protected()
|
||||
yield
|
||||
|
||||
for _ in gen_protected():
|
||||
pass
|
||||
|
||||
@_core.disable_ki_protection
|
||||
def gen_unprotected() -> Iterator[None]:
|
||||
assert not _core.currently_ki_protected()
|
||||
yield
|
||||
|
||||
for _ in gen_unprotected():
|
||||
pass
|
||||
|
||||
|
||||
# This used to be broken due to
|
||||
#
|
||||
# https://bugs.python.org/issue29590
|
||||
#
|
||||
# Specifically, after a coroutine is resumed with .throw(), then the stack
|
||||
# makes it look like the immediate caller is the function that called
|
||||
# .throw(), not the actual caller. So child() here would have a caller deep in
|
||||
# the guts of the run loop, and always be protected, even when it shouldn't
|
||||
# have been. (Solution: we don't use .throw() anymore.)
|
||||
async def test_ki_enabled_after_yield_briefly() -> None:
|
||||
@_core.enable_ki_protection
|
||||
async def protected() -> None:
|
||||
await child(True)
|
||||
|
||||
@_core.disable_ki_protection
|
||||
async def unprotected() -> None:
|
||||
await child(False)
|
||||
|
||||
async def child(expected: bool) -> None:
|
||||
import traceback
|
||||
|
||||
traceback.print_stack()
|
||||
assert _core.currently_ki_protected() == expected
|
||||
await _core.checkpoint()
|
||||
traceback.print_stack()
|
||||
assert _core.currently_ki_protected() == expected
|
||||
|
||||
await protected()
|
||||
await unprotected()
|
||||
|
||||
|
||||
# This also used to be broken due to
|
||||
# https://bugs.python.org/issue29590
|
||||
async def test_generator_based_context_manager_throw() -> None:
|
||||
@contextlib.contextmanager
|
||||
@_core.enable_ki_protection
|
||||
def protected_manager() -> Iterator[None]:
|
||||
assert _core.currently_ki_protected()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
assert _core.currently_ki_protected()
|
||||
|
||||
with protected_manager():
|
||||
assert not _core.currently_ki_protected()
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# This is the one that used to fail
|
||||
with protected_manager():
|
||||
raise KeyError
|
||||
|
||||
|
||||
# the async_generator package isn't typed, hence all the type: ignores
|
||||
@pytest.mark.skipif(async_generator is None, reason="async_generator not installed")
|
||||
async def test_async_generator_agen_protection() -> None:
|
||||
@_core.enable_ki_protection
|
||||
@async_generator # type: ignore[misc] # untyped generator
|
||||
async def agen_protected1() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
try:
|
||||
await yield_()
|
||||
finally:
|
||||
assert _core.currently_ki_protected()
|
||||
|
||||
@_core.disable_ki_protection
|
||||
@async_generator # type: ignore[misc] # untyped generator
|
||||
async def agen_unprotected1() -> None:
|
||||
assert not _core.currently_ki_protected()
|
||||
try:
|
||||
await yield_()
|
||||
finally:
|
||||
assert not _core.currently_ki_protected()
|
||||
|
||||
# Swap the order of the decorators:
|
||||
@async_generator # type: ignore[misc] # untyped generator
|
||||
@_core.enable_ki_protection
|
||||
async def agen_protected2() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
try:
|
||||
await yield_()
|
||||
finally:
|
||||
assert _core.currently_ki_protected()
|
||||
|
||||
@async_generator # type: ignore[misc] # untyped generator
|
||||
@_core.disable_ki_protection
|
||||
async def agen_unprotected2() -> None:
|
||||
assert not _core.currently_ki_protected()
|
||||
try:
|
||||
await yield_()
|
||||
finally:
|
||||
assert not _core.currently_ki_protected()
|
||||
|
||||
await _check_agen(agen_protected1)
|
||||
await _check_agen(agen_protected2)
|
||||
await _check_agen(agen_unprotected1)
|
||||
await _check_agen(agen_unprotected2)
|
||||
|
||||
|
||||
async def test_native_agen_protection() -> None:
|
||||
# Native async generators
|
||||
@_core.enable_ki_protection
|
||||
async def agen_protected() -> AsyncIterator[None]:
|
||||
assert _core.currently_ki_protected()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
assert _core.currently_ki_protected()
|
||||
|
||||
@_core.disable_ki_protection
|
||||
async def agen_unprotected() -> AsyncIterator[None]:
|
||||
assert not _core.currently_ki_protected()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
assert not _core.currently_ki_protected()
|
||||
|
||||
await _check_agen(agen_protected)
|
||||
await _check_agen(agen_unprotected)
|
||||
|
||||
|
||||
async def _check_agen(agen_fn: Callable[[], AsyncIterator[None]]) -> None:
|
||||
async for _ in agen_fn():
|
||||
assert not _core.currently_ki_protected()
|
||||
|
||||
# asynccontextmanager insists that the function passed must itself be an
|
||||
# async gen function, not a wrapper around one
|
||||
if inspect.isasyncgenfunction(agen_fn):
|
||||
async with contextlib.asynccontextmanager(agen_fn)():
|
||||
assert not _core.currently_ki_protected()
|
||||
|
||||
# Another case that's tricky due to:
|
||||
# https://bugs.python.org/issue29590
|
||||
with pytest.raises(KeyError):
|
||||
async with contextlib.asynccontextmanager(agen_fn)():
|
||||
raise KeyError
|
||||
|
||||
|
||||
# Test the case where there's no magic local anywhere in the call stack
|
||||
def test_ki_disabled_out_of_context() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
|
||||
|
||||
def test_ki_disabled_in_del() -> None:
|
||||
def nestedfunction() -> bool:
|
||||
return _core.currently_ki_protected()
|
||||
|
||||
def __del__() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
assert nestedfunction()
|
||||
|
||||
@_core.disable_ki_protection
|
||||
def outerfunction() -> None:
|
||||
assert not _core.currently_ki_protected()
|
||||
assert not nestedfunction()
|
||||
__del__()
|
||||
|
||||
__del__()
|
||||
outerfunction()
|
||||
assert nestedfunction()
|
||||
|
||||
|
||||
def test_ki_protection_works() -> None:
|
||||
async def sleeper(name: str, record: set[str]) -> None:
|
||||
try:
|
||||
while True:
|
||||
await _core.checkpoint()
|
||||
except _core.Cancelled:
|
||||
record.add(name + " ok")
|
||||
|
||||
async def raiser(name: str, record: set[str]) -> None:
|
||||
try:
|
||||
# os.kill runs signal handlers before returning, so we don't need
|
||||
# to worry that the handler will be delayed
|
||||
print("killing, protection =", _core.currently_ki_protected())
|
||||
ki_self()
|
||||
except KeyboardInterrupt:
|
||||
print("raised!")
|
||||
# Make sure we aren't getting cancelled as well as siginted
|
||||
await _core.checkpoint()
|
||||
record.add(name + " raise ok")
|
||||
raise
|
||||
else:
|
||||
print("didn't raise!")
|
||||
# If we didn't raise (b/c protected), then we *should* get
|
||||
# cancelled at the next opportunity
|
||||
try:
|
||||
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
|
||||
except _core.Cancelled:
|
||||
record.add(name + " cancel ok")
|
||||
|
||||
# simulated control-C during raiser, which is *unprotected*
|
||||
print("check 1")
|
||||
record_set: set[str] = set()
|
||||
|
||||
async def check_unprotected_kill() -> None:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sleeper, "s1", record_set)
|
||||
nursery.start_soon(sleeper, "s2", record_set)
|
||||
nursery.start_soon(raiser, "r1", record_set)
|
||||
|
||||
# raises inside a nursery, so the KeyboardInterrupt is wrapped in an ExceptionGroup
|
||||
with RaisesGroup(KeyboardInterrupt):
|
||||
_core.run(check_unprotected_kill)
|
||||
assert record_set == {"s1 ok", "s2 ok", "r1 raise ok"}
|
||||
|
||||
# simulated control-C during raiser, which is *protected*, so the KI gets
|
||||
# delivered to the main task instead
|
||||
print("check 2")
|
||||
record_set = set()
|
||||
|
||||
async def check_protected_kill() -> None:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sleeper, "s1", record_set)
|
||||
nursery.start_soon(sleeper, "s2", record_set)
|
||||
nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record_set)
|
||||
# __aexit__ blocks, and then receives the KI
|
||||
|
||||
# raises inside a nursery, so the KeyboardInterrupt is wrapped in an ExceptionGroup
|
||||
with RaisesGroup(KeyboardInterrupt):
|
||||
_core.run(check_protected_kill)
|
||||
assert record_set == {"s1 ok", "s2 ok", "r1 cancel ok"}
|
||||
|
||||
# kill at last moment still raises (run_sync_soon until it raises an
|
||||
# error, then kill)
|
||||
print("check 3")
|
||||
|
||||
async def check_kill_during_shutdown() -> None:
|
||||
token = _core.current_trio_token()
|
||||
|
||||
def kill_during_shutdown() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
try:
|
||||
token.run_sync_soon(kill_during_shutdown)
|
||||
except _core.RunFinishedError:
|
||||
# it's too late for regular handling! handle this!
|
||||
print("kill! kill!")
|
||||
ki_self()
|
||||
|
||||
token.run_sync_soon(kill_during_shutdown)
|
||||
|
||||
# no nurseries involved, so the KeyboardInterrupt isn't wrapped
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
_core.run(check_kill_during_shutdown)
|
||||
|
||||
# KI arrives very early, before main is even spawned
|
||||
print("check 4")
|
||||
|
||||
class InstrumentOfDeath(Instrument):
|
||||
def before_run(self) -> None:
|
||||
ki_self()
|
||||
|
||||
async def main_1() -> None:
|
||||
await _core.checkpoint()
|
||||
|
||||
# no nurseries involved, so the KeyboardInterrupt isn't wrapped
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
_core.run(main_1, instruments=[InstrumentOfDeath()])
|
||||
|
||||
# checkpoint_if_cancelled notices pending KI
|
||||
print("check 5")
|
||||
|
||||
@_core.enable_ki_protection
|
||||
async def main_2() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
ki_self()
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
await _core.checkpoint_if_cancelled()
|
||||
|
||||
_core.run(main_2)
|
||||
|
||||
# KI arrives while main task is not abortable, b/c already scheduled
|
||||
print("check 6")
|
||||
|
||||
@_core.enable_ki_protection
|
||||
async def main_3() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
ki_self()
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
await _core.checkpoint()
|
||||
|
||||
_core.run(main_3)
|
||||
|
||||
# KI arrives while main task is not abortable, b/c refuses to be aborted
|
||||
print("check 7")
|
||||
|
||||
@_core.enable_ki_protection
|
||||
async def main_4() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
ki_self()
|
||||
task = _core.current_task()
|
||||
|
||||
def abort(_: RaiseCancelT) -> Abort:
|
||||
_core.reschedule(task, outcome.Value(1))
|
||||
return _core.Abort.FAILED
|
||||
|
||||
assert await _core.wait_task_rescheduled(abort) == 1
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
await _core.checkpoint()
|
||||
|
||||
_core.run(main_4)
|
||||
|
||||
# KI delivered via slow abort
|
||||
print("check 8")
|
||||
|
||||
@_core.enable_ki_protection
|
||||
async def main_5() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
ki_self()
|
||||
task = _core.current_task()
|
||||
|
||||
def abort(raise_cancel: RaiseCancelT) -> Abort:
|
||||
result = outcome.capture(raise_cancel)
|
||||
_core.reschedule(task, result)
|
||||
return _core.Abort.FAILED
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
assert await _core.wait_task_rescheduled(abort)
|
||||
await _core.checkpoint()
|
||||
|
||||
_core.run(main_5)
|
||||
|
||||
# KI arrives just before main task exits, so the run_sync_soon machinery
|
||||
# is still functioning and will accept the callback to deliver the KI, but
|
||||
# by the time the callback is actually run, main has exited and can't be
|
||||
# aborted.
|
||||
print("check 9")
|
||||
|
||||
@_core.enable_ki_protection
|
||||
async def main_6() -> None:
|
||||
ki_self()
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
_core.run(main_6)
|
||||
|
||||
print("check 10")
|
||||
# KI in unprotected code, with
|
||||
# restrict_keyboard_interrupt_to_checkpoints=True
|
||||
record_list = []
|
||||
|
||||
async def main_7() -> None:
|
||||
# We're not KI protected...
|
||||
assert not _core.currently_ki_protected()
|
||||
ki_self()
|
||||
# ...but even after the KI, we keep running uninterrupted...
|
||||
record_list.append("ok")
|
||||
# ...until we hit a checkpoint:
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
await sleep(10)
|
||||
|
||||
_core.run(main_7, restrict_keyboard_interrupt_to_checkpoints=True)
|
||||
assert record_list == ["ok"]
|
||||
record_list = []
|
||||
# Exact same code raises KI early if we leave off the argument, doesn't
|
||||
# even reach the record.append call:
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
_core.run(main_7)
|
||||
assert record_list == []
|
||||
|
||||
# KI arrives while main task is inside a cancelled cancellation scope
|
||||
# the KeyboardInterrupt should take priority
|
||||
print("check 11")
|
||||
|
||||
@_core.enable_ki_protection
|
||||
async def main_8() -> None:
|
||||
assert _core.currently_ki_protected()
|
||||
with _core.CancelScope() as cancel_scope:
|
||||
cancel_scope.cancel()
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await _core.checkpoint()
|
||||
ki_self()
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
await _core.checkpoint()
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await _core.checkpoint()
|
||||
|
||||
_core.run(main_8)
|
||||
|
||||
|
||||
def test_ki_is_good_neighbor() -> None:
|
||||
# in the unlikely event someone overwrites our signal handler, we leave
|
||||
# the overwritten one be
|
||||
try:
|
||||
orig = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def my_handler(signum: object, frame: object) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
async def main() -> None:
|
||||
signal.signal(signal.SIGINT, my_handler)
|
||||
|
||||
_core.run(main)
|
||||
|
||||
assert signal.getsignal(signal.SIGINT) is my_handler
|
||||
finally:
|
||||
signal.signal(signal.SIGINT, orig)
|
||||
|
||||
|
||||
# Regression test for #461
|
||||
# don't know if _active not being visible is a problem
|
||||
def test_ki_with_broken_threads() -> None:
|
||||
thread = threading.main_thread()
|
||||
|
||||
# scary!
|
||||
original = threading._active[thread.ident] # type: ignore[attr-defined]
|
||||
|
||||
# put this in a try finally so we don't have a chance of cascading a
|
||||
# breakage down to everything else
|
||||
try:
|
||||
del threading._active[thread.ident] # type: ignore[attr-defined]
|
||||
|
||||
@_core.enable_ki_protection
|
||||
async def inner() -> None:
|
||||
assert signal.getsignal(signal.SIGINT) != signal.default_int_handler
|
||||
|
||||
_core.run(inner)
|
||||
finally:
|
||||
threading._active[thread.ident] = original # type: ignore[attr-defined]
|
||||
@@ -0,0 +1,118 @@
|
||||
import pytest
|
||||
|
||||
from trio import run
|
||||
from trio.lowlevel import RunVar, RunVarToken
|
||||
|
||||
from ... import _core
|
||||
|
||||
|
||||
# scary runvar tests
|
||||
def test_runvar_smoketest() -> None:
|
||||
t1 = RunVar[str]("test1")
|
||||
t2 = RunVar[str]("test2", default="catfish")
|
||||
|
||||
assert repr(t1) == "<RunVar name='test1'>"
|
||||
|
||||
async def first_check() -> None:
|
||||
with pytest.raises(LookupError):
|
||||
t1.get()
|
||||
|
||||
t1.set("swordfish")
|
||||
assert t1.get() == "swordfish"
|
||||
assert t2.get() == "catfish"
|
||||
assert t2.get(default="eel") == "eel"
|
||||
|
||||
t2.set("goldfish")
|
||||
assert t2.get() == "goldfish"
|
||||
assert t2.get(default="tuna") == "goldfish"
|
||||
|
||||
async def second_check() -> None:
|
||||
with pytest.raises(LookupError):
|
||||
t1.get()
|
||||
|
||||
assert t2.get() == "catfish"
|
||||
|
||||
run(first_check)
|
||||
run(second_check)
|
||||
|
||||
|
||||
def test_runvar_resetting() -> None:
|
||||
t1 = RunVar[str]("test1")
|
||||
t2 = RunVar[str]("test2", default="dogfish")
|
||||
t3 = RunVar[str]("test3")
|
||||
|
||||
async def reset_check() -> None:
|
||||
token = t1.set("moonfish")
|
||||
assert t1.get() == "moonfish"
|
||||
t1.reset(token)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
t1.reset(None) # type: ignore[arg-type]
|
||||
|
||||
with pytest.raises(LookupError):
|
||||
t1.get()
|
||||
|
||||
token2 = t2.set("catdogfish")
|
||||
assert t2.get() == "catdogfish"
|
||||
t2.reset(token2)
|
||||
assert t2.get() == "dogfish"
|
||||
|
||||
with pytest.raises(ValueError, match="^token has already been used$"):
|
||||
t2.reset(token2)
|
||||
|
||||
token3 = t3.set("basculin")
|
||||
assert t3.get() == "basculin"
|
||||
|
||||
with pytest.raises(ValueError, match="^token is not for us$"):
|
||||
t1.reset(token3)
|
||||
|
||||
run(reset_check)
|
||||
|
||||
|
||||
def test_runvar_sync() -> None:
|
||||
t1 = RunVar[str]("test1")
|
||||
|
||||
async def sync_check() -> None:
|
||||
async def task1() -> None:
|
||||
t1.set("plaice")
|
||||
assert t1.get() == "plaice"
|
||||
|
||||
async def task2(tok: RunVarToken[str]) -> None:
|
||||
t1.reset(tok)
|
||||
|
||||
with pytest.raises(LookupError):
|
||||
t1.get()
|
||||
|
||||
t1.set("haddock")
|
||||
|
||||
async with _core.open_nursery() as n:
|
||||
token = t1.set("cod")
|
||||
assert t1.get() == "cod"
|
||||
|
||||
n.start_soon(task1)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
assert t1.get() == "plaice"
|
||||
|
||||
n.start_soon(task2, token)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
assert t1.get() == "haddock"
|
||||
|
||||
run(sync_check)
|
||||
|
||||
|
||||
def test_accessing_runvar_outside_run_call_fails() -> None:
|
||||
t1 = RunVar[str]("test1")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
t1.set("asdf")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
t1.get()
|
||||
|
||||
async def get_token() -> RunVarToken[str]:
|
||||
return t1.set("ok")
|
||||
|
||||
token = run(get_token)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
t1.reset(token)
|
||||
@@ -0,0 +1,175 @@
|
||||
import time
|
||||
from math import inf
|
||||
|
||||
import pytest
|
||||
|
||||
from trio import sleep
|
||||
|
||||
from ... import _core
|
||||
from .. import wait_all_tasks_blocked
|
||||
from .._mock_clock import MockClock
|
||||
from .tutil import slow
|
||||
|
||||
|
||||
def test_mock_clock() -> None:
|
||||
REAL_NOW = 123.0
|
||||
c = MockClock()
|
||||
c._real_clock = lambda: REAL_NOW
|
||||
repr(c) # smoke test
|
||||
assert c.rate == 0
|
||||
assert c.current_time() == 0
|
||||
c.jump(1.2)
|
||||
assert c.current_time() == 1.2
|
||||
with pytest.raises(ValueError, match="^time can't go backwards$"):
|
||||
c.jump(-1)
|
||||
assert c.current_time() == 1.2
|
||||
assert c.deadline_to_sleep_time(1.1) == 0
|
||||
assert c.deadline_to_sleep_time(1.2) == 0
|
||||
assert c.deadline_to_sleep_time(1.3) > 999999
|
||||
|
||||
with pytest.raises(ValueError, match="^rate must be >= 0$"):
|
||||
c.rate = -1
|
||||
assert c.rate == 0
|
||||
|
||||
c.rate = 2
|
||||
assert c.current_time() == 1.2
|
||||
REAL_NOW += 1
|
||||
assert c.current_time() == 3.2
|
||||
assert c.deadline_to_sleep_time(3.1) == 0
|
||||
assert c.deadline_to_sleep_time(3.2) == 0
|
||||
assert c.deadline_to_sleep_time(4.2) == 0.5
|
||||
|
||||
c.rate = 0.5
|
||||
assert c.current_time() == 3.2
|
||||
assert c.deadline_to_sleep_time(3.1) == 0
|
||||
assert c.deadline_to_sleep_time(3.2) == 0
|
||||
assert c.deadline_to_sleep_time(4.2) == 2.0
|
||||
|
||||
c.jump(0.8)
|
||||
assert c.current_time() == 4.0
|
||||
REAL_NOW += 1
|
||||
assert c.current_time() == 4.5
|
||||
|
||||
c2 = MockClock(rate=3)
|
||||
assert c2.rate == 3
|
||||
assert c2.current_time() < 10
|
||||
|
||||
|
||||
async def test_mock_clock_autojump(mock_clock: MockClock) -> None:
|
||||
assert mock_clock.autojump_threshold == inf
|
||||
|
||||
mock_clock.autojump_threshold = 0
|
||||
assert mock_clock.autojump_threshold == 0
|
||||
|
||||
real_start = time.perf_counter()
|
||||
|
||||
virtual_start = _core.current_time()
|
||||
for i in range(10):
|
||||
print(f"sleeping {10 * i} seconds")
|
||||
await sleep(10 * i)
|
||||
print("woke up!")
|
||||
assert virtual_start + 10 * i == _core.current_time()
|
||||
virtual_start = _core.current_time()
|
||||
|
||||
real_duration = time.perf_counter() - real_start
|
||||
print(f"Slept {10 * sum(range(10))} seconds in {real_duration} seconds")
|
||||
assert real_duration < 1
|
||||
|
||||
mock_clock.autojump_threshold = 0.02
|
||||
t = _core.current_time()
|
||||
# this should wake up before the autojump threshold triggers, so time
|
||||
# shouldn't change
|
||||
await wait_all_tasks_blocked()
|
||||
assert t == _core.current_time()
|
||||
# this should too
|
||||
await wait_all_tasks_blocked(0.01)
|
||||
assert t == _core.current_time()
|
||||
|
||||
# set up a situation where the autojump task is blocked for a long long
|
||||
# time, to make sure that cancel-and-adjust-threshold logic is working
|
||||
mock_clock.autojump_threshold = 10000
|
||||
await wait_all_tasks_blocked()
|
||||
mock_clock.autojump_threshold = 0
|
||||
# if the above line didn't take affect immediately, then this would be
|
||||
# bad:
|
||||
await sleep(100000)
|
||||
|
||||
|
||||
async def test_mock_clock_autojump_interference(mock_clock: MockClock) -> None:
|
||||
mock_clock.autojump_threshold = 0.02
|
||||
|
||||
mock_clock2 = MockClock()
|
||||
# messing with the autojump threshold of a clock that isn't actually
|
||||
# installed in the run loop shouldn't do anything.
|
||||
mock_clock2.autojump_threshold = 0.01
|
||||
|
||||
# if the autojump_threshold of 0.01 were in effect, then the next line
|
||||
# would block forever, as the autojump task kept waking up to try to
|
||||
# jump the clock.
|
||||
await wait_all_tasks_blocked(0.015)
|
||||
|
||||
# but the 0.02 limit does apply
|
||||
await sleep(100000)
|
||||
|
||||
|
||||
def test_mock_clock_autojump_preset() -> None:
|
||||
# Check that we can set the autojump_threshold before the clock is
|
||||
# actually in use, and it gets picked up
|
||||
mock_clock = MockClock(autojump_threshold=0.1)
|
||||
mock_clock.autojump_threshold = 0.01
|
||||
real_start = time.perf_counter()
|
||||
_core.run(sleep, 10000, clock=mock_clock)
|
||||
assert time.perf_counter() - real_start < 1
|
||||
|
||||
|
||||
async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(
|
||||
mock_clock: MockClock,
|
||||
) -> None:
|
||||
# Checks that autojump_threshold=0 doesn't interfere with
|
||||
# calling wait_all_tasks_blocked with the default cushion=0.
|
||||
|
||||
mock_clock.autojump_threshold = 0
|
||||
|
||||
record = []
|
||||
|
||||
async def sleeper() -> None:
|
||||
await sleep(100)
|
||||
record.append("yawn")
|
||||
|
||||
async def waiter() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
record.append("waiter woke")
|
||||
await sleep(1000)
|
||||
record.append("waiter done")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sleeper)
|
||||
nursery.start_soon(waiter)
|
||||
|
||||
assert record == ["waiter woke", "yawn", "waiter done"]
|
||||
|
||||
|
||||
@slow
|
||||
async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(
|
||||
mock_clock: MockClock,
|
||||
) -> None:
|
||||
# Checks that autojump_threshold=0 doesn't interfere with
|
||||
# calling wait_all_tasks_blocked with a non-zero cushion.
|
||||
|
||||
mock_clock.autojump_threshold = 0
|
||||
|
||||
record = []
|
||||
|
||||
async def sleeper() -> None:
|
||||
await sleep(100)
|
||||
record.append("yawn")
|
||||
|
||||
async def waiter() -> None:
|
||||
await wait_all_tasks_blocked(1)
|
||||
record.append("waiter done")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sleeper)
|
||||
nursery.start_soon(waiter)
|
||||
|
||||
assert record == ["waiter done", "yawn"]
|
||||
@@ -0,0 +1,384 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TypeVar
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.lowlevel import (
|
||||
add_parking_lot_breaker,
|
||||
current_task,
|
||||
remove_parking_lot_breaker,
|
||||
)
|
||||
from trio.testing import Matcher, RaisesGroup
|
||||
|
||||
from ... import _core
|
||||
from ...testing import wait_all_tasks_blocked
|
||||
from .._parking_lot import ParkingLot
|
||||
from .tutil import check_sequence_matches
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def test_parking_lot_basic() -> None:
|
||||
record = []
|
||||
|
||||
async def waiter(i: int, lot: ParkingLot) -> None:
|
||||
record.append(f"sleep {i}")
|
||||
await lot.park()
|
||||
record.append(f"wake {i}")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
lot = ParkingLot()
|
||||
assert not lot
|
||||
assert len(lot) == 0
|
||||
assert lot.statistics().tasks_waiting == 0
|
||||
for i in range(3):
|
||||
nursery.start_soon(waiter, i, lot)
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(record) == 3
|
||||
assert bool(lot)
|
||||
assert len(lot) == 3
|
||||
assert lot.statistics().tasks_waiting == 3
|
||||
lot.unpark_all()
|
||||
assert lot.statistics().tasks_waiting == 0
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(record) == 6
|
||||
|
||||
check_sequence_matches(
|
||||
record,
|
||||
[{"sleep 0", "sleep 1", "sleep 2"}, {"wake 0", "wake 1", "wake 2"}],
|
||||
)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
record = []
|
||||
for i in range(3):
|
||||
nursery.start_soon(waiter, i, lot)
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(record) == 3
|
||||
for _ in range(3):
|
||||
lot.unpark()
|
||||
await wait_all_tasks_blocked()
|
||||
# 1-by-1 wakeups are strict FIFO
|
||||
assert record == [
|
||||
"sleep 0",
|
||||
"sleep 1",
|
||||
"sleep 2",
|
||||
"wake 0",
|
||||
"wake 1",
|
||||
"wake 2",
|
||||
]
|
||||
|
||||
# It's legal (but a no-op) to try and unpark while there's nothing parked
|
||||
lot.unpark()
|
||||
lot.unpark(count=1)
|
||||
lot.unpark(count=100)
|
||||
|
||||
# Check unpark with count
|
||||
async with _core.open_nursery() as nursery:
|
||||
record = []
|
||||
for i in range(3):
|
||||
nursery.start_soon(waiter, i, lot)
|
||||
await wait_all_tasks_blocked()
|
||||
lot.unpark(count=2)
|
||||
await wait_all_tasks_blocked()
|
||||
check_sequence_matches(
|
||||
record,
|
||||
["sleep 0", "sleep 1", "sleep 2", {"wake 0", "wake 1"}],
|
||||
)
|
||||
lot.unpark_all()
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"^Cannot pop a non-integer number of tasks\.$",
|
||||
):
|
||||
lot.unpark(count=1.5)
|
||||
|
||||
|
||||
async def cancellable_waiter(
|
||||
name: T,
|
||||
lot: ParkingLot,
|
||||
scopes: dict[T, _core.CancelScope],
|
||||
record: list[str],
|
||||
) -> None:
|
||||
with _core.CancelScope() as scope:
|
||||
scopes[name] = scope
|
||||
record.append(f"sleep {name}")
|
||||
try:
|
||||
await lot.park()
|
||||
except _core.Cancelled:
|
||||
record.append(f"cancelled {name}")
|
||||
else:
|
||||
record.append(f"wake {name}")
|
||||
|
||||
|
||||
async def test_parking_lot_cancel() -> None:
|
||||
record: list[str] = []
|
||||
scopes: dict[int, _core.CancelScope] = {}
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
lot = ParkingLot()
|
||||
nursery.start_soon(cancellable_waiter, 1, lot, scopes, record)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.start_soon(cancellable_waiter, 2, lot, scopes, record)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.start_soon(cancellable_waiter, 3, lot, scopes, record)
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(record) == 3
|
||||
|
||||
scopes[2].cancel()
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(record) == 4
|
||||
lot.unpark_all()
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(record) == 6
|
||||
|
||||
check_sequence_matches(
|
||||
record,
|
||||
["sleep 1", "sleep 2", "sleep 3", "cancelled 2", {"wake 1", "wake 3"}],
|
||||
)
|
||||
|
||||
|
||||
async def test_parking_lot_repark() -> None:
|
||||
record: list[str] = []
|
||||
scopes: dict[int, _core.CancelScope] = {}
|
||||
lot1 = ParkingLot()
|
||||
lot2 = ParkingLot()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
lot1.repark([]) # type: ignore[arg-type]
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.start_soon(cancellable_waiter, 2, lot1, scopes, record)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.start_soon(cancellable_waiter, 3, lot1, scopes, record)
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(record) == 3
|
||||
|
||||
assert len(lot1) == 3
|
||||
lot1.repark(lot2)
|
||||
assert len(lot1) == 2
|
||||
assert len(lot2) == 1
|
||||
lot2.unpark_all()
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(record) == 4
|
||||
assert record == ["sleep 1", "sleep 2", "sleep 3", "wake 1"]
|
||||
|
||||
lot1.repark_all(lot2)
|
||||
assert len(lot1) == 0
|
||||
assert len(lot2) == 2
|
||||
|
||||
scopes[2].cancel()
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(lot2) == 1
|
||||
assert record == [
|
||||
"sleep 1",
|
||||
"sleep 2",
|
||||
"sleep 3",
|
||||
"wake 1",
|
||||
"cancelled 2",
|
||||
]
|
||||
|
||||
lot2.unpark_all()
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == [
|
||||
"sleep 1",
|
||||
"sleep 2",
|
||||
"sleep 3",
|
||||
"wake 1",
|
||||
"cancelled 2",
|
||||
"wake 3",
|
||||
]
|
||||
|
||||
|
||||
async def test_parking_lot_repark_with_count() -> None:
|
||||
record: list[str] = []
|
||||
scopes: dict[int, _core.CancelScope] = {}
|
||||
lot1 = ParkingLot()
|
||||
lot2 = ParkingLot()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.start_soon(cancellable_waiter, 2, lot1, scopes, record)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.start_soon(cancellable_waiter, 3, lot1, scopes, record)
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(record) == 3
|
||||
|
||||
assert len(lot1) == 3
|
||||
assert len(lot2) == 0
|
||||
lot1.repark(lot2, count=2)
|
||||
assert len(lot1) == 1
|
||||
assert len(lot2) == 2
|
||||
while lot2:
|
||||
lot2.unpark()
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == [
|
||||
"sleep 1",
|
||||
"sleep 2",
|
||||
"sleep 3",
|
||||
"wake 1",
|
||||
"wake 2",
|
||||
]
|
||||
lot1.unpark_all()
|
||||
|
||||
|
||||
async def dummy_task(
|
||||
task_status: _core.TaskStatus[_core.Task] = trio.TASK_STATUS_IGNORED,
|
||||
) -> None:
|
||||
task_status.started(_core.current_task())
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
async def test_parking_lot_breaker_basic() -> None:
|
||||
"""Test basic functionality for breaking lots."""
|
||||
lot = ParkingLot()
|
||||
task = current_task()
|
||||
|
||||
# defaults to current task
|
||||
lot.break_lot()
|
||||
assert lot.broken_by == [task]
|
||||
|
||||
# breaking the lot again with the same task appends another copy in `broken_by`
|
||||
lot.break_lot()
|
||||
assert lot.broken_by == [task, task]
|
||||
|
||||
# trying to park in broken lot errors
|
||||
broken_by_str = re.escape(str([task, task]))
|
||||
with pytest.raises(
|
||||
_core.BrokenResourceError,
|
||||
match=f"^Attempted to park in parking lot broken by {broken_by_str}$",
|
||||
):
|
||||
await lot.park()
|
||||
|
||||
|
||||
async def test_parking_lot_break_parking_tasks() -> None:
|
||||
"""Checks that tasks currently waiting to park raise an error when the breaker exits."""
|
||||
|
||||
async def bad_parker(lot: ParkingLot, scope: _core.CancelScope) -> None:
|
||||
add_parking_lot_breaker(current_task(), lot)
|
||||
with scope:
|
||||
await trio.sleep_forever()
|
||||
|
||||
lot = ParkingLot()
|
||||
cs = _core.CancelScope()
|
||||
|
||||
# check that parked task errors
|
||||
with RaisesGroup(
|
||||
Matcher(_core.BrokenResourceError, match="^Parking lot broken by"),
|
||||
):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(bad_parker, lot, cs)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
nursery.start_soon(lot.park)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
cs.cancel()
|
||||
|
||||
|
||||
async def test_parking_lot_breaker_registration() -> None:
|
||||
lot = ParkingLot()
|
||||
task = current_task()
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Attempted to remove task as breaker for a lot it is not registered for",
|
||||
):
|
||||
remove_parking_lot_breaker(task, lot)
|
||||
|
||||
# check that a task can be registered as breaker for the same lot multiple times
|
||||
add_parking_lot_breaker(task, lot)
|
||||
add_parking_lot_breaker(task, lot)
|
||||
remove_parking_lot_breaker(task, lot)
|
||||
remove_parking_lot_breaker(task, lot)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Attempted to remove task as breaker for a lot it is not registered for",
|
||||
):
|
||||
remove_parking_lot_breaker(task, lot)
|
||||
|
||||
# registering a task as breaker on an already broken lot is fine
|
||||
lot.break_lot()
|
||||
child_task = None
|
||||
async with trio.open_nursery() as nursery:
|
||||
child_task = await nursery.start(dummy_task)
|
||||
add_parking_lot_breaker(child_task, lot)
|
||||
nursery.cancel_scope.cancel()
|
||||
assert lot.broken_by == [task, child_task]
|
||||
|
||||
# manually breaking a lot with an already exited task is fine
|
||||
lot = ParkingLot()
|
||||
lot.break_lot(child_task)
|
||||
assert lot.broken_by == [child_task]
|
||||
|
||||
|
||||
async def test_parking_lot_breaker_rebreak() -> None:
|
||||
lot = ParkingLot()
|
||||
task = current_task()
|
||||
lot.break_lot()
|
||||
|
||||
# breaking an already broken lot with a different task is allowed
|
||||
# The nursery is only to create a task we can pass to lot.break_lot
|
||||
async with trio.open_nursery() as nursery:
|
||||
child_task = await nursery.start(dummy_task)
|
||||
lot.break_lot(child_task)
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
assert lot.broken_by == [task, child_task]
|
||||
|
||||
|
||||
async def test_parking_lot_multiple_breakers_exit() -> None:
|
||||
# register multiple tasks as lot breakers, then have them all exit
|
||||
lot = ParkingLot()
|
||||
async with trio.open_nursery() as nursery:
|
||||
child_task1 = await nursery.start(dummy_task)
|
||||
child_task2 = await nursery.start(dummy_task)
|
||||
child_task3 = await nursery.start(dummy_task)
|
||||
add_parking_lot_breaker(child_task1, lot)
|
||||
add_parking_lot_breaker(child_task2, lot)
|
||||
add_parking_lot_breaker(child_task3, lot)
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# I think the order is guaranteed currently, but doesn't hurt to be safe.
|
||||
assert set(lot.broken_by) == {child_task1, child_task2, child_task3}
|
||||
|
||||
|
||||
async def test_parking_lot_breaker_register_exited_task() -> None:
|
||||
lot = ParkingLot()
|
||||
child_task = None
|
||||
async with trio.open_nursery() as nursery:
|
||||
child_task = await nursery.start(dummy_task)
|
||||
nursery.cancel_scope.cancel()
|
||||
# trying to register an exited task as lot breaker errors
|
||||
with pytest.raises(
|
||||
trio.BrokenResourceError,
|
||||
match="^Attempted to add already exited task as lot breaker.$",
|
||||
):
|
||||
add_parking_lot_breaker(child_task, lot)
|
||||
|
||||
|
||||
async def test_parking_lot_break_itself() -> None:
|
||||
"""Break a parking lot, where the breakee is parked.
|
||||
Doing this is weird, but should probably be supported.
|
||||
"""
|
||||
|
||||
async def return_me_and_park(
|
||||
lot: ParkingLot,
|
||||
*,
|
||||
task_status: _core.TaskStatus[_core.Task] = trio.TASK_STATUS_IGNORED,
|
||||
) -> None:
|
||||
task_status.started(_core.current_task())
|
||||
await lot.park()
|
||||
|
||||
lot = ParkingLot()
|
||||
with RaisesGroup(
|
||||
Matcher(_core.BrokenResourceError, match="^Parking lot broken by"),
|
||||
):
|
||||
async with _core.open_nursery() as nursery:
|
||||
child_task = await nursery.start(return_me_and_park, lot)
|
||||
lot.break_lot(child_task)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,195 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Iterator, NoReturn
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import _thread_cache
|
||||
from .._thread_cache import ThreadCache, start_thread_soon
|
||||
from .tutil import gc_collect_harder, slow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from outcome import Outcome
|
||||
|
||||
|
||||
def test_thread_cache_basics() -> None:
|
||||
q: Queue[Outcome[object]] = Queue()
|
||||
|
||||
def fn() -> NoReturn:
|
||||
raise RuntimeError("hi")
|
||||
|
||||
def deliver(outcome: Outcome[object]) -> None:
|
||||
q.put(outcome)
|
||||
|
||||
start_thread_soon(fn, deliver)
|
||||
|
||||
outcome = q.get()
|
||||
with pytest.raises(RuntimeError, match="hi"):
|
||||
outcome.unwrap()
|
||||
|
||||
|
||||
def test_thread_cache_deref() -> None:
|
||||
res = [False]
|
||||
|
||||
class del_me:
|
||||
def __call__(self) -> int:
|
||||
return 42
|
||||
|
||||
def __del__(self) -> None:
|
||||
res[0] = True
|
||||
|
||||
q: Queue[Outcome[int]] = Queue()
|
||||
|
||||
def deliver(outcome: Outcome[int]) -> None:
|
||||
q.put(outcome)
|
||||
|
||||
start_thread_soon(del_me(), deliver)
|
||||
outcome = q.get()
|
||||
assert outcome.unwrap() == 42
|
||||
|
||||
gc_collect_harder()
|
||||
assert res[0]
|
||||
|
||||
|
||||
@slow
|
||||
def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None:
|
||||
# We know that no-one else is using the thread cache, so if we keep
|
||||
# submitting new jobs the instant the previous one is finished, we should
|
||||
# keep getting the same thread over and over. This tests both that the
|
||||
# thread cache is LIFO, and that threads can be assigned new work *before*
|
||||
# deliver exits.
|
||||
|
||||
# Make sure there are a few threads running, so if we weren't LIFO then we
|
||||
# could grab the wrong one.
|
||||
q: Queue[Outcome[object]] = Queue()
|
||||
COUNT = 5
|
||||
for _ in range(COUNT):
|
||||
start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result))
|
||||
for _ in range(COUNT):
|
||||
q.get().unwrap()
|
||||
|
||||
seen_threads = set()
|
||||
done = threading.Event()
|
||||
|
||||
def deliver(n: int, _: object) -> None:
|
||||
print(n)
|
||||
seen_threads.add(threading.current_thread())
|
||||
if n == 0:
|
||||
done.set()
|
||||
else:
|
||||
start_thread_soon(lambda: None, lambda _: deliver(n - 1, _))
|
||||
|
||||
start_thread_soon(lambda: None, lambda _: deliver(5, _))
|
||||
|
||||
done.wait()
|
||||
|
||||
assert len(seen_threads) == 1
|
||||
|
||||
|
||||
@slow
|
||||
def test_idle_threads_exit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Temporarily set the idle timeout to something tiny, to speed up the
|
||||
# test. (But non-zero, so that the worker loop will at least yield the
|
||||
# CPU.)
|
||||
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
|
||||
|
||||
q: Queue[threading.Thread] = Queue()
|
||||
start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread()))
|
||||
seen_thread = q.get()
|
||||
# Since the idle timeout is 0, after sleeping for 1 second, the thread
|
||||
# should have exited
|
||||
time.sleep(1)
|
||||
assert not seen_thread.is_alive()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _join_started_threads() -> Iterator[None]:
|
||||
before = frozenset(threading.enumerate())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for thread in threading.enumerate():
|
||||
if thread not in before:
|
||||
thread.join(timeout=1.0)
|
||||
assert not thread.is_alive()
|
||||
|
||||
|
||||
def test_race_between_idle_exit_and_job_assignment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# This is a lock where the first few times you try to acquire it with a
|
||||
# timeout, it waits until the lock is available and then pretends to time
|
||||
# out. Using this in our thread cache implementation causes the following
|
||||
# sequence:
|
||||
#
|
||||
# 1. start_thread_soon grabs the worker thread, assigns it a job, and
|
||||
# releases its lock.
|
||||
# 2. The worker thread wakes up (because the lock has been released), but
|
||||
# the JankyLock lies to it and tells it that the lock timed out. So the
|
||||
# worker thread tries to exit.
|
||||
# 3. The worker thread checks for the race between exiting and being
|
||||
# assigned a job, and discovers that it *is* in the process of being
|
||||
# assigned a job, so it loops around and tries to acquire the lock
|
||||
# again.
|
||||
# 4. Eventually the JankyLock admits that the lock is available, and
|
||||
# everything proceeds as normal.
|
||||
|
||||
class JankyLock:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._counter = 3
|
||||
|
||||
def acquire(self, timeout: int = -1) -> bool:
|
||||
got_it = self._lock.acquire(timeout=timeout)
|
||||
if timeout == -1:
|
||||
return True
|
||||
elif got_it:
|
||||
if self._counter > 0:
|
||||
self._counter -= 1
|
||||
self._lock.release()
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def release(self) -> None:
|
||||
self._lock.release()
|
||||
|
||||
monkeypatch.setattr(_thread_cache, "Lock", JankyLock)
|
||||
|
||||
with _join_started_threads():
|
||||
tc = ThreadCache()
|
||||
done = threading.Event()
|
||||
tc.start_thread_soon(lambda: None, lambda _: done.set())
|
||||
done.wait()
|
||||
# Let's kill the thread we started, so it doesn't hang around until the
|
||||
# test suite finishes. Doesn't really do any harm, but it can be confusing
|
||||
# to see it in debug output.
|
||||
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
|
||||
tc.start_thread_soon(lambda: None, lambda _: None)
|
||||
|
||||
|
||||
def test_raise_in_deliver(capfd: pytest.CaptureFixture[str]) -> None:
|
||||
seen_threads = set()
|
||||
|
||||
def track_threads() -> None:
|
||||
seen_threads.add(threading.current_thread())
|
||||
|
||||
def deliver(_: object) -> NoReturn:
|
||||
done.set()
|
||||
raise RuntimeError("don't do this")
|
||||
|
||||
done = threading.Event()
|
||||
start_thread_soon(track_threads, deliver)
|
||||
done.wait()
|
||||
done = threading.Event()
|
||||
start_thread_soon(track_threads, lambda _: done.set())
|
||||
done.wait()
|
||||
assert len(seen_threads) == 1
|
||||
err = capfd.readouterr().err
|
||||
assert "don't do this" in err
|
||||
assert "delivering result" in err
|
||||
@@ -0,0 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from .tutil import check_sequence_matches
|
||||
|
||||
|
||||
def test_check_sequence_matches() -> None:
|
||||
check_sequence_matches([1, 2, 3], [1, 2, 3])
|
||||
with pytest.raises(AssertionError):
|
||||
check_sequence_matches([1, 3, 2], [1, 2, 3])
|
||||
check_sequence_matches([1, 2, 3, 4], [1, {2, 3}, 4])
|
||||
check_sequence_matches([1, 3, 2, 4], [1, {2, 3}, 4])
|
||||
with pytest.raises(AssertionError):
|
||||
check_sequence_matches([1, 2, 4, 3], [1, {2, 3}, 4])
|
||||
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
|
||||
from ... import _core
|
||||
from ...testing import assert_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:.*UnboundedQueue:trio.TrioDeprecationWarning",
|
||||
)
|
||||
|
||||
|
||||
async def test_UnboundedQueue_basic() -> None:
|
||||
q: _core.UnboundedQueue[str | int | None] = _core.UnboundedQueue()
|
||||
q.put_nowait("hi")
|
||||
assert await q.get_batch() == ["hi"]
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
q.get_batch_nowait()
|
||||
q.put_nowait(1)
|
||||
q.put_nowait(2)
|
||||
q.put_nowait(3)
|
||||
assert q.get_batch_nowait() == [1, 2, 3]
|
||||
|
||||
assert q.empty()
|
||||
assert q.qsize() == 0
|
||||
q.put_nowait(None)
|
||||
assert not q.empty()
|
||||
assert q.qsize() == 1
|
||||
|
||||
stats = q.statistics()
|
||||
assert stats.qsize == 1
|
||||
assert stats.tasks_waiting == 0
|
||||
|
||||
# smoke test
|
||||
repr(q)
|
||||
|
||||
|
||||
async def test_UnboundedQueue_blocking() -> None:
|
||||
record = []
|
||||
q = _core.UnboundedQueue[int]()
|
||||
|
||||
async def get_batch_consumer() -> None:
|
||||
while True:
|
||||
batch = await q.get_batch()
|
||||
assert batch
|
||||
record.append(batch)
|
||||
|
||||
async def aiter_consumer() -> None:
|
||||
async for batch in q:
|
||||
assert batch
|
||||
record.append(batch)
|
||||
|
||||
for consumer in (get_batch_consumer, aiter_consumer):
|
||||
record.clear()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(consumer)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
stats = q.statistics()
|
||||
assert stats.qsize == 0
|
||||
assert stats.tasks_waiting == 1
|
||||
q.put_nowait(10)
|
||||
q.put_nowait(11)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
q.put_nowait(12)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
assert record == [[10, 11], [12]]
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_UnboundedQueue_fairness() -> None:
|
||||
q = _core.UnboundedQueue[int]()
|
||||
|
||||
# If there's no-one else around, we can put stuff in and take it out
|
||||
# again, no problem
|
||||
q.put_nowait(1)
|
||||
q.put_nowait(2)
|
||||
assert q.get_batch_nowait() == [1, 2]
|
||||
|
||||
result = None
|
||||
|
||||
async def get_batch(q: _core.UnboundedQueue[int]) -> None:
|
||||
nonlocal result
|
||||
result = await q.get_batch()
|
||||
|
||||
# But if someone else is waiting to read, then they get dibs
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(get_batch, q)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
q.put_nowait(3)
|
||||
q.put_nowait(4)
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
q.get_batch_nowait()
|
||||
assert result == [3, 4]
|
||||
|
||||
# If two tasks are trying to read, they alternate
|
||||
record = []
|
||||
|
||||
async def reader(name: str) -> None:
|
||||
while True:
|
||||
record.append((name, await q.get_batch()))
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(reader, "a")
|
||||
await _core.wait_all_tasks_blocked()
|
||||
nursery.start_soon(reader, "b")
|
||||
await _core.wait_all_tasks_blocked()
|
||||
|
||||
for i in range(20):
|
||||
q.put_nowait(i)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
assert record == list(zip(itertools.cycle("ab"), [[i] for i in range(20)]))
|
||||
|
||||
|
||||
async def test_UnboundedQueue_trivial_yields() -> None:
|
||||
q = _core.UnboundedQueue[None]()
|
||||
|
||||
q.put_nowait(None)
|
||||
with assert_checkpoints():
|
||||
await q.get_batch()
|
||||
|
||||
q.put_nowait(None)
|
||||
with assert_checkpoints():
|
||||
async for _ in q: # pragma: no branch
|
||||
break
|
||||
|
||||
|
||||
async def test_UnboundedQueue_no_spurious_wakeups() -> None:
|
||||
# If we have two tasks waiting, and put two items into the queue... then
|
||||
# only one task wakes up
|
||||
record = []
|
||||
|
||||
async def getter(q: _core.UnboundedQueue[int], i: int) -> None:
|
||||
got = await q.get_batch()
|
||||
record.append((i, got))
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
q = _core.UnboundedQueue[int]()
|
||||
nursery.start_soon(getter, q, 1)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.start_soon(getter, q, 2)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
for i in range(10):
|
||||
q.put_nowait(i)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
assert record == [(1, list(range(10)))]
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
@@ -0,0 +1,299 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import create_autospec
|
||||
|
||||
import pytest
|
||||
|
||||
on_windows = os.name == "nt"
|
||||
# Mark all the tests in this file as being windows-only
|
||||
pytestmark = pytest.mark.skipif(not on_windows, reason="windows only")
|
||||
|
||||
assert (
|
||||
sys.platform == "win32" or not TYPE_CHECKING
|
||||
) # Skip type checking when not on Windows
|
||||
|
||||
from ... import _core, sleep
|
||||
from ...testing import wait_all_tasks_blocked
|
||||
from .tutil import gc_collect_harder, restore_unraisablehook, slow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from io import BufferedWriter
|
||||
|
||||
if on_windows:
|
||||
from .._windows_cffi import (
|
||||
INVALID_HANDLE_VALUE,
|
||||
FileFlags,
|
||||
Handle,
|
||||
ffi,
|
||||
kernel32,
|
||||
raise_winerror,
|
||||
)
|
||||
|
||||
|
||||
def test_winerror(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
mock = create_autospec(ffi.getwinerror)
|
||||
monkeypatch.setattr(ffi, "getwinerror", mock)
|
||||
|
||||
# Returning none = no error, should not happen.
|
||||
mock.return_value = None
|
||||
with pytest.raises(RuntimeError, match=r"^No error set\?$"):
|
||||
raise_winerror()
|
||||
mock.assert_called_once_with()
|
||||
mock.reset_mock()
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"^No error set\?$"):
|
||||
raise_winerror(38)
|
||||
mock.assert_called_once_with(38)
|
||||
mock.reset_mock()
|
||||
|
||||
mock.return_value = (12, "test error")
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^\[WinError 12\] test error: 'file_1' -> 'file_2'$",
|
||||
) as exc:
|
||||
raise_winerror(filename="file_1", filename2="file_2")
|
||||
mock.assert_called_once_with()
|
||||
mock.reset_mock()
|
||||
assert exc.value.winerror == 12
|
||||
assert exc.value.strerror == "test error"
|
||||
assert exc.value.filename == "file_1"
|
||||
assert exc.value.filename2 == "file_2"
|
||||
|
||||
# With an explicit number passed in, it overrides what getwinerror() returns.
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^\[WinError 18\] test error: 'a/file' -> 'b/file'$",
|
||||
) as exc:
|
||||
raise_winerror(18, filename="a/file", filename2="b/file")
|
||||
mock.assert_called_once_with(18)
|
||||
mock.reset_mock()
|
||||
assert exc.value.winerror == 18
|
||||
assert exc.value.strerror == "test error"
|
||||
assert exc.value.filename == "a/file"
|
||||
assert exc.value.filename2 == "b/file"
|
||||
|
||||
|
||||
# The undocumented API that this is testing should be changed to stop using
|
||||
# UnboundedQueue (or just removed until we have time to redo it), but until
|
||||
# then we filter out the warning.
|
||||
@pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning")
|
||||
async def test_completion_key_listen() -> None:
|
||||
from .. import _io_windows
|
||||
|
||||
async def post(key: int) -> None:
|
||||
iocp = Handle(ffi.cast("HANDLE", _core.current_iocp()))
|
||||
for i in range(10):
|
||||
print("post", i)
|
||||
if i % 3 == 0:
|
||||
await _core.checkpoint()
|
||||
success = kernel32.PostQueuedCompletionStatus(iocp, i, key, ffi.NULL)
|
||||
assert success
|
||||
|
||||
with _core.monitor_completion_key() as (key, queue):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(post, key)
|
||||
i = 0
|
||||
print("loop")
|
||||
async for batch in queue: # pragma: no branch
|
||||
print("got some", batch)
|
||||
for info in batch:
|
||||
assert isinstance(info, _io_windows.CompletionKeyEventInfo)
|
||||
assert info.lpOverlapped == 0
|
||||
assert info.dwNumberOfBytesTransferred == i
|
||||
i += 1
|
||||
if i == 10:
|
||||
break
|
||||
print("end loop")
|
||||
|
||||
|
||||
async def test_readinto_overlapped() -> None:
|
||||
data = b"1" * 1024 + b"2" * 1024 + b"3" * 1024 + b"4" * 1024
|
||||
buffer = bytearray(len(data))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tdir:
|
||||
tfile = os.path.join(tdir, "numbers.txt")
|
||||
with open( # noqa: ASYNC230 # This is a test, synchronous is ok
|
||||
tfile,
|
||||
"wb",
|
||||
) as fp:
|
||||
fp.write(data)
|
||||
fp.flush()
|
||||
|
||||
rawname = tfile.encode("utf-16le") + b"\0\0"
|
||||
rawname_buf = ffi.from_buffer(rawname)
|
||||
handle = kernel32.CreateFileW(
|
||||
ffi.cast("LPCWSTR", rawname_buf),
|
||||
FileFlags.GENERIC_READ,
|
||||
FileFlags.FILE_SHARE_READ,
|
||||
ffi.NULL, # no security attributes
|
||||
FileFlags.OPEN_EXISTING,
|
||||
FileFlags.FILE_FLAG_OVERLAPPED,
|
||||
ffi.NULL, # no template file
|
||||
)
|
||||
if handle == INVALID_HANDLE_VALUE: # pragma: no cover
|
||||
raise_winerror()
|
||||
|
||||
try:
|
||||
with memoryview(buffer) as buffer_view:
|
||||
|
||||
async def read_region(start: int, end: int) -> None:
|
||||
await _core.readinto_overlapped(
|
||||
handle,
|
||||
buffer_view[start:end],
|
||||
start,
|
||||
)
|
||||
|
||||
_core.register_with_iocp(handle)
|
||||
async with _core.open_nursery() as nursery:
|
||||
for start in range(0, 4096, 512):
|
||||
nursery.start_soon(read_region, start, start + 512)
|
||||
|
||||
assert buffer == data
|
||||
|
||||
with pytest.raises((BufferError, TypeError)):
|
||||
await _core.readinto_overlapped(handle, b"immutable")
|
||||
finally:
|
||||
kernel32.CloseHandle(handle)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pipe_with_overlapped_read() -> Generator[tuple[BufferedWriter, int], None, None]:
|
||||
import msvcrt
|
||||
from asyncio.windows_utils import pipe
|
||||
|
||||
read_handle, write_handle = pipe(overlapped=(True, False))
|
||||
try:
|
||||
write_fd = msvcrt.open_osfhandle(write_handle, 0)
|
||||
yield os.fdopen(write_fd, "wb", closefd=False), read_handle
|
||||
finally:
|
||||
kernel32.CloseHandle(Handle(ffi.cast("HANDLE", read_handle)))
|
||||
kernel32.CloseHandle(Handle(ffi.cast("HANDLE", write_handle)))
|
||||
|
||||
|
||||
@restore_unraisablehook()
|
||||
def test_forgot_to_register_with_iocp() -> None:
|
||||
with pipe_with_overlapped_read() as (write_fp, read_handle):
|
||||
with write_fp:
|
||||
write_fp.write(b"test\n")
|
||||
|
||||
left_run_yet = False
|
||||
|
||||
async def main() -> None:
|
||||
target = bytearray(1)
|
||||
try:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
_core.readinto_overlapped,
|
||||
read_handle,
|
||||
target,
|
||||
name="xyz",
|
||||
)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
finally:
|
||||
# Run loop is exited without unwinding running tasks, so
|
||||
# we don't get here until the main() coroutine is GC'ed
|
||||
assert left_run_yet
|
||||
|
||||
with pytest.raises(_core.TrioInternalError) as exc_info:
|
||||
_core.run(main)
|
||||
left_run_yet = True
|
||||
assert "Failed to cancel overlapped I/O in xyz " in str(exc_info.value)
|
||||
assert "forget to call register_with_iocp()?" in str(exc_info.value)
|
||||
|
||||
# Make sure the Nursery.__del__ assertion about dangling children
|
||||
# gets put with the correct test
|
||||
del exc_info
|
||||
gc_collect_harder()
|
||||
|
||||
|
||||
@slow
|
||||
async def test_too_late_to_cancel() -> None:
|
||||
import time
|
||||
|
||||
with pipe_with_overlapped_read() as (write_fp, read_handle):
|
||||
_core.register_with_iocp(read_handle)
|
||||
target = bytearray(6)
|
||||
async with _core.open_nursery() as nursery:
|
||||
# Start an async read in the background
|
||||
nursery.start_soon(_core.readinto_overlapped, read_handle, target)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# Synchronous write to the other end of the pipe
|
||||
with write_fp:
|
||||
write_fp.write(b"test1\ntest2\n")
|
||||
|
||||
# Note: not trio.sleep! We're making sure the OS level
|
||||
# ReadFile completes, before Trio has a chance to execute
|
||||
# another checkpoint and notice it completed.
|
||||
time.sleep(1) # noqa: ASYNC251
|
||||
nursery.cancel_scope.cancel()
|
||||
assert target[:6] == b"test1\n"
|
||||
|
||||
# Do another I/O to make sure we've actually processed the
|
||||
# fallback completion that was posted when CancelIoEx failed.
|
||||
assert await _core.readinto_overlapped(read_handle, target) == 6
|
||||
assert target[:6] == b"test2\n"
|
||||
|
||||
|
||||
def test_lsp_that_hooks_select_gives_good_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from .. import _io_windows
|
||||
from .._windows_cffi import CData, WSAIoctls, _handle
|
||||
|
||||
def patched_get_underlying(
|
||||
sock: int | CData,
|
||||
*,
|
||||
which: int = WSAIoctls.SIO_BASE_HANDLE,
|
||||
) -> CData:
|
||||
if hasattr(sock, "fileno"): # pragma: no branch
|
||||
sock = sock.fileno()
|
||||
if which == WSAIoctls.SIO_BSP_HANDLE_SELECT:
|
||||
return _handle(sock + 1)
|
||||
else:
|
||||
return _handle(sock)
|
||||
|
||||
monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying)
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ",
|
||||
):
|
||||
_core.run(sleep, 0)
|
||||
|
||||
|
||||
def test_lsp_that_completely_hides_base_socket_gives_good_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns
|
||||
# self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns
|
||||
# self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to
|
||||
# make sure we get an error rather than an infinite loop.
|
||||
|
||||
from .. import _io_windows
|
||||
from .._windows_cffi import CData, WSAIoctls, _handle
|
||||
|
||||
def patched_get_underlying(
|
||||
sock: int | CData,
|
||||
*,
|
||||
which: int = WSAIoctls.SIO_BASE_HANDLE,
|
||||
) -> CData:
|
||||
if hasattr(sock, "fileno"): # pragma: no branch
|
||||
sock = sock.fileno()
|
||||
if which == WSAIoctls.SIO_BASE_HANDLE:
|
||||
raise OSError("nope")
|
||||
else:
|
||||
return _handle(sock)
|
||||
|
||||
monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying)
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a diff",
|
||||
):
|
||||
_core.run(sleep, 0)
|
||||
@@ -0,0 +1,117 @@
|
||||
# Utilities for testing
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import os
|
||||
import socket as stdlib_socket
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import closing, contextmanager
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import pytest
|
||||
|
||||
# See trio/_tests/conftest.py for the other half of this
|
||||
from trio._tests.pytest_plugin import RUN_SLOW
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
|
||||
slow = pytest.mark.skipif(not RUN_SLOW, reason="use --run-slow to run slow tests")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
try:
|
||||
s = stdlib_socket.socket(stdlib_socket.AF_INET6, stdlib_socket.SOCK_STREAM, 0)
|
||||
except OSError: # pragma: no cover
|
||||
# Some systems don't even support creating an IPv6 socket, let alone
|
||||
# binding it. (ex: Linux with 'ipv6.disable=1' in the kernel command line)
|
||||
# We don't have any of those in our CI, and there's nothing that gets
|
||||
# tested _only_ if can_create_ipv6 = False, so we'll just no-cover this.
|
||||
can_create_ipv6 = False
|
||||
can_bind_ipv6 = False
|
||||
else:
|
||||
can_create_ipv6 = True
|
||||
with s:
|
||||
try:
|
||||
s.bind(("::1", 0))
|
||||
except OSError: # pragma: no cover # since support for 3.7 was removed
|
||||
can_bind_ipv6 = False
|
||||
else:
|
||||
can_bind_ipv6 = True
|
||||
|
||||
creates_ipv6 = pytest.mark.skipif(not can_create_ipv6, reason="need IPv6")
|
||||
binds_ipv6 = pytest.mark.skipif(not can_bind_ipv6, reason="need IPv6")
|
||||
|
||||
|
||||
def gc_collect_harder() -> None:
|
||||
# In the test suite we sometimes want to call gc.collect() to make sure
|
||||
# that any objects with noisy __del__ methods (e.g. unawaited coroutines)
|
||||
# get collected before we continue, so their noise doesn't leak into
|
||||
# unrelated tests.
|
||||
#
|
||||
# On PyPy, coroutine objects (for example) can survive at least 1 round of
|
||||
# garbage collection, because executing their __del__ method to print the
|
||||
# warning can cause them to be resurrected. So we call collect a few times
|
||||
# to make sure.
|
||||
for _ in range(5):
|
||||
gc.collect()
|
||||
|
||||
|
||||
# Some of our tests need to leak coroutines, and thus trigger the
|
||||
# "RuntimeWarning: coroutine '...' was never awaited" message. This context
|
||||
# manager should be used anywhere this happens to hide those messages, because
|
||||
# when expected they're clutter.
|
||||
@contextmanager
|
||||
def ignore_coroutine_never_awaited_warnings() -> Generator[None, None, None]:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message="coroutine '.*' was never awaited")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Make sure to trigger any coroutine __del__ methods now, before
|
||||
# we leave the context manager.
|
||||
gc_collect_harder()
|
||||
|
||||
|
||||
def _noop(*args: object, **kwargs: object) -> None:
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
@contextmanager
|
||||
def restore_unraisablehook() -> Generator[None, None, None]:
|
||||
sys.unraisablehook, prev = sys.__unraisablehook__, sys.unraisablehook
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.unraisablehook = prev
|
||||
|
||||
|
||||
# Used to check sequences that might have some elements out of order.
|
||||
# Example usage:
|
||||
# The sequences [1, 2.1, 2.2, 3] and [1, 2.2, 2.1, 3] are both
|
||||
# matched by the template [1, {2.1, 2.2}, 3]
|
||||
def check_sequence_matches(seq: Sequence[T], template: Iterable[T | set[T]]) -> None:
|
||||
i = 0
|
||||
for pattern in template:
|
||||
if not isinstance(pattern, set):
|
||||
pattern = {pattern}
|
||||
got = set(seq[i : i + len(pattern)])
|
||||
assert got == pattern
|
||||
i += len(got)
|
||||
|
||||
|
||||
# https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=246350
|
||||
skip_if_fbsd_pipes_broken = pytest.mark.skipif(
|
||||
sys.platform != "win32" # prevent mypy from complaining about missing uname
|
||||
and hasattr(os, "uname")
|
||||
and os.uname().sysname == "FreeBSD"
|
||||
and os.uname().release[:4] < "12.2",
|
||||
reason="hangs on FreeBSD 12.1 and earlier, due to FreeBSD bug #246350",
|
||||
)
|
||||
|
||||
|
||||
def create_asyncio_future_in_new_loop() -> asyncio.Future[object]:
|
||||
with closing(asyncio.new_event_loop()) as loop:
|
||||
return loop.create_future()
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Test variadic generic typing for Nursery.start[_soon]()."""
|
||||
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from trio import TASK_STATUS_IGNORED, Nursery, TaskStatus
|
||||
|
||||
|
||||
async def task_0() -> None: ...
|
||||
|
||||
|
||||
async def task_1a(value: int) -> None: ...
|
||||
|
||||
|
||||
async def task_1b(value: str) -> None: ...
|
||||
|
||||
|
||||
async def task_2a(a: int, b: str) -> None: ...
|
||||
|
||||
|
||||
async def task_2b(a: str, b: int) -> None: ...
|
||||
|
||||
|
||||
async def task_2c(a: str, b: int, optional: bool = False) -> None: ...
|
||||
|
||||
|
||||
async def task_requires_kw(a: int, *, b: bool) -> None: ...
|
||||
|
||||
|
||||
async def task_startable_1(
|
||||
a: str,
|
||||
*,
|
||||
task_status: TaskStatus[bool] = TASK_STATUS_IGNORED,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
async def task_startable_2(
|
||||
a: str,
|
||||
b: float,
|
||||
*,
|
||||
task_status: TaskStatus[bool] = TASK_STATUS_IGNORED,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
async def task_requires_start(*, task_status: TaskStatus[str]) -> None:
|
||||
"""Check a function requiring start() to be used."""
|
||||
|
||||
|
||||
async def task_pos_or_kw(value: str, task_status: TaskStatus[int]) -> None:
|
||||
"""Check a function which doesn't use the *-syntax works."""
|
||||
|
||||
|
||||
def check_start_soon(nursery: Nursery) -> None:
|
||||
"""start_soon() functionality."""
|
||||
nursery.start_soon(task_0)
|
||||
nursery.start_soon(task_1a) # type: ignore
|
||||
nursery.start_soon(task_2b) # type: ignore
|
||||
|
||||
nursery.start_soon(task_0, 45) # type: ignore
|
||||
nursery.start_soon(task_1a, 32)
|
||||
nursery.start_soon(task_1b, 32) # type: ignore
|
||||
nursery.start_soon(task_1a, "abc") # type: ignore
|
||||
nursery.start_soon(task_1b, "abc")
|
||||
|
||||
nursery.start_soon(task_2b, "abc") # type: ignore
|
||||
nursery.start_soon(task_2a, 38, "46")
|
||||
nursery.start_soon(task_2c, "abc", 12, True)
|
||||
|
||||
nursery.start_soon(task_2c, "abc", 12)
|
||||
task_2c_cast: Callable[[str, int], Awaitable[object]] = (
|
||||
task_2c # The assignment makes it work.
|
||||
)
|
||||
nursery.start_soon(task_2c_cast, "abc", 12)
|
||||
|
||||
nursery.start_soon(task_requires_kw, 12, True) # type: ignore
|
||||
# Tasks following the start() API can be made to work.
|
||||
nursery.start_soon(task_startable_1, "cdf")
|
||||
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, overload
|
||||
|
||||
import trio
|
||||
from typing_extensions import assert_type
|
||||
|
||||
|
||||
async def sleep_sort(values: Sequence[float]) -> list[float]:
|
||||
return [1]
|
||||
|
||||
|
||||
async def has_optional(arg: int | None = None) -> int:
|
||||
return 5
|
||||
|
||||
|
||||
@overload
|
||||
async def foo_overloaded(arg: int) -> str: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def foo_overloaded(arg: str) -> int: ...
|
||||
|
||||
|
||||
async def foo_overloaded(arg: int | str) -> int | str:
|
||||
if isinstance(arg, str):
|
||||
return 5
|
||||
return "hello"
|
||||
|
||||
|
||||
v = trio.run(
|
||||
sleep_sort,
|
||||
(1, 3, 5, 2, 4),
|
||||
clock=trio.testing.MockClock(autojump_threshold=0),
|
||||
)
|
||||
assert_type(v, "list[float]")
|
||||
trio.run(sleep_sort, ["hi", "there"]) # type: ignore[arg-type]
|
||||
trio.run(sleep_sort) # type: ignore[arg-type]
|
||||
|
||||
r = trio.run(has_optional)
|
||||
assert_type(r, int)
|
||||
r = trio.run(has_optional, 5)
|
||||
trio.run(has_optional, 7, 8) # type: ignore[arg-type]
|
||||
trio.run(has_optional, "hello") # type: ignore[arg-type]
|
||||
|
||||
|
||||
assert_type(trio.run(foo_overloaded, 5), str)
|
||||
assert_type(trio.run(foo_overloaded, ""), int)
|
||||
@@ -0,0 +1,295 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
import sys
|
||||
import traceback
|
||||
from functools import partial
|
||||
from itertools import count
|
||||
from threading import Lock, Thread
|
||||
from typing import Any, Callable, Generic, TypeVar
|
||||
|
||||
import outcome
|
||||
|
||||
RetT = TypeVar("RetT")
|
||||
|
||||
|
||||
def _to_os_thread_name(name: str) -> bytes:
|
||||
# ctypes handles the trailing \00
|
||||
return name.encode("ascii", errors="replace")[:15]
|
||||
|
||||
|
||||
# used to construct the method used to set os thread name, or None, depending on platform.
|
||||
# called once on import
|
||||
def get_os_thread_name_func() -> Callable[[int | None, str], None] | None:
|
||||
def namefunc(
|
||||
setname: Callable[[int, bytes], int],
|
||||
ident: int | None,
|
||||
name: str,
|
||||
) -> None:
|
||||
# Thread.ident is None "if it has not been started". Unclear if that can happen
|
||||
# with current usage.
|
||||
if ident is not None: # pragma: no cover
|
||||
setname(ident, _to_os_thread_name(name))
|
||||
|
||||
# namefunc on Mac also takes an ident, even if pthread_setname_np doesn't/can't use it
|
||||
# so the caller don't need to care about platform.
|
||||
def darwin_namefunc(
|
||||
setname: Callable[[bytes], int],
|
||||
ident: int | None,
|
||||
name: str,
|
||||
) -> None:
|
||||
# I don't know if Mac can rename threads that hasn't been started, but default
|
||||
# to no to be on the safe side.
|
||||
if ident is not None: # pragma: no cover
|
||||
setname(_to_os_thread_name(name))
|
||||
|
||||
# find the pthread library
|
||||
# this will fail on windows and musl
|
||||
libpthread_path = ctypes.util.find_library("pthread")
|
||||
if not libpthread_path:
|
||||
# musl includes pthread functions directly in libc.so
|
||||
# (but note that find_library("c") does not work on musl,
|
||||
# see: https://github.com/python/cpython/issues/65821)
|
||||
# so try that library instead
|
||||
# if it doesn't exist, CDLL() will fail below
|
||||
libpthread_path = "libc.so"
|
||||
|
||||
# Sometimes windows can find the path, but gives a permission error when
|
||||
# accessing it. Catching a wider exception in case of more esoteric errors.
|
||||
# https://github.com/python-trio/trio/issues/2688
|
||||
try:
|
||||
libpthread = ctypes.CDLL(libpthread_path)
|
||||
except Exception: # pragma: no cover
|
||||
return None
|
||||
|
||||
# get the setname method from it
|
||||
# afaik this should never fail
|
||||
pthread_setname_np = getattr(libpthread, "pthread_setname_np", None)
|
||||
if pthread_setname_np is None: # pragma: no cover
|
||||
return None
|
||||
|
||||
# specify function prototype
|
||||
pthread_setname_np.restype = ctypes.c_int
|
||||
|
||||
# on mac OSX pthread_setname_np does not take a thread id,
|
||||
# it only lets threads name themselves, which is not a problem for us.
|
||||
# Just need to make sure to call it correctly
|
||||
if sys.platform == "darwin":
|
||||
pthread_setname_np.argtypes = [ctypes.c_char_p]
|
||||
return partial(darwin_namefunc, pthread_setname_np)
|
||||
|
||||
# otherwise assume linux parameter conventions. Should also work on *BSD
|
||||
pthread_setname_np.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
return partial(namefunc, pthread_setname_np)
|
||||
|
||||
|
||||
# construct os thread name method
|
||||
set_os_thread_name = get_os_thread_name_func()
|
||||
|
||||
# The "thread cache" is a simple unbounded thread pool, i.e., it automatically
|
||||
# spawns as many threads as needed to handle all the requests its given. Its
|
||||
# only purpose is to cache worker threads so that they don't have to be
|
||||
# started from scratch every time we want to delegate some work to a thread.
|
||||
# It's expected that some higher-level code will track how many threads are in
|
||||
# use to avoid overwhelming the system (e.g. the limiter= argument to
|
||||
# trio.to_thread.run_sync).
|
||||
#
|
||||
# To maximize sharing, there's only one thread cache per process, even if you
|
||||
# have multiple calls to trio.run.
|
||||
#
|
||||
# Guarantees:
|
||||
#
|
||||
# It's safe to call start_thread_soon simultaneously from
|
||||
# multiple threads.
|
||||
#
|
||||
# Idle threads are chosen in LIFO order, i.e. we *don't* spread work evenly
|
||||
# over all threads. Instead we try to let some threads do most of the work
|
||||
# while others sit idle as much as possible. Compared to FIFO, this has better
|
||||
# memory cache behavior, and it makes it easier to detect when we have too
|
||||
# many threads, so idle ones can exit.
|
||||
#
|
||||
# This code assumes that 'dict' has the following properties:
|
||||
#
|
||||
# - __setitem__, __delitem__, and popitem are all thread-safe and atomic with
|
||||
# respect to each other. This is guaranteed by the GIL.
|
||||
#
|
||||
# - popitem returns the most-recently-added item (i.e., __setitem__ + popitem
|
||||
# give you a LIFO queue). This relies on dicts being insertion-ordered, like
|
||||
# they are in py36+.
|
||||
|
||||
# How long a thread will idle waiting for new work before gives up and exits.
|
||||
# This value is pretty arbitrary; I don't think it matters too much.
|
||||
IDLE_TIMEOUT = 10 # seconds
|
||||
|
||||
name_counter = count()
|
||||
|
||||
|
||||
class WorkerThread(Generic[RetT]):
|
||||
def __init__(self, thread_cache: ThreadCache) -> None:
|
||||
self._job: (
|
||||
tuple[
|
||||
Callable[[], RetT],
|
||||
Callable[[outcome.Outcome[RetT]], object],
|
||||
str | None,
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
self._thread_cache = thread_cache
|
||||
# This Lock is used in an unconventional way.
|
||||
#
|
||||
# "Unlocked" means we have a pending job that's been assigned to us;
|
||||
# "locked" means that we don't.
|
||||
#
|
||||
# Initially we have no job, so it starts out in locked state.
|
||||
self._worker_lock = Lock()
|
||||
self._worker_lock.acquire()
|
||||
self._default_name = f"Trio thread {next(name_counter)}"
|
||||
|
||||
self._thread = Thread(target=self._work, name=self._default_name, daemon=True)
|
||||
|
||||
if set_os_thread_name:
|
||||
set_os_thread_name(self._thread.ident, self._default_name)
|
||||
self._thread.start()
|
||||
|
||||
def _handle_job(self) -> None:
|
||||
# Handle job in a separate method to ensure user-created
|
||||
# objects are cleaned up in a consistent manner.
|
||||
assert self._job is not None
|
||||
fn, deliver, name = self._job
|
||||
self._job = None
|
||||
|
||||
# set name
|
||||
if name is not None:
|
||||
self._thread.name = name
|
||||
if set_os_thread_name:
|
||||
set_os_thread_name(self._thread.ident, name)
|
||||
result = outcome.capture(fn)
|
||||
|
||||
# reset name if it was changed
|
||||
if name is not None:
|
||||
self._thread.name = self._default_name
|
||||
if set_os_thread_name:
|
||||
set_os_thread_name(self._thread.ident, self._default_name)
|
||||
|
||||
# Tell the cache that we're available to be assigned a new
|
||||
# job. We do this *before* calling 'deliver', so that if
|
||||
# 'deliver' triggers a new job, it can be assigned to us
|
||||
# instead of spawning a new thread.
|
||||
self._thread_cache._idle_workers[self] = None
|
||||
try:
|
||||
deliver(result)
|
||||
except BaseException as e:
|
||||
print("Exception while delivering result of thread", file=sys.stderr)
|
||||
traceback.print_exception(type(e), e, e.__traceback__)
|
||||
|
||||
def _work(self) -> None:
|
||||
while True:
|
||||
if self._worker_lock.acquire(timeout=IDLE_TIMEOUT):
|
||||
# We got a job
|
||||
self._handle_job()
|
||||
else:
|
||||
# Timeout acquiring lock, so we can probably exit. But,
|
||||
# there's a race condition: we might be assigned a job *just*
|
||||
# as we're about to exit. So we have to check.
|
||||
try:
|
||||
del self._thread_cache._idle_workers[self]
|
||||
except KeyError:
|
||||
# Someone else removed us from the idle worker queue, so
|
||||
# they must be in the process of assigning us a job - loop
|
||||
# around and wait for it.
|
||||
continue
|
||||
else:
|
||||
# We successfully removed ourselves from the idle
|
||||
# worker queue, so no more jobs are incoming; it's safe to
|
||||
# exit.
|
||||
return
|
||||
|
||||
|
||||
class ThreadCache:
|
||||
def __init__(self) -> None:
|
||||
self._idle_workers: dict[WorkerThread[Any], None] = {}
|
||||
|
||||
def start_thread_soon(
|
||||
self,
|
||||
fn: Callable[[], RetT],
|
||||
deliver: Callable[[outcome.Outcome[RetT]], object],
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
worker: WorkerThread[RetT]
|
||||
try:
|
||||
worker, _ = self._idle_workers.popitem()
|
||||
except KeyError:
|
||||
worker = WorkerThread(self)
|
||||
worker._job = (fn, deliver, name)
|
||||
worker._worker_lock.release()
|
||||
|
||||
|
||||
THREAD_CACHE = ThreadCache()
|
||||
|
||||
|
||||
def start_thread_soon(
|
||||
fn: Callable[[], RetT],
|
||||
deliver: Callable[[outcome.Outcome[RetT]], object],
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
"""Runs ``deliver(outcome.capture(fn))`` in a worker thread.
|
||||
|
||||
Generally ``fn`` does some blocking work, and ``deliver`` delivers the
|
||||
result back to whoever is interested.
|
||||
|
||||
This is a low-level, no-frills interface, very similar to using
|
||||
`threading.Thread` to spawn a thread directly. The main difference is
|
||||
that this function tries to reuse threads when possible, so it can be
|
||||
a bit faster than `threading.Thread`.
|
||||
|
||||
Worker threads have the `~threading.Thread.daemon` flag set, which means
|
||||
that if your main thread exits, worker threads will automatically be
|
||||
killed. If you want to make sure that your ``fn`` runs to completion, then
|
||||
you should make sure that the main thread remains alive until ``deliver``
|
||||
is called.
|
||||
|
||||
It is safe to call this function simultaneously from multiple threads.
|
||||
|
||||
Args:
|
||||
|
||||
fn (sync function): Performs arbitrary blocking work.
|
||||
|
||||
deliver (sync function): Takes the `outcome.Outcome` of ``fn``, and
|
||||
delivers it. *Must not block.*
|
||||
|
||||
Because worker threads are cached and reused for multiple calls, neither
|
||||
function should mutate thread-level state, like `threading.local` objects
|
||||
– or if they do, they should be careful to revert their changes before
|
||||
returning.
|
||||
|
||||
Note:
|
||||
|
||||
The split between ``fn`` and ``deliver`` serves two purposes. First,
|
||||
it's convenient, since most callers need something like this anyway.
|
||||
|
||||
Second, it avoids a small race condition that could cause too many
|
||||
threads to be spawned. Consider a program that wants to run several
|
||||
jobs sequentially on a thread, so the main thread submits a job, waits
|
||||
for it to finish, submits another job, etc. In theory, this program
|
||||
should only need one worker thread. But what could happen is:
|
||||
|
||||
1. Worker thread: First job finishes, and calls ``deliver``.
|
||||
|
||||
2. Main thread: receives notification that the job finished, and calls
|
||||
``start_thread_soon``.
|
||||
|
||||
3. Main thread: sees that no worker threads are marked idle, so spawns
|
||||
a second worker thread.
|
||||
|
||||
4. Original worker thread: marks itself as idle.
|
||||
|
||||
To avoid this, threads mark themselves as idle *before* calling
|
||||
``deliver``.
|
||||
|
||||
Is this potential extra thread a major problem? Maybe not, but it's
|
||||
easy enough to avoid, and we figure that if the user is trying to
|
||||
limit how many threads they're using then it's polite to respect that.
|
||||
|
||||
"""
|
||||
THREAD_CACHE.start_thread_soon(fn, deliver, name)
|
||||
@@ -0,0 +1,287 @@
|
||||
"""These are the only functions that ever yield back to the task runner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Any, Callable, NoReturn
|
||||
|
||||
import attrs
|
||||
import outcome
|
||||
|
||||
from . import _run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ._run import Task
|
||||
|
||||
|
||||
# Helper for the bottommost 'yield'. You can't use 'yield' inside an async
|
||||
# function, but you can inside a generator, and if you decorate your generator
|
||||
# with @types.coroutine, then it's even awaitable. However, it's still not a
|
||||
# real async function: in particular, it isn't recognized by
|
||||
# inspect.iscoroutinefunction, and it doesn't trigger the unawaited coroutine
|
||||
# tracking machinery. Since our traps are public APIs, we make them real async
|
||||
# functions, and then this helper takes care of the actual yield:
|
||||
@types.coroutine
|
||||
def _async_yield(obj: Any) -> Any: # type: ignore[misc]
|
||||
return (yield obj)
|
||||
|
||||
|
||||
# This class object is used as a singleton.
|
||||
# Not exported in the trio._core namespace, but imported directly by _run.
|
||||
class CancelShieldedCheckpoint:
|
||||
pass
|
||||
|
||||
|
||||
async def cancel_shielded_checkpoint() -> None:
|
||||
"""Introduce a schedule point, but not a cancel point.
|
||||
|
||||
This is *not* a :ref:`checkpoint <checkpoints>`, but it is half of a
|
||||
checkpoint, and when combined with :func:`checkpoint_if_cancelled` it can
|
||||
make a full checkpoint.
|
||||
|
||||
Equivalent to (but potentially more efficient than)::
|
||||
|
||||
with trio.CancelScope(shield=True):
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
"""
|
||||
(await _async_yield(CancelShieldedCheckpoint)).unwrap()
|
||||
|
||||
|
||||
# Return values for abort functions
|
||||
class Abort(enum.Enum):
|
||||
""":class:`enum.Enum` used as the return value from abort functions.
|
||||
|
||||
See :func:`wait_task_rescheduled` for details.
|
||||
|
||||
.. data:: SUCCEEDED
|
||||
FAILED
|
||||
|
||||
"""
|
||||
|
||||
SUCCEEDED = 1
|
||||
FAILED = 2
|
||||
|
||||
|
||||
# Not exported in the trio._core namespace, but imported directly by _run.
|
||||
@attrs.frozen(slots=False)
|
||||
class WaitTaskRescheduled:
|
||||
abort_func: Callable[[RaiseCancelT], Abort]
|
||||
|
||||
|
||||
RaiseCancelT: TypeAlias = Callable[[], NoReturn]
|
||||
|
||||
|
||||
# Should always return the type a Task "expects", unless you willfully reschedule it
|
||||
# with a bad value.
|
||||
async def wait_task_rescheduled(abort_func: Callable[[RaiseCancelT], Abort]) -> Any:
|
||||
"""Put the current task to sleep, with cancellation support.
|
||||
|
||||
This is the lowest-level API for blocking in Trio. Every time a
|
||||
:class:`~trio.lowlevel.Task` blocks, it does so by calling this function
|
||||
(usually indirectly via some higher-level API).
|
||||
|
||||
This is a tricky interface with no guard rails. If you can use
|
||||
:class:`ParkingLot` or the built-in I/O wait functions instead, then you
|
||||
should.
|
||||
|
||||
Generally the way it works is that before calling this function, you make
|
||||
arrangements for "someone" to call :func:`reschedule` on the current task
|
||||
at some later point.
|
||||
|
||||
Then you call :func:`wait_task_rescheduled`, passing in ``abort_func``, an
|
||||
"abort callback".
|
||||
|
||||
(Terminology: in Trio, "aborting" is the process of attempting to
|
||||
interrupt a blocked task to deliver a cancellation.)
|
||||
|
||||
There are two possibilities for what happens next:
|
||||
|
||||
1. "Someone" calls :func:`reschedule` on the current task, and
|
||||
:func:`wait_task_rescheduled` returns or raises whatever value or error
|
||||
was passed to :func:`reschedule`.
|
||||
|
||||
2. The call's context transitions to a cancelled state (e.g. due to a
|
||||
timeout expiring). When this happens, the ``abort_func`` is called. Its
|
||||
interface looks like::
|
||||
|
||||
def abort_func(raise_cancel):
|
||||
...
|
||||
return trio.lowlevel.Abort.SUCCEEDED # or FAILED
|
||||
|
||||
It should attempt to clean up any state associated with this call, and
|
||||
in particular, arrange that :func:`reschedule` will *not* be called
|
||||
later. If (and only if!) it is successful, then it should return
|
||||
:data:`Abort.SUCCEEDED`, in which case the task will automatically be
|
||||
rescheduled with an appropriate :exc:`~trio.Cancelled` error.
|
||||
|
||||
Otherwise, it should return :data:`Abort.FAILED`. This means that the
|
||||
task can't be cancelled at this time, and still has to make sure that
|
||||
"someone" eventually calls :func:`reschedule`.
|
||||
|
||||
At that point there are again two possibilities. You can simply ignore
|
||||
the cancellation altogether: wait for the operation to complete and
|
||||
then reschedule and continue as normal. (For example, this is what
|
||||
:func:`trio.to_thread.run_sync` does if cancellation is disabled.)
|
||||
The other possibility is that the ``abort_func`` does succeed in
|
||||
cancelling the operation, but for some reason isn't able to report that
|
||||
right away. (Example: on Windows, it's possible to request that an
|
||||
async ("overlapped") I/O operation be cancelled, but this request is
|
||||
*also* asynchronous – you don't find out until later whether the
|
||||
operation was actually cancelled or not.) To report a delayed
|
||||
cancellation, then you should reschedule the task yourself, and call
|
||||
the ``raise_cancel`` callback passed to ``abort_func`` to raise a
|
||||
:exc:`~trio.Cancelled` (or possibly :exc:`KeyboardInterrupt`) exception
|
||||
into this task. Either of the approaches sketched below can work::
|
||||
|
||||
# Option 1:
|
||||
# Catch the exception from raise_cancel and inject it into the task.
|
||||
# (This is what Trio does automatically for you if you return
|
||||
# Abort.SUCCEEDED.)
|
||||
trio.lowlevel.reschedule(task, outcome.capture(raise_cancel))
|
||||
|
||||
# Option 2:
|
||||
# wait to be woken by "someone", and then decide whether to raise
|
||||
# the error from inside the task.
|
||||
outer_raise_cancel = None
|
||||
def abort(inner_raise_cancel):
|
||||
nonlocal outer_raise_cancel
|
||||
outer_raise_cancel = inner_raise_cancel
|
||||
TRY_TO_CANCEL_OPERATION()
|
||||
return trio.lowlevel.Abort.FAILED
|
||||
await wait_task_rescheduled(abort)
|
||||
if OPERATION_WAS_SUCCESSFULLY_CANCELLED:
|
||||
# raises the error
|
||||
outer_raise_cancel()
|
||||
|
||||
In any case it's guaranteed that we only call the ``abort_func`` at most
|
||||
once per call to :func:`wait_task_rescheduled`.
|
||||
|
||||
Sometimes, it's useful to be able to share some mutable sleep-related data
|
||||
between the sleeping task, the abort function, and the waking task. You
|
||||
can use the sleeping task's :data:`~Task.custom_sleep_data` attribute to
|
||||
store this data, and Trio won't touch it, except to make sure that it gets
|
||||
cleared when the task is rescheduled.
|
||||
|
||||
.. warning::
|
||||
|
||||
If your ``abort_func`` raises an error, or returns any value other than
|
||||
:data:`Abort.SUCCEEDED` or :data:`Abort.FAILED`, then Trio will crash
|
||||
violently. Be careful! Similarly, it is entirely possible to deadlock a
|
||||
Trio program by failing to reschedule a blocked task, or cause havoc by
|
||||
calling :func:`reschedule` too many times. Remember what we said up
|
||||
above about how you should use a higher-level API if at all possible?
|
||||
|
||||
"""
|
||||
return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap()
|
||||
|
||||
|
||||
# Not exported in the trio._core namespace, but imported directly by _run.
|
||||
@attrs.frozen(slots=False)
|
||||
class PermanentlyDetachCoroutineObject:
|
||||
final_outcome: outcome.Outcome[Any]
|
||||
|
||||
|
||||
async def permanently_detach_coroutine_object(
|
||||
final_outcome: outcome.Outcome[Any],
|
||||
) -> Any:
|
||||
"""Permanently detach the current task from the Trio scheduler.
|
||||
|
||||
Normally, a Trio task doesn't exit until its coroutine object exits. When
|
||||
you call this function, Trio acts like the coroutine object just exited
|
||||
and the task terminates with the given outcome. This is useful if you want
|
||||
to permanently switch the coroutine object over to a different coroutine
|
||||
runner.
|
||||
|
||||
When the calling coroutine enters this function it's running under Trio,
|
||||
and when the function returns it's running under the foreign coroutine
|
||||
runner.
|
||||
|
||||
You should make sure that the coroutine object has released any
|
||||
Trio-specific resources it has acquired (e.g. nurseries).
|
||||
|
||||
Args:
|
||||
final_outcome (outcome.Outcome): Trio acts as if the current task exited
|
||||
with the given return value or exception.
|
||||
|
||||
Returns or raises whatever value or exception the new coroutine runner
|
||||
uses to resume the coroutine.
|
||||
|
||||
"""
|
||||
if _run.current_task().child_nurseries:
|
||||
raise RuntimeError(
|
||||
"can't permanently detach a coroutine object with open nurseries",
|
||||
)
|
||||
return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome))
|
||||
|
||||
|
||||
async def temporarily_detach_coroutine_object(
|
||||
abort_func: Callable[[RaiseCancelT], Abort],
|
||||
) -> Any:
|
||||
"""Temporarily detach the current coroutine object from the Trio
|
||||
scheduler.
|
||||
|
||||
When the calling coroutine enters this function it's running under Trio,
|
||||
and when the function returns it's running under the foreign coroutine
|
||||
runner.
|
||||
|
||||
The Trio :class:`Task` will continue to exist, but will be suspended until
|
||||
you use :func:`reattach_detached_coroutine_object` to resume it. In the
|
||||
mean time, you can use another coroutine runner to schedule the coroutine
|
||||
object. In fact, you have to – the function doesn't return until the
|
||||
coroutine is advanced from outside.
|
||||
|
||||
Note that you'll need to save the current :class:`Task` object to later
|
||||
resume; you can retrieve it with :func:`current_task`. You can also use
|
||||
this :class:`Task` object to retrieve the coroutine object – see
|
||||
:data:`Task.coro`.
|
||||
|
||||
Args:
|
||||
abort_func: Same as for :func:`wait_task_rescheduled`, except that it
|
||||
must return :data:`Abort.FAILED`. (If it returned
|
||||
:data:`Abort.SUCCEEDED`, then Trio would attempt to reschedule the
|
||||
detached task directly without going through
|
||||
:func:`reattach_detached_coroutine_object`, which would be bad.)
|
||||
Your ``abort_func`` should still arrange for whatever the coroutine
|
||||
object is doing to be cancelled, and then reattach to Trio and call
|
||||
the ``raise_cancel`` callback, if possible.
|
||||
|
||||
Returns or raises whatever value or exception the new coroutine runner
|
||||
uses to resume the coroutine.
|
||||
|
||||
"""
|
||||
return await _async_yield(WaitTaskRescheduled(abort_func))
|
||||
|
||||
|
||||
async def reattach_detached_coroutine_object(task: Task, yield_value: object) -> None:
|
||||
"""Reattach a coroutine object that was detached using
|
||||
:func:`temporarily_detach_coroutine_object`.
|
||||
|
||||
When the calling coroutine enters this function it's running under the
|
||||
foreign coroutine runner, and when the function returns it's running under
|
||||
Trio.
|
||||
|
||||
This must be called from inside the coroutine being resumed, and yields
|
||||
whatever value you pass in. (Presumably you'll pass a value that will
|
||||
cause the current coroutine runner to stop scheduling this task.) Then the
|
||||
coroutine is resumed by the Trio scheduler at the next opportunity.
|
||||
|
||||
Args:
|
||||
task (Task): The Trio task object that the current coroutine was
|
||||
detached from.
|
||||
yield_value (object): The object to yield to the current coroutine
|
||||
runner.
|
||||
|
||||
"""
|
||||
# This is a kind of crude check – in particular, it can fail if the
|
||||
# passed-in task is where the coroutine *runner* is running. But this is
|
||||
# an experts-only interface, and there's no easy way to do a more accurate
|
||||
# check, so I guess that's OK.
|
||||
if not task.coro.cr_running:
|
||||
raise RuntimeError("given task does not match calling coroutine")
|
||||
_run.reschedule(task, outcome.Value("reattaching"))
|
||||
value = await _async_yield(yield_value)
|
||||
assert value == outcome.Value("reattaching")
|
||||
@@ -0,0 +1,163 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
import attrs
|
||||
|
||||
from .. import _core
|
||||
from .._deprecate import deprecated
|
||||
from .._util import final
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class UnboundedQueueStatistics:
|
||||
"""An object containing debugging information.
|
||||
|
||||
Currently, the following fields are defined:
|
||||
|
||||
* ``qsize``: The number of items currently in the queue.
|
||||
* ``tasks_waiting``: The number of tasks blocked on this queue's
|
||||
:meth:`get_batch` method.
|
||||
|
||||
"""
|
||||
|
||||
qsize: int
|
||||
tasks_waiting: int
|
||||
|
||||
|
||||
@final
|
||||
class UnboundedQueue(Generic[T]):
|
||||
"""An unbounded queue suitable for certain unusual forms of inter-task
|
||||
communication.
|
||||
|
||||
This class is designed for use as a queue in cases where the producer for
|
||||
some reason cannot be subjected to back-pressure, i.e., :meth:`put_nowait`
|
||||
has to always succeed. In order to prevent the queue backlog from actually
|
||||
growing without bound, the consumer API is modified to dequeue items in
|
||||
"batches". If a consumer task processes each batch without yielding, then
|
||||
this helps achieve (but does not guarantee) an effective bound on the
|
||||
queue's memory use, at the cost of potentially increasing system latencies
|
||||
in general. You should generally prefer to use a memory channel
|
||||
instead if you can.
|
||||
|
||||
Currently each batch completely empties the queue, but `this may change in
|
||||
the future <https://github.com/python-trio/trio/issues/51>`__.
|
||||
|
||||
A :class:`UnboundedQueue` object can be used as an asynchronous iterator,
|
||||
where each iteration returns a new batch of items. I.e., these two loops
|
||||
are equivalent::
|
||||
|
||||
async for batch in queue:
|
||||
...
|
||||
|
||||
while True:
|
||||
obj = await queue.get_batch()
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
@deprecated(
|
||||
"0.9.0",
|
||||
issue=497,
|
||||
thing="trio.lowlevel.UnboundedQueue",
|
||||
instead="trio.open_memory_channel(math.inf)",
|
||||
use_triodeprecationwarning=True,
|
||||
)
|
||||
def __init__(self) -> None:
|
||||
self._lot = _core.ParkingLot()
|
||||
self._data: list[T] = []
|
||||
# used to allow handoff from put to the first task in the lot
|
||||
self._can_get = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<UnboundedQueue holding {len(self._data)} items>"
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Returns the number of items currently in the queue."""
|
||||
return len(self._data)
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Returns True if the queue is empty, False otherwise.
|
||||
|
||||
There is some subtlety to interpreting this method's return value: see
|
||||
`issue #63 <https://github.com/python-trio/trio/issues/63>`__.
|
||||
|
||||
"""
|
||||
return not self._data
|
||||
|
||||
@_core.enable_ki_protection
|
||||
def put_nowait(self, obj: T) -> None:
|
||||
"""Put an object into the queue, without blocking.
|
||||
|
||||
This always succeeds, because the queue is unbounded. We don't provide
|
||||
a blocking ``put`` method, because it would never need to block.
|
||||
|
||||
Args:
|
||||
obj (object): The object to enqueue.
|
||||
|
||||
"""
|
||||
if not self._data:
|
||||
assert not self._can_get
|
||||
if self._lot:
|
||||
self._lot.unpark(count=1)
|
||||
else:
|
||||
self._can_get = True
|
||||
self._data.append(obj)
|
||||
|
||||
def _get_batch_protected(self) -> list[T]:
|
||||
data = self._data.copy()
|
||||
self._data.clear()
|
||||
self._can_get = False
|
||||
return data
|
||||
|
||||
def get_batch_nowait(self) -> list[T]:
|
||||
"""Attempt to get the next batch from the queue, without blocking.
|
||||
|
||||
Returns:
|
||||
list: A list of dequeued items, in order. On a successful call this
|
||||
list is always non-empty; if it would be empty we raise
|
||||
:exc:`~trio.WouldBlock` instead.
|
||||
|
||||
Raises:
|
||||
~trio.WouldBlock: if the queue is empty.
|
||||
|
||||
"""
|
||||
if not self._can_get:
|
||||
raise _core.WouldBlock
|
||||
return self._get_batch_protected()
|
||||
|
||||
async def get_batch(self) -> list[T]:
|
||||
"""Get the next batch from the queue, blocking as necessary.
|
||||
|
||||
Returns:
|
||||
list: A list of dequeued items, in order. This list is always
|
||||
non-empty.
|
||||
|
||||
"""
|
||||
await _core.checkpoint_if_cancelled()
|
||||
if not self._can_get:
|
||||
await self._lot.park()
|
||||
return self._get_batch_protected()
|
||||
else:
|
||||
try:
|
||||
return self._get_batch_protected()
|
||||
finally:
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
|
||||
def statistics(self) -> UnboundedQueueStatistics:
|
||||
"""Return an :class:`UnboundedQueueStatistics` object containing debugging information."""
|
||||
return UnboundedQueueStatistics(
|
||||
qsize=len(self._data),
|
||||
tasks_waiting=self._lot.statistics().tasks_waiting,
|
||||
)
|
||||
|
||||
def __aiter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> list[T]:
|
||||
return await self.get_batch()
|
||||
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import signal
|
||||
import socket
|
||||
import warnings
|
||||
|
||||
from .. import _core
|
||||
from .._util import is_main_thread
|
||||
|
||||
|
||||
class WakeupSocketpair:
|
||||
def __init__(self) -> None:
|
||||
# explicitly typed to please `pyright --verifytypes` without `--ignoreexternal`
|
||||
self.wakeup_sock: socket.socket
|
||||
self.write_sock: socket.socket
|
||||
|
||||
self.wakeup_sock, self.write_sock = socket.socketpair()
|
||||
self.wakeup_sock.setblocking(False)
|
||||
self.write_sock.setblocking(False)
|
||||
# This somewhat reduces the amount of memory wasted queueing up data
|
||||
# for wakeups. With these settings, maximum number of 1-byte sends
|
||||
# before getting BlockingIOError:
|
||||
# Linux 4.8: 6
|
||||
# macOS (darwin 15.5): 1
|
||||
# Windows 10: 525347
|
||||
# Windows you're weird. (And on Windows setting SNDBUF to 0 makes send
|
||||
# blocking, even on non-blocking sockets, so don't do that.)
|
||||
self.wakeup_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
|
||||
self.write_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1)
|
||||
# On Windows this is a TCP socket so this might matter. On other
|
||||
# platforms this fails b/c AF_UNIX sockets aren't actually TCP.
|
||||
with contextlib.suppress(OSError):
|
||||
self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
self.old_wakeup_fd: int | None = None
|
||||
|
||||
def wakeup_thread_and_signal_safe(self) -> None:
|
||||
with contextlib.suppress(BlockingIOError):
|
||||
self.write_sock.send(b"\x00")
|
||||
|
||||
async def wait_woken(self) -> None:
|
||||
await _core.wait_readable(self.wakeup_sock)
|
||||
self.drain()
|
||||
|
||||
def drain(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
self.wakeup_sock.recv(2**16)
|
||||
except BlockingIOError:
|
||||
pass
|
||||
|
||||
def wakeup_on_signals(self) -> None:
|
||||
assert self.old_wakeup_fd is None
|
||||
if not is_main_thread():
|
||||
return
|
||||
fd = self.write_sock.fileno()
|
||||
self.old_wakeup_fd = signal.set_wakeup_fd(fd, warn_on_full_buffer=False)
|
||||
if self.old_wakeup_fd != -1:
|
||||
warnings.warn(
|
||||
RuntimeWarning(
|
||||
"It looks like Trio's signal handling code might have "
|
||||
"collided with another library you're using. If you're "
|
||||
"running Trio in guest mode, then this might mean you "
|
||||
"should set host_uses_signal_set_wakeup_fd=True. "
|
||||
"Otherwise, file a bug on Trio and we'll help you figure "
|
||||
"out what's going on.",
|
||||
),
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
self.wakeup_sock.close()
|
||||
self.write_sock.close()
|
||||
if self.old_wakeup_fd is not None:
|
||||
signal.set_wakeup_fd(self.old_wakeup_fd)
|
||||
@@ -0,0 +1,520 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import re
|
||||
from typing import TYPE_CHECKING, NewType, NoReturn, Protocol, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import cffi
|
||||
|
||||
################################################################
|
||||
# Functions and types
|
||||
################################################################
|
||||
|
||||
LIB = """
|
||||
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa383751(v=vs.85).aspx
|
||||
typedef int BOOL;
|
||||
typedef unsigned char BYTE;
|
||||
typedef BYTE BOOLEAN;
|
||||
typedef void* PVOID;
|
||||
typedef PVOID HANDLE;
|
||||
typedef unsigned long DWORD;
|
||||
typedef unsigned long ULONG;
|
||||
typedef unsigned int NTSTATUS;
|
||||
typedef unsigned long u_long;
|
||||
typedef ULONG *PULONG;
|
||||
typedef const void *LPCVOID;
|
||||
typedef void *LPVOID;
|
||||
typedef const wchar_t *LPCWSTR;
|
||||
|
||||
typedef uintptr_t ULONG_PTR;
|
||||
typedef uintptr_t UINT_PTR;
|
||||
|
||||
typedef UINT_PTR SOCKET;
|
||||
|
||||
typedef struct _OVERLAPPED {
|
||||
ULONG_PTR Internal;
|
||||
ULONG_PTR InternalHigh;
|
||||
union {
|
||||
struct {
|
||||
DWORD Offset;
|
||||
DWORD OffsetHigh;
|
||||
} DUMMYSTRUCTNAME;
|
||||
PVOID Pointer;
|
||||
} DUMMYUNIONNAME;
|
||||
|
||||
HANDLE hEvent;
|
||||
} OVERLAPPED, *LPOVERLAPPED;
|
||||
|
||||
typedef OVERLAPPED WSAOVERLAPPED;
|
||||
typedef LPOVERLAPPED LPWSAOVERLAPPED;
|
||||
typedef PVOID LPSECURITY_ATTRIBUTES;
|
||||
typedef PVOID LPCSTR;
|
||||
|
||||
typedef struct _OVERLAPPED_ENTRY {
|
||||
ULONG_PTR lpCompletionKey;
|
||||
LPOVERLAPPED lpOverlapped;
|
||||
ULONG_PTR Internal;
|
||||
DWORD dwNumberOfBytesTransferred;
|
||||
} OVERLAPPED_ENTRY, *LPOVERLAPPED_ENTRY;
|
||||
|
||||
// kernel32.dll
|
||||
HANDLE WINAPI CreateIoCompletionPort(
|
||||
_In_ HANDLE FileHandle,
|
||||
_In_opt_ HANDLE ExistingCompletionPort,
|
||||
_In_ ULONG_PTR CompletionKey,
|
||||
_In_ DWORD NumberOfConcurrentThreads
|
||||
);
|
||||
|
||||
BOOL SetFileCompletionNotificationModes(
|
||||
HANDLE FileHandle,
|
||||
UCHAR Flags
|
||||
);
|
||||
|
||||
HANDLE CreateFileW(
|
||||
LPCWSTR lpFileName,
|
||||
DWORD dwDesiredAccess,
|
||||
DWORD dwShareMode,
|
||||
LPSECURITY_ATTRIBUTES lpSecurityAttributes,
|
||||
DWORD dwCreationDisposition,
|
||||
DWORD dwFlagsAndAttributes,
|
||||
HANDLE hTemplateFile
|
||||
);
|
||||
|
||||
BOOL WINAPI CloseHandle(
|
||||
_In_ HANDLE hObject
|
||||
);
|
||||
|
||||
BOOL WINAPI PostQueuedCompletionStatus(
|
||||
_In_ HANDLE CompletionPort,
|
||||
_In_ DWORD dwNumberOfBytesTransferred,
|
||||
_In_ ULONG_PTR dwCompletionKey,
|
||||
_In_opt_ LPOVERLAPPED lpOverlapped
|
||||
);
|
||||
|
||||
BOOL WINAPI GetQueuedCompletionStatusEx(
|
||||
_In_ HANDLE CompletionPort,
|
||||
_Out_ LPOVERLAPPED_ENTRY lpCompletionPortEntries,
|
||||
_In_ ULONG ulCount,
|
||||
_Out_ PULONG ulNumEntriesRemoved,
|
||||
_In_ DWORD dwMilliseconds,
|
||||
_In_ BOOL fAlertable
|
||||
);
|
||||
|
||||
BOOL WINAPI CancelIoEx(
|
||||
_In_ HANDLE hFile,
|
||||
_In_opt_ LPOVERLAPPED lpOverlapped
|
||||
);
|
||||
|
||||
BOOL WriteFile(
|
||||
HANDLE hFile,
|
||||
LPCVOID lpBuffer,
|
||||
DWORD nNumberOfBytesToWrite,
|
||||
LPDWORD lpNumberOfBytesWritten,
|
||||
LPOVERLAPPED lpOverlapped
|
||||
);
|
||||
|
||||
BOOL ReadFile(
|
||||
HANDLE hFile,
|
||||
LPVOID lpBuffer,
|
||||
DWORD nNumberOfBytesToRead,
|
||||
LPDWORD lpNumberOfBytesRead,
|
||||
LPOVERLAPPED lpOverlapped
|
||||
);
|
||||
|
||||
BOOL WINAPI SetConsoleCtrlHandler(
|
||||
_In_opt_ void* HandlerRoutine,
|
||||
_In_ BOOL Add
|
||||
);
|
||||
|
||||
HANDLE CreateEventA(
|
||||
LPSECURITY_ATTRIBUTES lpEventAttributes,
|
||||
BOOL bManualReset,
|
||||
BOOL bInitialState,
|
||||
LPCSTR lpName
|
||||
);
|
||||
|
||||
BOOL SetEvent(
|
||||
HANDLE hEvent
|
||||
);
|
||||
|
||||
BOOL ResetEvent(
|
||||
HANDLE hEvent
|
||||
);
|
||||
|
||||
DWORD WaitForSingleObject(
|
||||
HANDLE hHandle,
|
||||
DWORD dwMilliseconds
|
||||
);
|
||||
|
||||
DWORD WaitForMultipleObjects(
|
||||
DWORD nCount,
|
||||
HANDLE *lpHandles,
|
||||
BOOL bWaitAll,
|
||||
DWORD dwMilliseconds
|
||||
);
|
||||
|
||||
ULONG RtlNtStatusToDosError(
|
||||
NTSTATUS Status
|
||||
);
|
||||
|
||||
int WSAIoctl(
|
||||
SOCKET s,
|
||||
DWORD dwIoControlCode,
|
||||
LPVOID lpvInBuffer,
|
||||
DWORD cbInBuffer,
|
||||
LPVOID lpvOutBuffer,
|
||||
DWORD cbOutBuffer,
|
||||
LPDWORD lpcbBytesReturned,
|
||||
LPWSAOVERLAPPED lpOverlapped,
|
||||
// actually LPWSAOVERLAPPED_COMPLETION_ROUTINE
|
||||
void* lpCompletionRoutine
|
||||
);
|
||||
|
||||
int WSAGetLastError();
|
||||
|
||||
BOOL DeviceIoControl(
|
||||
HANDLE hDevice,
|
||||
DWORD dwIoControlCode,
|
||||
LPVOID lpInBuffer,
|
||||
DWORD nInBufferSize,
|
||||
LPVOID lpOutBuffer,
|
||||
DWORD nOutBufferSize,
|
||||
LPDWORD lpBytesReturned,
|
||||
LPOVERLAPPED lpOverlapped
|
||||
);
|
||||
|
||||
// From https://github.com/piscisaureus/wepoll/blob/master/src/afd.h
|
||||
typedef struct _AFD_POLL_HANDLE_INFO {
|
||||
HANDLE Handle;
|
||||
ULONG Events;
|
||||
NTSTATUS Status;
|
||||
} AFD_POLL_HANDLE_INFO, *PAFD_POLL_HANDLE_INFO;
|
||||
|
||||
// This is really defined as a messy union to allow stuff like
|
||||
// i.DUMMYSTRUCTNAME.LowPart, but we don't need those complications.
|
||||
// Under all that it's just an int64.
|
||||
typedef int64_t LARGE_INTEGER;
|
||||
|
||||
typedef struct _AFD_POLL_INFO {
|
||||
LARGE_INTEGER Timeout;
|
||||
ULONG NumberOfHandles;
|
||||
ULONG Exclusive;
|
||||
AFD_POLL_HANDLE_INFO Handles[1];
|
||||
} AFD_POLL_INFO, *PAFD_POLL_INFO;
|
||||
|
||||
"""
|
||||
|
||||
# cribbed from pywincffi
|
||||
# programmatically strips out those annotations MSDN likes, like _In_
|
||||
REGEX_SAL_ANNOTATION = re.compile(
|
||||
r"\b(_In_|_Inout_|_Out_|_Outptr_|_Reserved_)(opt_)?\b",
|
||||
)
|
||||
LIB = REGEX_SAL_ANNOTATION.sub(" ", LIB)
|
||||
|
||||
# Other fixups:
|
||||
# - get rid of FAR, cffi doesn't like it
|
||||
LIB = re.sub(r"\bFAR\b", " ", LIB)
|
||||
# - PASCAL is apparently an alias for __stdcall (on modern compilers - modern
|
||||
# being _MSC_VER >= 800)
|
||||
LIB = re.sub(r"\bPASCAL\b", "__stdcall", LIB)
|
||||
|
||||
ffi = cffi.api.FFI()
|
||||
ffi.cdef(LIB)
|
||||
|
||||
CData: TypeAlias = cffi.api.FFI.CData
|
||||
CType: TypeAlias = cffi.api.FFI.CType
|
||||
AlwaysNull: TypeAlias = CType # We currently always pass ffi.NULL here.
|
||||
Handle = NewType("Handle", CData)
|
||||
HandleArray = NewType("HandleArray", CData)
|
||||
|
||||
|
||||
class _Kernel32(Protocol):
|
||||
"""Statically typed version of the kernel32.dll functions we use."""
|
||||
|
||||
def CreateIoCompletionPort(
|
||||
self,
|
||||
FileHandle: Handle,
|
||||
ExistingCompletionPort: CData | AlwaysNull,
|
||||
CompletionKey: int,
|
||||
NumberOfConcurrentThreads: int,
|
||||
/,
|
||||
) -> Handle: ...
|
||||
|
||||
def CreateEventA(
|
||||
self,
|
||||
lpEventAttributes: AlwaysNull,
|
||||
bManualReset: bool,
|
||||
bInitialState: bool,
|
||||
lpName: AlwaysNull,
|
||||
/,
|
||||
) -> Handle: ...
|
||||
|
||||
def SetFileCompletionNotificationModes(
|
||||
self,
|
||||
handle: Handle,
|
||||
flags: CompletionModes,
|
||||
/,
|
||||
) -> int: ...
|
||||
|
||||
def PostQueuedCompletionStatus(
|
||||
self,
|
||||
CompletionPort: Handle,
|
||||
dwNumberOfBytesTransferred: int,
|
||||
dwCompletionKey: int,
|
||||
lpOverlapped: CData | AlwaysNull,
|
||||
/,
|
||||
) -> bool: ...
|
||||
|
||||
def CancelIoEx(
|
||||
self,
|
||||
hFile: Handle,
|
||||
lpOverlapped: CData | AlwaysNull,
|
||||
/,
|
||||
) -> bool: ...
|
||||
|
||||
def WriteFile(
|
||||
self,
|
||||
hFile: Handle,
|
||||
# not sure about this type
|
||||
lpBuffer: CData,
|
||||
nNumberOfBytesToWrite: int,
|
||||
lpNumberOfBytesWritten: AlwaysNull,
|
||||
lpOverlapped: _Overlapped,
|
||||
/,
|
||||
) -> bool: ...
|
||||
|
||||
def ReadFile(
|
||||
self,
|
||||
hFile: Handle,
|
||||
# not sure about this type
|
||||
lpBuffer: CData,
|
||||
nNumberOfBytesToRead: int,
|
||||
lpNumberOfBytesRead: AlwaysNull,
|
||||
lpOverlapped: _Overlapped,
|
||||
/,
|
||||
) -> bool: ...
|
||||
|
||||
def GetQueuedCompletionStatusEx(
|
||||
self,
|
||||
CompletionPort: Handle,
|
||||
lpCompletionPortEntries: CData,
|
||||
ulCount: int,
|
||||
ulNumEntriesRemoved: CData,
|
||||
dwMilliseconds: int,
|
||||
fAlertable: bool | int,
|
||||
/,
|
||||
) -> CData: ...
|
||||
|
||||
def CreateFileW(
|
||||
self,
|
||||
lpFileName: CData,
|
||||
dwDesiredAccess: FileFlags,
|
||||
dwShareMode: FileFlags,
|
||||
lpSecurityAttributes: AlwaysNull,
|
||||
dwCreationDisposition: FileFlags,
|
||||
dwFlagsAndAttributes: FileFlags,
|
||||
hTemplateFile: AlwaysNull,
|
||||
/,
|
||||
) -> Handle: ...
|
||||
|
||||
def WaitForSingleObject(self, hHandle: Handle, dwMilliseconds: int, /) -> CData: ...
|
||||
|
||||
def WaitForMultipleObjects(
|
||||
self,
|
||||
nCount: int,
|
||||
lpHandles: HandleArray,
|
||||
bWaitAll: bool,
|
||||
dwMilliseconds: int,
|
||||
/,
|
||||
) -> ErrorCodes: ...
|
||||
|
||||
def SetEvent(self, handle: Handle, /) -> None: ...
|
||||
|
||||
def CloseHandle(self, handle: Handle, /) -> bool: ...
|
||||
|
||||
def DeviceIoControl(
|
||||
self,
|
||||
hDevice: Handle,
|
||||
dwIoControlCode: int,
|
||||
# this is wrong (it's not always null)
|
||||
lpInBuffer: AlwaysNull,
|
||||
nInBufferSize: int,
|
||||
# this is also wrong
|
||||
lpOutBuffer: AlwaysNull,
|
||||
nOutBufferSize: int,
|
||||
lpBytesReturned: AlwaysNull,
|
||||
lpOverlapped: CData,
|
||||
/,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
class _Nt(Protocol):
|
||||
"""Statically typed version of the dtdll.dll functions we use."""
|
||||
|
||||
def RtlNtStatusToDosError(self, status: int, /) -> ErrorCodes: ...
|
||||
|
||||
|
||||
class _Ws2(Protocol):
|
||||
"""Statically typed version of the ws2_32.dll functions we use."""
|
||||
|
||||
def WSAGetLastError(self) -> int: ...
|
||||
|
||||
def WSAIoctl(
|
||||
self,
|
||||
socket: CData,
|
||||
dwIoControlCode: WSAIoctls,
|
||||
lpvInBuffer: AlwaysNull,
|
||||
cbInBuffer: int,
|
||||
lpvOutBuffer: CData,
|
||||
cbOutBuffer: int,
|
||||
lpcbBytesReturned: CData, # int*
|
||||
lpOverlapped: AlwaysNull,
|
||||
# actually LPWSAOVERLAPPED_COMPLETION_ROUTINE
|
||||
lpCompletionRoutine: AlwaysNull,
|
||||
/,
|
||||
) -> int: ...
|
||||
|
||||
|
||||
class _DummyStruct(Protocol):
|
||||
Offset: int
|
||||
OffsetHigh: int
|
||||
|
||||
|
||||
class _DummyUnion(Protocol):
|
||||
DUMMYSTRUCTNAME: _DummyStruct
|
||||
Pointer: object
|
||||
|
||||
|
||||
class _Overlapped(Protocol):
|
||||
Internal: int
|
||||
InternalHigh: int
|
||||
DUMMYUNIONNAME: _DummyUnion
|
||||
hEvent: Handle
|
||||
|
||||
|
||||
kernel32 = cast(_Kernel32, ffi.dlopen("kernel32.dll"))
|
||||
ntdll = cast(_Nt, ffi.dlopen("ntdll.dll"))
|
||||
ws2_32 = cast(_Ws2, ffi.dlopen("ws2_32.dll"))
|
||||
|
||||
################################################################
|
||||
# Magic numbers
|
||||
################################################################
|
||||
|
||||
# Here's a great resource for looking these up:
|
||||
# https://www.magnumdb.com
|
||||
# (Tip: check the box to see "Hex value")
|
||||
|
||||
INVALID_HANDLE_VALUE = Handle(ffi.cast("HANDLE", -1))
|
||||
|
||||
|
||||
class ErrorCodes(enum.IntEnum):
|
||||
STATUS_TIMEOUT = 0x102
|
||||
WAIT_TIMEOUT = 0x102
|
||||
WAIT_ABANDONED = 0x80
|
||||
WAIT_OBJECT_0 = 0x00 # object is signaled
|
||||
WAIT_FAILED = 0xFFFFFFFF
|
||||
ERROR_IO_PENDING = 997
|
||||
ERROR_OPERATION_ABORTED = 995
|
||||
ERROR_ABANDONED_WAIT_0 = 735
|
||||
ERROR_INVALID_HANDLE = 6
|
||||
ERROR_INVALID_PARMETER = 87
|
||||
ERROR_NOT_FOUND = 1168
|
||||
ERROR_NOT_SOCKET = 10038
|
||||
|
||||
|
||||
class FileFlags(enum.IntFlag):
|
||||
GENERIC_READ = 0x80000000
|
||||
SYNCHRONIZE = 0x00100000
|
||||
FILE_FLAG_OVERLAPPED = 0x40000000
|
||||
FILE_SHARE_READ = 1
|
||||
FILE_SHARE_WRITE = 2
|
||||
FILE_SHARE_DELETE = 4
|
||||
CREATE_NEW = 1
|
||||
CREATE_ALWAYS = 2
|
||||
OPEN_EXISTING = 3
|
||||
OPEN_ALWAYS = 4
|
||||
TRUNCATE_EXISTING = 5
|
||||
|
||||
|
||||
class AFDPollFlags(enum.IntFlag):
|
||||
# These are drawn from a combination of:
|
||||
# https://github.com/piscisaureus/wepoll/blob/master/src/afd.h
|
||||
# https://github.com/reactos/reactos/blob/master/sdk/include/reactos/drivers/afd/shared.h
|
||||
AFD_POLL_RECEIVE = 0x0001
|
||||
AFD_POLL_RECEIVE_EXPEDITED = 0x0002 # OOB/urgent data
|
||||
AFD_POLL_SEND = 0x0004
|
||||
AFD_POLL_DISCONNECT = 0x0008 # received EOF (FIN)
|
||||
AFD_POLL_ABORT = 0x0010 # received RST
|
||||
AFD_POLL_LOCAL_CLOSE = 0x0020 # local socket object closed
|
||||
AFD_POLL_CONNECT = 0x0040 # socket is successfully connected
|
||||
AFD_POLL_ACCEPT = 0x0080 # you can call accept on this socket
|
||||
AFD_POLL_CONNECT_FAIL = 0x0100 # connect() terminated unsuccessfully
|
||||
# See WSAEventSelect docs for more details on these four:
|
||||
AFD_POLL_QOS = 0x0200
|
||||
AFD_POLL_GROUP_QOS = 0x0400
|
||||
AFD_POLL_ROUTING_INTERFACE_CHANGE = 0x0800
|
||||
AFD_POLL_EVENT_ADDRESS_LIST_CHANGE = 0x1000
|
||||
|
||||
|
||||
class WSAIoctls(enum.IntEnum):
|
||||
SIO_BASE_HANDLE = 0x48000022
|
||||
SIO_BSP_HANDLE_SELECT = 0x4800001C
|
||||
SIO_BSP_HANDLE_POLL = 0x4800001D
|
||||
|
||||
|
||||
class CompletionModes(enum.IntFlag):
|
||||
FILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 0x1
|
||||
FILE_SKIP_SET_EVENT_ON_HANDLE = 0x2
|
||||
|
||||
|
||||
class IoControlCodes(enum.IntEnum):
|
||||
IOCTL_AFD_POLL = 0x00012024
|
||||
|
||||
|
||||
################################################################
|
||||
# Generic helpers
|
||||
################################################################
|
||||
|
||||
|
||||
def _handle(obj: int | CData) -> Handle:
|
||||
# For now, represent handles as either cffi HANDLEs or as ints. If you
|
||||
# try to pass in a file descriptor instead, it's not going to work
|
||||
# out. (For that msvcrt.get_osfhandle does the trick, but I don't know if
|
||||
# we'll actually need that for anything...) For sockets this doesn't
|
||||
# matter, Python never allocates an fd. So let's wait until we actually
|
||||
# encounter the problem before worrying about it.
|
||||
if isinstance(obj, int):
|
||||
return Handle(ffi.cast("HANDLE", obj))
|
||||
return Handle(obj)
|
||||
|
||||
|
||||
def handle_array(count: int) -> HandleArray:
|
||||
"""Make an array of handles."""
|
||||
return HandleArray(ffi.new(f"HANDLE[{count}]"))
|
||||
|
||||
|
||||
def raise_winerror(
|
||||
winerror: int | None = None,
|
||||
*,
|
||||
filename: str | None = None,
|
||||
filename2: str | None = None,
|
||||
) -> NoReturn:
|
||||
# assert sys.platform == "win32" # TODO: make this work in MyPy
|
||||
# ... in the meanwhile, ffi.getwinerror() is undefined on non-Windows, necessitating the type
|
||||
# ignores.
|
||||
|
||||
if winerror is None:
|
||||
err = ffi.getwinerror() # type: ignore[attr-defined,unused-ignore]
|
||||
if err is None:
|
||||
raise RuntimeError("No error set?")
|
||||
winerror, msg = err
|
||||
else:
|
||||
err = ffi.getwinerror(winerror) # type: ignore[attr-defined,unused-ignore]
|
||||
if err is None:
|
||||
raise RuntimeError("No error set?")
|
||||
_, msg = err
|
||||
# https://docs.python.org/3/library/exceptions.html#OSError
|
||||
raise OSError(0, msg, filename, winerror, filename2)
|
||||
@@ -0,0 +1,177 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, ClassVar, TypeVar
|
||||
|
||||
import attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
ArgsT = ParamSpec("ArgsT")
|
||||
|
||||
RetT = TypeVar("RetT")
|
||||
|
||||
|
||||
# We want our warnings to be visible by default (at least for now), but we
|
||||
# also want it to be possible to override that using the -W switch. AFAICT
|
||||
# this means we cannot inherit from DeprecationWarning, because the only way
|
||||
# to make it visible by default then would be to add our own filter at import
|
||||
# time, but that would override -W switches...
|
||||
class TrioDeprecationWarning(FutureWarning):
|
||||
"""Warning emitted if you use deprecated Trio functionality.
|
||||
|
||||
As a young project, Trio is currently quite aggressive about deprecating
|
||||
and/or removing functionality that we realize was a bad idea. If you use
|
||||
Trio, you should subscribe to `issue #1
|
||||
<https://github.com/python-trio/trio/issues/1>`__ to get information about
|
||||
upcoming deprecations and other backwards compatibility breaking changes.
|
||||
|
||||
Despite the name, this class currently inherits from
|
||||
:class:`FutureWarning`, not :class:`DeprecationWarning`, because while
|
||||
we're in young-and-aggressive mode we want these warnings to be visible by
|
||||
default. You can hide them by installing a filter or with the ``-W``
|
||||
switch: see the :mod:`warnings` documentation for details.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _url_for_issue(issue: int) -> str:
|
||||
return f"https://github.com/python-trio/trio/issues/{issue}"
|
||||
|
||||
|
||||
def _stringify(thing: object) -> str:
|
||||
if hasattr(thing, "__module__") and hasattr(thing, "__qualname__"):
|
||||
return f"{thing.__module__}.{thing.__qualname__}"
|
||||
return str(thing)
|
||||
|
||||
|
||||
def warn_deprecated(
|
||||
thing: object,
|
||||
version: str,
|
||||
*,
|
||||
issue: int | None,
|
||||
instead: object,
|
||||
stacklevel: int = 2,
|
||||
use_triodeprecationwarning: bool = False,
|
||||
) -> None:
|
||||
stacklevel += 1
|
||||
msg = f"{_stringify(thing)} is deprecated since Trio {version}"
|
||||
if instead is None:
|
||||
msg += " with no replacement"
|
||||
else:
|
||||
msg += f"; use {_stringify(instead)} instead"
|
||||
if issue is not None:
|
||||
msg += f" ({_url_for_issue(issue)})"
|
||||
if use_triodeprecationwarning:
|
||||
warning_class: type[Warning] = TrioDeprecationWarning
|
||||
else:
|
||||
warning_class = DeprecationWarning
|
||||
warnings.warn(warning_class(msg), stacklevel=stacklevel)
|
||||
|
||||
|
||||
# @deprecated("0.2.0", issue=..., instead=...)
|
||||
# def ...
|
||||
def deprecated(
|
||||
version: str,
|
||||
*,
|
||||
thing: object = None,
|
||||
issue: int | None,
|
||||
instead: object,
|
||||
use_triodeprecationwarning: bool = False,
|
||||
) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]:
|
||||
def do_wrap(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]:
|
||||
nonlocal thing
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
|
||||
warn_deprecated(
|
||||
thing,
|
||||
version,
|
||||
instead=instead,
|
||||
issue=issue,
|
||||
use_triodeprecationwarning=use_triodeprecationwarning,
|
||||
)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
# If our __module__ or __qualname__ get modified, we want to pick up
|
||||
# on that, so we read them off the wrapper object instead of the (now
|
||||
# hidden) fn object
|
||||
if thing is None:
|
||||
thing = wrapper
|
||||
|
||||
if wrapper.__doc__ is not None:
|
||||
doc = wrapper.__doc__
|
||||
doc = doc.rstrip()
|
||||
doc += "\n\n"
|
||||
doc += f".. deprecated:: {version}\n"
|
||||
if instead is not None:
|
||||
doc += f" Use {_stringify(instead)} instead.\n"
|
||||
if issue is not None:
|
||||
doc += f" For details, see `issue #{issue} <{_url_for_issue(issue)}>`__.\n"
|
||||
doc += "\n"
|
||||
wrapper.__doc__ = doc
|
||||
|
||||
return wrapper
|
||||
|
||||
return do_wrap
|
||||
|
||||
|
||||
def deprecated_alias(
|
||||
old_qualname: str,
|
||||
new_fn: Callable[ArgsT, RetT],
|
||||
version: str,
|
||||
*,
|
||||
issue: int | None,
|
||||
) -> Callable[ArgsT, RetT]:
|
||||
@deprecated(version, issue=issue, instead=new_fn)
|
||||
@wraps(new_fn, assigned=("__module__", "__annotations__"))
|
||||
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
|
||||
"""Deprecated alias."""
|
||||
return new_fn(*args, **kwargs)
|
||||
|
||||
wrapper.__qualname__ = old_qualname
|
||||
wrapper.__name__ = old_qualname.rpartition(".")[-1]
|
||||
return wrapper
|
||||
|
||||
|
||||
@attrs.frozen(slots=False)
|
||||
class DeprecatedAttribute:
|
||||
_not_set: ClassVar[object] = object()
|
||||
|
||||
value: object
|
||||
version: str
|
||||
issue: int | None
|
||||
instead: object = _not_set
|
||||
|
||||
|
||||
class _ModuleWithDeprecations(ModuleType):
|
||||
__deprecated_attributes__: dict[str, DeprecatedAttribute]
|
||||
|
||||
def __getattr__(self, name: str) -> object:
|
||||
if name in self.__deprecated_attributes__:
|
||||
info = self.__deprecated_attributes__[name]
|
||||
instead = info.instead
|
||||
if instead is DeprecatedAttribute._not_set:
|
||||
instead = info.value
|
||||
thing = f"{self.__name__}.{name}"
|
||||
warn_deprecated(thing, info.version, issue=info.issue, instead=instead)
|
||||
return info.value
|
||||
|
||||
msg = "module '{}' has no attribute '{}'"
|
||||
raise AttributeError(msg.format(self.__name__, name))
|
||||
|
||||
|
||||
def enable_attribute_deprecations(module_name: str) -> None:
|
||||
module = sys.modules[module_name]
|
||||
module.__class__ = _ModuleWithDeprecations
|
||||
assert isinstance(module, _ModuleWithDeprecations)
|
||||
# Make sure that this is always defined so that
|
||||
# _ModuleWithDeprecations.__getattr__ can access it without jumping
|
||||
# through hoops or risking infinite recursion.
|
||||
module.__deprecated_attributes__ = {}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,509 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from functools import partial
|
||||
from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AnyStr,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterable,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from ._util import async_wraps
|
||||
from .abc import AsyncResource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import (
|
||||
OpenBinaryMode,
|
||||
OpenBinaryModeReading,
|
||||
OpenBinaryModeUpdating,
|
||||
OpenBinaryModeWriting,
|
||||
OpenTextMode,
|
||||
StrOrBytesPath,
|
||||
)
|
||||
from typing_extensions import Literal
|
||||
|
||||
# This list is also in the docs, make sure to keep them in sync
|
||||
_FILE_SYNC_ATTRS: set[str] = {
|
||||
"closed",
|
||||
"encoding",
|
||||
"errors",
|
||||
"fileno",
|
||||
"isatty",
|
||||
"newlines",
|
||||
"readable",
|
||||
"seekable",
|
||||
"writable",
|
||||
# not defined in *IOBase:
|
||||
"buffer",
|
||||
"raw",
|
||||
"line_buffering",
|
||||
"closefd",
|
||||
"name",
|
||||
"mode",
|
||||
"getvalue",
|
||||
"getbuffer",
|
||||
}
|
||||
|
||||
# This list is also in the docs, make sure to keep them in sync
|
||||
_FILE_ASYNC_METHODS: set[str] = {
|
||||
"flush",
|
||||
"read",
|
||||
"read1",
|
||||
"readall",
|
||||
"readinto",
|
||||
"readline",
|
||||
"readlines",
|
||||
"seek",
|
||||
"tell",
|
||||
"truncate",
|
||||
"write",
|
||||
"writelines",
|
||||
# not defined in *IOBase:
|
||||
"readinto1",
|
||||
"peek",
|
||||
}
|
||||
|
||||
|
||||
FileT = TypeVar("FileT")
|
||||
FileT_co = TypeVar("FileT_co", covariant=True)
|
||||
T = TypeVar("T")
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
T_contra = TypeVar("T_contra", contravariant=True)
|
||||
AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True)
|
||||
AnyStr_contra = TypeVar("AnyStr_contra", str, bytes, contravariant=True)
|
||||
|
||||
# This is a little complicated. IO objects have a lot of methods, and which are available on
|
||||
# different types varies wildly. We want to match the interface of whatever file we're wrapping.
|
||||
# This pile of protocols each has one sync method/property, meaning they're going to be compatible
|
||||
# with a file class that supports that method/property. The ones parameterized with AnyStr take
|
||||
# either str or bytes depending.
|
||||
|
||||
# The wrapper is then a generic class, where the typevar is set to the type of the sync file we're
|
||||
# wrapping. For generics, adding a type to self has a special meaning - properties/methods can be
|
||||
# conditional - it's only valid to call them if the object you're accessing them on is compatible
|
||||
# with that type hint. By using the protocols, the type checker will be checking to see if the
|
||||
# wrapped type has that method, and only allow the methods that do to be called. We can then alter
|
||||
# the signature however it needs to match runtime behaviour.
|
||||
# More info: https://mypy.readthedocs.io/en/stable/more_types.html#advanced-uses-of-self-types
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Buffer, Protocol
|
||||
|
||||
# fmt: off
|
||||
|
||||
class _HasClosed(Protocol):
|
||||
@property
|
||||
def closed(self) -> bool: ...
|
||||
|
||||
class _HasEncoding(Protocol):
|
||||
@property
|
||||
def encoding(self) -> str: ...
|
||||
|
||||
class _HasErrors(Protocol):
|
||||
@property
|
||||
def errors(self) -> str | None: ...
|
||||
|
||||
class _HasFileNo(Protocol):
|
||||
def fileno(self) -> int: ...
|
||||
|
||||
class _HasIsATTY(Protocol):
|
||||
def isatty(self) -> bool: ...
|
||||
|
||||
class _HasNewlines(Protocol[T_co]):
|
||||
# Type varies here - documented to be None, tuple of strings, strings. Typeshed uses Any.
|
||||
@property
|
||||
def newlines(self) -> T_co: ...
|
||||
|
||||
class _HasReadable(Protocol):
|
||||
def readable(self) -> bool: ...
|
||||
|
||||
class _HasSeekable(Protocol):
|
||||
def seekable(self) -> bool: ...
|
||||
|
||||
class _HasWritable(Protocol):
|
||||
def writable(self) -> bool: ...
|
||||
|
||||
class _HasBuffer(Protocol):
|
||||
@property
|
||||
def buffer(self) -> BinaryIO: ...
|
||||
|
||||
class _HasRaw(Protocol):
|
||||
@property
|
||||
def raw(self) -> io.RawIOBase: ...
|
||||
|
||||
class _HasLineBuffering(Protocol):
|
||||
@property
|
||||
def line_buffering(self) -> bool: ...
|
||||
|
||||
class _HasCloseFD(Protocol):
|
||||
@property
|
||||
def closefd(self) -> bool: ...
|
||||
|
||||
class _HasName(Protocol):
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
class _HasMode(Protocol):
|
||||
@property
|
||||
def mode(self) -> str: ...
|
||||
|
||||
class _CanGetValue(Protocol[AnyStr_co]):
|
||||
def getvalue(self) -> AnyStr_co: ...
|
||||
|
||||
class _CanGetBuffer(Protocol):
|
||||
def getbuffer(self) -> memoryview: ...
|
||||
|
||||
class _CanFlush(Protocol):
|
||||
def flush(self) -> None: ...
|
||||
|
||||
class _CanRead(Protocol[AnyStr_co]):
|
||||
def read(self, size: int | None = ..., /) -> AnyStr_co: ...
|
||||
|
||||
class _CanRead1(Protocol):
|
||||
def read1(self, size: int | None = ..., /) -> bytes: ...
|
||||
|
||||
class _CanReadAll(Protocol[AnyStr_co]):
|
||||
def readall(self) -> AnyStr_co: ...
|
||||
|
||||
class _CanReadInto(Protocol):
|
||||
def readinto(self, buf: Buffer, /) -> int | None: ...
|
||||
|
||||
class _CanReadInto1(Protocol):
|
||||
def readinto1(self, buffer: Buffer, /) -> int: ...
|
||||
|
||||
class _CanReadLine(Protocol[AnyStr_co]):
|
||||
def readline(self, size: int = ..., /) -> AnyStr_co: ...
|
||||
|
||||
class _CanReadLines(Protocol[AnyStr]):
|
||||
def readlines(self, hint: int = ..., /) -> list[AnyStr]: ...
|
||||
|
||||
class _CanSeek(Protocol):
|
||||
def seek(self, target: int, whence: int = 0, /) -> int: ...
|
||||
|
||||
class _CanTell(Protocol):
|
||||
def tell(self) -> int: ...
|
||||
|
||||
class _CanTruncate(Protocol):
|
||||
def truncate(self, size: int | None = ..., /) -> int: ...
|
||||
|
||||
class _CanWrite(Protocol[T_contra]):
|
||||
def write(self, data: T_contra, /) -> int: ...
|
||||
|
||||
class _CanWriteLines(Protocol[T_contra]):
|
||||
# The lines parameter varies for bytes/str, so use a typevar to make the async match.
|
||||
def writelines(self, lines: Iterable[T_contra], /) -> None: ...
|
||||
|
||||
class _CanPeek(Protocol[AnyStr_co]):
|
||||
def peek(self, size: int = 0, /) -> AnyStr_co: ...
|
||||
|
||||
class _CanDetach(Protocol[T_co]):
|
||||
# The T typevar will be the unbuffered/binary file this file wraps.
|
||||
def detach(self) -> T_co: ...
|
||||
|
||||
class _CanClose(Protocol):
|
||||
def close(self) -> None: ...
|
||||
|
||||
|
||||
# FileT needs to be covariant for the protocol trick to work - the real IO types are effectively a
|
||||
# subtype of the protocols.
|
||||
class AsyncIOWrapper(AsyncResource, Generic[FileT_co]):
|
||||
"""A generic :class:`~io.IOBase` wrapper that implements the :term:`asynchronous
|
||||
file object` interface. Wrapped methods that could block are executed in
|
||||
:meth:`trio.to_thread.run_sync`.
|
||||
|
||||
All properties and methods defined in :mod:`~io` are exposed by this
|
||||
wrapper, if they exist in the wrapped file object.
|
||||
"""
|
||||
|
||||
def __init__(self, file: FileT_co) -> None:
|
||||
self._wrapped = file
|
||||
|
||||
@property
|
||||
def wrapped(self) -> FileT_co:
|
||||
"""object: A reference to the wrapped file object"""
|
||||
|
||||
return self._wrapped
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __getattr__(self, name: str) -> object:
|
||||
if name in _FILE_SYNC_ATTRS:
|
||||
return getattr(self._wrapped, name)
|
||||
if name in _FILE_ASYNC_METHODS:
|
||||
meth = getattr(self._wrapped, name)
|
||||
|
||||
@async_wraps(self.__class__, self._wrapped.__class__, name)
|
||||
async def wrapper(*args, **kwargs):
|
||||
func = partial(meth, *args, **kwargs)
|
||||
return await trio.to_thread.run_sync(func)
|
||||
|
||||
# cache the generated method
|
||||
setattr(self, name, wrapper)
|
||||
return wrapper
|
||||
|
||||
raise AttributeError(name)
|
||||
|
||||
def __dir__(self) -> Iterable[str]:
|
||||
attrs = set(super().__dir__())
|
||||
attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a))
|
||||
attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a))
|
||||
return attrs
|
||||
|
||||
def __aiter__(self) -> AsyncIOWrapper[FileT_co]:
|
||||
return self
|
||||
|
||||
async def __anext__(self: AsyncIOWrapper[_CanReadLine[AnyStr]]) -> AnyStr:
|
||||
line = await self.readline()
|
||||
if line:
|
||||
return line
|
||||
else:
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def detach(self: AsyncIOWrapper[_CanDetach[T]]) -> AsyncIOWrapper[T]:
|
||||
"""Like :meth:`io.BufferedIOBase.detach`, but async.
|
||||
|
||||
This also re-wraps the result in a new :term:`asynchronous file object`
|
||||
wrapper.
|
||||
|
||||
"""
|
||||
|
||||
raw = await trio.to_thread.run_sync(self._wrapped.detach)
|
||||
return wrap_file(raw)
|
||||
|
||||
async def aclose(self: AsyncIOWrapper[_CanClose]) -> None:
|
||||
"""Like :meth:`io.IOBase.close`, but async.
|
||||
|
||||
This is also shielded from cancellation; if a cancellation scope is
|
||||
cancelled, the wrapped file object will still be safely closed.
|
||||
|
||||
"""
|
||||
|
||||
# ensure the underling file is closed during cancellation
|
||||
with trio.CancelScope(shield=True):
|
||||
await trio.to_thread.run_sync(self._wrapped.close)
|
||||
|
||||
await trio.lowlevel.checkpoint_if_cancelled()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# fmt: off
|
||||
# Based on typing.IO and io stubs.
|
||||
@property
|
||||
def closed(self: AsyncIOWrapper[_HasClosed]) -> bool: ...
|
||||
@property
|
||||
def encoding(self: AsyncIOWrapper[_HasEncoding]) -> str: ...
|
||||
@property
|
||||
def errors(self: AsyncIOWrapper[_HasErrors]) -> str | None: ...
|
||||
@property
|
||||
def newlines(self: AsyncIOWrapper[_HasNewlines[T]]) -> T: ...
|
||||
@property
|
||||
def buffer(self: AsyncIOWrapper[_HasBuffer]) -> BinaryIO: ...
|
||||
@property
|
||||
def raw(self: AsyncIOWrapper[_HasRaw]) -> io.RawIOBase: ...
|
||||
@property
|
||||
def line_buffering(self: AsyncIOWrapper[_HasLineBuffering]) -> int: ...
|
||||
@property
|
||||
def closefd(self: AsyncIOWrapper[_HasCloseFD]) -> bool: ...
|
||||
@property
|
||||
def name(self: AsyncIOWrapper[_HasName]) -> str: ...
|
||||
@property
|
||||
def mode(self: AsyncIOWrapper[_HasMode]) -> str: ...
|
||||
|
||||
def fileno(self: AsyncIOWrapper[_HasFileNo]) -> int: ...
|
||||
def isatty(self: AsyncIOWrapper[_HasIsATTY]) -> bool: ...
|
||||
def readable(self: AsyncIOWrapper[_HasReadable]) -> bool: ...
|
||||
def seekable(self: AsyncIOWrapper[_HasSeekable]) -> bool: ...
|
||||
def writable(self: AsyncIOWrapper[_HasWritable]) -> bool: ...
|
||||
def getvalue(self: AsyncIOWrapper[_CanGetValue[AnyStr]]) -> AnyStr: ...
|
||||
def getbuffer(self: AsyncIOWrapper[_CanGetBuffer]) -> memoryview: ...
|
||||
async def flush(self: AsyncIOWrapper[_CanFlush]) -> None: ...
|
||||
async def read(self: AsyncIOWrapper[_CanRead[AnyStr]], size: int | None = -1, /) -> AnyStr: ...
|
||||
async def read1(self: AsyncIOWrapper[_CanRead1], size: int | None = -1, /) -> bytes: ...
|
||||
async def readall(self: AsyncIOWrapper[_CanReadAll[AnyStr]]) -> AnyStr: ...
|
||||
async def readinto(self: AsyncIOWrapper[_CanReadInto], buf: Buffer, /) -> int | None: ...
|
||||
async def readline(self: AsyncIOWrapper[_CanReadLine[AnyStr]], size: int = -1, /) -> AnyStr: ...
|
||||
async def readlines(self: AsyncIOWrapper[_CanReadLines[AnyStr]]) -> list[AnyStr]: ...
|
||||
async def seek(self: AsyncIOWrapper[_CanSeek], target: int, whence: int = 0, /) -> int: ...
|
||||
async def tell(self: AsyncIOWrapper[_CanTell]) -> int: ...
|
||||
async def truncate(self: AsyncIOWrapper[_CanTruncate], size: int | None = None, /) -> int: ...
|
||||
async def write(self: AsyncIOWrapper[_CanWrite[T]], data: T, /) -> int: ...
|
||||
async def writelines(self: AsyncIOWrapper[_CanWriteLines[T]], lines: Iterable[T], /) -> None: ...
|
||||
async def readinto1(self: AsyncIOWrapper[_CanReadInto1], buffer: Buffer, /) -> int: ...
|
||||
async def peek(self: AsyncIOWrapper[_CanPeek[AnyStr]], size: int = 0, /) -> AnyStr: ...
|
||||
|
||||
|
||||
# Type hints are copied from builtin open.
|
||||
_OpenFile = Union["StrOrBytesPath", int]
|
||||
_Opener = Callable[[str, int], int]
|
||||
|
||||
|
||||
@overload
|
||||
async def open_file(
|
||||
file: _OpenFile,
|
||||
mode: OpenTextMode = "r",
|
||||
buffering: int = -1,
|
||||
encoding: str | None = None,
|
||||
errors: str | None = None,
|
||||
newline: str | None = None,
|
||||
closefd: bool = True,
|
||||
opener: _Opener | None = None,
|
||||
) -> AsyncIOWrapper[io.TextIOWrapper]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def open_file(
|
||||
file: _OpenFile,
|
||||
mode: OpenBinaryMode,
|
||||
buffering: Literal[0],
|
||||
encoding: None = None,
|
||||
errors: None = None,
|
||||
newline: None = None,
|
||||
closefd: bool = True,
|
||||
opener: _Opener | None = None,
|
||||
) -> AsyncIOWrapper[io.FileIO]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def open_file(
|
||||
file: _OpenFile,
|
||||
mode: OpenBinaryModeUpdating,
|
||||
buffering: Literal[-1, 1] = -1,
|
||||
encoding: None = None,
|
||||
errors: None = None,
|
||||
newline: None = None,
|
||||
closefd: bool = True,
|
||||
opener: _Opener | None = None,
|
||||
) -> AsyncIOWrapper[io.BufferedRandom]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def open_file(
|
||||
file: _OpenFile,
|
||||
mode: OpenBinaryModeWriting,
|
||||
buffering: Literal[-1, 1] = -1,
|
||||
encoding: None = None,
|
||||
errors: None = None,
|
||||
newline: None = None,
|
||||
closefd: bool = True,
|
||||
opener: _Opener | None = None,
|
||||
) -> AsyncIOWrapper[io.BufferedWriter]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def open_file(
|
||||
file: _OpenFile,
|
||||
mode: OpenBinaryModeReading,
|
||||
buffering: Literal[-1, 1] = -1,
|
||||
encoding: None = None,
|
||||
errors: None = None,
|
||||
newline: None = None,
|
||||
closefd: bool = True,
|
||||
opener: _Opener | None = None,
|
||||
) -> AsyncIOWrapper[io.BufferedReader]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def open_file(
|
||||
file: _OpenFile,
|
||||
mode: OpenBinaryMode,
|
||||
buffering: int,
|
||||
encoding: None = None,
|
||||
errors: None = None,
|
||||
newline: None = None,
|
||||
closefd: bool = True,
|
||||
opener: _Opener | None = None,
|
||||
) -> AsyncIOWrapper[BinaryIO]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def open_file( # type: ignore[misc] # Any usage matches builtins.open().
|
||||
file: _OpenFile,
|
||||
mode: str,
|
||||
buffering: int = -1,
|
||||
encoding: str | None = None,
|
||||
errors: str | None = None,
|
||||
newline: str | None = None,
|
||||
closefd: bool = True,
|
||||
opener: _Opener | None = None,
|
||||
) -> AsyncIOWrapper[IO[Any]]: ...
|
||||
|
||||
|
||||
async def open_file(
|
||||
file: _OpenFile,
|
||||
mode: str = "r",
|
||||
buffering: int = -1,
|
||||
encoding: str | None = None,
|
||||
errors: str | None = None,
|
||||
newline: str | None = None,
|
||||
closefd: bool = True,
|
||||
opener: _Opener | None = None,
|
||||
) -> AsyncIOWrapper[Any]:
|
||||
"""Asynchronous version of :func:`open`.
|
||||
|
||||
Returns:
|
||||
An :term:`asynchronous file object`
|
||||
|
||||
Example::
|
||||
|
||||
async with await trio.open_file(filename) as f:
|
||||
async for line in f:
|
||||
pass
|
||||
|
||||
assert f.closed
|
||||
|
||||
See also:
|
||||
:func:`trio.Path.open`
|
||||
|
||||
"""
|
||||
_file = wrap_file(
|
||||
await trio.to_thread.run_sync(
|
||||
io.open,
|
||||
file,
|
||||
mode,
|
||||
buffering,
|
||||
encoding,
|
||||
errors,
|
||||
newline,
|
||||
closefd,
|
||||
opener,
|
||||
),
|
||||
)
|
||||
return _file
|
||||
|
||||
|
||||
def wrap_file(file: FileT) -> AsyncIOWrapper[FileT]:
|
||||
"""This wraps any file object in a wrapper that provides an asynchronous
|
||||
file object interface.
|
||||
|
||||
Args:
|
||||
file: a :term:`file object`
|
||||
|
||||
Returns:
|
||||
An :term:`asynchronous file object` that wraps ``file``
|
||||
|
||||
Example::
|
||||
|
||||
async_file = trio.wrap_file(StringIO('asdf'))
|
||||
|
||||
assert await async_file.read() == 'asdf'
|
||||
|
||||
"""
|
||||
|
||||
def has(attr: str) -> bool:
|
||||
return hasattr(file, attr) and callable(getattr(file, attr))
|
||||
|
||||
if not (has("close") and (has("read") or has("write"))):
|
||||
raise TypeError(
|
||||
f"{file} does not implement required duck-file methods: "
|
||||
"close and (read or write)",
|
||||
)
|
||||
|
||||
return AsyncIOWrapper(file)
|
||||
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
import attrs
|
||||
|
||||
import trio
|
||||
from trio._util import final
|
||||
|
||||
from .abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
|
||||
SendStreamT = TypeVar("SendStreamT", bound=SendStream)
|
||||
ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream)
|
||||
|
||||
|
||||
async def aclose_forcefully(resource: AsyncResource) -> None:
|
||||
"""Close an async resource or async generator immediately, without
|
||||
blocking to do any graceful cleanup.
|
||||
|
||||
:class:`~trio.abc.AsyncResource` objects guarantee that if their
|
||||
:meth:`~trio.abc.AsyncResource.aclose` method is cancelled, then they will
|
||||
still close the resource (albeit in a potentially ungraceful
|
||||
fashion). :func:`aclose_forcefully` is a convenience function that
|
||||
exploits this behavior to let you force a resource to be closed without
|
||||
blocking: it works by calling ``await resource.aclose()`` and then
|
||||
cancelling it immediately.
|
||||
|
||||
Most users won't need this, but it may be useful on cleanup paths where
|
||||
you can't afford to block, or if you want to close a resource and don't
|
||||
care about handling it gracefully. For example, if
|
||||
:class:`~trio.SSLStream` encounters an error and cannot perform its
|
||||
own graceful close, then there's no point in waiting to gracefully shut
|
||||
down the underlying transport either, so it calls ``await
|
||||
aclose_forcefully(self.transport_stream)``.
|
||||
|
||||
Note that this function is async, and that it acts as a checkpoint, but
|
||||
unlike most async functions it cannot block indefinitely (at least,
|
||||
assuming the underlying resource object is correctly implemented).
|
||||
|
||||
"""
|
||||
with trio.CancelScope() as cs:
|
||||
cs.cancel()
|
||||
await resource.aclose()
|
||||
|
||||
|
||||
def _is_halfclosable(stream: SendStream) -> TypeGuard[HalfCloseableStream]:
|
||||
"""Check if the stream has a send_eof() method."""
|
||||
return hasattr(stream, "send_eof")
|
||||
|
||||
|
||||
@final
|
||||
@attrs.define(eq=False, slots=False)
|
||||
class StapledStream(
|
||||
HalfCloseableStream,
|
||||
Generic[SendStreamT, ReceiveStreamT],
|
||||
):
|
||||
"""This class `staples <https://en.wikipedia.org/wiki/Staple_(fastener)>`__
|
||||
together two unidirectional streams to make single bidirectional stream.
|
||||
|
||||
Args:
|
||||
send_stream (~trio.abc.SendStream): The stream to use for sending.
|
||||
receive_stream (~trio.abc.ReceiveStream): The stream to use for
|
||||
receiving.
|
||||
|
||||
Example:
|
||||
|
||||
A silly way to make a stream that echoes back whatever you write to
|
||||
it::
|
||||
|
||||
left, right = trio.testing.memory_stream_pair()
|
||||
echo_stream = StapledStream(SocketStream(left), SocketStream(right))
|
||||
await echo_stream.send_all(b"x")
|
||||
assert await echo_stream.receive_some() == b"x"
|
||||
|
||||
:class:`StapledStream` objects implement the methods in the
|
||||
:class:`~trio.abc.HalfCloseableStream` interface. They also have two
|
||||
additional public attributes:
|
||||
|
||||
.. attribute:: send_stream
|
||||
|
||||
The underlying :class:`~trio.abc.SendStream`. :meth:`send_all` and
|
||||
:meth:`wait_send_all_might_not_block` are delegated to this object.
|
||||
|
||||
.. attribute:: receive_stream
|
||||
|
||||
The underlying :class:`~trio.abc.ReceiveStream`. :meth:`receive_some`
|
||||
is delegated to this object.
|
||||
|
||||
"""
|
||||
|
||||
send_stream: SendStreamT
|
||||
receive_stream: ReceiveStreamT
|
||||
|
||||
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
|
||||
"""Calls ``self.send_stream.send_all``."""
|
||||
return await self.send_stream.send_all(data)
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
"""Calls ``self.send_stream.wait_send_all_might_not_block``."""
|
||||
return await self.send_stream.wait_send_all_might_not_block()
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
"""Shuts down the send side of the stream.
|
||||
|
||||
If :meth:`self.send_stream.send_eof() <trio.abc.HalfCloseableStream.send_eof>` exists,
|
||||
then this calls it. Otherwise, this calls
|
||||
:meth:`self.send_stream.aclose() <trio.abc.AsyncResource.aclose>`.
|
||||
"""
|
||||
stream = self.send_stream
|
||||
if _is_halfclosable(stream):
|
||||
return await stream.send_eof()
|
||||
else:
|
||||
return await stream.aclose()
|
||||
|
||||
# we intentionally accept more types from the caller than we support returning
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes:
|
||||
"""Calls ``self.receive_stream.receive_some``."""
|
||||
return await self.receive_stream.receive_some(max_bytes)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Calls ``aclose`` on both underlying streams."""
|
||||
try:
|
||||
await self.send_stream.aclose()
|
||||
finally:
|
||||
await self.receive_stream.aclose()
|
||||
@@ -0,0 +1,251 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import trio
|
||||
from trio import TaskStatus
|
||||
|
||||
from . import socket as tsocket
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import ExceptionGroup
|
||||
|
||||
|
||||
# Default backlog size:
|
||||
#
|
||||
# Having the backlog too low can cause practical problems (a perfectly healthy
|
||||
# service that starts failing to accept connections if they arrive in a
|
||||
# burst).
|
||||
#
|
||||
# Having it too high doesn't really cause any problems. Like any buffer, you
|
||||
# want backlog queue to be zero usually, and it won't save you if you're
|
||||
# getting connection attempts faster than you can call accept() on an ongoing
|
||||
# basis. But unlike other buffers, this one doesn't really provide any
|
||||
# backpressure. If a connection gets stuck waiting in the backlog queue, then
|
||||
# from the peer's point of view the connection succeeded but then their
|
||||
# send/recv will stall until we get to it, possibly for a long time. OTOH if
|
||||
# there isn't room in the backlog queue, then their connect stalls, possibly
|
||||
# for a long time, which is pretty much the same thing.
|
||||
#
|
||||
# A large backlog can also use a bit more kernel memory, but this seems fairly
|
||||
# negligible these days.
|
||||
#
|
||||
# So this suggests we should make the backlog as large as possible. This also
|
||||
# matches what Golang does. However, they do it in a weird way, where they
|
||||
# have a bunch of code to sniff out the configured upper limit for backlog on
|
||||
# different operating systems. But on every system, passing in a too-large
|
||||
# backlog just causes it to be silently truncated to the configured maximum,
|
||||
# so this is unnecessary -- we can just pass in "infinity" and get the maximum
|
||||
# that way. (Verified on Windows, Linux, macOS using
|
||||
# notes-to-self/measure-listen-backlog.py)
|
||||
def _compute_backlog(backlog: int | None) -> int:
|
||||
# Many systems (Linux, BSDs, ...) store the backlog in a uint16 and are
|
||||
# missing overflow protection, so we apply our own overflow protection.
|
||||
# https://github.com/golang/go/issues/5030
|
||||
if not isinstance(backlog, int) and backlog is not None:
|
||||
raise TypeError(f"backlog must be an int or None, not {backlog!r}")
|
||||
if backlog is None:
|
||||
return 0xFFFF
|
||||
return min(backlog, 0xFFFF)
|
||||
|
||||
|
||||
async def open_tcp_listeners(
|
||||
port: int,
|
||||
*,
|
||||
host: str | bytes | None = None,
|
||||
backlog: int | None = None,
|
||||
) -> list[trio.SocketListener]:
|
||||
"""Create :class:`SocketListener` objects to listen for TCP connections.
|
||||
|
||||
Args:
|
||||
|
||||
port (int): The port to listen on.
|
||||
|
||||
If you use 0 as your port, then the kernel will automatically pick
|
||||
an arbitrary open port. But be careful: if you use this feature when
|
||||
binding to multiple IP addresses, then each IP address will get its
|
||||
own random port, and the returned listeners will probably be
|
||||
listening on different ports. In particular, this will happen if you
|
||||
use ``host=None`` – which is the default – because in this case
|
||||
:func:`open_tcp_listeners` will bind to both the IPv4 wildcard
|
||||
address (``0.0.0.0``) and also the IPv6 wildcard address (``::``).
|
||||
|
||||
host (str, bytes, or None): The local interface to bind to. This is
|
||||
passed to :func:`~socket.getaddrinfo` with the ``AI_PASSIVE`` flag
|
||||
set.
|
||||
|
||||
If you want to bind to the wildcard address on both IPv4 and IPv6,
|
||||
in order to accept connections on all available interfaces, then
|
||||
pass ``None``. This is the default.
|
||||
|
||||
If you have a specific interface you want to bind to, pass its IP
|
||||
address or hostname here. If a hostname resolves to multiple IP
|
||||
addresses, this function will open one listener on each of them.
|
||||
|
||||
If you want to use only IPv4, or only IPv6, but want to accept on
|
||||
all interfaces, pass the family-specific wildcard address:
|
||||
``"0.0.0.0"`` for IPv4-only and ``"::"`` for IPv6-only.
|
||||
|
||||
backlog (int or None): The listen backlog to use. If you leave this as
|
||||
``None`` then Trio will pick a good default. (Currently: whatever
|
||||
your system has configured as the maximum backlog.)
|
||||
|
||||
Returns:
|
||||
list of :class:`SocketListener`
|
||||
|
||||
Raises:
|
||||
:class:`TypeError` if invalid arguments.
|
||||
|
||||
"""
|
||||
# getaddrinfo sometimes allows port=None, sometimes not (depending on
|
||||
# whether host=None). And on some systems it treats "" as 0, others it
|
||||
# doesn't:
|
||||
# http://klickverbot.at/blog/2012/01/getaddrinfo-edge-case-behavior-on-windows-linux-and-osx/
|
||||
if not isinstance(port, int):
|
||||
raise TypeError(f"port must be an int not {port!r}")
|
||||
|
||||
computed_backlog = _compute_backlog(backlog)
|
||||
|
||||
addresses = await tsocket.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
type=tsocket.SOCK_STREAM,
|
||||
flags=tsocket.AI_PASSIVE,
|
||||
)
|
||||
|
||||
listeners = []
|
||||
unsupported_address_families = []
|
||||
try:
|
||||
for family, type_, proto, _, sockaddr in addresses:
|
||||
try:
|
||||
sock = tsocket.socket(family, type_, proto)
|
||||
except OSError as ex:
|
||||
if ex.errno == errno.EAFNOSUPPORT:
|
||||
# If a system only supports IPv4, or only IPv6, it
|
||||
# is still likely that getaddrinfo will return
|
||||
# both an IPv4 and an IPv6 address. As long as at
|
||||
# least one of the returned addresses can be
|
||||
# turned into a socket, we won't complain about a
|
||||
# failure to create the other.
|
||||
unsupported_address_families.append(ex)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
try:
|
||||
# See https://github.com/python-trio/trio/issues/39
|
||||
if sys.platform != "win32":
|
||||
sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_REUSEADDR, 1)
|
||||
|
||||
if family == tsocket.AF_INET6:
|
||||
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, 1)
|
||||
|
||||
await sock.bind(sockaddr)
|
||||
sock.listen(computed_backlog)
|
||||
|
||||
listeners.append(trio.SocketListener(sock))
|
||||
except:
|
||||
sock.close()
|
||||
raise
|
||||
except:
|
||||
for listener in listeners:
|
||||
listener.socket.close()
|
||||
raise
|
||||
|
||||
if unsupported_address_families and not listeners:
|
||||
msg = (
|
||||
"This system doesn't support any of the kinds of "
|
||||
"socket that that address could use"
|
||||
)
|
||||
raise OSError(errno.EAFNOSUPPORT, msg) from ExceptionGroup(
|
||||
msg,
|
||||
unsupported_address_families,
|
||||
)
|
||||
|
||||
return listeners
|
||||
|
||||
|
||||
async def serve_tcp(
|
||||
handler: Callable[[trio.SocketStream], Awaitable[object]],
|
||||
port: int,
|
||||
*,
|
||||
host: str | bytes | None = None,
|
||||
backlog: int | None = None,
|
||||
handler_nursery: trio.Nursery | None = None,
|
||||
task_status: TaskStatus[list[trio.SocketListener]] = trio.TASK_STATUS_IGNORED,
|
||||
) -> None:
|
||||
"""Listen for incoming TCP connections, and for each one start a task
|
||||
running ``handler(stream)``.
|
||||
|
||||
This is a thin convenience wrapper around :func:`open_tcp_listeners` and
|
||||
:func:`serve_listeners` – see them for full details.
|
||||
|
||||
.. warning::
|
||||
|
||||
If ``handler`` raises an exception, then this function doesn't do
|
||||
anything special to catch it – so by default the exception will
|
||||
propagate out and crash your server. If you don't want this, then catch
|
||||
exceptions inside your ``handler``, or use a ``handler_nursery`` object
|
||||
that responds to exceptions in some other way.
|
||||
|
||||
When used with ``nursery.start`` you get back the newly opened listeners.
|
||||
So, for example, if you want to start a server in your test suite and then
|
||||
connect to it to check that it's working properly, you can use something
|
||||
like::
|
||||
|
||||
from trio import SocketListener, SocketStream
|
||||
from trio.testing import open_stream_to_socket_listener
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0)
|
||||
client_stream: SocketStream = await open_stream_to_socket_listener(listeners[0])
|
||||
|
||||
# Then send and receive data on 'client_stream', for example:
|
||||
await client_stream.send_all(b"GET / HTTP/1.0\\r\\n\\r\\n")
|
||||
|
||||
This avoids several common pitfalls:
|
||||
|
||||
1. It lets the kernel pick a random open port, so your test suite doesn't
|
||||
depend on any particular port being open.
|
||||
|
||||
2. It waits for the server to be accepting connections on that port before
|
||||
``start`` returns, so there's no race condition where the incoming
|
||||
connection arrives before the server is ready.
|
||||
|
||||
3. It uses the Listener object to find out which port was picked, so it
|
||||
can connect to the right place.
|
||||
|
||||
Args:
|
||||
handler: The handler to start for each incoming connection. Passed to
|
||||
:func:`serve_listeners`.
|
||||
|
||||
port: The port to listen on. Use 0 to let the kernel pick an open port.
|
||||
Passed to :func:`open_tcp_listeners`.
|
||||
|
||||
host (str, bytes, or None): The host interface to listen on; use
|
||||
``None`` to bind to the wildcard address. Passed to
|
||||
:func:`open_tcp_listeners`.
|
||||
|
||||
backlog: The listen backlog, or None to have a good default picked.
|
||||
Passed to :func:`open_tcp_listeners`.
|
||||
|
||||
handler_nursery: The nursery to start handlers in, or None to use an
|
||||
internal nursery. Passed to :func:`serve_listeners`.
|
||||
|
||||
task_status: This function can be used with ``nursery.start``.
|
||||
|
||||
Returns:
|
||||
This function only returns when cancelled.
|
||||
|
||||
"""
|
||||
listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog)
|
||||
await trio.serve_listeners(
|
||||
handler,
|
||||
listeners,
|
||||
handler_nursery=handler_nursery,
|
||||
task_status=task_status,
|
||||
)
|
||||
@@ -0,0 +1,411 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import trio
|
||||
from trio.socket import SOCK_STREAM, SocketType, getaddrinfo, socket
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from socket import AddressFamily, SocketKind
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
|
||||
|
||||
|
||||
# Implementation of RFC 6555 "Happy eyeballs"
|
||||
# https://tools.ietf.org/html/rfc6555
|
||||
#
|
||||
# Basically, the problem here is that if we want to connect to some host, and
|
||||
# DNS returns multiple IP addresses, then we don't know which of them will
|
||||
# actually work -- it can happen that some of them are reachable, and some of
|
||||
# them are not. One particularly common situation where this happens is on a
|
||||
# host that thinks it has ipv6 connectivity, but really doesn't. But in
|
||||
# principle this could happen for any kind of multi-home situation (e.g. the
|
||||
# route to one mirror is down but another is up).
|
||||
#
|
||||
# The naive algorithm (e.g. the stdlib's socket.create_connection) would be to
|
||||
# pick one of the IP addresses and try to connect; if that fails, try the
|
||||
# next; etc. The problem with this is that TCP is stubborn, and if the first
|
||||
# address is a blackhole then it might take a very long time (tens of seconds)
|
||||
# before that connection attempt fails.
|
||||
#
|
||||
# That's where RFC 6555 comes in. It tells us that what we do is:
|
||||
# - get the list of IPs from getaddrinfo, trusting the order it gives us (with
|
||||
# one exception noted in section 5.4)
|
||||
# - start a connection attempt to the first IP
|
||||
# - when this fails OR if it's still going after DELAY seconds, then start a
|
||||
# connection attempt to the second IP
|
||||
# - when this fails OR if it's still going after another DELAY seconds, then
|
||||
# start a connection attempt to the third IP
|
||||
# - ... repeat until we run out of IPs.
|
||||
#
|
||||
# Our implementation is similarly straightforward: we spawn a chain of tasks,
|
||||
# where each one (a) waits until the previous connection has failed or DELAY
|
||||
# seconds have passed, (b) spawns the next task, (c) attempts to connect. As
|
||||
# soon as any task crashes or succeeds, we cancel all the tasks and return.
|
||||
#
|
||||
# Note: this currently doesn't attempt to cache any results, so if you make
|
||||
# multiple connections to the same host it'll re-run the happy-eyeballs
|
||||
# algorithm each time. RFC 6555 is pretty confusing about whether this is
|
||||
# allowed. Section 4 describes an algorithm that attempts ipv4 and ipv6
|
||||
# simultaneously, and then says "The client MUST cache information regarding
|
||||
# the outcome of each connection attempt, and it uses that information to
|
||||
# avoid thrashing the network with subsequent attempts." Then section 4.2 says
|
||||
# "implementations MUST prefer the first IP address family returned by the
|
||||
# host's address preference policy, unless implementing a stateful
|
||||
# algorithm". Here "stateful" means "one that caches information about
|
||||
# previous attempts". So my reading of this is that IF you're starting ipv4
|
||||
# and ipv6 at the same time then you MUST cache the result for ~ten minutes,
|
||||
# but IF you're "preferring" one protocol by trying it first (like we are),
|
||||
# then you don't need to cache.
|
||||
#
|
||||
# Caching is quite tricky: to get it right you need to do things like detect
|
||||
# when the network interfaces are reconfigured, and if you get it wrong then
|
||||
# connection attempts basically just don't work. So we don't even try.
|
||||
|
||||
# "Firefox and Chrome use 300 ms"
|
||||
# https://tools.ietf.org/html/rfc6555#section-6
|
||||
# Though
|
||||
# https://www.researchgate.net/profile/Vaibhav_Bajpai3/publication/304568993_Measuring_the_Effects_of_Happy_Eyeballs/links/5773848e08ae6f328f6c284c/Measuring-the-Effects-of-Happy-Eyeballs.pdf
|
||||
# claims that Firefox actually uses 0 ms, unless an about:config option is
|
||||
# toggled and then it uses 250 ms.
|
||||
DEFAULT_DELAY = 0.250
|
||||
|
||||
# How should we call getaddrinfo? In particular, should we use AI_ADDRCONFIG?
|
||||
#
|
||||
# The idea of AI_ADDRCONFIG is that it only returns addresses that might
|
||||
# work. E.g., if getaddrinfo knows that you don't have any IPv6 connectivity,
|
||||
# then it doesn't return any IPv6 addresses. And this is kinda nice, because
|
||||
# it means maybe you can skip sending AAAA requests entirely. But in practice,
|
||||
# it doesn't really work right.
|
||||
#
|
||||
# - on Linux/glibc, empirically, the default is to return all addresses, and
|
||||
# with AI_ADDRCONFIG then it only returns IPv6 addresses if there is at least
|
||||
# one non-loopback IPv6 address configured... but this can be a link-local
|
||||
# address, so in practice I guess this is basically always configured if IPv6
|
||||
# is enabled at all. OTOH if you pass in "::1" as the target address with
|
||||
# AI_ADDRCONFIG and there's no *external* IPv6 address configured, you get an
|
||||
# error. So AI_ADDRCONFIG mostly doesn't do anything, even when you would want
|
||||
# it to, and when it does do something it might break things that would have
|
||||
# worked.
|
||||
#
|
||||
# - on Windows 10, empirically, if no IPv6 address is configured then by
|
||||
# default they are also suppressed from getaddrinfo (flags=0 and
|
||||
# flags=AI_ADDRCONFIG seem to do the same thing). If you pass AI_ALL, then you
|
||||
# get the full list.
|
||||
# ...except for localhost! getaddrinfo("localhost", "80") gives me ::1, even
|
||||
# though there's no ipv6 and other queries only return ipv4.
|
||||
# If you pass in and IPv6 IP address as the target address, then that's always
|
||||
# returned OK, even with AI_ADDRCONFIG set and no IPv6 configured.
|
||||
#
|
||||
# But I guess other versions of windows messed this up, judging from these bug
|
||||
# reports:
|
||||
# https://bugs.chromium.org/p/chromium/issues/detail?id=5234
|
||||
# https://bugs.chromium.org/p/chromium/issues/detail?id=32522#c50
|
||||
#
|
||||
# So basically the options are either to use AI_ADDRCONFIG and then add some
|
||||
# complicated special cases to work around its brokenness, or else don't use
|
||||
# AI_ADDRCONFIG and accept that sometimes on legacy/misconfigured networks
|
||||
# we'll waste 300 ms trying to connect to a blackholed destination.
|
||||
#
|
||||
# Twisted and Tornado always uses default flags. I think we'll do the same.
|
||||
|
||||
|
||||
@contextmanager
|
||||
def close_all() -> Generator[set[SocketType], None, None]:
|
||||
sockets_to_close: set[SocketType] = set()
|
||||
try:
|
||||
yield sockets_to_close
|
||||
finally:
|
||||
errs = []
|
||||
for sock in sockets_to_close:
|
||||
try:
|
||||
sock.close()
|
||||
except BaseException as exc:
|
||||
errs.append(exc)
|
||||
if len(errs) == 1:
|
||||
raise errs[0]
|
||||
elif errs:
|
||||
raise BaseExceptionGroup("", errs)
|
||||
|
||||
|
||||
def reorder_for_rfc_6555_section_5_4(
|
||||
targets: list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
Any,
|
||||
]
|
||||
],
|
||||
) -> None:
|
||||
# RFC 6555 section 5.4 says that if getaddrinfo returns multiple address
|
||||
# families (e.g. IPv4 and IPv6), then you should make sure that your first
|
||||
# and second attempts use different families:
|
||||
#
|
||||
# https://tools.ietf.org/html/rfc6555#section-5.4
|
||||
#
|
||||
# This function post-processes the results from getaddrinfo, in-place, to
|
||||
# satisfy this requirement.
|
||||
for i in range(1, len(targets)):
|
||||
if targets[i][0] != targets[0][0]:
|
||||
# Found the first entry with a different address family; move it
|
||||
# so that it becomes the second item on the list.
|
||||
if i != 1:
|
||||
targets.insert(1, targets.pop(i))
|
||||
break
|
||||
|
||||
|
||||
def format_host_port(host: str | bytes, port: int | str) -> str:
|
||||
host = host.decode("ascii") if isinstance(host, bytes) else host
|
||||
if ":" in host:
|
||||
return f"[{host}]:{port}"
|
||||
else:
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
# Twisted's HostnameEndpoint has a good set of configurables:
|
||||
# https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.HostnameEndpoint.html
|
||||
#
|
||||
# - per-connection timeout
|
||||
# this doesn't seem useful -- we let you set a timeout on the whole thing
|
||||
# using Trio's normal mechanisms, and that seems like enough
|
||||
# - delay between attempts
|
||||
# - bind address (but not port!)
|
||||
# they *don't* support multiple address bindings, like giving the ipv4 and
|
||||
# ipv6 addresses of the host.
|
||||
# I think maybe our semantics should be: we accept a list of bind addresses,
|
||||
# and we bind to the first one that is compatible with the
|
||||
# connection attempt we want to make, and if none are compatible then we
|
||||
# don't try to connect to that target.
|
||||
#
|
||||
# XX TODO: implement bind address support
|
||||
#
|
||||
# Actually, the best option is probably to be explicit: {AF_INET: "...",
|
||||
# AF_INET6: "..."}
|
||||
# this might be simpler after
|
||||
async def open_tcp_stream(
|
||||
host: str | bytes,
|
||||
port: int,
|
||||
*,
|
||||
happy_eyeballs_delay: float | None = DEFAULT_DELAY,
|
||||
local_address: str | None = None,
|
||||
) -> trio.SocketStream:
|
||||
"""Connect to the given host and port over TCP.
|
||||
|
||||
If the given ``host`` has multiple IP addresses associated with it, then
|
||||
we have a problem: which one do we use?
|
||||
|
||||
One approach would be to attempt to connect to the first one, and then if
|
||||
that fails, attempt to connect to the second one ... until we've tried all
|
||||
of them. But the problem with this is that if the first IP address is
|
||||
unreachable (for example, because it's an IPv6 address and our network
|
||||
discards IPv6 packets), then we might end up waiting tens of seconds for
|
||||
the first connection attempt to timeout before we try the second address.
|
||||
|
||||
Another approach would be to attempt to connect to all of the addresses at
|
||||
the same time, in parallel, and then use whichever connection succeeds
|
||||
first, abandoning the others. This would be fast, but create a lot of
|
||||
unnecessary load on the network and the remote server.
|
||||
|
||||
This function strikes a balance between these two extremes: it works its
|
||||
way through the available addresses one at a time, like the first
|
||||
approach; but, if ``happy_eyeballs_delay`` seconds have passed and it's
|
||||
still waiting for an attempt to succeed or fail, then it gets impatient
|
||||
and starts the next connection attempt in parallel. As soon as any one
|
||||
connection attempt succeeds, all the other attempts are cancelled. This
|
||||
avoids unnecessary load because most connections will succeed after just
|
||||
one or two attempts, but if one of the addresses is unreachable then it
|
||||
doesn't slow us down too much.
|
||||
|
||||
This is known as a "happy eyeballs" algorithm, and our particular variant
|
||||
is modelled after how Chrome connects to webservers; see `RFC 6555
|
||||
<https://tools.ietf.org/html/rfc6555>`__ for more details.
|
||||
|
||||
Args:
|
||||
host (str or bytes): The host to connect to. Can be an IPv4 address,
|
||||
IPv6 address, or a hostname.
|
||||
|
||||
port (int): The port to connect to.
|
||||
|
||||
happy_eyeballs_delay (float or None): How many seconds to wait for each
|
||||
connection attempt to succeed or fail before getting impatient and
|
||||
starting another one in parallel. Set to `None` if you want
|
||||
to limit to only one connection attempt at a time (like
|
||||
:func:`socket.create_connection`). Default: 0.25 (250 ms).
|
||||
|
||||
local_address (None or str): The local IP address or hostname to use as
|
||||
the source for outgoing connections. If ``None``, we let the OS pick
|
||||
the source IP.
|
||||
|
||||
This is useful in some exotic networking configurations where your
|
||||
host has multiple IP addresses, and you want to force the use of a
|
||||
specific one.
|
||||
|
||||
Note that if you pass an IPv4 ``local_address``, then you won't be
|
||||
able to connect to IPv6 hosts, and vice-versa. If you want to take
|
||||
advantage of this to force the use of IPv4 or IPv6 without
|
||||
specifying an exact source address, you can use the IPv4 wildcard
|
||||
address ``local_address="0.0.0.0"``, or the IPv6 wildcard address
|
||||
``local_address="::"``.
|
||||
|
||||
Returns:
|
||||
SocketStream: a :class:`~trio.abc.Stream` connected to the given server.
|
||||
|
||||
Raises:
|
||||
OSError: if the connection fails.
|
||||
|
||||
See also:
|
||||
open_ssl_over_tcp_stream
|
||||
|
||||
"""
|
||||
|
||||
# To keep our public API surface smaller, rule out some cases that
|
||||
# getaddrinfo will accept in some circumstances, but that act weird or
|
||||
# have non-portable behavior or are just plain not useful.
|
||||
if not isinstance(host, (str, bytes)):
|
||||
raise ValueError(f"host must be str or bytes, not {host!r}")
|
||||
if not isinstance(port, int):
|
||||
raise TypeError(f"port must be int, not {port!r}")
|
||||
|
||||
if happy_eyeballs_delay is None:
|
||||
happy_eyeballs_delay = DEFAULT_DELAY
|
||||
|
||||
targets = await getaddrinfo(host, port, type=SOCK_STREAM)
|
||||
|
||||
# I don't think this can actually happen -- if there are no results,
|
||||
# getaddrinfo should have raised OSError instead of returning an empty
|
||||
# list. But let's be paranoid and handle it anyway:
|
||||
if not targets:
|
||||
msg = f"no results found for hostname lookup: {format_host_port(host, port)}"
|
||||
raise OSError(msg)
|
||||
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
|
||||
# This list records all the connection failures that we ignored.
|
||||
oserrors: list[OSError] = []
|
||||
|
||||
# Keeps track of the socket that we're going to complete with,
|
||||
# need to make sure this isn't automatically closed
|
||||
winning_socket: SocketType | None = None
|
||||
|
||||
# Try connecting to the specified address. Possible outcomes:
|
||||
# - success: record connected socket in winning_socket and cancel
|
||||
# concurrent attempts
|
||||
# - failure: record exception in oserrors, set attempt_failed allowing
|
||||
# the next connection attempt to start early
|
||||
# code needs to ensure sockets can be closed appropriately in the
|
||||
# face of crash or cancellation
|
||||
async def attempt_connect(
|
||||
socket_args: tuple[AddressFamily, SocketKind, int],
|
||||
sockaddr: Any,
|
||||
attempt_failed: trio.Event,
|
||||
) -> None:
|
||||
nonlocal winning_socket
|
||||
|
||||
try:
|
||||
sock = socket(*socket_args)
|
||||
open_sockets.add(sock)
|
||||
|
||||
if local_address is not None:
|
||||
# TCP connections are identified by a 4-tuple:
|
||||
#
|
||||
# (local IP, local port, remote IP, remote port)
|
||||
#
|
||||
# So if a single local IP wants to make multiple connections
|
||||
# to the same (remote IP, remote port) pair, then those
|
||||
# connections have to use different local ports, or else TCP
|
||||
# won't be able to tell them apart. OTOH, if you have multiple
|
||||
# connections to different remote IP/ports, then those
|
||||
# connections can share a local port.
|
||||
#
|
||||
# Normally, when you call bind(), the kernel will immediately
|
||||
# assign a specific local port to your socket. At this point
|
||||
# the kernel doesn't know which (remote IP, remote port)
|
||||
# you're going to use, so it has to pick a local port that
|
||||
# *no* other connection is using. That's the only way to
|
||||
# guarantee that this local port will be usable later when we
|
||||
# call connect(). (Alternatively, you can set SO_REUSEADDR to
|
||||
# allow multiple nascent connections to share the same port,
|
||||
# but then connect() might fail with EADDRNOTAVAIL if we get
|
||||
# unlucky and our TCP 4-tuple ends up colliding with another
|
||||
# unrelated connection.)
|
||||
#
|
||||
# So calling bind() before connect() works, but it disables
|
||||
# sharing of local ports. This is inefficient: it makes you
|
||||
# more likely to run out of local ports.
|
||||
#
|
||||
# But on some versions of Linux, we can re-enable sharing of
|
||||
# local ports by setting a special flag. This flag tells
|
||||
# bind() to only bind the IP, and not the port. That way,
|
||||
# connect() is allowed to pick the the port, and it can do a
|
||||
# better job of it because it knows the remote IP/port.
|
||||
with suppress(OSError, AttributeError):
|
||||
sock.setsockopt(
|
||||
trio.socket.IPPROTO_IP,
|
||||
trio.socket.IP_BIND_ADDRESS_NO_PORT,
|
||||
1,
|
||||
)
|
||||
try:
|
||||
await sock.bind((local_address, 0))
|
||||
except OSError:
|
||||
raise OSError(
|
||||
f"local_address={local_address!r} is incompatible "
|
||||
f"with remote address {sockaddr!r}",
|
||||
) from None
|
||||
|
||||
await sock.connect(sockaddr)
|
||||
|
||||
# Success! Save the winning socket and cancel all outstanding
|
||||
# connection attempts.
|
||||
winning_socket = sock
|
||||
nursery.cancel_scope.cancel()
|
||||
except OSError as exc:
|
||||
# This connection attempt failed, but the next one might
|
||||
# succeed. Save the error for later so we can report it if
|
||||
# everything fails, and tell the next attempt that it should go
|
||||
# ahead (if it hasn't already).
|
||||
oserrors.append(exc)
|
||||
attempt_failed.set()
|
||||
|
||||
with close_all() as open_sockets:
|
||||
# nursery spawns a task for each connection attempt, will be
|
||||
# cancelled by the task that gets a successful connection
|
||||
async with trio.open_nursery() as nursery:
|
||||
for address_family, socket_type, proto, _, addr in targets:
|
||||
# create an event to indicate connection failure,
|
||||
# allowing the next target to be tried early
|
||||
attempt_failed = trio.Event()
|
||||
|
||||
# workaround to check types until typing of nursery.start_soon improved
|
||||
if TYPE_CHECKING:
|
||||
await attempt_connect(
|
||||
(address_family, socket_type, proto),
|
||||
addr,
|
||||
attempt_failed,
|
||||
)
|
||||
|
||||
nursery.start_soon(
|
||||
attempt_connect,
|
||||
(address_family, socket_type, proto),
|
||||
addr,
|
||||
attempt_failed,
|
||||
)
|
||||
|
||||
# give this attempt at most this time before moving on
|
||||
with trio.move_on_after(happy_eyeballs_delay):
|
||||
await attempt_failed.wait()
|
||||
|
||||
# nothing succeeded
|
||||
if winning_socket is None:
|
||||
assert len(oserrors) == len(targets)
|
||||
msg = f"all attempts to connect to {format_host_port(host, port)} failed"
|
||||
raise OSError(msg) from ExceptionGroup(msg, oserrors)
|
||||
else:
|
||||
stream = trio.SocketStream(winning_socket)
|
||||
open_sockets.remove(winning_socket)
|
||||
return stream
|
||||
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Protocol, TypeVar
|
||||
|
||||
import trio
|
||||
from trio.socket import SOCK_STREAM, socket
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
class Closable(Protocol):
|
||||
def close(self) -> None: ...
|
||||
|
||||
|
||||
CloseT = TypeVar("CloseT", bound=Closable)
|
||||
|
||||
|
||||
try:
|
||||
from trio.socket import AF_UNIX
|
||||
|
||||
has_unix = True
|
||||
except ImportError:
|
||||
has_unix = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def close_on_error(obj: CloseT) -> Generator[CloseT, None, None]:
|
||||
try:
|
||||
yield obj
|
||||
except:
|
||||
obj.close()
|
||||
raise
|
||||
|
||||
|
||||
async def open_unix_socket(
|
||||
filename: str | bytes | os.PathLike[str] | os.PathLike[bytes],
|
||||
) -> trio.SocketStream:
|
||||
"""Opens a connection to the specified
|
||||
`Unix domain socket <https://en.wikipedia.org/wiki/Unix_domain_socket>`__.
|
||||
|
||||
You must have read/write permission on the specified file to connect.
|
||||
|
||||
Args:
|
||||
filename (str or bytes): The filename to open the connection to.
|
||||
|
||||
Returns:
|
||||
SocketStream: a :class:`~trio.abc.Stream` connected to the given file.
|
||||
|
||||
Raises:
|
||||
OSError: If the socket file could not be connected to.
|
||||
RuntimeError: If AF_UNIX sockets are not supported.
|
||||
"""
|
||||
if not has_unix:
|
||||
raise RuntimeError("Unix sockets are not supported on this platform")
|
||||
|
||||
# much more simplified logic vs tcp sockets - one socket type and only one
|
||||
# possible location to connect to
|
||||
sock = socket(AF_UNIX, SOCK_STREAM)
|
||||
with close_on_error(sock):
|
||||
await sock.connect(os.fspath(filename))
|
||||
|
||||
return trio.SocketStream(sock)
|
||||
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable, NoReturn, TypeVar
|
||||
|
||||
import trio
|
||||
|
||||
# Errors that accept(2) can return, and which indicate that the system is
|
||||
# overloaded
|
||||
ACCEPT_CAPACITY_ERRNOS = {
|
||||
errno.EMFILE,
|
||||
errno.ENFILE,
|
||||
errno.ENOMEM,
|
||||
errno.ENOBUFS,
|
||||
}
|
||||
|
||||
# How long to sleep when we get one of those errors
|
||||
SLEEP_TIME = 0.100
|
||||
|
||||
# The logger we use to complain when this happens
|
||||
LOGGER = logging.getLogger("trio.serve_listeners")
|
||||
|
||||
|
||||
StreamT = TypeVar("StreamT", bound=trio.abc.AsyncResource)
|
||||
ListenerT = TypeVar("ListenerT", bound=trio.abc.Listener[Any])
|
||||
Handler = Callable[[StreamT], Awaitable[object]]
|
||||
|
||||
|
||||
async def _run_handler(stream: StreamT, handler: Handler[StreamT]) -> None:
|
||||
try:
|
||||
await handler(stream)
|
||||
finally:
|
||||
await trio.aclose_forcefully(stream)
|
||||
|
||||
|
||||
async def _serve_one_listener(
|
||||
listener: trio.abc.Listener[StreamT],
|
||||
handler_nursery: trio.Nursery,
|
||||
handler: Handler[StreamT],
|
||||
) -> NoReturn:
|
||||
async with listener:
|
||||
while True:
|
||||
try:
|
||||
stream = await listener.accept()
|
||||
except OSError as exc:
|
||||
if exc.errno in ACCEPT_CAPACITY_ERRNOS:
|
||||
LOGGER.error(
|
||||
"accept returned %s (%s); retrying in %s seconds",
|
||||
errno.errorcode[exc.errno],
|
||||
os.strerror(exc.errno),
|
||||
SLEEP_TIME,
|
||||
exc_info=True,
|
||||
)
|
||||
await trio.sleep(SLEEP_TIME)
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
handler_nursery.start_soon(_run_handler, stream, handler)
|
||||
|
||||
|
||||
# This cannot be typed correctly, we need generic typevar bounds / HKT to indicate the
|
||||
# relationship between StreamT & ListenerT.
|
||||
# https://github.com/python/typing/issues/1226
|
||||
# https://github.com/python/typing/issues/548
|
||||
|
||||
|
||||
async def serve_listeners(
|
||||
handler: Handler[StreamT],
|
||||
listeners: list[ListenerT],
|
||||
*,
|
||||
handler_nursery: trio.Nursery | None = None,
|
||||
task_status: trio.TaskStatus[list[ListenerT]] = trio.TASK_STATUS_IGNORED,
|
||||
) -> NoReturn:
|
||||
r"""Listen for incoming connections on ``listeners``, and for each one
|
||||
start a task running ``handler(stream)``.
|
||||
|
||||
.. warning::
|
||||
|
||||
If ``handler`` raises an exception, then this function doesn't do
|
||||
anything special to catch it – so by default the exception will
|
||||
propagate out and crash your server. If you don't want this, then catch
|
||||
exceptions inside your ``handler``, or use a ``handler_nursery`` object
|
||||
that responds to exceptions in some other way.
|
||||
|
||||
Args:
|
||||
|
||||
handler: An async callable, that will be invoked like
|
||||
``handler_nursery.start_soon(handler, stream)`` for each incoming
|
||||
connection.
|
||||
|
||||
listeners: A list of :class:`~trio.abc.Listener` objects.
|
||||
:func:`serve_listeners` takes responsibility for closing them.
|
||||
|
||||
handler_nursery: The nursery used to start handlers, or any object with
|
||||
a ``start_soon`` method. If ``None`` (the default), then
|
||||
:func:`serve_listeners` will create a new nursery internally and use
|
||||
that.
|
||||
|
||||
task_status: This function can be used with ``nursery.start``, which
|
||||
will return ``listeners``.
|
||||
|
||||
Returns:
|
||||
|
||||
This function never returns unless cancelled.
|
||||
|
||||
Resource handling:
|
||||
|
||||
If ``handler`` neglects to close the ``stream``, then it will be closed
|
||||
using :func:`trio.aclose_forcefully`.
|
||||
|
||||
Error handling:
|
||||
|
||||
Most errors coming from :meth:`~trio.abc.Listener.accept` are allowed to
|
||||
propagate out (crashing the server in the process). However, some errors –
|
||||
those which indicate that the server is temporarily overloaded – are
|
||||
handled specially. These are :class:`OSError`\s with one of the following
|
||||
errnos:
|
||||
|
||||
* ``EMFILE``: process is out of file descriptors
|
||||
* ``ENFILE``: system is out of file descriptors
|
||||
* ``ENOBUFS``, ``ENOMEM``: the kernel hit some sort of memory limitation
|
||||
when trying to create a socket object
|
||||
|
||||
When :func:`serve_listeners` gets one of these errors, then it:
|
||||
|
||||
* Logs the error to the standard library logger ``trio.serve_listeners``
|
||||
(level = ERROR, with exception information included). By default this
|
||||
causes it to be printed to stderr.
|
||||
* Waits 100 ms before calling ``accept`` again, in hopes that the
|
||||
system will recover.
|
||||
|
||||
"""
|
||||
async with trio.open_nursery() as nursery:
|
||||
if handler_nursery is None:
|
||||
handler_nursery = nursery
|
||||
for listener in listeners:
|
||||
nursery.start_soon(_serve_one_listener, listener, handler_nursery, handler)
|
||||
# The listeners are already queueing connections when we're called,
|
||||
# but we wait until the end to call started() just in case we get an
|
||||
# error or whatever.
|
||||
task_status.started(listeners)
|
||||
|
||||
raise AssertionError(
|
||||
"_serve_one_listener should never complete",
|
||||
) # pragma: no cover
|
||||
@@ -0,0 +1,414 @@
|
||||
# "High-level" networking interface
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import TYPE_CHECKING, overload
|
||||
|
||||
import trio
|
||||
|
||||
from . import socket as tsocket
|
||||
from ._util import ConflictDetector, final
|
||||
from .abc import HalfCloseableStream, Listener
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from typing_extensions import Buffer
|
||||
|
||||
from ._socket import SocketType
|
||||
|
||||
# XX TODO: this number was picked arbitrarily. We should do experiments to
|
||||
# tune it. (Or make it dynamic -- one idea is to start small and increase it
|
||||
# if we observe single reads filling up the whole buffer, at least within some
|
||||
# limits.)
|
||||
DEFAULT_RECEIVE_SIZE = 65536
|
||||
|
||||
_closed_stream_errnos = {
|
||||
# Unix
|
||||
errno.EBADF,
|
||||
# Windows
|
||||
errno.ENOTSOCK,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _translate_socket_errors_to_stream_errors() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield
|
||||
except OSError as exc:
|
||||
if exc.errno in _closed_stream_errnos:
|
||||
raise trio.ClosedResourceError("this socket was already closed") from None
|
||||
else:
|
||||
raise trio.BrokenResourceError(f"socket connection broken: {exc}") from exc
|
||||
|
||||
|
||||
@final
|
||||
class SocketStream(HalfCloseableStream):
|
||||
"""An implementation of the :class:`trio.abc.HalfCloseableStream`
|
||||
interface based on a raw network socket.
|
||||
|
||||
Args:
|
||||
socket: The Trio socket object to wrap. Must have type ``SOCK_STREAM``,
|
||||
and be connected.
|
||||
|
||||
By default for TCP sockets, :class:`SocketStream` enables ``TCP_NODELAY``,
|
||||
and (on platforms where it's supported) enables ``TCP_NOTSENT_LOWAT`` with
|
||||
a reasonable buffer size (currently 16 KiB) – see `issue #72
|
||||
<https://github.com/python-trio/trio/issues/72>`__ for discussion. You can
|
||||
of course override these defaults by calling :meth:`setsockopt`.
|
||||
|
||||
Once a :class:`SocketStream` object is constructed, it implements the full
|
||||
:class:`trio.abc.HalfCloseableStream` interface. In addition, it provides
|
||||
a few extra features:
|
||||
|
||||
.. attribute:: socket
|
||||
|
||||
The Trio socket object that this stream wraps.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, socket: SocketType):
|
||||
if not isinstance(socket, tsocket.SocketType):
|
||||
raise TypeError("SocketStream requires a Trio socket object")
|
||||
if socket.type != tsocket.SOCK_STREAM:
|
||||
raise ValueError("SocketStream requires a SOCK_STREAM socket")
|
||||
|
||||
self.socket = socket
|
||||
self._send_conflict_detector = ConflictDetector(
|
||||
"another task is currently sending data on this SocketStream",
|
||||
)
|
||||
|
||||
# Socket defaults:
|
||||
|
||||
# Not supported on e.g. unix domain sockets
|
||||
with suppress(OSError):
|
||||
self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, True)
|
||||
|
||||
if hasattr(tsocket, "TCP_NOTSENT_LOWAT"):
|
||||
# 16 KiB is pretty arbitrary and could probably do with some
|
||||
# tuning. (Apple is also setting this by default in CFNetwork
|
||||
# apparently -- I'm curious what value they're using, though I
|
||||
# couldn't find it online trivially. CFNetwork-129.20 source
|
||||
# has no mentions of TCP_NOTSENT_LOWAT. This presentation says
|
||||
# "typically 8 kilobytes":
|
||||
# http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1
|
||||
# ). The theory is that you want it to be bandwidth *
|
||||
# rescheduling interval.
|
||||
with suppress(OSError):
|
||||
self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NOTSENT_LOWAT, 2**14)
|
||||
|
||||
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
|
||||
if self.socket.did_shutdown_SHUT_WR:
|
||||
raise trio.ClosedResourceError("can't send data after sending EOF")
|
||||
with self._send_conflict_detector:
|
||||
with _translate_socket_errors_to_stream_errors():
|
||||
with memoryview(data) as data:
|
||||
if not data:
|
||||
if self.socket.fileno() == -1:
|
||||
raise trio.ClosedResourceError("socket was already closed")
|
||||
await trio.lowlevel.checkpoint()
|
||||
return
|
||||
total_sent = 0
|
||||
while total_sent < len(data):
|
||||
with data[total_sent:] as remaining:
|
||||
sent = await self.socket.send(remaining)
|
||||
total_sent += sent
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
with self._send_conflict_detector:
|
||||
if self.socket.fileno() == -1:
|
||||
raise trio.ClosedResourceError
|
||||
with _translate_socket_errors_to_stream_errors():
|
||||
await self.socket.wait_writable()
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
with self._send_conflict_detector:
|
||||
await trio.lowlevel.checkpoint()
|
||||
# On macOS, calling shutdown a second time raises ENOTCONN, but
|
||||
# send_eof needs to be idempotent.
|
||||
if self.socket.did_shutdown_SHUT_WR:
|
||||
return
|
||||
with _translate_socket_errors_to_stream_errors():
|
||||
self.socket.shutdown(tsocket.SHUT_WR)
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes:
|
||||
if max_bytes is None:
|
||||
max_bytes = DEFAULT_RECEIVE_SIZE
|
||||
if max_bytes < 1:
|
||||
raise ValueError("max_bytes must be >= 1")
|
||||
with _translate_socket_errors_to_stream_errors():
|
||||
return await self.socket.recv(max_bytes)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.socket.close()
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
# __aenter__, __aexit__ inherited from HalfCloseableStream are OK
|
||||
|
||||
@overload
|
||||
def setsockopt(self, level: int, option: int, value: int | Buffer) -> None: ...
|
||||
|
||||
@overload
|
||||
def setsockopt(self, level: int, option: int, value: None, length: int) -> None: ...
|
||||
|
||||
def setsockopt(
|
||||
self,
|
||||
level: int,
|
||||
option: int,
|
||||
value: int | Buffer | None,
|
||||
length: int | None = None,
|
||||
) -> None:
|
||||
"""Set an option on the underlying socket.
|
||||
|
||||
See :meth:`socket.socket.setsockopt` for details.
|
||||
|
||||
"""
|
||||
if length is None:
|
||||
if value is None:
|
||||
raise TypeError(
|
||||
"invalid value for argument 'value', must not be None when specifying length",
|
||||
)
|
||||
return self.socket.setsockopt(level, option, value)
|
||||
if value is not None:
|
||||
raise TypeError(
|
||||
f"invalid value for argument 'value': {value!r}, must be None when specifying optlen",
|
||||
)
|
||||
return self.socket.setsockopt(level, option, value, length)
|
||||
|
||||
@overload
|
||||
def getsockopt(self, level: int, option: int) -> int: ...
|
||||
|
||||
@overload
|
||||
def getsockopt(self, level: int, option: int, buffersize: int) -> bytes: ...
|
||||
|
||||
def getsockopt(self, level: int, option: int, buffersize: int = 0) -> int | bytes:
|
||||
"""Check the current value of an option on the underlying socket.
|
||||
|
||||
See :meth:`socket.socket.getsockopt` for details.
|
||||
|
||||
"""
|
||||
# This is to work around
|
||||
# https://bitbucket.org/pypy/pypy/issues/2561
|
||||
# We should be able to drop it when the next PyPy3 beta is released.
|
||||
if buffersize == 0:
|
||||
return self.socket.getsockopt(level, option)
|
||||
else:
|
||||
return self.socket.getsockopt(level, option, buffersize)
|
||||
|
||||
|
||||
################################################################
|
||||
# SocketListener
|
||||
################################################################
|
||||
|
||||
# Accept error handling
|
||||
# =====================
|
||||
#
|
||||
# Literature review
|
||||
# -----------------
|
||||
#
|
||||
# Here's a list of all the possible errors that accept() can return, according
|
||||
# to the POSIX spec or the Linux, FreeBSD, macOS, and Windows docs:
|
||||
#
|
||||
# Can't happen with a Trio socket:
|
||||
# - EAGAIN/(WSA)EWOULDBLOCK
|
||||
# - EINTR
|
||||
# - WSANOTINITIALISED
|
||||
# - WSAEINPROGRESS: a blocking call is already in progress
|
||||
# - WSAEINTR: someone called WSACancelBlockingCall, but we don't make blocking
|
||||
# calls in the first place
|
||||
#
|
||||
# Something is wrong with our call:
|
||||
# - EBADF: not a file descriptor
|
||||
# - (WSA)EINVAL: socket isn't listening, or (Linux, BSD) bad flags
|
||||
# - (WSA)ENOTSOCK: not a socket
|
||||
# - (WSA)EOPNOTSUPP: this kind of socket doesn't support accept
|
||||
# - (Linux, FreeBSD, Windows) EFAULT: the sockaddr pointer points to readonly
|
||||
# memory
|
||||
#
|
||||
# Something is wrong with the environment:
|
||||
# - (WSA)EMFILE: this process hit its fd limit
|
||||
# - ENFILE: the system hit its fd limit
|
||||
# - (WSA)ENOBUFS, ENOMEM: unspecified memory problems
|
||||
#
|
||||
# Something is wrong with the connection we were going to accept. There's a
|
||||
# ton of variability between systems here:
|
||||
# - ECONNABORTED: documented everywhere, but apparently only the BSDs do this
|
||||
# (signals a connection was closed/reset before being accepted)
|
||||
# - EPROTO: unspecified protocol error
|
||||
# - (Linux) EPERM: firewall rule prevented connection
|
||||
# - (Linux) ENETDOWN, EPROTO, ENOPROTOOPT, EHOSTDOWN, ENONET, EHOSTUNREACH,
|
||||
# EOPNOTSUPP, ENETUNREACH, ENOSR, ESOCKTNOSUPPORT, EPROTONOSUPPORT,
|
||||
# ETIMEDOUT, ... or any other error that the socket could give, because
|
||||
# apparently if an error happens on a connection before it's accept()ed,
|
||||
# Linux will report that error from accept().
|
||||
# - (Windows) WSAECONNRESET, WSAENETDOWN
|
||||
#
|
||||
#
|
||||
# Code review
|
||||
# -----------
|
||||
#
|
||||
# What do other libraries do?
|
||||
#
|
||||
# Twisted on Unix or when using nonblocking I/O on Windows:
|
||||
# - ignores EPERM, with comment about Linux firewalls
|
||||
# - logs and ignores EMFILE, ENOBUFS, ENFILE, ENOMEM, ECONNABORTED
|
||||
# Comment notes that ECONNABORTED is a BSDism and that Linux returns the
|
||||
# socket before having it fail, and macOS just silently discards it.
|
||||
# - other errors are raised, which is logged + kills the socket
|
||||
# ref: src/twisted/internet/tcp.py, Port.doRead
|
||||
#
|
||||
# Twisted using IOCP on Windows:
|
||||
# - logs and ignores all errors
|
||||
# ref: src/twisted/internet/iocpreactor/tcp.py, Port.handleAccept
|
||||
#
|
||||
# Tornado:
|
||||
# - ignore ECONNABORTED (comments notes that it was observed on FreeBSD)
|
||||
# - everything else raised, but all this does (by default) is cause it to be
|
||||
# logged and then ignored
|
||||
# (ref: tornado/netutil.py, tornado/ioloop.py)
|
||||
#
|
||||
# libuv on Unix:
|
||||
# - ignores ECONNABORTED
|
||||
# - does a "trick" for EMFILE or ENFILE
|
||||
# - all other errors passed to the connection_cb to be handled
|
||||
# (ref: src/unix/stream.c:uv__server_io, uv__emfile_trick)
|
||||
#
|
||||
# libuv on Windows:
|
||||
# src/win/tcp.c:uv_tcp_queue_accept
|
||||
# this calls AcceptEx, and then arranges to call:
|
||||
# src/win/tcp.c:uv_process_tcp_accept_req
|
||||
# this gets the result from AcceptEx. If the original AcceptEx call failed,
|
||||
# then "we stop accepting connections and report this error to the
|
||||
# connection callback". I think this is for things like ENOTSOCK. If
|
||||
# AcceptEx successfully queues an overlapped operation, and then that
|
||||
# reports an error, it's just discarded.
|
||||
#
|
||||
# asyncio, selector mode:
|
||||
# - ignores EWOULDBLOCK, EINTR, ECONNABORTED
|
||||
# - on EMFILE, ENFILE, ENOBUFS, ENOMEM, logs an error and then disables the
|
||||
# listening loop for 1 second
|
||||
# - everything else raises, but then the event loop just logs and ignores it
|
||||
# (selector_events.py: BaseSelectorEventLoop._accept_connection)
|
||||
#
|
||||
#
|
||||
# What should we do?
|
||||
# ------------------
|
||||
#
|
||||
# When accept() returns an error, we can either ignore it or raise it.
|
||||
#
|
||||
# We have a long list of errors that should be ignored, and a long list of
|
||||
# errors that should be raised. The big question is what to do with an error
|
||||
# that isn't on either list. On Linux apparently you can get nearly arbitrary
|
||||
# errors from accept() and they should be ignored, because it just indicates a
|
||||
# socket that crashed before it began, and there isn't really anything to be
|
||||
# done about this, plus on other platforms you may not get any indication at
|
||||
# all, so programs have to tolerate not getting any indication too. OTOH if we
|
||||
# get an unexpected error then it could indicate something arbitrarily bad --
|
||||
# after all, it's unexpected.
|
||||
#
|
||||
# Given that we know that other libraries seem to be getting along fine with a
|
||||
# fairly minimal list of errors to ignore, I think we'll be OK if we write
|
||||
# down that list and then raise on everything else.
|
||||
#
|
||||
# The other question is what to do about the capacity problem errors: EMFILE,
|
||||
# ENFILE, ENOBUFS, ENOMEM. Just flat out ignoring these is clearly not optimal
|
||||
# -- at the very least you want to log them, and probably you want to take
|
||||
# some remedial action. And if we ignore them then it prevents higher levels
|
||||
# from doing anything clever with them. So we raise them.
|
||||
|
||||
_ignorable_accept_errno_names = [
|
||||
# Linux can do this when the a connection is denied by the firewall
|
||||
"EPERM",
|
||||
# BSDs with an early close/reset
|
||||
"ECONNABORTED",
|
||||
# All the other miscellany noted above -- may not happen in practice, but
|
||||
# whatever.
|
||||
"EPROTO",
|
||||
"ENETDOWN",
|
||||
"ENOPROTOOPT",
|
||||
"EHOSTDOWN",
|
||||
"ENONET",
|
||||
"EHOSTUNREACH",
|
||||
"EOPNOTSUPP",
|
||||
"ENETUNREACH",
|
||||
"ENOSR",
|
||||
"ESOCKTNOSUPPORT",
|
||||
"EPROTONOSUPPORT",
|
||||
"ETIMEDOUT",
|
||||
"ECONNRESET",
|
||||
]
|
||||
|
||||
# Not all errnos are defined on all platforms
|
||||
_ignorable_accept_errnos: set[int] = set()
|
||||
for name in _ignorable_accept_errno_names:
|
||||
with suppress(AttributeError):
|
||||
_ignorable_accept_errnos.add(getattr(errno, name))
|
||||
|
||||
|
||||
@final
|
||||
class SocketListener(Listener[SocketStream]):
|
||||
"""A :class:`~trio.abc.Listener` that uses a listening socket to accept
|
||||
incoming connections as :class:`SocketStream` objects.
|
||||
|
||||
Args:
|
||||
socket: The Trio socket object to wrap. Must have type ``SOCK_STREAM``,
|
||||
and be listening.
|
||||
|
||||
Note that the :class:`SocketListener` "takes ownership" of the given
|
||||
socket; closing the :class:`SocketListener` will also close the socket.
|
||||
|
||||
.. attribute:: socket
|
||||
|
||||
The Trio socket object that this stream wraps.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, socket: SocketType):
|
||||
if not isinstance(socket, tsocket.SocketType):
|
||||
raise TypeError("SocketListener requires a Trio socket object")
|
||||
if socket.type != tsocket.SOCK_STREAM:
|
||||
raise ValueError("SocketListener requires a SOCK_STREAM socket")
|
||||
try:
|
||||
listening = socket.getsockopt(tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN)
|
||||
except OSError:
|
||||
# SO_ACCEPTCONN fails on macOS; we just have to trust the user.
|
||||
pass
|
||||
else:
|
||||
if not listening:
|
||||
raise ValueError("SocketListener requires a listening socket")
|
||||
|
||||
self.socket = socket
|
||||
|
||||
async def accept(self) -> SocketStream:
|
||||
"""Accept an incoming connection.
|
||||
|
||||
Returns:
|
||||
:class:`SocketStream`
|
||||
|
||||
Raises:
|
||||
OSError: if the underlying call to ``accept`` raises an unexpected
|
||||
error.
|
||||
ClosedResourceError: if you already closed the socket.
|
||||
|
||||
This method handles routine errors like ``ECONNABORTED``, but passes
|
||||
other errors on to its caller. In particular, it does *not* make any
|
||||
special effort to handle resource exhaustion errors like ``EMFILE``,
|
||||
``ENFILE``, ``ENOBUFS``, ``ENOMEM``.
|
||||
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
sock, _ = await self.socket.accept()
|
||||
except OSError as exc:
|
||||
if exc.errno in _closed_stream_errnos:
|
||||
raise trio.ClosedResourceError from None
|
||||
if exc.errno not in _ignorable_accept_errnos:
|
||||
raise
|
||||
else:
|
||||
return SocketStream(sock)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close this listener and its underlying socket."""
|
||||
self.socket.close()
|
||||
await trio.lowlevel.checkpoint()
|
||||
@@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ssl
|
||||
from typing import TYPE_CHECKING, NoReturn, TypeVar
|
||||
|
||||
import trio
|
||||
|
||||
from ._highlevel_open_tcp_stream import DEFAULT_DELAY
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from ._highlevel_socket import SocketStream
|
||||
|
||||
|
||||
# It might have been nice to take a ssl_protocols= argument here to set up
|
||||
# NPN/ALPN, but to do this we have to mutate the context object, which is OK
|
||||
# if it's one we created, but not OK if it's one that was passed in... and
|
||||
# the one major protocol using NPN/ALPN is HTTP/2, which mandates that you use
|
||||
# a specially configured SSLContext anyway! I also thought maybe we could copy
|
||||
# the given SSLContext and then mutate the copy, but it's no good as SSLContext
|
||||
# objects can't be copied: https://bugs.python.org/issue33023.
|
||||
# So... let's punt on that for now. Hopefully we'll be getting a new Python
|
||||
# TLS API soon and can revisit this then.
|
||||
async def open_ssl_over_tcp_stream(
|
||||
host: str | bytes,
|
||||
port: int,
|
||||
*,
|
||||
https_compatible: bool = False,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
happy_eyeballs_delay: float | None = DEFAULT_DELAY,
|
||||
) -> trio.SSLStream[SocketStream]:
|
||||
"""Make a TLS-encrypted Connection to the given host and port over TCP.
|
||||
|
||||
This is a convenience wrapper that calls :func:`open_tcp_stream` and
|
||||
wraps the result in an :class:`~trio.SSLStream`.
|
||||
|
||||
This function does not perform the TLS handshake; you can do it
|
||||
manually by calling :meth:`~trio.SSLStream.do_handshake`, or else
|
||||
it will be performed automatically the first time you send or receive
|
||||
data.
|
||||
|
||||
Args:
|
||||
host (bytes or str): The host to connect to. We require the server
|
||||
to have a TLS certificate valid for this hostname.
|
||||
port (int): The port to connect to.
|
||||
https_compatible (bool): Set this to True if you're connecting to a web
|
||||
server. See :class:`~trio.SSLStream` for details. Default:
|
||||
False.
|
||||
ssl_context (:class:`~ssl.SSLContext` or None): The SSL context to
|
||||
use. If None (the default), :func:`ssl.create_default_context`
|
||||
will be called to create a context.
|
||||
happy_eyeballs_delay (float): See :func:`open_tcp_stream`.
|
||||
|
||||
Returns:
|
||||
trio.SSLStream: the encrypted connection to the server.
|
||||
|
||||
"""
|
||||
tcp_stream = await trio.open_tcp_stream(
|
||||
host,
|
||||
port,
|
||||
happy_eyeballs_delay=happy_eyeballs_delay,
|
||||
)
|
||||
if ssl_context is None:
|
||||
ssl_context = ssl.create_default_context()
|
||||
|
||||
if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
|
||||
ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
|
||||
|
||||
return trio.SSLStream(
|
||||
tcp_stream,
|
||||
ssl_context,
|
||||
server_hostname=host,
|
||||
https_compatible=https_compatible,
|
||||
)
|
||||
|
||||
|
||||
async def open_ssl_over_tcp_listeners(
|
||||
port: int,
|
||||
ssl_context: ssl.SSLContext,
|
||||
*,
|
||||
host: str | bytes | None = None,
|
||||
https_compatible: bool = False,
|
||||
backlog: int | None = None,
|
||||
) -> list[trio.SSLListener[SocketStream]]:
|
||||
"""Start listening for SSL/TLS-encrypted TCP connections to the given port.
|
||||
|
||||
Args:
|
||||
port (int): The port to listen on. See :func:`open_tcp_listeners`.
|
||||
ssl_context (~ssl.SSLContext): The SSL context to use for all incoming
|
||||
connections.
|
||||
host (str, bytes, or None): The address to bind to; use ``None`` to bind
|
||||
to the wildcard address. See :func:`open_tcp_listeners`.
|
||||
https_compatible (bool): See :class:`~trio.SSLStream` for details.
|
||||
backlog (int or None): See :func:`open_tcp_listeners` for details.
|
||||
|
||||
"""
|
||||
tcp_listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog)
|
||||
ssl_listeners = [
|
||||
trio.SSLListener(tcp_listener, ssl_context, https_compatible=https_compatible)
|
||||
for tcp_listener in tcp_listeners
|
||||
]
|
||||
return ssl_listeners
|
||||
|
||||
|
||||
async def serve_ssl_over_tcp(
|
||||
handler: Callable[[trio.SSLStream[SocketStream]], Awaitable[object]],
|
||||
port: int,
|
||||
ssl_context: ssl.SSLContext,
|
||||
*,
|
||||
host: str | bytes | None = None,
|
||||
https_compatible: bool = False,
|
||||
backlog: int | None = None,
|
||||
handler_nursery: trio.Nursery | None = None,
|
||||
task_status: trio.TaskStatus[
|
||||
list[trio.SSLListener[SocketStream]]
|
||||
] = trio.TASK_STATUS_IGNORED,
|
||||
) -> NoReturn:
|
||||
"""Listen for incoming TCP connections, and for each one start a task
|
||||
running ``handler(stream)``.
|
||||
|
||||
This is a thin convenience wrapper around
|
||||
:func:`open_ssl_over_tcp_listeners` and :func:`serve_listeners` – see them
|
||||
for full details.
|
||||
|
||||
.. warning::
|
||||
|
||||
If ``handler`` raises an exception, then this function doesn't do
|
||||
anything special to catch it – so by default the exception will
|
||||
propagate out and crash your server. If you don't want this, then catch
|
||||
exceptions inside your ``handler``, or use a ``handler_nursery`` object
|
||||
that responds to exceptions in some other way.
|
||||
|
||||
When used with ``nursery.start`` you get back the newly opened listeners.
|
||||
See the documentation for :func:`serve_tcp` for an example where this is
|
||||
useful.
|
||||
|
||||
Args:
|
||||
handler: The handler to start for each incoming connection. Passed to
|
||||
:func:`serve_listeners`.
|
||||
|
||||
port (int): The port to listen on. Use 0 to let the kernel pick
|
||||
an open port. Ultimately passed to :func:`open_tcp_listeners`.
|
||||
|
||||
ssl_context (~ssl.SSLContext): The SSL context to use for all incoming
|
||||
connections. Passed to :func:`open_ssl_over_tcp_listeners`.
|
||||
|
||||
host (str, bytes, or None): The address to bind to; use ``None`` to bind
|
||||
to the wildcard address. Ultimately passed to
|
||||
:func:`open_tcp_listeners`.
|
||||
|
||||
https_compatible (bool): Set this to True if you want to use
|
||||
"HTTPS-style" TLS. See :class:`~trio.SSLStream` for details.
|
||||
|
||||
backlog (int or None): See :class:`~trio.SSLStream` for details.
|
||||
|
||||
handler_nursery: The nursery to start handlers in, or None to use an
|
||||
internal nursery. Passed to :func:`serve_listeners`.
|
||||
|
||||
task_status: This function can be used with ``nursery.start``.
|
||||
|
||||
Returns:
|
||||
This function only returns when cancelled.
|
||||
|
||||
"""
|
||||
listeners = await trio.open_ssl_over_tcp_listeners(
|
||||
port,
|
||||
ssl_context,
|
||||
host=host,
|
||||
https_compatible=https_compatible,
|
||||
backlog=backlog,
|
||||
)
|
||||
await trio.serve_listeners(
|
||||
handler,
|
||||
listeners,
|
||||
handler_nursery=handler_nursery,
|
||||
task_status=task_status,
|
||||
)
|
||||
@@ -0,0 +1,264 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
from functools import partial, update_wrapper
|
||||
from inspect import cleandoc
|
||||
from typing import IO, TYPE_CHECKING, Any, BinaryIO, ClassVar, TypeVar, overload
|
||||
|
||||
from trio._file_io import AsyncIOWrapper, wrap_file
|
||||
from trio._util import final
|
||||
from trio.to_thread import run_sync
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from io import BufferedRandom, BufferedReader, BufferedWriter, FileIO, TextIOWrapper
|
||||
|
||||
from _typeshed import (
|
||||
OpenBinaryMode,
|
||||
OpenBinaryModeReading,
|
||||
OpenBinaryModeUpdating,
|
||||
OpenBinaryModeWriting,
|
||||
OpenTextMode,
|
||||
)
|
||||
from typing_extensions import Concatenate, Literal, ParamSpec, Self
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
PathT = TypeVar("PathT", bound="Path")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _wraps_async(
|
||||
wrapped: Callable[..., Any],
|
||||
) -> Callable[[Callable[P, T]], Callable[P, Awaitable[T]]]:
|
||||
def decorator(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]:
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
return await run_sync(partial(fn, *args, **kwargs))
|
||||
|
||||
update_wrapper(wrapper, wrapped)
|
||||
if wrapped.__doc__:
|
||||
wrapper.__doc__ = (
|
||||
f"Like :meth:`~{wrapped.__module__}.{wrapped.__qualname__}`, but async.\n"
|
||||
f"\n"
|
||||
f"{cleandoc(wrapped.__doc__)}\n"
|
||||
)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _wrap_method(
|
||||
fn: Callable[Concatenate[pathlib.Path, P], T],
|
||||
) -> Callable[Concatenate[Path, P], Awaitable[T]]:
|
||||
@_wraps_async(fn)
|
||||
def wrapper(self: Path, /, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
return fn(self._wrapped_cls(self), *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _wrap_method_path(
|
||||
fn: Callable[Concatenate[pathlib.Path, P], pathlib.Path],
|
||||
) -> Callable[Concatenate[PathT, P], Awaitable[PathT]]:
|
||||
@_wraps_async(fn)
|
||||
def wrapper(self: PathT, /, *args: P.args, **kwargs: P.kwargs) -> PathT:
|
||||
return self.__class__(fn(self._wrapped_cls(self), *args, **kwargs))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _wrap_method_path_iterable(
|
||||
fn: Callable[Concatenate[pathlib.Path, P], Iterable[pathlib.Path]],
|
||||
) -> Callable[Concatenate[PathT, P], Awaitable[Iterable[PathT]]]:
|
||||
@_wraps_async(fn)
|
||||
def wrapper(self: PathT, /, *args: P.args, **kwargs: P.kwargs) -> Iterable[PathT]:
|
||||
return map(self.__class__, [*fn(self._wrapped_cls(self), *args, **kwargs)])
|
||||
|
||||
if wrapper.__doc__:
|
||||
wrapper.__doc__ += (
|
||||
f"\n"
|
||||
f"This is an async method that returns a synchronous iterator, so you\n"
|
||||
f"use it like:\n"
|
||||
f"\n"
|
||||
f".. code:: python\n"
|
||||
f"\n"
|
||||
f" for subpath in await mypath.{fn.__name__}():\n"
|
||||
f" ...\n"
|
||||
f"\n"
|
||||
f".. note::\n"
|
||||
f"\n"
|
||||
f" The iterator is loaded into memory immediately during the initial\n"
|
||||
f" call (see `issue #501\n"
|
||||
f" <https://github.com/python-trio/trio/issues/501>`__ for discussion).\n"
|
||||
)
|
||||
return wrapper
|
||||
|
||||
|
||||
class Path(pathlib.PurePath):
|
||||
"""An async :class:`pathlib.Path` that executes blocking methods in :meth:`trio.to_thread.run_sync`.
|
||||
|
||||
Instantiating :class:`Path` returns a concrete platform-specific subclass, one of :class:`PosixPath` or
|
||||
:class:`WindowsPath`.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
_wrapped_cls: ClassVar[type[pathlib.Path]]
|
||||
|
||||
def __new__(cls, *args: str | os.PathLike[str]) -> Self:
|
||||
if cls is Path:
|
||||
cls = WindowsPath if os.name == "nt" else PosixPath # type: ignore[assignment]
|
||||
return super().__new__(cls, *args)
|
||||
|
||||
@classmethod
|
||||
@_wraps_async(pathlib.Path.cwd)
|
||||
def cwd(cls) -> Self:
|
||||
return cls(pathlib.Path.cwd())
|
||||
|
||||
@classmethod
|
||||
@_wraps_async(pathlib.Path.home)
|
||||
def home(cls) -> Self:
|
||||
return cls(pathlib.Path.home())
|
||||
|
||||
@overload
|
||||
async def open(
|
||||
self,
|
||||
mode: OpenTextMode = "r",
|
||||
buffering: int = -1,
|
||||
encoding: str | None = None,
|
||||
errors: str | None = None,
|
||||
newline: str | None = None,
|
||||
) -> AsyncIOWrapper[TextIOWrapper]: ...
|
||||
|
||||
@overload
|
||||
async def open(
|
||||
self,
|
||||
mode: OpenBinaryMode,
|
||||
buffering: Literal[0],
|
||||
encoding: None = None,
|
||||
errors: None = None,
|
||||
newline: None = None,
|
||||
) -> AsyncIOWrapper[FileIO]: ...
|
||||
|
||||
@overload
|
||||
async def open(
|
||||
self,
|
||||
mode: OpenBinaryModeUpdating,
|
||||
buffering: Literal[-1, 1] = -1,
|
||||
encoding: None = None,
|
||||
errors: None = None,
|
||||
newline: None = None,
|
||||
) -> AsyncIOWrapper[BufferedRandom]: ...
|
||||
|
||||
@overload
|
||||
async def open(
|
||||
self,
|
||||
mode: OpenBinaryModeWriting,
|
||||
buffering: Literal[-1, 1] = -1,
|
||||
encoding: None = None,
|
||||
errors: None = None,
|
||||
newline: None = None,
|
||||
) -> AsyncIOWrapper[BufferedWriter]: ...
|
||||
|
||||
@overload
|
||||
async def open(
|
||||
self,
|
||||
mode: OpenBinaryModeReading,
|
||||
buffering: Literal[-1, 1] = -1,
|
||||
encoding: None = None,
|
||||
errors: None = None,
|
||||
newline: None = None,
|
||||
) -> AsyncIOWrapper[BufferedReader]: ...
|
||||
|
||||
@overload
|
||||
async def open(
|
||||
self,
|
||||
mode: OpenBinaryMode,
|
||||
buffering: int = -1,
|
||||
encoding: None = None,
|
||||
errors: None = None,
|
||||
newline: None = None,
|
||||
) -> AsyncIOWrapper[BinaryIO]: ...
|
||||
|
||||
@overload
|
||||
async def open( # type: ignore[misc] # Any usage matches builtins.open().
|
||||
self,
|
||||
mode: str,
|
||||
buffering: int = -1,
|
||||
encoding: str | None = None,
|
||||
errors: str | None = None,
|
||||
newline: str | None = None,
|
||||
) -> AsyncIOWrapper[IO[Any]]: ...
|
||||
|
||||
@_wraps_async(pathlib.Path.open) # type: ignore[misc] # Overload return mismatch.
|
||||
def open(self, *args: Any, **kwargs: Any) -> AsyncIOWrapper[IO[Any]]:
|
||||
return wrap_file(self._wrapped_cls(self).open(*args, **kwargs))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"trio.Path({str(self)!r})"
|
||||
|
||||
stat = _wrap_method(pathlib.Path.stat)
|
||||
chmod = _wrap_method(pathlib.Path.chmod)
|
||||
exists = _wrap_method(pathlib.Path.exists)
|
||||
glob = _wrap_method_path_iterable(pathlib.Path.glob)
|
||||
rglob = _wrap_method_path_iterable(pathlib.Path.rglob)
|
||||
is_dir = _wrap_method(pathlib.Path.is_dir)
|
||||
is_file = _wrap_method(pathlib.Path.is_file)
|
||||
is_symlink = _wrap_method(pathlib.Path.is_symlink)
|
||||
is_socket = _wrap_method(pathlib.Path.is_socket)
|
||||
is_fifo = _wrap_method(pathlib.Path.is_fifo)
|
||||
is_block_device = _wrap_method(pathlib.Path.is_block_device)
|
||||
is_char_device = _wrap_method(pathlib.Path.is_char_device)
|
||||
if sys.version_info >= (3, 12):
|
||||
is_junction = _wrap_method(pathlib.Path.is_junction)
|
||||
iterdir = _wrap_method_path_iterable(pathlib.Path.iterdir)
|
||||
lchmod = _wrap_method(pathlib.Path.lchmod)
|
||||
lstat = _wrap_method(pathlib.Path.lstat)
|
||||
mkdir = _wrap_method(pathlib.Path.mkdir)
|
||||
if sys.platform != "win32":
|
||||
owner = _wrap_method(pathlib.Path.owner)
|
||||
group = _wrap_method(pathlib.Path.group)
|
||||
if sys.platform != "win32" or sys.version_info >= (3, 12):
|
||||
is_mount = _wrap_method(pathlib.Path.is_mount)
|
||||
if sys.version_info >= (3, 9):
|
||||
readlink = _wrap_method_path(pathlib.Path.readlink)
|
||||
rename = _wrap_method_path(pathlib.Path.rename)
|
||||
replace = _wrap_method_path(pathlib.Path.replace)
|
||||
resolve = _wrap_method_path(pathlib.Path.resolve)
|
||||
rmdir = _wrap_method(pathlib.Path.rmdir)
|
||||
symlink_to = _wrap_method(pathlib.Path.symlink_to)
|
||||
if sys.version_info >= (3, 10):
|
||||
hardlink_to = _wrap_method(pathlib.Path.hardlink_to)
|
||||
touch = _wrap_method(pathlib.Path.touch)
|
||||
unlink = _wrap_method(pathlib.Path.unlink)
|
||||
absolute = _wrap_method_path(pathlib.Path.absolute)
|
||||
expanduser = _wrap_method_path(pathlib.Path.expanduser)
|
||||
read_bytes = _wrap_method(pathlib.Path.read_bytes)
|
||||
read_text = _wrap_method(pathlib.Path.read_text)
|
||||
samefile = _wrap_method(pathlib.Path.samefile)
|
||||
write_bytes = _wrap_method(pathlib.Path.write_bytes)
|
||||
write_text = _wrap_method(pathlib.Path.write_text)
|
||||
if sys.version_info < (3, 12):
|
||||
link_to = _wrap_method(pathlib.Path.link_to)
|
||||
if sys.version_info >= (3, 13):
|
||||
full_match = _wrap_method(pathlib.Path.full_match)
|
||||
|
||||
|
||||
@final
|
||||
class PosixPath(Path, pathlib.PurePosixPath):
|
||||
"""An async :class:`pathlib.PosixPath` that executes blocking methods in :meth:`trio.to_thread.run_sync`."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
_wrapped_cls: ClassVar[type[pathlib.Path]] = pathlib.PosixPath
|
||||
|
||||
|
||||
@final
|
||||
class WindowsPath(Path, pathlib.PureWindowsPath):
|
||||
"""An async :class:`pathlib.WindowsPath` that executes blocking methods in :meth:`trio.to_thread.run_sync`."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
_wrapped_cls: ClassVar[type[pathlib.Path]] = pathlib.WindowsPath
|
||||
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
import warnings
|
||||
from code import InteractiveConsole
|
||||
|
||||
import outcome
|
||||
|
||||
import trio
|
||||
import trio.lowlevel
|
||||
from trio._util import final
|
||||
|
||||
|
||||
@final
|
||||
class TrioInteractiveConsole(InteractiveConsole):
|
||||
# code.InteractiveInterpreter defines locals as Mapping[str, Any]
|
||||
# but when we pass this to FunctionType it expects a dict. So
|
||||
# we make the type more specific on our subclass
|
||||
locals: dict[str, object]
|
||||
|
||||
def __init__(self, repl_locals: dict[str, object] | None = None):
|
||||
super().__init__(locals=repl_locals)
|
||||
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
|
||||
|
||||
def runcode(self, code: types.CodeType) -> None:
|
||||
func = types.FunctionType(code, self.locals)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
result = trio.from_thread.run(outcome.acapture, func)
|
||||
else:
|
||||
result = trio.from_thread.run_sync(outcome.capture, func)
|
||||
if isinstance(result, outcome.Error):
|
||||
# If it is SystemExit, quit the repl. Otherwise, print the traceback.
|
||||
# If there is a SystemExit inside a BaseExceptionGroup, it probably isn't
|
||||
# the user trying to quit the repl, but rather an error in the code. So, we
|
||||
# don't try to inspect groups for SystemExit. Instead, we just print and
|
||||
# return to the REPL.
|
||||
if isinstance(result.error, SystemExit):
|
||||
raise result.error
|
||||
else:
|
||||
# Inline our own version of self.showtraceback that can use
|
||||
# outcome.Error.error directly to print clean tracebacks.
|
||||
# This also means overriding self.showtraceback does nothing.
|
||||
sys.last_type, sys.last_value = type(result.error), result.error
|
||||
sys.last_traceback = result.error.__traceback__
|
||||
# see https://docs.python.org/3/library/sys.html#sys.last_exc
|
||||
if sys.version_info >= (3, 12):
|
||||
sys.last_exc = result.error
|
||||
|
||||
# We always use sys.excepthook, unlike other implementations.
|
||||
# This means that overriding self.write also does nothing to tbs.
|
||||
sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback)
|
||||
|
||||
|
||||
async def run_repl(console: TrioInteractiveConsole) -> None:
|
||||
banner = (
|
||||
f"trio REPL {sys.version} on {sys.platform}\n"
|
||||
f'Use "await" directly instead of "trio.run()".\n'
|
||||
f'Type "help", "copyright", "credits" or "license" '
|
||||
f"for more information.\n"
|
||||
f'{getattr(sys, "ps1", ">>> ")}import trio'
|
||||
)
|
||||
try:
|
||||
await trio.to_thread.run_sync(console.interact, banner)
|
||||
finally:
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"^coroutine .* was never awaited$",
|
||||
category=RuntimeWarning,
|
||||
)
|
||||
|
||||
|
||||
def main(original_locals: dict[str, object]) -> None:
|
||||
with contextlib.suppress(ImportError):
|
||||
import readline # noqa: F401
|
||||
|
||||
repl_locals: dict[str, object] = {"trio": trio}
|
||||
for key in {
|
||||
"__name__",
|
||||
"__package__",
|
||||
"__loader__",
|
||||
"__spec__",
|
||||
"__builtins__",
|
||||
"__file__",
|
||||
}:
|
||||
repl_locals[key] = original_locals[key]
|
||||
|
||||
console = TrioInteractiveConsole(repl_locals)
|
||||
trio.run(run_repl, console)
|
||||
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import signal
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import trio
|
||||
|
||||
from ._util import ConflictDetector, is_main_thread, signal_raise
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Callable, Generator, Iterable
|
||||
from types import FrameType
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
# Discussion of signal handling strategies:
|
||||
#
|
||||
# - On Windows signals barely exist. There are no options; signal handlers are
|
||||
# the only available API.
|
||||
#
|
||||
# - On Linux signalfd is arguably the natural way. Semantics: signalfd acts as
|
||||
# an *alternative* signal delivery mechanism. The way you use it is to mask
|
||||
# out the relevant signals process-wide (so that they don't get delivered
|
||||
# the normal way), and then when you read from signalfd that actually counts
|
||||
# as delivering it (despite the mask). The problem with this is that we
|
||||
# don't have any reliable way to mask out signals process-wide -- the only
|
||||
# way to do that in Python is to call pthread_sigmask from the main thread
|
||||
# *before starting any other threads*, and as a library we can't really
|
||||
# impose that, and the failure mode is annoying (signals get delivered via
|
||||
# signal handlers whether we want them to or not).
|
||||
#
|
||||
# - on macOS/*BSD, kqueue is the natural way. Semantics: kqueue acts as an
|
||||
# *extra* signal delivery mechanism. Signals are delivered the normal
|
||||
# way, *and* are delivered to kqueue. So you want to set them to SIG_IGN so
|
||||
# that they don't end up pending forever (I guess?). I can't find any actual
|
||||
# docs on how masking and EVFILT_SIGNAL interact. I did see someone note
|
||||
# that if a signal is pending when the kqueue filter is added then you
|
||||
# *don't* get notified of that, which makes sense. But still, we have to
|
||||
# manipulate signal state (e.g. setting SIG_IGN) which as far as Python is
|
||||
# concerned means we have to do this from the main thread.
|
||||
#
|
||||
# So in summary, there don't seem to be any compelling advantages to using the
|
||||
# platform-native signal notification systems; they're kinda nice, but it's
|
||||
# simpler to implement the naive signal-handler-based system once and be
|
||||
# done. (The big advantage would be if there were a reliable way to monitor
|
||||
# for SIGCHLD from outside the main thread and without interfering with other
|
||||
# libraries that also want to monitor for SIGCHLD. But there isn't. I guess
|
||||
# kqueue might give us that, but in kqueue we don't need it, because kqueue
|
||||
# can directly monitor for child process state changes.)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _signal_handler(
|
||||
signals: Iterable[int],
|
||||
handler: Callable[[int, FrameType | None], object] | int | signal.Handlers | None,
|
||||
) -> Generator[None, None, None]:
|
||||
original_handlers = {}
|
||||
try:
|
||||
for signum in set(signals):
|
||||
original_handlers[signum] = signal.signal(signum, handler)
|
||||
yield
|
||||
finally:
|
||||
for signum, original_handler in original_handlers.items():
|
||||
signal.signal(signum, original_handler)
|
||||
|
||||
|
||||
class SignalReceiver:
|
||||
def __init__(self) -> None:
|
||||
# {signal num: None}
|
||||
self._pending: OrderedDict[int, None] = OrderedDict()
|
||||
self._lot = trio.lowlevel.ParkingLot()
|
||||
self._conflict_detector = ConflictDetector(
|
||||
"only one task can iterate on a signal receiver at a time",
|
||||
)
|
||||
self._closed = False
|
||||
|
||||
def _add(self, signum: int) -> None:
|
||||
if self._closed:
|
||||
signal_raise(signum)
|
||||
else:
|
||||
self._pending[signum] = None
|
||||
self._lot.unpark()
|
||||
|
||||
def _redeliver_remaining(self) -> None:
|
||||
# First make sure that any signals still in the delivery pipeline will
|
||||
# get redelivered
|
||||
self._closed = True
|
||||
|
||||
# And then redeliver any that are sitting in pending. This is done
|
||||
# using a weird recursive construct to make sure we process everything
|
||||
# even if some of the handlers raise exceptions.
|
||||
def deliver_next() -> None:
|
||||
if self._pending:
|
||||
signum, _ = self._pending.popitem(last=False)
|
||||
try:
|
||||
signal_raise(signum)
|
||||
finally:
|
||||
deliver_next()
|
||||
|
||||
deliver_next()
|
||||
|
||||
def __aiter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> int:
|
||||
if self._closed:
|
||||
raise RuntimeError("open_signal_receiver block already exited")
|
||||
# In principle it would be possible to support multiple concurrent
|
||||
# calls to __anext__, but doing it without race conditions is quite
|
||||
# tricky, and there doesn't seem to be any point in trying.
|
||||
with self._conflict_detector:
|
||||
if not self._pending:
|
||||
await self._lot.park()
|
||||
else:
|
||||
await trio.lowlevel.checkpoint()
|
||||
signum, _ = self._pending.popitem(last=False)
|
||||
return signum
|
||||
|
||||
|
||||
def get_pending_signal_count(rec: AsyncIterator[int]) -> int:
|
||||
"""Helper for tests, not public or otherwise used."""
|
||||
# open_signal_receiver() always produces SignalReceiver, this should not fail.
|
||||
assert isinstance(rec, SignalReceiver)
|
||||
return len(rec._pending)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def open_signal_receiver(
|
||||
*signals: signal.Signals | int,
|
||||
) -> Generator[AsyncIterator[int], None, None]:
|
||||
"""A context manager for catching signals.
|
||||
|
||||
Entering this context manager starts listening for the given signals and
|
||||
returns an async iterator; exiting the context manager stops listening.
|
||||
|
||||
The async iterator blocks until a signal arrives, and then yields it.
|
||||
|
||||
Note that if you leave the ``with`` block while the iterator has
|
||||
unextracted signals still pending inside it, then they will be
|
||||
re-delivered using Python's regular signal handling logic. This avoids a
|
||||
race condition when signals arrives just before we exit the ``with``
|
||||
block.
|
||||
|
||||
Args:
|
||||
signals: the signals to listen for.
|
||||
|
||||
Raises:
|
||||
TypeError: if no signals were provided.
|
||||
|
||||
RuntimeError: if you try to use this anywhere except Python's main
|
||||
thread. (This is a Python limitation.)
|
||||
|
||||
Example:
|
||||
|
||||
A common convention for Unix daemons is that they should reload their
|
||||
configuration when they receive a ``SIGHUP``. Here's a sketch of what
|
||||
that might look like using :func:`open_signal_receiver`::
|
||||
|
||||
with trio.open_signal_receiver(signal.SIGHUP) as signal_aiter:
|
||||
async for signum in signal_aiter:
|
||||
assert signum == signal.SIGHUP
|
||||
reload_configuration()
|
||||
|
||||
"""
|
||||
if not signals:
|
||||
raise TypeError("No signals were provided")
|
||||
|
||||
if not is_main_thread():
|
||||
raise RuntimeError(
|
||||
"Sorry, open_signal_receiver is only possible when running in "
|
||||
"Python interpreter's main thread",
|
||||
)
|
||||
token = trio.lowlevel.current_trio_token()
|
||||
queue = SignalReceiver()
|
||||
|
||||
def handler(signum: int, frame: FrameType | None) -> None:
|
||||
token.run_sync_soon(queue._add, signum, idempotent=True)
|
||||
|
||||
try:
|
||||
with _signal_handler(signals, handler):
|
||||
yield queue
|
||||
finally:
|
||||
queue._redeliver_remaining()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,951 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import operator as _operator
|
||||
import ssl as _stdlib_ssl
|
||||
from enum import Enum as _Enum
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Final as TFinal, Generic, TypeVar
|
||||
|
||||
import trio
|
||||
|
||||
from . import _sync
|
||||
from ._highlevel_generic import aclose_forcefully
|
||||
from ._util import ConflictDetector, final
|
||||
from .abc import Listener, Stream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
# General theory of operation:
|
||||
#
|
||||
# We implement an API that closely mirrors the stdlib ssl module's blocking
|
||||
# API, and we do it using the stdlib ssl module's non-blocking in-memory API.
|
||||
# The stdlib non-blocking in-memory API is barely documented, and acts as a
|
||||
# thin wrapper around openssl, whose documentation also leaves something to be
|
||||
# desired. So here's the main things you need to know to understand the code
|
||||
# in this file:
|
||||
#
|
||||
# We use an ssl.SSLObject, which exposes the four main I/O operations:
|
||||
#
|
||||
# - do_handshake: performs the initial handshake. Must be called once at the
|
||||
# beginning of each connection; is a no-op once it's completed once.
|
||||
#
|
||||
# - write: takes some unencrypted data and attempts to send it to the remote
|
||||
# peer.
|
||||
|
||||
# - read: attempts to decrypt and return some data from the remote peer.
|
||||
#
|
||||
# - unwrap: this is weirdly named; maybe it helps to realize that the thing it
|
||||
# wraps is called SSL_shutdown. It sends a cryptographically signed message
|
||||
# saying "I'm closing this connection now", and then waits to receive the
|
||||
# same from the remote peer (unless we already received one, in which case
|
||||
# it returns immediately).
|
||||
#
|
||||
# All of these operations read and write from some in-memory buffers called
|
||||
# "BIOs", which are an opaque OpenSSL-specific object that's basically
|
||||
# semantically equivalent to a Python bytearray. When they want to send some
|
||||
# bytes to the remote peer, they append them to the outgoing BIO, and when
|
||||
# they want to receive some bytes from the remote peer, they try to pull them
|
||||
# out of the incoming BIO. "Sending" always succeeds, because the outgoing BIO
|
||||
# can always be extended to hold more data. "Receiving" acts sort of like a
|
||||
# non-blocking socket: it might manage to get some data immediately, or it
|
||||
# might fail and need to be tried again later. We can also directly add or
|
||||
# remove data from the BIOs whenever we want.
|
||||
#
|
||||
# Now the problem is that while these I/O operations are opaque atomic
|
||||
# operations from the point of view of us calling them, under the hood they
|
||||
# might require some arbitrary sequence of sends and receives from the remote
|
||||
# peer. This is particularly true for do_handshake, which generally requires a
|
||||
# few round trips, but it's also true for write and read, due to an evil thing
|
||||
# called "renegotiation".
|
||||
#
|
||||
# Renegotiation is the process by which one of the peers might arbitrarily
|
||||
# decide to redo the handshake at any time. Did I mention it's evil? It's
|
||||
# pretty evil, and almost universally hated. The HTTP/2 spec forbids the use
|
||||
# of TLS renegotiation for HTTP/2 connections. TLS 1.3 removes it from the
|
||||
# protocol entirely. It's impossible to trigger a renegotiation if using
|
||||
# Python's ssl module. OpenSSL's renegotiation support is pretty buggy [1].
|
||||
# Nonetheless, it does get used in real life, mostly in two cases:
|
||||
#
|
||||
# 1) Normally in TLS 1.2 and below, when the client side of a connection wants
|
||||
# to present a certificate to prove their identity, that certificate gets sent
|
||||
# in plaintext. This is bad, because it means that anyone eavesdropping can
|
||||
# see who's connecting – it's like sending your username in plain text. Not as
|
||||
# bad as sending your password in plain text, but still, pretty bad. However,
|
||||
# renegotiations *are* encrypted. So as a workaround, it's not uncommon for
|
||||
# systems that want to use client certificates to first do an anonymous
|
||||
# handshake, and then to turn around and do a second handshake (=
|
||||
# renegotiation) and this time ask for a client cert. Or sometimes this is
|
||||
# done on a case-by-case basis, e.g. a web server might accept a connection,
|
||||
# read the request, and then once it sees the page you're asking for it might
|
||||
# stop and ask you for a certificate.
|
||||
#
|
||||
# 2) In principle the same TLS connection can be used for an arbitrarily long
|
||||
# time, and might transmit arbitrarily large amounts of data. But this creates
|
||||
# a cryptographic problem: an attacker who has access to arbitrarily large
|
||||
# amounts of data that's all encrypted using the same key may eventually be
|
||||
# able to use this to figure out the key. Is this a real practical problem? I
|
||||
# have no idea, I'm not a cryptographer. In any case, some people worry that
|
||||
# it's a problem, so their TLS libraries are designed to automatically trigger
|
||||
# a renegotiation every once in a while on some sort of timer.
|
||||
#
|
||||
# The end result is that you might be going along, minding your own business,
|
||||
# and then *bam*! a wild renegotiation appears! And you just have to cope.
|
||||
#
|
||||
# The reason that coping with renegotiations is difficult is that some
|
||||
# unassuming "read" or "write" call might find itself unable to progress until
|
||||
# it does a handshake, which remember is a process with multiple round
|
||||
# trips. So read might have to send data, and write might have to receive
|
||||
# data, and this might happen multiple times. And some of those attempts might
|
||||
# fail because there isn't any data yet, and need to be retried. Managing all
|
||||
# this is pretty complicated.
|
||||
#
|
||||
# Here's how openssl (and thus the stdlib ssl module) handle this. All of the
|
||||
# I/O operations above follow the same rules. When you call one of them:
|
||||
#
|
||||
# - it might write some data to the outgoing BIO
|
||||
# - it might read some data from the incoming BIO
|
||||
# - it might raise SSLWantReadError if it can't complete without reading more
|
||||
# data from the incoming BIO. This is important: the "read" in ReadError
|
||||
# refers to reading from the *underlying* stream.
|
||||
# - (and in principle it might raise SSLWantWriteError too, but that never
|
||||
# happens when using memory BIOs, so never mind)
|
||||
#
|
||||
# If it doesn't raise an error, then the operation completed successfully
|
||||
# (though we still need to take any outgoing data out of the memory buffer and
|
||||
# put it onto the wire). If it *does* raise an error, then we need to retry
|
||||
# *exactly that method call* later – in particular, if a 'write' failed, we
|
||||
# need to try again later *with the same data*, because openssl might have
|
||||
# already committed some of the initial parts of our data to its output even
|
||||
# though it didn't tell us that, and has remembered that the next time we call
|
||||
# write it needs to skip the first 1024 bytes or whatever it is. (Well,
|
||||
# technically, we're actually allowed to call 'write' again with a data buffer
|
||||
# which is the same as our old one PLUS some extra stuff added onto the end,
|
||||
# but in Trio that never comes up so never mind.)
|
||||
#
|
||||
# There are some people online who claim that once you've gotten a Want*Error
|
||||
# then the *very next call* you make to openssl *must* be the same as the
|
||||
# previous one. I'm pretty sure those people are wrong. In particular, it's
|
||||
# okay to call write, get a WantReadError, and then call read a few times;
|
||||
# it's just that *the next time you call write*, it has to be with the same
|
||||
# data.
|
||||
#
|
||||
# One final wrinkle: we want our SSLStream to support full-duplex operation,
|
||||
# i.e. it should be possible for one task to be calling send_all while another
|
||||
# task is calling receive_some. But renegotiation makes this a big hassle, because
|
||||
# even if SSLStream's restricts themselves to one task calling send_all and one
|
||||
# task calling receive_some, those two tasks might end up both wanting to call
|
||||
# send_all, or both to call receive_some at the same time *on the underlying
|
||||
# stream*. So we have to do some careful locking to hide this problem from our
|
||||
# users.
|
||||
#
|
||||
# (Renegotiation is evil.)
|
||||
#
|
||||
# So our basic strategy is to define a single helper method called "_retry",
|
||||
# which has generic logic for dealing with SSLWantReadError, pushing data from
|
||||
# the outgoing BIO to the wire, reading data from the wire to the incoming
|
||||
# BIO, retrying an I/O call until it works, and synchronizing with other tasks
|
||||
# that might be calling _retry concurrently. Basically it takes an SSLObject
|
||||
# non-blocking in-memory method and converts it into a Trio async blocking
|
||||
# method. _retry is only about 30 lines of code, but all these cases
|
||||
# multiplied by concurrent calls make it extremely tricky, so there are lots
|
||||
# of comments down below on the details, and a really extensive test suite in
|
||||
# test_ssl.py. And now you know *why* it's so tricky, and can probably
|
||||
# understand how it works.
|
||||
#
|
||||
# [1] https://rt.openssl.org/Ticket/Display.html?id=3712
|
||||
|
||||
# XX how closely should we match the stdlib API?
|
||||
# - maybe suppress_ragged_eofs=False is a better default?
|
||||
# - maybe check crypto folks for advice?
|
||||
# - this is also interesting: https://bugs.python.org/issue8108#msg102867
|
||||
|
||||
# Definitely keep an eye on Cory's TLS API ideas on security-sig etc.
|
||||
|
||||
# XX document behavior on cancellation/error (i.e.: all is lost abandon
|
||||
# stream)
|
||||
# docs will need to make very clear that this is different from all the other
|
||||
# cancellations in core Trio
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
################################################################
|
||||
# SSLStream
|
||||
################################################################
|
||||
|
||||
# Ideally, when the user calls SSLStream.receive_some() with no argument, then
|
||||
# we should do exactly one call to self.transport_stream.receive_some(),
|
||||
# decrypt everything we got, and return it. Unfortunately, the way openssl's
|
||||
# API works, we have to pick how much data we want to allow when we call
|
||||
# read(), and then it (potentially) triggers a call to
|
||||
# transport_stream.receive_some(). So at the time we pick the amount of data
|
||||
# to decrypt, we don't know how much data we've read. As a simple heuristic,
|
||||
# we record the max amount of data returned by previous calls to
|
||||
# transport_stream.receive_some(), and we use that for future calls to read().
|
||||
# But what do we use for the very first call? That's what this constant sets.
|
||||
#
|
||||
# Note that the value passed to read() is a limit on the amount of
|
||||
# *decrypted* data, but we can only see the size of the *encrypted* data
|
||||
# returned by transport_stream.receive_some(). TLS adds a small amount of
|
||||
# framing overhead, and TLS compression is rarely used these days because it's
|
||||
# insecure. So the size of the encrypted data should be a slight over-estimate
|
||||
# of the size of the decrypted data, which is exactly what we want.
|
||||
#
|
||||
# The specific value is not really based on anything; it might be worth tuning
|
||||
# at some point. But, if you have an TCP connection with the typical 1500 byte
|
||||
# MTU and an initial window of 10 (see RFC 6928), then the initial burst of
|
||||
# data will be limited to ~15000 bytes (or a bit less due to IP-level framing
|
||||
# overhead), so this is chosen to be larger than that.
|
||||
STARTING_RECEIVE_SIZE: TFinal = 16384
|
||||
|
||||
|
||||
def _is_eof(exc: BaseException | None) -> bool:
|
||||
# There appears to be a bug on Python 3.10, where SSLErrors
|
||||
# aren't properly translated into SSLEOFErrors.
|
||||
# This stringly-typed error check is borrowed from the AnyIO
|
||||
# project.
|
||||
return isinstance(exc, _stdlib_ssl.SSLEOFError) or (
|
||||
"UNEXPECTED_EOF_WHILE_READING" in getattr(exc, "strerror", ())
|
||||
)
|
||||
|
||||
|
||||
class NeedHandshakeError(Exception):
|
||||
"""Some :class:`SSLStream` methods can't return any meaningful data until
|
||||
after the handshake. If you call them before the handshake, they raise
|
||||
this error.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class _Once:
|
||||
def __init__(self, afn: Callable[..., Awaitable[object]], *args: object) -> None:
|
||||
self._afn = afn
|
||||
self._args = args
|
||||
self.started = False
|
||||
self._done = _sync.Event()
|
||||
|
||||
async def ensure(self, *, checkpoint: bool) -> None:
|
||||
if not self.started:
|
||||
self.started = True
|
||||
await self._afn(*self._args)
|
||||
self._done.set()
|
||||
elif not checkpoint and self._done.is_set():
|
||||
return
|
||||
else:
|
||||
await self._done.wait()
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return bool(self._done.is_set())
|
||||
|
||||
|
||||
_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"])
|
||||
|
||||
# invariant
|
||||
T_Stream = TypeVar("T_Stream", bound=Stream)
|
||||
|
||||
|
||||
@final
|
||||
class SSLStream(Stream, Generic[T_Stream]):
|
||||
r"""Encrypted communication using SSL/TLS.
|
||||
|
||||
:class:`SSLStream` wraps an arbitrary :class:`~trio.abc.Stream`, and
|
||||
allows you to perform encrypted communication over it using the usual
|
||||
:class:`~trio.abc.Stream` interface. You pass regular data to
|
||||
:meth:`send_all`, then it encrypts it and sends the encrypted data on the
|
||||
underlying :class:`~trio.abc.Stream`; :meth:`receive_some` takes encrypted
|
||||
data out of the underlying :class:`~trio.abc.Stream` and decrypts it
|
||||
before returning it.
|
||||
|
||||
You should read the standard library's :mod:`ssl` documentation carefully
|
||||
before attempting to use this class, and probably other general
|
||||
documentation on SSL/TLS as well. SSL/TLS is subtle and quick to
|
||||
anger. Really. I'm not kidding.
|
||||
|
||||
Args:
|
||||
transport_stream (~trio.abc.Stream): The stream used to transport
|
||||
encrypted data. Required.
|
||||
|
||||
ssl_context (~ssl.SSLContext): The :class:`~ssl.SSLContext` used for
|
||||
this connection. Required. Usually created by calling
|
||||
:func:`ssl.create_default_context`.
|
||||
|
||||
server_hostname (str, bytes, or None): The name of the server being
|
||||
connected to. Used for `SNI
|
||||
<https://en.wikipedia.org/wiki/Server_Name_Indication>`__ and for
|
||||
validating the server's certificate (if hostname checking is
|
||||
enabled). This is effectively mandatory for clients, and actually
|
||||
mandatory if ``ssl_context.check_hostname`` is ``True``.
|
||||
|
||||
server_side (bool): Whether this stream is acting as a client or
|
||||
server. Defaults to False, i.e. client mode.
|
||||
|
||||
https_compatible (bool): There are two versions of SSL/TLS commonly
|
||||
encountered in the wild: the standard version, and the version used
|
||||
for HTTPS (HTTP-over-SSL/TLS).
|
||||
|
||||
Standard-compliant SSL/TLS implementations always send a
|
||||
cryptographically signed ``close_notify`` message before closing the
|
||||
connection. This is important because if the underlying transport
|
||||
were simply closed, then there wouldn't be any way for the other
|
||||
side to know whether the connection was intentionally closed by the
|
||||
peer that they negotiated a cryptographic connection to, or by some
|
||||
`man-in-the-middle
|
||||
<https://en.wikipedia.org/wiki/Man-in-the-middle_attack>`__ attacker
|
||||
who can't manipulate the cryptographic stream, but can manipulate
|
||||
the transport layer (a so-called "truncation attack").
|
||||
|
||||
However, this part of the standard is widely ignored by real-world
|
||||
HTTPS implementations, which means that if you want to interoperate
|
||||
with them, then you NEED to ignore it too.
|
||||
|
||||
Fortunately this isn't as bad as it sounds, because the HTTP
|
||||
protocol already includes its own equivalent of ``close_notify``, so
|
||||
doing this again at the SSL/TLS level is redundant. But not all
|
||||
protocols do! Therefore, by default Trio implements the safer
|
||||
standard-compliant version (``https_compatible=False``). But if
|
||||
you're speaking HTTPS or some other protocol where
|
||||
``close_notify``\s are commonly skipped, then you should set
|
||||
``https_compatible=True``; with this setting, Trio will neither
|
||||
expect nor send ``close_notify`` messages.
|
||||
|
||||
If you have code that was written to use :class:`ssl.SSLSocket` and
|
||||
now you're porting it to Trio, then it may be useful to know that a
|
||||
difference between :class:`SSLStream` and :class:`ssl.SSLSocket` is
|
||||
that :class:`~ssl.SSLSocket` implements the
|
||||
``https_compatible=True`` behavior by default.
|
||||
|
||||
Attributes:
|
||||
transport_stream (trio.abc.Stream): The underlying transport stream
|
||||
that was passed to ``__init__``. An example of when this would be
|
||||
useful is if you're using :class:`SSLStream` over a
|
||||
:class:`~trio.SocketStream` and want to call the
|
||||
:class:`~trio.SocketStream`'s :meth:`~trio.SocketStream.setsockopt`
|
||||
method.
|
||||
|
||||
Internally, this class is implemented using an instance of
|
||||
:class:`ssl.SSLObject`, and all of :class:`~ssl.SSLObject`'s methods and
|
||||
attributes are re-exported as methods and attributes on this class.
|
||||
However, there is one difference: :class:`~ssl.SSLObject` has several
|
||||
methods that return information about the encrypted connection, like
|
||||
:meth:`~ssl.SSLSocket.cipher` or
|
||||
:meth:`~ssl.SSLSocket.selected_alpn_protocol`. If you call them before the
|
||||
handshake, when they can't possibly return useful data, then
|
||||
:class:`ssl.SSLObject` returns None, but :class:`trio.SSLStream`
|
||||
raises :exc:`NeedHandshakeError`.
|
||||
|
||||
This also means that if you register a SNI callback using
|
||||
`~ssl.SSLContext.sni_callback`, then the first argument your callback
|
||||
receives will be a :class:`ssl.SSLObject`.
|
||||
|
||||
"""
|
||||
|
||||
# Note: any new arguments here should likely also be added to
|
||||
# SSLListener.__init__, and maybe the open_ssl_over_tcp_* helpers.
|
||||
def __init__(
|
||||
self,
|
||||
transport_stream: T_Stream,
|
||||
ssl_context: _stdlib_ssl.SSLContext,
|
||||
*,
|
||||
server_hostname: str | bytes | None = None,
|
||||
server_side: bool = False,
|
||||
https_compatible: bool = False,
|
||||
) -> None:
|
||||
self.transport_stream: T_Stream = transport_stream
|
||||
self._state = _State.OK
|
||||
self._https_compatible = https_compatible
|
||||
self._outgoing = _stdlib_ssl.MemoryBIO()
|
||||
self._delayed_outgoing: bytes | None = None
|
||||
self._incoming = _stdlib_ssl.MemoryBIO()
|
||||
self._ssl_object = ssl_context.wrap_bio(
|
||||
self._incoming,
|
||||
self._outgoing,
|
||||
server_side=server_side,
|
||||
server_hostname=server_hostname,
|
||||
)
|
||||
# Tracks whether we've already done the initial handshake
|
||||
self._handshook = _Once(self._do_handshake)
|
||||
|
||||
# These are used to synchronize access to self.transport_stream
|
||||
self._inner_send_lock = _sync.StrictFIFOLock()
|
||||
self._inner_recv_count = 0
|
||||
self._inner_recv_lock = _sync.Lock()
|
||||
|
||||
# These are used to make sure that our caller doesn't attempt to make
|
||||
# multiple concurrent calls to send_all/wait_send_all_might_not_block
|
||||
# or to receive_some.
|
||||
self._outer_send_conflict_detector = ConflictDetector(
|
||||
"another task is currently sending data on this SSLStream",
|
||||
)
|
||||
self._outer_recv_conflict_detector = ConflictDetector(
|
||||
"another task is currently receiving data on this SSLStream",
|
||||
)
|
||||
|
||||
self._estimated_receive_size = STARTING_RECEIVE_SIZE
|
||||
|
||||
_forwarded: ClassVar = {
|
||||
"context",
|
||||
"server_side",
|
||||
"server_hostname",
|
||||
"session",
|
||||
"session_reused",
|
||||
"getpeercert",
|
||||
"selected_npn_protocol",
|
||||
"cipher",
|
||||
"shared_ciphers",
|
||||
"compression",
|
||||
"pending",
|
||||
"get_channel_binding",
|
||||
"selected_alpn_protocol",
|
||||
"version",
|
||||
}
|
||||
|
||||
_after_handshake: ClassVar = {
|
||||
"session_reused",
|
||||
"getpeercert",
|
||||
"selected_npn_protocol",
|
||||
"cipher",
|
||||
"shared_ciphers",
|
||||
"compression",
|
||||
"get_channel_binding",
|
||||
"selected_alpn_protocol",
|
||||
"version",
|
||||
}
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in self._forwarded:
|
||||
if name in self._after_handshake and not self._handshook.done:
|
||||
raise NeedHandshakeError(f"call do_handshake() before calling {name!r}")
|
||||
|
||||
return getattr(self._ssl_object, name)
|
||||
else:
|
||||
raise AttributeError(name)
|
||||
|
||||
def __setattr__(self, name: str, value: object) -> None:
|
||||
if name in self._forwarded:
|
||||
setattr(self._ssl_object, name, value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __dir__(self) -> list[str]:
|
||||
return list(super().__dir__()) + list(self._forwarded)
|
||||
|
||||
def _check_status(self) -> None:
|
||||
if self._state is _State.OK:
|
||||
return
|
||||
elif self._state is _State.BROKEN:
|
||||
raise trio.BrokenResourceError
|
||||
elif self._state is _State.CLOSED:
|
||||
raise trio.ClosedResourceError
|
||||
else: # pragma: no cover
|
||||
raise AssertionError()
|
||||
|
||||
# This is probably the single trickiest function in Trio. It has lots of
|
||||
# comments, though, just make sure to think carefully if you ever have to
|
||||
# touch it. The big comment at the top of this file will help explain
|
||||
# too.
|
||||
async def _retry(
|
||||
self,
|
||||
fn: Callable[..., T],
|
||||
*args: object,
|
||||
ignore_want_read: bool = False,
|
||||
is_handshake: bool = False,
|
||||
) -> T | None:
|
||||
await trio.lowlevel.checkpoint_if_cancelled()
|
||||
yielded = False
|
||||
finished = False
|
||||
while not finished:
|
||||
# WARNING: this code needs to be very careful with when it
|
||||
# calls 'await'! There might be multiple tasks calling this
|
||||
# function at the same time trying to do different operations,
|
||||
# so we need to be careful to:
|
||||
#
|
||||
# 1) interact with the SSLObject, then
|
||||
# 2) await on exactly one thing that lets us make forward
|
||||
# progress, then
|
||||
# 3) loop or exit
|
||||
#
|
||||
# In particular we don't want to yield while interacting with
|
||||
# the SSLObject (because it's shared state, so someone else
|
||||
# might come in and mess with it while we're suspended), and
|
||||
# we don't want to yield *before* starting the operation that
|
||||
# will help us make progress, because then someone else might
|
||||
# come in and leapfrog us.
|
||||
|
||||
# Call the SSLObject method, and get its result.
|
||||
#
|
||||
# NB: despite what the docs say, SSLWantWriteError can't
|
||||
# happen – "Writes to memory BIOs will always succeed if
|
||||
# memory is available: that is their size can grow
|
||||
# indefinitely."
|
||||
# https://wiki.openssl.org/index.php/Manual:BIO_s_mem(3)
|
||||
want_read = False
|
||||
ret = None
|
||||
try:
|
||||
ret = fn(*args)
|
||||
except _stdlib_ssl.SSLWantReadError:
|
||||
want_read = True
|
||||
except (_stdlib_ssl.SSLError, _stdlib_ssl.CertificateError) as exc:
|
||||
self._state = _State.BROKEN
|
||||
raise trio.BrokenResourceError from exc
|
||||
else:
|
||||
finished = True
|
||||
if ignore_want_read:
|
||||
want_read = False
|
||||
finished = True
|
||||
to_send = self._outgoing.read()
|
||||
|
||||
# Some versions of SSL_do_handshake have a bug in how they handle
|
||||
# the TLS 1.3 handshake on the server side: after the handshake
|
||||
# finishes, they automatically send session tickets, even though
|
||||
# the client may not be expecting data to arrive at this point and
|
||||
# sending it could cause a deadlock or lost data. This applies at
|
||||
# least to OpenSSL 1.1.1c and earlier, and the OpenSSL devs
|
||||
# currently have no plans to fix it:
|
||||
#
|
||||
# https://github.com/openssl/openssl/issues/7948
|
||||
# https://github.com/openssl/openssl/issues/7967
|
||||
#
|
||||
# The correct behavior is to wait to send session tickets on the
|
||||
# first call to SSL_write. (This is what BoringSSL does.) So, we
|
||||
# use a heuristic to detect when OpenSSL has tried to send session
|
||||
# tickets, and we manually delay sending them until the
|
||||
# appropriate moment. For more discussion see:
|
||||
#
|
||||
# https://github.com/python-trio/trio/issues/819#issuecomment-517529763
|
||||
if (
|
||||
is_handshake
|
||||
and not want_read
|
||||
and self._ssl_object.server_side
|
||||
and self._ssl_object.version() == "TLSv1.3"
|
||||
):
|
||||
assert self._delayed_outgoing is None
|
||||
self._delayed_outgoing = to_send
|
||||
to_send = b""
|
||||
|
||||
# Outputs from the above code block are:
|
||||
#
|
||||
# - to_send: bytestring; if non-empty then we need to send
|
||||
# this data to make forward progress
|
||||
#
|
||||
# - want_read: True if we need to receive_some some data to make
|
||||
# forward progress
|
||||
#
|
||||
# - finished: False means that we need to retry the call to
|
||||
# fn(*args) again, after having pushed things forward. True
|
||||
# means we still need to do whatever was said (in particular
|
||||
# send any data in to_send), but once we do then we're
|
||||
# done.
|
||||
#
|
||||
# - ret: the operation's return value. (Meaningless unless
|
||||
# finished is True.)
|
||||
#
|
||||
# Invariant: want_read and finished can't both be True at the
|
||||
# same time.
|
||||
#
|
||||
# Now we need to move things forward. There are two things we
|
||||
# might have to do, and any given operation might require
|
||||
# either, both, or neither to proceed:
|
||||
#
|
||||
# - send the data in to_send
|
||||
#
|
||||
# - receive_some some data and put it into the incoming BIO
|
||||
#
|
||||
# Our strategy is: if there's data to send, send it;
|
||||
# *otherwise* if there's data to receive_some, receive_some it.
|
||||
#
|
||||
# If both need to happen, then we only send. Why? Well, we
|
||||
# know that *right now* we have to both send and receive_some
|
||||
# before the operation can complete. But as soon as we yield,
|
||||
# that information becomes potentially stale – e.g. while
|
||||
# we're sending, some other task might go and receive_some the
|
||||
# data we need and put it into the incoming BIO. And if it
|
||||
# does, then we *definitely don't* want to do a receive_some –
|
||||
# there might not be any more data coming, and we'd deadlock!
|
||||
# We could do something tricky to keep track of whether a
|
||||
# receive_some happens while we're sending, but the case where
|
||||
# we have to do both is very unusual (only during a
|
||||
# renegotiation), so it's better to keep things simple. So we
|
||||
# do just one potentially-blocking operation, then check again
|
||||
# for fresh information.
|
||||
#
|
||||
# And we prioritize sending over receiving because, if there
|
||||
# are multiple tasks that want to receive_some, then it
|
||||
# doesn't matter what order they go in. But if there are
|
||||
# multiple tasks that want to send, then they each have
|
||||
# different data, and the data needs to get put onto the wire
|
||||
# in the same order that it was retrieved from the outgoing
|
||||
# BIO. So if we have data to send, that *needs* to be the
|
||||
# *very* *next* *thing* we do, to make sure no-one else sneaks
|
||||
# in before us. Or if we can't send immediately because
|
||||
# someone else is, then we at least need to get in line
|
||||
# immediately.
|
||||
if to_send:
|
||||
# NOTE: This relies on the lock being strict FIFO fair!
|
||||
async with self._inner_send_lock:
|
||||
yielded = True
|
||||
try:
|
||||
if self._delayed_outgoing is not None:
|
||||
to_send = self._delayed_outgoing + to_send
|
||||
self._delayed_outgoing = None
|
||||
await self.transport_stream.send_all(to_send)
|
||||
except:
|
||||
# Some unknown amount of our data got sent, and we
|
||||
# don't know how much. This stream is doomed.
|
||||
self._state = _State.BROKEN
|
||||
raise
|
||||
elif want_read:
|
||||
# It's possible that someone else is already blocked in
|
||||
# transport_stream.receive_some. If so then we want to
|
||||
# wait for them to finish, but we don't want to call
|
||||
# transport_stream.receive_some again ourselves; we just
|
||||
# want to loop around and check if their contribution
|
||||
# helped anything. So we make a note of how many times
|
||||
# some task has been through here before taking the lock,
|
||||
# and if it's changed by the time we get the lock, then we
|
||||
# skip calling transport_stream.receive_some and loop
|
||||
# around immediately.
|
||||
recv_count = self._inner_recv_count
|
||||
async with self._inner_recv_lock:
|
||||
yielded = True
|
||||
if recv_count == self._inner_recv_count:
|
||||
data = await self.transport_stream.receive_some()
|
||||
if not data:
|
||||
self._incoming.write_eof()
|
||||
else:
|
||||
self._estimated_receive_size = max(
|
||||
self._estimated_receive_size,
|
||||
len(data),
|
||||
)
|
||||
self._incoming.write(data)
|
||||
self._inner_recv_count += 1
|
||||
if not yielded:
|
||||
await trio.lowlevel.cancel_shielded_checkpoint()
|
||||
return ret
|
||||
|
||||
async def _do_handshake(self) -> None:
|
||||
try:
|
||||
await self._retry(self._ssl_object.do_handshake, is_handshake=True)
|
||||
except:
|
||||
self._state = _State.BROKEN
|
||||
raise
|
||||
|
||||
async def do_handshake(self) -> None:
|
||||
"""Ensure that the initial handshake has completed.
|
||||
|
||||
The SSL protocol requires an initial handshake to exchange
|
||||
certificates, select cryptographic keys, and so forth, before any
|
||||
actual data can be sent or received. You don't have to call this
|
||||
method; if you don't, then :class:`SSLStream` will automatically
|
||||
perform the handshake as needed, the first time you try to send or
|
||||
receive data. But if you want to trigger it manually – for example,
|
||||
because you want to look at the peer's certificate before you start
|
||||
talking to them – then you can call this method.
|
||||
|
||||
If the initial handshake is already in progress in another task, this
|
||||
waits for it to complete and then returns.
|
||||
|
||||
If the initial handshake has already completed, this returns
|
||||
immediately without doing anything (except executing a checkpoint).
|
||||
|
||||
.. warning:: If this method is cancelled, then it may leave the
|
||||
:class:`SSLStream` in an unusable state. If this happens then any
|
||||
future attempt to use the object will raise
|
||||
:exc:`trio.BrokenResourceError`.
|
||||
|
||||
"""
|
||||
self._check_status()
|
||||
await self._handshook.ensure(checkpoint=True)
|
||||
|
||||
# Most things work if we don't explicitly force do_handshake to be called
|
||||
# before calling receive_some or send_all, because openssl will
|
||||
# automatically perform the handshake on the first SSL_{read,write}
|
||||
# call. BUT, allowing openssl to do this will disable Python's hostname
|
||||
# checking!!! See:
|
||||
# https://bugs.python.org/issue30141
|
||||
# So we *definitely* have to make sure that do_handshake is called
|
||||
# before doing anything else.
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
|
||||
"""Read some data from the underlying transport, decrypt it, and
|
||||
return it.
|
||||
|
||||
See :meth:`trio.abc.ReceiveStream.receive_some` for details.
|
||||
|
||||
.. warning:: If this method is cancelled while the initial handshake
|
||||
or a renegotiation are in progress, then it may leave the
|
||||
:class:`SSLStream` in an unusable state. If this happens then any
|
||||
future attempt to use the object will raise
|
||||
:exc:`trio.BrokenResourceError`.
|
||||
|
||||
"""
|
||||
with self._outer_recv_conflict_detector:
|
||||
self._check_status()
|
||||
try:
|
||||
await self._handshook.ensure(checkpoint=False)
|
||||
except trio.BrokenResourceError as exc:
|
||||
# For some reason, EOF before handshake sometimes raises
|
||||
# SSLSyscallError instead of SSLEOFError (e.g. on my linux
|
||||
# laptop, but not on appveyor). Thanks openssl.
|
||||
if self._https_compatible and (
|
||||
isinstance(exc.__cause__, _stdlib_ssl.SSLSyscallError)
|
||||
or _is_eof(exc.__cause__)
|
||||
):
|
||||
await trio.lowlevel.checkpoint()
|
||||
return b""
|
||||
else:
|
||||
raise
|
||||
if max_bytes is None:
|
||||
# If we somehow have more data already in our pending buffer
|
||||
# than the estimate receive size, bump up our size a bit for
|
||||
# this read only.
|
||||
max_bytes = max(self._estimated_receive_size, self._incoming.pending)
|
||||
else:
|
||||
max_bytes = _operator.index(max_bytes)
|
||||
if max_bytes < 1:
|
||||
raise ValueError("max_bytes must be >= 1")
|
||||
try:
|
||||
received = await self._retry(self._ssl_object.read, max_bytes)
|
||||
assert received is not None
|
||||
return received
|
||||
except trio.BrokenResourceError as exc:
|
||||
# This isn't quite equivalent to just returning b"" in the
|
||||
# first place, because we still end up with self._state set to
|
||||
# BROKEN. But that's actually fine, because after getting an
|
||||
# EOF on TLS then the only thing you can do is close the
|
||||
# stream, and closing doesn't care about the state.
|
||||
|
||||
if self._https_compatible and _is_eof(exc.__cause__):
|
||||
await trio.lowlevel.checkpoint()
|
||||
return b""
|
||||
else:
|
||||
raise
|
||||
|
||||
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
|
||||
"""Encrypt some data and then send it on the underlying transport.
|
||||
|
||||
See :meth:`trio.abc.SendStream.send_all` for details.
|
||||
|
||||
.. warning:: If this method is cancelled, then it may leave the
|
||||
:class:`SSLStream` in an unusable state. If this happens then any
|
||||
attempt to use the object will raise
|
||||
:exc:`trio.BrokenResourceError`.
|
||||
|
||||
"""
|
||||
with self._outer_send_conflict_detector:
|
||||
self._check_status()
|
||||
await self._handshook.ensure(checkpoint=False)
|
||||
# SSLObject interprets write(b"") as an EOF for some reason, which
|
||||
# is not what we want.
|
||||
if not data:
|
||||
await trio.lowlevel.checkpoint()
|
||||
return
|
||||
await self._retry(self._ssl_object.write, data)
|
||||
|
||||
async def unwrap(self) -> tuple[Stream, bytes | bytearray]:
|
||||
"""Cleanly close down the SSL/TLS encryption layer, allowing the
|
||||
underlying stream to be used for unencrypted communication.
|
||||
|
||||
You almost certainly don't need this.
|
||||
|
||||
Returns:
|
||||
A pair ``(transport_stream, trailing_bytes)``, where
|
||||
``transport_stream`` is the underlying transport stream, and
|
||||
``trailing_bytes`` is a byte string. Since :class:`SSLStream`
|
||||
doesn't necessarily know where the end of the encrypted data will
|
||||
be, it can happen that it accidentally reads too much from the
|
||||
underlying stream. ``trailing_bytes`` contains this extra data; you
|
||||
should process it as if it was returned from a call to
|
||||
``transport_stream.receive_some(...)``.
|
||||
|
||||
"""
|
||||
with self._outer_recv_conflict_detector, self._outer_send_conflict_detector:
|
||||
self._check_status()
|
||||
await self._handshook.ensure(checkpoint=False)
|
||||
await self._retry(self._ssl_object.unwrap)
|
||||
transport_stream = self.transport_stream
|
||||
self._state = _State.CLOSED
|
||||
self.transport_stream = None # type: ignore[assignment] # State is CLOSED now, nothing should use
|
||||
return (transport_stream, self._incoming.read())
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Gracefully shut down this connection, and close the underlying
|
||||
transport.
|
||||
|
||||
If ``https_compatible`` is False (the default), then this attempts to
|
||||
first send a ``close_notify`` and then close the underlying stream by
|
||||
calling its :meth:`~trio.abc.AsyncResource.aclose` method.
|
||||
|
||||
If ``https_compatible`` is set to True, then this simply closes the
|
||||
underlying stream and marks this stream as closed.
|
||||
|
||||
"""
|
||||
if self._state is _State.CLOSED:
|
||||
await trio.lowlevel.checkpoint()
|
||||
return
|
||||
if self._state is _State.BROKEN or self._https_compatible:
|
||||
self._state = _State.CLOSED
|
||||
await self.transport_stream.aclose()
|
||||
return
|
||||
try:
|
||||
# https_compatible=False, so we're in spec-compliant mode and have
|
||||
# to send close_notify so that the other side gets a cryptographic
|
||||
# assurance that we've called aclose. Of course, we can't do
|
||||
# anything cryptographic until after we've completed the
|
||||
# handshake:
|
||||
await self._handshook.ensure(checkpoint=False)
|
||||
# Then, we call SSL_shutdown *once*, because we want to send a
|
||||
# close_notify but *not* wait for the other side to send back a
|
||||
# response. In principle it would be more polite to wait for the
|
||||
# other side to reply with their own close_notify. However, if
|
||||
# they aren't paying attention (e.g., if they're just sending
|
||||
# data and not receiving) then we will never notice our
|
||||
# close_notify and we'll be waiting forever. Eventually we'll time
|
||||
# out (hopefully), but it's still kind of nasty. And we can't
|
||||
# require the other side to always be receiving, because (a)
|
||||
# backpressure is kind of important, and (b) I bet there are
|
||||
# broken TLS implementations out there that don't receive all the
|
||||
# time. (Like e.g. anyone using Python ssl in synchronous mode.)
|
||||
#
|
||||
# The send-then-immediately-close behavior is explicitly allowed
|
||||
# by the TLS specs, so we're ok on that.
|
||||
#
|
||||
# Subtlety: SSLObject.unwrap will immediately call it a second
|
||||
# time, and the second time will raise SSLWantReadError because
|
||||
# there hasn't been time for the other side to respond
|
||||
# yet. (Unless they spontaneously sent a close_notify before we
|
||||
# called this, and it's either already been processed or gets
|
||||
# pulled out of the buffer by Python's second call.) So the way to
|
||||
# do what we want is to ignore SSLWantReadError on this call.
|
||||
#
|
||||
# Also, because the other side might have already sent
|
||||
# close_notify and closed their connection then it's possible that
|
||||
# our attempt to send close_notify will raise
|
||||
# BrokenResourceError. This is totally legal, and in fact can happen
|
||||
# with two well-behaved Trio programs talking to each other, so we
|
||||
# don't want to raise an error. So we suppress BrokenResourceError
|
||||
# here. (This is safe, because literally the only thing this call
|
||||
# to _retry will do is send the close_notify alert, so that's
|
||||
# surely where the error comes from.)
|
||||
#
|
||||
# FYI in some cases this could also raise SSLSyscallError which I
|
||||
# think is because SSL_shutdown is terrible. (Check out that note
|
||||
# at the bottom of the man page saying that it sometimes gets
|
||||
# raised spuriously.) I haven't seen this since we switched to
|
||||
# immediately closing the socket, and I don't know exactly what
|
||||
# conditions cause it and how to respond, so for now we're just
|
||||
# letting that happen. But if you start seeing it, then hopefully
|
||||
# this will give you a little head start on tracking it down,
|
||||
# because whoa did this puzzle us at the 2017 PyCon sprints.
|
||||
#
|
||||
# Also, if someone else is blocked in send/receive, then we aren't
|
||||
# going to be able to do a clean shutdown. If that happens, we'll
|
||||
# just do an unclean shutdown.
|
||||
with contextlib.suppress(trio.BrokenResourceError, trio.BusyResourceError):
|
||||
await self._retry(self._ssl_object.unwrap, ignore_want_read=True)
|
||||
except:
|
||||
# Failure! Kill the stream and move on.
|
||||
await aclose_forcefully(self.transport_stream)
|
||||
raise
|
||||
else:
|
||||
# Success! Gracefully close the underlying stream.
|
||||
await self.transport_stream.aclose()
|
||||
finally:
|
||||
self._state = _State.CLOSED
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
"""See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`."""
|
||||
# This method's implementation is deceptively simple.
|
||||
#
|
||||
# First, we take the outer send lock, because of Trio's standard
|
||||
# semantics that wait_send_all_might_not_block and send_all
|
||||
# conflict.
|
||||
with self._outer_send_conflict_detector:
|
||||
self._check_status()
|
||||
# Then we take the inner send lock. We know that no other tasks
|
||||
# are calling self.send_all or self.wait_send_all_might_not_block,
|
||||
# because we have the outer_send_lock. But! There might be another
|
||||
# task calling self.receive_some -> transport_stream.send_all, in
|
||||
# which case if we were to call
|
||||
# transport_stream.wait_send_all_might_not_block directly we'd
|
||||
# have two tasks doing write-related operations on
|
||||
# transport_stream simultaneously, which is not allowed. We
|
||||
# *don't* want to raise this conflict to our caller, because it's
|
||||
# purely an internal affair – all they did was call
|
||||
# wait_send_all_might_not_block and receive_some at the same time,
|
||||
# which is totally valid. And waiting for the lock is OK, because
|
||||
# a call to send_all certainly wouldn't complete while the other
|
||||
# task holds the lock.
|
||||
async with self._inner_send_lock:
|
||||
# Now we have the lock, which creates another potential
|
||||
# problem: what if a call to self.receive_some attempts to do
|
||||
# transport_stream.send_all now? It'll have to wait for us to
|
||||
# finish! But that's OK, because we release the lock as soon
|
||||
# as the underlying stream becomes writable, and the
|
||||
# self.receive_some call wasn't going to make any progress
|
||||
# until then anyway.
|
||||
#
|
||||
# Of course, this does mean we might return *before* the
|
||||
# stream is logically writable, because immediately after we
|
||||
# return self.receive_some might write some data and make it
|
||||
# non-writable again. But that's OK too,
|
||||
# wait_send_all_might_not_block only guarantees that it
|
||||
# doesn't return late.
|
||||
await self.transport_stream.wait_send_all_might_not_block()
|
||||
|
||||
|
||||
# this is necessary for Sphinx, see also `_abc.py`
|
||||
SSLStream.__module__ = SSLStream.__module__.replace("._ssl", "")
|
||||
|
||||
|
||||
@final
|
||||
class SSLListener(Listener[SSLStream[T_Stream]]):
|
||||
"""A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers.
|
||||
|
||||
:class:`SSLListener` wraps around another Listener, and converts
|
||||
all incoming connections to encrypted connections by wrapping them
|
||||
in a :class:`SSLStream`.
|
||||
|
||||
Args:
|
||||
transport_listener (~trio.abc.Listener): The listener whose incoming
|
||||
connections will be wrapped in :class:`SSLStream`.
|
||||
|
||||
ssl_context (~ssl.SSLContext): The :class:`~ssl.SSLContext` that will be
|
||||
used for incoming connections.
|
||||
|
||||
https_compatible (bool): Passed on to :class:`SSLStream`.
|
||||
|
||||
Attributes:
|
||||
transport_listener (trio.abc.Listener): The underlying listener that was
|
||||
passed to ``__init__``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport_listener: Listener[T_Stream],
|
||||
ssl_context: _stdlib_ssl.SSLContext,
|
||||
*,
|
||||
https_compatible: bool = False,
|
||||
) -> None:
|
||||
self.transport_listener = transport_listener
|
||||
self._ssl_context = ssl_context
|
||||
self._https_compatible = https_compatible
|
||||
|
||||
async def accept(self) -> SSLStream[T_Stream]:
|
||||
"""Accept the next connection and wrap it in an :class:`SSLStream`.
|
||||
|
||||
See :meth:`trio.abc.Listener.accept` for details.
|
||||
|
||||
"""
|
||||
transport_stream = await self.transport_listener.accept()
|
||||
return SSLStream(
|
||||
transport_stream,
|
||||
self._ssl_context,
|
||||
server_side=True,
|
||||
https_compatible=self._https_compatible,
|
||||
)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the transport listener."""
|
||||
await self.transport_listener.aclose()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,123 @@
|
||||
# Platform-specific subprocess bits'n'pieces.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import trio
|
||||
|
||||
from .. import _core, _subprocess
|
||||
from .._abc import ReceiveStream, SendStream # noqa: TCH001
|
||||
|
||||
_wait_child_exiting_error: ImportError | None = None
|
||||
_create_child_pipe_error: ImportError | None = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# internal types for the pipe representations used in type checking only
|
||||
class ClosableSendStream(SendStream):
|
||||
def close(self) -> None: ...
|
||||
|
||||
class ClosableReceiveStream(ReceiveStream):
|
||||
def close(self) -> None: ...
|
||||
|
||||
|
||||
# Fallback versions of the functions provided -- implementations
|
||||
# per OS are imported atop these at the bottom of the module.
|
||||
async def wait_child_exiting(process: _subprocess.Process) -> None:
|
||||
"""Block until the child process managed by ``process`` is exiting.
|
||||
|
||||
It is invalid to call this function if the process has already
|
||||
been waited on; that is, ``process.returncode`` must be None.
|
||||
|
||||
When this function returns, it indicates that a call to
|
||||
:meth:`subprocess.Popen.wait` will immediately be able to
|
||||
return the process's exit status. The actual exit status is not
|
||||
consumed by this call, since :class:`~subprocess.Popen` wants
|
||||
to be able to do that itself.
|
||||
"""
|
||||
raise NotImplementedError from _wait_child_exiting_error # pragma: no cover
|
||||
|
||||
|
||||
def create_pipe_to_child_stdin() -> tuple[ClosableSendStream, int]:
|
||||
"""Create a new pipe suitable for sending data from this
|
||||
process to the standard input of a child we're about to spawn.
|
||||
|
||||
Returns:
|
||||
A pair ``(trio_end, subprocess_end)`` where ``trio_end`` is a
|
||||
:class:`~trio.abc.SendStream` and ``subprocess_end`` is
|
||||
something suitable for passing as the ``stdin`` argument of
|
||||
:class:`subprocess.Popen`.
|
||||
"""
|
||||
raise NotImplementedError from _create_child_pipe_error # pragma: no cover
|
||||
|
||||
|
||||
def create_pipe_from_child_output() -> tuple[ClosableReceiveStream, int]:
|
||||
"""Create a new pipe suitable for receiving data into this
|
||||
process from the standard output or error stream of a child
|
||||
we're about to spawn.
|
||||
|
||||
Returns:
|
||||
A pair ``(trio_end, subprocess_end)`` where ``trio_end`` is a
|
||||
:class:`~trio.abc.ReceiveStream` and ``subprocess_end`` is
|
||||
something suitable for passing as the ``stdin`` argument of
|
||||
:class:`subprocess.Popen`.
|
||||
"""
|
||||
raise NotImplementedError from _create_child_pipe_error # pragma: no cover
|
||||
|
||||
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
from .windows import wait_child_exiting
|
||||
elif sys.platform != "linux" and (TYPE_CHECKING or hasattr(_core, "wait_kevent")):
|
||||
from .kqueue import wait_child_exiting
|
||||
else:
|
||||
# as it's an exported symbol, noqa'd
|
||||
from .waitid import wait_child_exiting # noqa: F401
|
||||
except ImportError as ex: # pragma: no cover
|
||||
_wait_child_exiting_error = ex
|
||||
|
||||
try:
|
||||
if TYPE_CHECKING:
|
||||
# Not worth type checking these definitions
|
||||
pass
|
||||
|
||||
elif os.name == "posix":
|
||||
|
||||
def create_pipe_to_child_stdin():
|
||||
rfd, wfd = os.pipe()
|
||||
return trio.lowlevel.FdStream(wfd), rfd
|
||||
|
||||
def create_pipe_from_child_output():
|
||||
rfd, wfd = os.pipe()
|
||||
return trio.lowlevel.FdStream(rfd), wfd
|
||||
|
||||
elif os.name == "nt":
|
||||
import msvcrt
|
||||
|
||||
# This isn't exported or documented, but it's also not
|
||||
# underscore-prefixed, and seems kosher to use. The asyncio docs
|
||||
# for 3.5 included an example that imported socketpair from
|
||||
# windows_utils (before socket.socketpair existed on Windows), and
|
||||
# when asyncio.windows_utils.socketpair was removed in 3.7, the
|
||||
# removal was mentioned in the release notes.
|
||||
from asyncio.windows_utils import pipe as windows_pipe
|
||||
|
||||
from .._windows_pipes import PipeReceiveStream, PipeSendStream
|
||||
|
||||
def create_pipe_to_child_stdin():
|
||||
# for stdin, we want the write end (our end) to use overlapped I/O
|
||||
rh, wh = windows_pipe(overlapped=(False, True))
|
||||
return PipeSendStream(wh), msvcrt.open_osfhandle(rh, os.O_RDONLY)
|
||||
|
||||
def create_pipe_from_child_output():
|
||||
# for stdout/err, it's the read end that's overlapped
|
||||
rh, wh = windows_pipe(overlapped=(True, False))
|
||||
return PipeReceiveStream(rh), msvcrt.open_osfhandle(wh, 0)
|
||||
|
||||
else: # pragma: no cover
|
||||
raise ImportError("pipes not implemented on this platform")
|
||||
|
||||
except ImportError as ex: # pragma: no cover
|
||||
_create_child_pipe_error = ex
|
||||
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import select
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .. import _core, _subprocess
|
||||
|
||||
assert (sys.platform != "win32" and sys.platform != "linux") or not TYPE_CHECKING
|
||||
|
||||
|
||||
async def wait_child_exiting(process: _subprocess.Process) -> None:
|
||||
kqueue = _core.current_kqueue()
|
||||
try:
|
||||
from select import KQ_NOTE_EXIT
|
||||
except ImportError: # pragma: no cover
|
||||
# pypy doesn't define KQ_NOTE_EXIT:
|
||||
# https://bitbucket.org/pypy/pypy/issues/2921/
|
||||
# I verified this value against both Darwin and FreeBSD
|
||||
KQ_NOTE_EXIT = 0x80000000
|
||||
|
||||
def make_event(flags: int) -> select.kevent:
|
||||
return select.kevent(
|
||||
process.pid,
|
||||
filter=select.KQ_FILTER_PROC,
|
||||
flags=flags,
|
||||
fflags=KQ_NOTE_EXIT,
|
||||
)
|
||||
|
||||
try:
|
||||
kqueue.control([make_event(select.KQ_EV_ADD | select.KQ_EV_ONESHOT)], 0)
|
||||
except ProcessLookupError: # pragma: no cover
|
||||
# This can supposedly happen if the process is in the process
|
||||
# of exiting, and it can even be the case that kqueue says the
|
||||
# process doesn't exist before waitpid(WNOHANG) says it hasn't
|
||||
# exited yet. See the discussion in https://chromium.googlesource.com/
|
||||
# chromium/src/base/+/master/process/kill_mac.cc .
|
||||
# We haven't actually seen this error occur since we added
|
||||
# locking to prevent multiple calls to wait_child_exiting()
|
||||
# for the same process simultaneously, but given the explanation
|
||||
# in Chromium it seems we should still keep the check.
|
||||
return
|
||||
|
||||
def abort(_: _core.RaiseCancelT) -> _core.Abort:
|
||||
kqueue.control([make_event(select.KQ_EV_DELETE)], 0)
|
||||
return _core.Abort.SUCCEEDED
|
||||
|
||||
await _core.wait_kevent(process.pid, select.KQ_FILTER_PROC, abort)
|
||||
@@ -0,0 +1,113 @@
|
||||
import errno
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .. import _core, _subprocess
|
||||
from .._sync import CapacityLimiter, Event
|
||||
from .._threads import to_thread_run_sync
|
||||
|
||||
assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING
|
||||
|
||||
try:
|
||||
from os import waitid
|
||||
|
||||
def sync_wait_reapable(pid: int) -> None:
|
||||
waitid(os.P_PID, pid, os.WEXITED | os.WNOWAIT)
|
||||
|
||||
except ImportError:
|
||||
# pypy doesn't define os.waitid so we need to pull it out ourselves
|
||||
# using cffi: https://bitbucket.org/pypy/pypy/issues/2922/
|
||||
import cffi
|
||||
|
||||
waitid_ffi = cffi.FFI()
|
||||
|
||||
# Believe it or not, siginfo_t starts with fields in the
|
||||
# same layout on both Linux and Darwin. The Linux structure
|
||||
# is bigger so that's what we use to size `pad`; while
|
||||
# there are a few extra fields in there, most of it is
|
||||
# true padding which would not be written by the syscall.
|
||||
waitid_ffi.cdef(
|
||||
"""
|
||||
typedef struct siginfo_s {
|
||||
int si_signo;
|
||||
int si_errno;
|
||||
int si_code;
|
||||
int si_pid;
|
||||
int si_uid;
|
||||
int si_status;
|
||||
int pad[26];
|
||||
} siginfo_t;
|
||||
int waitid(int idtype, int id, siginfo_t* result, int options);
|
||||
""",
|
||||
)
|
||||
waitid_cffi = waitid_ffi.dlopen(None).waitid # type: ignore[attr-defined]
|
||||
|
||||
def sync_wait_reapable(pid: int) -> None:
|
||||
P_PID = 1
|
||||
WEXITED = 0x00000004
|
||||
if sys.platform == "darwin": # pragma: no cover
|
||||
# waitid() is not exposed on Python on Darwin but does
|
||||
# work through CFFI; note that we typically won't get
|
||||
# here since Darwin also defines kqueue
|
||||
WNOWAIT = 0x00000020
|
||||
else:
|
||||
WNOWAIT = 0x01000000
|
||||
result = waitid_ffi.new("siginfo_t *")
|
||||
while waitid_cffi(P_PID, pid, result, WEXITED | WNOWAIT) < 0:
|
||||
got_errno = waitid_ffi.errno
|
||||
if got_errno == errno.EINTR:
|
||||
continue
|
||||
raise OSError(got_errno, os.strerror(got_errno))
|
||||
|
||||
|
||||
# adapted from
|
||||
# https://github.com/python-trio/trio/issues/4#issuecomment-398967572
|
||||
|
||||
waitid_limiter = CapacityLimiter(math.inf)
|
||||
|
||||
|
||||
async def _waitid_system_task(pid: int, event: Event) -> None:
|
||||
"""Spawn a thread that waits for ``pid`` to exit, then wake any tasks
|
||||
that were waiting on it.
|
||||
"""
|
||||
# abandon_on_cancel=True: if this task is cancelled, then we abandon the
|
||||
# thread to keep running waitpid in the background. Since this is
|
||||
# always run as a system task, this will only happen if the whole
|
||||
# call to trio.run is shutting down.
|
||||
|
||||
try:
|
||||
await to_thread_run_sync(
|
||||
sync_wait_reapable,
|
||||
pid,
|
||||
abandon_on_cancel=True,
|
||||
limiter=waitid_limiter,
|
||||
)
|
||||
except OSError:
|
||||
# If waitid fails, waitpid will fail too, so it still makes
|
||||
# sense to wake up the callers of wait_process_exiting(). The
|
||||
# most likely reason for this error in practice is a child
|
||||
# exiting when wait() is not possible because SIGCHLD is
|
||||
# ignored.
|
||||
pass
|
||||
finally:
|
||||
event.set()
|
||||
|
||||
|
||||
async def wait_child_exiting(process: "_subprocess.Process") -> None:
|
||||
# Logic of this function:
|
||||
# - The first time we get called, we create an Event and start
|
||||
# an instance of _waitid_system_task that will set the Event
|
||||
# when waitid() completes. If that Event is set before
|
||||
# we get cancelled, we're good.
|
||||
# - Otherwise, a following call after the cancellation must
|
||||
# reuse the Event created during the first call, lest we
|
||||
# create an arbitrary number of threads waiting on the same
|
||||
# process.
|
||||
|
||||
if process._wait_for_exit_data is None:
|
||||
process._wait_for_exit_data = event = Event()
|
||||
_core.spawn_system_task(_waitid_system_task, process.pid, event)
|
||||
assert isinstance(process._wait_for_exit_data, Event)
|
||||
await process._wait_for_exit_data.wait()
|
||||
@@ -0,0 +1,11 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .._wait_for_object import WaitForSingleObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import _subprocess
|
||||
|
||||
|
||||
async def wait_child_exiting(process: "_subprocess.Process") -> None:
|
||||
# _handle is not in Popen stubs, though it is present on Windows.
|
||||
await WaitForSingleObject(int(process._proc._handle)) # type: ignore[attr-defined]
|
||||
@@ -0,0 +1,876 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
import attrs
|
||||
|
||||
import trio
|
||||
|
||||
from . import _core
|
||||
from ._core import (
|
||||
Abort,
|
||||
ParkingLot,
|
||||
RaiseCancelT,
|
||||
add_parking_lot_breaker,
|
||||
enable_ki_protection,
|
||||
remove_parking_lot_breaker,
|
||||
)
|
||||
from ._util import final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
from ._core import Task
|
||||
from ._core._parking_lot import ParkingLotStatistics
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class EventStatistics:
|
||||
"""An object containing debugging information.
|
||||
|
||||
Currently the following fields are defined:
|
||||
|
||||
* ``tasks_waiting``: The number of tasks blocked on this event's
|
||||
:meth:`trio.Event.wait` method.
|
||||
|
||||
"""
|
||||
|
||||
tasks_waiting: int
|
||||
|
||||
|
||||
@final
|
||||
@attrs.define(repr=False, eq=False)
|
||||
class Event:
|
||||
"""A waitable boolean value useful for inter-task synchronization,
|
||||
inspired by :class:`threading.Event`.
|
||||
|
||||
An event object has an internal boolean flag, representing whether
|
||||
the event has happened yet. The flag is initially False, and the
|
||||
:meth:`wait` method waits until the flag is True. If the flag is
|
||||
already True, then :meth:`wait` returns immediately. (If the event has
|
||||
already happened, there's nothing to wait for.) The :meth:`set` method
|
||||
sets the flag to True, and wakes up any waiters.
|
||||
|
||||
This behavior is useful because it helps avoid race conditions and
|
||||
lost wakeups: it doesn't matter whether :meth:`set` gets called just
|
||||
before or after :meth:`wait`. If you want a lower-level wakeup
|
||||
primitive that doesn't have this protection, consider :class:`Condition`
|
||||
or :class:`trio.lowlevel.ParkingLot`.
|
||||
|
||||
.. note:: Unlike `threading.Event`, `trio.Event` has no
|
||||
`~threading.Event.clear` method. In Trio, once an `Event` has happened,
|
||||
it cannot un-happen. If you need to represent a series of events,
|
||||
consider creating a new `Event` object for each one (they're cheap!),
|
||||
or other synchronization methods like :ref:`channels <channels>` or
|
||||
`trio.lowlevel.ParkingLot`.
|
||||
|
||||
"""
|
||||
|
||||
_tasks: set[Task] = attrs.field(factory=set, init=False)
|
||||
_flag: bool = attrs.field(default=False, init=False)
|
||||
|
||||
def is_set(self) -> bool:
|
||||
"""Return the current value of the internal flag."""
|
||||
return self._flag
|
||||
|
||||
@enable_ki_protection
|
||||
def set(self) -> None:
|
||||
"""Set the internal flag value to True, and wake any waiting tasks."""
|
||||
if not self._flag:
|
||||
self._flag = True
|
||||
for task in self._tasks:
|
||||
_core.reschedule(task)
|
||||
self._tasks.clear()
|
||||
|
||||
async def wait(self) -> None:
|
||||
"""Block until the internal flag value becomes True.
|
||||
|
||||
If it's already True, then this method returns immediately.
|
||||
|
||||
"""
|
||||
if self._flag:
|
||||
await trio.lowlevel.checkpoint()
|
||||
else:
|
||||
task = _core.current_task()
|
||||
self._tasks.add(task)
|
||||
|
||||
def abort_fn(_: RaiseCancelT) -> Abort:
|
||||
self._tasks.remove(task)
|
||||
return _core.Abort.SUCCEEDED
|
||||
|
||||
await _core.wait_task_rescheduled(abort_fn)
|
||||
|
||||
def statistics(self) -> EventStatistics:
|
||||
"""Return an object containing debugging information.
|
||||
|
||||
Currently the following fields are defined:
|
||||
|
||||
* ``tasks_waiting``: The number of tasks blocked on this event's
|
||||
:meth:`wait` method.
|
||||
|
||||
"""
|
||||
return EventStatistics(tasks_waiting=len(self._tasks))
|
||||
|
||||
|
||||
class _HasAcquireRelease(Protocol):
|
||||
"""Only classes with acquire() and release() can use the mixin's implementations."""
|
||||
|
||||
async def acquire(self) -> object: ...
|
||||
|
||||
def release(self) -> object: ...
|
||||
|
||||
|
||||
class AsyncContextManagerMixin:
|
||||
@enable_ki_protection
|
||||
async def __aenter__(self: _HasAcquireRelease) -> None:
|
||||
await self.acquire()
|
||||
|
||||
@enable_ki_protection
|
||||
async def __aexit__(
|
||||
self: _HasAcquireRelease,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.release()
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class CapacityLimiterStatistics:
|
||||
"""An object containing debugging information.
|
||||
|
||||
Currently the following fields are defined:
|
||||
|
||||
* ``borrowed_tokens``: The number of tokens currently borrowed from
|
||||
the sack.
|
||||
* ``total_tokens``: The total number of tokens in the sack. Usually
|
||||
this will be larger than ``borrowed_tokens``, but it's possibly for
|
||||
it to be smaller if :attr:`trio.CapacityLimiter.total_tokens` was recently decreased.
|
||||
* ``borrowers``: A list of all tasks or other entities that currently
|
||||
hold a token.
|
||||
* ``tasks_waiting``: The number of tasks blocked on this
|
||||
:class:`CapacityLimiter`\'s :meth:`trio.CapacityLimiter.acquire` or
|
||||
:meth:`trio.CapacityLimiter.acquire_on_behalf_of` methods.
|
||||
|
||||
"""
|
||||
|
||||
borrowed_tokens: int
|
||||
total_tokens: int | float
|
||||
borrowers: list[Task | object]
|
||||
tasks_waiting: int
|
||||
|
||||
|
||||
# Can be a generic type with a default of Task if/when PEP 696 is released
|
||||
# and implemented in type checkers. Making it fully generic would currently
|
||||
# introduce a lot of unnecessary hassle.
|
||||
@final
|
||||
class CapacityLimiter(AsyncContextManagerMixin):
|
||||
"""An object for controlling access to a resource with limited capacity.
|
||||
|
||||
Sometimes you need to put a limit on how many tasks can do something at
|
||||
the same time. For example, you might want to use some threads to run
|
||||
multiple blocking I/O operations in parallel... but if you use too many
|
||||
threads at once, then your system can become overloaded and it'll actually
|
||||
make things slower. One popular solution is to impose a policy like "run
|
||||
up to 40 threads at the same time, but no more". But how do you implement
|
||||
a policy like this?
|
||||
|
||||
That's what :class:`CapacityLimiter` is for. You can think of a
|
||||
:class:`CapacityLimiter` object as a sack that starts out holding some fixed
|
||||
number of tokens::
|
||||
|
||||
limit = trio.CapacityLimiter(40)
|
||||
|
||||
Then tasks can come along and borrow a token out of the sack::
|
||||
|
||||
# Borrow a token:
|
||||
async with limit:
|
||||
# We are holding a token!
|
||||
await perform_expensive_operation()
|
||||
# Exiting the 'async with' block puts the token back into the sack
|
||||
|
||||
And crucially, if you try to borrow a token but the sack is empty, then
|
||||
you have to wait for another task to finish what it's doing and put its
|
||||
token back first before you can take it and continue.
|
||||
|
||||
Another way to think of it: a :class:`CapacityLimiter` is like a sofa with a
|
||||
fixed number of seats, and if they're all taken then you have to wait for
|
||||
someone to get up before you can sit down.
|
||||
|
||||
By default, :func:`trio.to_thread.run_sync` uses a
|
||||
:class:`CapacityLimiter` to limit the number of threads running at once;
|
||||
see `trio.to_thread.current_default_thread_limiter` for details.
|
||||
|
||||
If you're familiar with semaphores, then you can think of this as a
|
||||
restricted semaphore that's specialized for one common use case, with
|
||||
additional error checking. For a more traditional semaphore, see
|
||||
:class:`Semaphore`.
|
||||
|
||||
.. note::
|
||||
|
||||
Don't confuse this with the `"leaky bucket"
|
||||
<https://en.wikipedia.org/wiki/Leaky_bucket>`__ or `"token bucket"
|
||||
<https://en.wikipedia.org/wiki/Token_bucket>`__ algorithms used to
|
||||
limit bandwidth usage on networks. The basic idea of using tokens to
|
||||
track a resource limit is similar, but this is a very simple sack where
|
||||
tokens aren't automatically created or destroyed over time; they're
|
||||
just borrowed and then put back.
|
||||
|
||||
"""
|
||||
|
||||
# total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing
|
||||
def __init__(self, total_tokens: int | float): # noqa: PYI041
|
||||
self._lot = ParkingLot()
|
||||
self._borrowers: set[Task | object] = set()
|
||||
# Maps tasks attempting to acquire -> borrower, to handle on-behalf-of
|
||||
self._pending_borrowers: dict[Task, Task | object] = {}
|
||||
# invoke the property setter for validation
|
||||
self.total_tokens: int | float = total_tokens
|
||||
assert self._total_tokens == total_tokens
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<trio.CapacityLimiter at {id(self):#x}, {len(self._borrowers)}/{self._total_tokens} with {len(self._lot)} waiting>"
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int | float:
|
||||
"""The total capacity available.
|
||||
|
||||
You can change :attr:`total_tokens` by assigning to this attribute. If
|
||||
you make it larger, then the appropriate number of waiting tasks will
|
||||
be woken immediately to take the new tokens. If you decrease
|
||||
total_tokens below the number of tasks that are currently using the
|
||||
resource, then all current tasks will be allowed to finish as normal,
|
||||
but no new tasks will be allowed in until the total number of tasks
|
||||
drops below the new total_tokens.
|
||||
|
||||
"""
|
||||
return self._total_tokens
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, new_total_tokens: int | float) -> None: # noqa: PYI041
|
||||
if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf:
|
||||
raise TypeError("total_tokens must be an int or math.inf")
|
||||
if new_total_tokens < 1:
|
||||
raise ValueError("total_tokens must be >= 1")
|
||||
self._total_tokens = new_total_tokens
|
||||
self._wake_waiters()
|
||||
|
||||
def _wake_waiters(self) -> None:
|
||||
available = self._total_tokens - len(self._borrowers)
|
||||
for woken in self._lot.unpark(count=available):
|
||||
self._borrowers.add(self._pending_borrowers.pop(woken))
|
||||
|
||||
@property
|
||||
def borrowed_tokens(self) -> int:
|
||||
"""The amount of capacity that's currently in use."""
|
||||
return len(self._borrowers)
|
||||
|
||||
@property
|
||||
def available_tokens(self) -> int | float:
|
||||
"""The amount of capacity that's available to use."""
|
||||
return self.total_tokens - self.borrowed_tokens
|
||||
|
||||
@enable_ki_protection
|
||||
def acquire_nowait(self) -> None:
|
||||
"""Borrow a token from the sack, without blocking.
|
||||
|
||||
Raises:
|
||||
WouldBlock: if no tokens are available.
|
||||
RuntimeError: if the current task already holds one of this sack's
|
||||
tokens.
|
||||
|
||||
"""
|
||||
self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task())
|
||||
|
||||
@enable_ki_protection
|
||||
def acquire_on_behalf_of_nowait(self, borrower: Task | object) -> None:
|
||||
"""Borrow a token from the sack on behalf of ``borrower``, without
|
||||
blocking.
|
||||
|
||||
Args:
|
||||
borrower: A :class:`trio.lowlevel.Task` or arbitrary opaque object
|
||||
used to record who is borrowing this token. This is used by
|
||||
:func:`trio.to_thread.run_sync` to allow threads to "hold
|
||||
tokens", with the intention in the future of using it to `allow
|
||||
deadlock detection and other useful things
|
||||
<https://github.com/python-trio/trio/issues/182>`__
|
||||
|
||||
Raises:
|
||||
WouldBlock: if no tokens are available.
|
||||
RuntimeError: if ``borrower`` already holds one of this sack's
|
||||
tokens.
|
||||
|
||||
"""
|
||||
if borrower in self._borrowers:
|
||||
raise RuntimeError(
|
||||
"this borrower is already holding one of this CapacityLimiter's tokens",
|
||||
)
|
||||
if len(self._borrowers) < self._total_tokens and not self._lot:
|
||||
self._borrowers.add(borrower)
|
||||
else:
|
||||
raise trio.WouldBlock
|
||||
|
||||
@enable_ki_protection
|
||||
async def acquire(self) -> None:
|
||||
"""Borrow a token from the sack, blocking if necessary.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the current task already holds one of this sack's
|
||||
tokens.
|
||||
|
||||
"""
|
||||
await self.acquire_on_behalf_of(trio.lowlevel.current_task())
|
||||
|
||||
@enable_ki_protection
|
||||
async def acquire_on_behalf_of(self, borrower: Task | object) -> None:
|
||||
"""Borrow a token from the sack on behalf of ``borrower``, blocking if
|
||||
necessary.
|
||||
|
||||
Args:
|
||||
borrower: A :class:`trio.lowlevel.Task` or arbitrary opaque object
|
||||
used to record who is borrowing this token; see
|
||||
:meth:`acquire_on_behalf_of_nowait` for details.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``borrower`` task already holds one of this sack's
|
||||
tokens.
|
||||
|
||||
"""
|
||||
await trio.lowlevel.checkpoint_if_cancelled()
|
||||
try:
|
||||
self.acquire_on_behalf_of_nowait(borrower)
|
||||
except trio.WouldBlock:
|
||||
task = trio.lowlevel.current_task()
|
||||
self._pending_borrowers[task] = borrower
|
||||
try:
|
||||
await self._lot.park()
|
||||
except trio.Cancelled:
|
||||
self._pending_borrowers.pop(task)
|
||||
raise
|
||||
else:
|
||||
await trio.lowlevel.cancel_shielded_checkpoint()
|
||||
|
||||
@enable_ki_protection
|
||||
def release(self) -> None:
|
||||
"""Put a token back into the sack.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the current task has not acquired one of this
|
||||
sack's tokens.
|
||||
|
||||
"""
|
||||
self.release_on_behalf_of(trio.lowlevel.current_task())
|
||||
|
||||
@enable_ki_protection
|
||||
def release_on_behalf_of(self, borrower: Task | object) -> None:
|
||||
"""Put a token back into the sack on behalf of ``borrower``.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the given borrower has not acquired one of this
|
||||
sack's tokens.
|
||||
|
||||
"""
|
||||
if borrower not in self._borrowers:
|
||||
raise RuntimeError(
|
||||
"this borrower isn't holding any of this CapacityLimiter's tokens",
|
||||
)
|
||||
self._borrowers.remove(borrower)
|
||||
self._wake_waiters()
|
||||
|
||||
def statistics(self) -> CapacityLimiterStatistics:
|
||||
"""Return an object containing debugging information.
|
||||
|
||||
Currently the following fields are defined:
|
||||
|
||||
* ``borrowed_tokens``: The number of tokens currently borrowed from
|
||||
the sack.
|
||||
* ``total_tokens``: The total number of tokens in the sack. Usually
|
||||
this will be larger than ``borrowed_tokens``, but it's possibly for
|
||||
it to be smaller if :attr:`total_tokens` was recently decreased.
|
||||
* ``borrowers``: A list of all tasks or other entities that currently
|
||||
hold a token.
|
||||
* ``tasks_waiting``: The number of tasks blocked on this
|
||||
:class:`CapacityLimiter`\'s :meth:`acquire` or
|
||||
:meth:`acquire_on_behalf_of` methods.
|
||||
|
||||
"""
|
||||
return CapacityLimiterStatistics(
|
||||
borrowed_tokens=len(self._borrowers),
|
||||
total_tokens=self._total_tokens,
|
||||
# Use a list instead of a frozenset just in case we start to allow
|
||||
# one borrower to hold multiple tokens in the future
|
||||
borrowers=list(self._borrowers),
|
||||
tasks_waiting=len(self._lot),
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class Semaphore(AsyncContextManagerMixin):
|
||||
"""A `semaphore <https://en.wikipedia.org/wiki/Semaphore_(programming)>`__.
|
||||
|
||||
A semaphore holds an integer value, which can be incremented by
|
||||
calling :meth:`release` and decremented by calling :meth:`acquire` – but
|
||||
the value is never allowed to drop below zero. If the value is zero, then
|
||||
:meth:`acquire` will block until someone calls :meth:`release`.
|
||||
|
||||
If you're looking for a :class:`Semaphore` to limit the number of tasks
|
||||
that can access some resource simultaneously, then consider using a
|
||||
:class:`CapacityLimiter` instead.
|
||||
|
||||
This object's interface is similar to, but different from, that of
|
||||
:class:`threading.Semaphore`.
|
||||
|
||||
A :class:`Semaphore` object can be used as an async context manager; it
|
||||
blocks on entry but not on exit.
|
||||
|
||||
Args:
|
||||
initial_value (int): A non-negative integer giving semaphore's initial
|
||||
value.
|
||||
max_value (int or None): If given, makes this a "bounded" semaphore that
|
||||
raises an error if the value is about to exceed the given
|
||||
``max_value``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, initial_value: int, *, max_value: int | None = None):
|
||||
if not isinstance(initial_value, int):
|
||||
raise TypeError("initial_value must be an int")
|
||||
if initial_value < 0:
|
||||
raise ValueError("initial value must be >= 0")
|
||||
if max_value is not None:
|
||||
if not isinstance(max_value, int):
|
||||
raise TypeError("max_value must be None or an int")
|
||||
if max_value < initial_value:
|
||||
raise ValueError("max_values must be >= initial_value")
|
||||
|
||||
# Invariants:
|
||||
# bool(self._lot) implies self._value == 0
|
||||
# (or equivalently: self._value > 0 implies not self._lot)
|
||||
self._lot = trio.lowlevel.ParkingLot()
|
||||
self._value = initial_value
|
||||
self._max_value = max_value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self._max_value is None:
|
||||
max_value_str = ""
|
||||
else:
|
||||
max_value_str = f", max_value={self._max_value}"
|
||||
return f"<trio.Semaphore({self._value}{max_value_str}) at {id(self):#x}>"
|
||||
|
||||
@property
|
||||
def value(self) -> int:
|
||||
"""The current value of the semaphore."""
|
||||
return self._value
|
||||
|
||||
@property
|
||||
def max_value(self) -> int | None:
|
||||
"""The maximum allowed value. May be None to indicate no limit."""
|
||||
return self._max_value
|
||||
|
||||
@enable_ki_protection
|
||||
def acquire_nowait(self) -> None:
|
||||
"""Attempt to decrement the semaphore value, without blocking.
|
||||
|
||||
Raises:
|
||||
WouldBlock: if the value is zero.
|
||||
|
||||
"""
|
||||
if self._value > 0:
|
||||
assert not self._lot
|
||||
self._value -= 1
|
||||
else:
|
||||
raise trio.WouldBlock
|
||||
|
||||
@enable_ki_protection
|
||||
async def acquire(self) -> None:
|
||||
"""Decrement the semaphore value, blocking if necessary to avoid
|
||||
letting it drop below zero.
|
||||
|
||||
"""
|
||||
await trio.lowlevel.checkpoint_if_cancelled()
|
||||
try:
|
||||
self.acquire_nowait()
|
||||
except trio.WouldBlock:
|
||||
await self._lot.park()
|
||||
else:
|
||||
await trio.lowlevel.cancel_shielded_checkpoint()
|
||||
|
||||
@enable_ki_protection
|
||||
def release(self) -> None:
|
||||
"""Increment the semaphore value, possibly waking a task blocked in
|
||||
:meth:`acquire`.
|
||||
|
||||
Raises:
|
||||
ValueError: if incrementing the value would cause it to exceed
|
||||
:attr:`max_value`.
|
||||
|
||||
"""
|
||||
if self._lot:
|
||||
assert self._value == 0
|
||||
self._lot.unpark(count=1)
|
||||
else:
|
||||
if self._max_value is not None and self._value == self._max_value:
|
||||
raise ValueError("semaphore released too many times")
|
||||
self._value += 1
|
||||
|
||||
def statistics(self) -> ParkingLotStatistics:
|
||||
"""Return an object containing debugging information.
|
||||
|
||||
Currently the following fields are defined:
|
||||
|
||||
* ``tasks_waiting``: The number of tasks blocked on this semaphore's
|
||||
:meth:`acquire` method.
|
||||
|
||||
"""
|
||||
return self._lot.statistics()
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class LockStatistics:
|
||||
"""An object containing debugging information for a Lock.
|
||||
|
||||
Currently the following fields are defined:
|
||||
|
||||
* ``locked`` (boolean): indicating whether the lock is held.
|
||||
* ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock,
|
||||
or None if the lock is not held.
|
||||
* ``tasks_waiting`` (int): The number of tasks blocked on this lock's
|
||||
:meth:`trio.Lock.acquire` method.
|
||||
|
||||
"""
|
||||
|
||||
locked: bool
|
||||
owner: Task | None
|
||||
tasks_waiting: int
|
||||
|
||||
|
||||
@attrs.define(eq=False, repr=False, slots=False)
|
||||
class _LockImpl(AsyncContextManagerMixin):
|
||||
_lot: ParkingLot = attrs.field(factory=ParkingLot, init=False)
|
||||
_owner: Task | None = attrs.field(default=None, init=False)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.locked():
|
||||
s1 = "locked"
|
||||
s2 = f" with {len(self._lot)} waiters"
|
||||
else:
|
||||
s1 = "unlocked"
|
||||
s2 = ""
|
||||
return f"<{s1} {self.__class__.__name__} object at {id(self):#x}{s2}>"
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Check whether the lock is currently held.
|
||||
|
||||
Returns:
|
||||
bool: True if the lock is held, False otherwise.
|
||||
|
||||
"""
|
||||
return self._owner is not None
|
||||
|
||||
@enable_ki_protection
|
||||
def acquire_nowait(self) -> None:
|
||||
"""Attempt to acquire the lock, without blocking.
|
||||
|
||||
Raises:
|
||||
WouldBlock: if the lock is held.
|
||||
|
||||
"""
|
||||
|
||||
task = trio.lowlevel.current_task()
|
||||
if self._owner is task:
|
||||
raise RuntimeError("attempt to re-acquire an already held Lock")
|
||||
elif self._owner is None and not self._lot:
|
||||
# No-one owns it
|
||||
self._owner = task
|
||||
add_parking_lot_breaker(task, self._lot)
|
||||
else:
|
||||
raise trio.WouldBlock
|
||||
|
||||
@enable_ki_protection
|
||||
async def acquire(self) -> None:
|
||||
"""Acquire the lock, blocking if necessary.
|
||||
|
||||
Raises:
|
||||
BrokenResourceError: if the owner of the lock exits without releasing.
|
||||
"""
|
||||
await trio.lowlevel.checkpoint_if_cancelled()
|
||||
try:
|
||||
self.acquire_nowait()
|
||||
except trio.WouldBlock:
|
||||
try:
|
||||
# NOTE: it's important that the contended acquire path is just
|
||||
# "_lot.park()", because that's how Condition.wait() acquires the
|
||||
# lock as well.
|
||||
await self._lot.park()
|
||||
except trio.BrokenResourceError:
|
||||
raise trio.BrokenResourceError(
|
||||
f"Owner of this lock exited without releasing: {self._owner}",
|
||||
) from None
|
||||
else:
|
||||
await trio.lowlevel.cancel_shielded_checkpoint()
|
||||
|
||||
@enable_ki_protection
|
||||
def release(self) -> None:
|
||||
"""Release the lock.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the calling task does not hold the lock.
|
||||
|
||||
"""
|
||||
task = trio.lowlevel.current_task()
|
||||
if task is not self._owner:
|
||||
raise RuntimeError("can't release a Lock you don't own")
|
||||
remove_parking_lot_breaker(self._owner, self._lot)
|
||||
if self._lot:
|
||||
(self._owner,) = self._lot.unpark(count=1)
|
||||
add_parking_lot_breaker(self._owner, self._lot)
|
||||
else:
|
||||
self._owner = None
|
||||
|
||||
def statistics(self) -> LockStatistics:
|
||||
"""Return an object containing debugging information.
|
||||
|
||||
Currently the following fields are defined:
|
||||
|
||||
* ``locked``: boolean indicating whether the lock is held.
|
||||
* ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock,
|
||||
or None if the lock is not held.
|
||||
* ``tasks_waiting``: The number of tasks blocked on this lock's
|
||||
:meth:`acquire` method.
|
||||
|
||||
"""
|
||||
return LockStatistics(
|
||||
locked=self.locked(),
|
||||
owner=self._owner,
|
||||
tasks_waiting=len(self._lot),
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class Lock(_LockImpl):
|
||||
"""A classic `mutex
|
||||
<https://en.wikipedia.org/wiki/Lock_(computer_science)>`__.
|
||||
|
||||
This is a non-reentrant, single-owner lock. Unlike
|
||||
:class:`threading.Lock`, only the owner of the lock is allowed to release
|
||||
it.
|
||||
|
||||
A :class:`Lock` object can be used as an async context manager; it
|
||||
blocks on entry but not on exit.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@final
|
||||
class StrictFIFOLock(_LockImpl):
|
||||
r"""A variant of :class:`Lock` where tasks are guaranteed to acquire the
|
||||
lock in strict first-come-first-served order.
|
||||
|
||||
An example of when this is useful is if you're implementing something like
|
||||
:class:`trio.SSLStream` or an HTTP/2 server using `h2
|
||||
<https://hyper-h2.readthedocs.io/>`__, where you have multiple concurrent
|
||||
tasks that are interacting with a shared state machine, and at
|
||||
unpredictable moments the state machine requests that a chunk of data be
|
||||
sent over the network. (For example, when using h2 simply reading incoming
|
||||
data can occasionally `create outgoing data to send
|
||||
<https://http2.github.io/http2-spec/#PING>`__.) The challenge is to make
|
||||
sure that these chunks are sent in the correct order, without being
|
||||
garbled.
|
||||
|
||||
One option would be to use a regular :class:`Lock`, and wrap it around
|
||||
every interaction with the state machine::
|
||||
|
||||
# This approach is sometimes workable but often sub-optimal; see below
|
||||
async with lock:
|
||||
state_machine.do_something()
|
||||
if state_machine.has_data_to_send():
|
||||
await conn.sendall(state_machine.get_data_to_send())
|
||||
|
||||
But this can be problematic. If you're using h2 then *usually* reading
|
||||
incoming data doesn't create the need to send any data, so we don't want
|
||||
to force every task that tries to read from the network to sit and wait
|
||||
a potentially long time for ``sendall`` to finish. And in some situations
|
||||
this could even potentially cause a deadlock, if the remote peer is
|
||||
waiting for you to read some data before it accepts the data you're
|
||||
sending.
|
||||
|
||||
:class:`StrictFIFOLock` provides an alternative. We can rewrite our
|
||||
example like::
|
||||
|
||||
# Note: no awaits between when we start using the state machine and
|
||||
# when we block to take the lock!
|
||||
state_machine.do_something()
|
||||
if state_machine.has_data_to_send():
|
||||
# Notice that we fetch the data to send out of the state machine
|
||||
# *before* sleeping, so that other tasks won't see it.
|
||||
chunk = state_machine.get_data_to_send()
|
||||
async with strict_fifo_lock:
|
||||
await conn.sendall(chunk)
|
||||
|
||||
First we do all our interaction with the state machine in a single
|
||||
scheduling quantum (notice there are no ``await``\s in there), so it's
|
||||
automatically atomic with respect to other tasks. And then if and only if
|
||||
we have data to send, we get in line to send it – and
|
||||
:class:`StrictFIFOLock` guarantees that each task will send its data in
|
||||
the same order that the state machine generated it.
|
||||
|
||||
Currently, :class:`StrictFIFOLock` is identical to :class:`Lock`,
|
||||
but (a) this may not always be true in the future, especially if Trio ever
|
||||
implements `more sophisticated scheduling policies
|
||||
<https://github.com/python-trio/trio/issues/32>`__, and (b) the above code
|
||||
is relying on a pretty subtle property of its lock. Using a
|
||||
:class:`StrictFIFOLock` acts as an executable reminder that you're relying
|
||||
on this property.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class ConditionStatistics:
|
||||
r"""An object containing debugging information for a Condition.
|
||||
|
||||
Currently the following fields are defined:
|
||||
|
||||
* ``tasks_waiting`` (int): The number of tasks blocked on this condition's
|
||||
:meth:`trio.Condition.wait` method.
|
||||
* ``lock_statistics``: The result of calling the underlying
|
||||
:class:`Lock`\s :meth:`~Lock.statistics` method.
|
||||
|
||||
"""
|
||||
|
||||
tasks_waiting: int
|
||||
lock_statistics: LockStatistics
|
||||
|
||||
|
||||
@final
|
||||
class Condition(AsyncContextManagerMixin):
|
||||
"""A classic `condition variable
|
||||
<https://en.wikipedia.org/wiki/Monitor_(synchronization)>`__, similar to
|
||||
:class:`threading.Condition`.
|
||||
|
||||
A :class:`Condition` object can be used as an async context manager to
|
||||
acquire the underlying lock; it blocks on entry but not on exit.
|
||||
|
||||
Args:
|
||||
lock (Lock): the lock object to use. If given, must be a
|
||||
:class:`trio.Lock`. If None, a new :class:`Lock` will be allocated
|
||||
and used.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, lock: Lock | None = None):
|
||||
if lock is None:
|
||||
lock = Lock()
|
||||
if type(lock) is not Lock:
|
||||
raise TypeError("lock must be a trio.Lock")
|
||||
self._lock = lock
|
||||
self._lot = trio.lowlevel.ParkingLot()
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Check whether the underlying lock is currently held.
|
||||
|
||||
Returns:
|
||||
bool: True if the lock is held, False otherwise.
|
||||
|
||||
"""
|
||||
return self._lock.locked()
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
"""Attempt to acquire the underlying lock, without blocking.
|
||||
|
||||
Raises:
|
||||
WouldBlock: if the lock is currently held.
|
||||
|
||||
"""
|
||||
return self._lock.acquire_nowait()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Acquire the underlying lock, blocking if necessary.
|
||||
|
||||
Raises:
|
||||
BrokenResourceError: if the owner of the underlying lock exits without releasing.
|
||||
"""
|
||||
await self._lock.acquire()
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release the underlying lock."""
|
||||
self._lock.release()
|
||||
|
||||
@enable_ki_protection
|
||||
async def wait(self) -> None:
|
||||
"""Wait for another task to call :meth:`notify` or
|
||||
:meth:`notify_all`.
|
||||
|
||||
When calling this method, you must hold the lock. It releases the lock
|
||||
while waiting, and then re-acquires it before waking up.
|
||||
|
||||
There is a subtlety with how this method interacts with cancellation:
|
||||
when cancelled it will block to re-acquire the lock before raising
|
||||
:exc:`Cancelled`. This may cause cancellation to be less prompt than
|
||||
expected. The advantage is that it makes code like this work::
|
||||
|
||||
async with condition:
|
||||
await condition.wait()
|
||||
|
||||
If we didn't re-acquire the lock before waking up, and :meth:`wait`
|
||||
were cancelled here, then we'd crash in ``condition.__aexit__`` when
|
||||
we tried to release the lock we no longer held.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the calling task does not hold the lock.
|
||||
BrokenResourceError: if the owner of the lock exits without releasing, when attempting to re-acquire.
|
||||
|
||||
"""
|
||||
if trio.lowlevel.current_task() is not self._lock._owner:
|
||||
raise RuntimeError("must hold the lock to wait")
|
||||
self.release()
|
||||
# NOTE: we go to sleep on self._lot, but we'll wake up on
|
||||
# self._lock._lot. That's all that's required to acquire a Lock.
|
||||
try:
|
||||
await self._lot.park()
|
||||
except:
|
||||
with trio.CancelScope(shield=True):
|
||||
await self.acquire()
|
||||
raise
|
||||
|
||||
def notify(self, n: int = 1) -> None:
|
||||
"""Wake one or more tasks that are blocked in :meth:`wait`.
|
||||
|
||||
Args:
|
||||
n (int): The number of tasks to wake.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the calling task does not hold the lock.
|
||||
|
||||
"""
|
||||
if trio.lowlevel.current_task() is not self._lock._owner:
|
||||
raise RuntimeError("must hold the lock to notify")
|
||||
self._lot.repark(self._lock._lot, count=n)
|
||||
|
||||
def notify_all(self) -> None:
|
||||
"""Wake all tasks that are currently blocked in :meth:`wait`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the calling task does not hold the lock.
|
||||
|
||||
"""
|
||||
if trio.lowlevel.current_task() is not self._lock._owner:
|
||||
raise RuntimeError("must hold the lock to notify")
|
||||
self._lot.repark_all(self._lock._lot)
|
||||
|
||||
def statistics(self) -> ConditionStatistics:
|
||||
r"""Return an object containing debugging information.
|
||||
|
||||
Currently the following fields are defined:
|
||||
|
||||
* ``tasks_waiting``: The number of tasks blocked on this condition's
|
||||
:meth:`wait` method.
|
||||
* ``lock_statistics``: The result of calling the underlying
|
||||
:class:`Lock`\s :meth:`~Lock.statistics` method.
|
||||
|
||||
"""
|
||||
return ConditionStatistics(
|
||||
tasks_waiting=len(self._lot),
|
||||
lock_statistics=self._lock.statistics(),
|
||||
)
|
||||
@@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python3
|
||||
"""This is a file that wraps calls to `pyright --verifytypes`, achieving two things:
|
||||
1. give an error if docstrings are missing.
|
||||
pyright will give a number of missing docstrings, and error messages, but not exit with a non-zero value.
|
||||
2. filter out specific errors we don't care about.
|
||||
this is largely due to 1, but also because Trio does some very complex stuff and --verifytypes has few to no ways of ignoring specific errors.
|
||||
|
||||
If this check is giving you false alarms, you can ignore them by adding logic to `has_docstring_at_runtime`, in the main loop in `check_type`, or by updating the json file.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
# this file is not run as part of the tests, instead it's run standalone from check.sh
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
|
||||
# not needed if everything is working, but if somebody does something to generate
|
||||
# tons of errors, we can be nice and stop them from getting 3*tons of output
|
||||
printed_diagnostics: set[str] = set()
|
||||
|
||||
|
||||
# TODO: consider checking manually without `--ignoreexternal`, and/or
|
||||
# removing it from the below call later on.
|
||||
def run_pyright(platform: str) -> subprocess.CompletedProcess[bytes]:
|
||||
return subprocess.run(
|
||||
[
|
||||
"pyright",
|
||||
# Specify a platform and version to keep imported modules consistent.
|
||||
f"--pythonplatform={platform}",
|
||||
"--pythonversion=3.8",
|
||||
"--verifytypes=trio",
|
||||
"--outputjson",
|
||||
"--ignoreexternal",
|
||||
],
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
|
||||
def has_docstring_at_runtime(name: str) -> bool:
|
||||
"""Pyright gives us an object identifier of xx.yy.zz
|
||||
This function tries to decompose that into its constituent parts, such that we
|
||||
can resolve it, in order to check whether it has a `__doc__` at runtime and
|
||||
verifytypes misses it because we're doing overly fancy stuff.
|
||||
"""
|
||||
# This assert is solely for stopping isort from removing our imports of trio & trio.testing
|
||||
# It could also be done with isort:skip, but that'd also disable import sorting and the like.
|
||||
assert trio.testing
|
||||
|
||||
# figure out what part of the name is the module, so we can "import" it
|
||||
name_parts = name.split(".")
|
||||
assert name_parts[0] == "trio"
|
||||
if name_parts[1] == "tests":
|
||||
return True
|
||||
|
||||
# traverse down the remaining identifiers with getattr
|
||||
obj = trio
|
||||
try:
|
||||
for obj_name in name_parts[1:]:
|
||||
obj = getattr(obj, obj_name)
|
||||
except AttributeError as exc:
|
||||
# asynciowrapper does funky getattr stuff
|
||||
if "AsyncIOWrapper" in str(exc) or name in (
|
||||
# Symbols not existing on all platforms, so we can't dynamically inspect them.
|
||||
# Manually confirmed to have docstrings but pyright doesn't see them due to
|
||||
# export shenanigans. TODO: actually manually confirm that.
|
||||
# In theory we could verify these at runtime, probably by running the script separately
|
||||
# on separate platforms. It might also be a decent idea to work the other way around,
|
||||
# a la test_static_tool_sees_class_members
|
||||
# darwin
|
||||
"trio.lowlevel.current_kqueue",
|
||||
"trio.lowlevel.monitor_kevent",
|
||||
"trio.lowlevel.wait_kevent",
|
||||
"trio._core._io_kqueue._KqueueStatistics",
|
||||
# windows
|
||||
"trio._socket.SocketType.share",
|
||||
"trio._core._io_windows._WindowsStatistics",
|
||||
"trio._core._windows_cffi.Handle",
|
||||
"trio.lowlevel.current_iocp",
|
||||
"trio.lowlevel.monitor_completion_key",
|
||||
"trio.lowlevel.readinto_overlapped",
|
||||
"trio.lowlevel.register_with_iocp",
|
||||
"trio.lowlevel.wait_overlapped",
|
||||
"trio.lowlevel.write_overlapped",
|
||||
"trio.lowlevel.WaitForSingleObject",
|
||||
"trio.socket.fromshare",
|
||||
# linux
|
||||
# this test will fail on linux, but I don't develop on linux. So the next
|
||||
# person to do so is very welcome to open a pull request and populate with
|
||||
# objects
|
||||
# TODO: these are erroring on all platforms, why?
|
||||
"trio._highlevel_generic.StapledStream.send_stream",
|
||||
"trio._highlevel_generic.StapledStream.receive_stream",
|
||||
"trio._ssl.SSLStream.transport_stream",
|
||||
"trio._file_io._HasFileNo",
|
||||
"trio._file_io._HasFileNo.fileno",
|
||||
):
|
||||
return True
|
||||
|
||||
else:
|
||||
print(
|
||||
f"Pyright sees {name} at runtime, but unable to getattr({obj.__name__}, {obj_name}).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return False
|
||||
return bool(obj.__doc__)
|
||||
|
||||
|
||||
def check_type(
|
||||
platform: str,
|
||||
full_diagnostics_file: Path | None,
|
||||
expected_errors: list[object],
|
||||
) -> list[object]:
|
||||
# convince isort we use the trio import
|
||||
assert trio
|
||||
|
||||
# run pyright, load output into json
|
||||
res = run_pyright(platform)
|
||||
current_result = json.loads(res.stdout)
|
||||
|
||||
if res.stderr:
|
||||
print(res.stderr, file=sys.stderr)
|
||||
|
||||
if full_diagnostics_file:
|
||||
with open(full_diagnostics_file, "a") as f:
|
||||
json.dump(current_result, f, sort_keys=True, indent=4)
|
||||
|
||||
errors = []
|
||||
|
||||
for symbol in current_result["typeCompleteness"]["symbols"]:
|
||||
diagnostics = symbol["diagnostics"]
|
||||
name = symbol["name"]
|
||||
for diagnostic in diagnostics:
|
||||
message = diagnostic["message"]
|
||||
if name in (
|
||||
"trio._path.PosixPath",
|
||||
"trio._path.WindowsPath",
|
||||
) and message.startswith("Type of base class "):
|
||||
continue
|
||||
|
||||
if name.startswith("trio._path.Path"):
|
||||
if message.startswith("No docstring found for"):
|
||||
continue
|
||||
if message.startswith(
|
||||
"Type is missing type annotation and could be inferred differently by type checkers",
|
||||
):
|
||||
continue
|
||||
|
||||
# ignore errors about missing docstrings if they're available at runtime
|
||||
if message.startswith("No docstring found for"):
|
||||
if has_docstring_at_runtime(symbol["name"]):
|
||||
continue
|
||||
else:
|
||||
# Missing docstring messages include the name of the object.
|
||||
# Other errors don't, so we add it.
|
||||
message = f"{name}: {message}"
|
||||
if message not in expected_errors and message not in printed_diagnostics:
|
||||
print(f"new error: {message}", file=sys.stderr)
|
||||
errors.append(message)
|
||||
printed_diagnostics.add(message)
|
||||
|
||||
continue
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> int:
|
||||
if args.full_diagnostics_file:
|
||||
full_diagnostics_file = Path(args.full_diagnostics_file)
|
||||
full_diagnostics_file.write_text("")
|
||||
else:
|
||||
full_diagnostics_file = None
|
||||
|
||||
errors_by_platform_file = Path(__file__).parent / "_check_type_completeness.json"
|
||||
if errors_by_platform_file.exists():
|
||||
with open(errors_by_platform_file) as f:
|
||||
errors_by_platform = json.load(f)
|
||||
else:
|
||||
errors_by_platform = {"Linux": [], "Windows": [], "Darwin": [], "all": []}
|
||||
|
||||
changed = False
|
||||
for platform in "Linux", "Windows", "Darwin":
|
||||
platform_errors = errors_by_platform[platform] + errors_by_platform["all"]
|
||||
print("*" * 20, f"\nChecking {platform}...")
|
||||
errors = check_type(platform, full_diagnostics_file, platform_errors)
|
||||
|
||||
new_errors = [e for e in errors if e not in platform_errors]
|
||||
missing_errors = [e for e in platform_errors if e not in errors]
|
||||
|
||||
if new_errors:
|
||||
print(
|
||||
f"New errors introduced in `pyright --verifytypes`. Fix them, or ignore them by modifying {errors_by_platform_file}, either manually or with '--overwrite-file'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
changed = True
|
||||
if missing_errors:
|
||||
print(
|
||||
f"Congratulations, you have resolved existing errors! Please remove them from {errors_by_platform_file}, either manually or with '--overwrite-file'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
changed = True
|
||||
print(missing_errors, file=sys.stderr)
|
||||
|
||||
errors_by_platform[platform] = errors
|
||||
print("*" * 20)
|
||||
|
||||
# cut down the size of the json file by a lot, and make it easier to parse for
|
||||
# humans, by moving errors that appear on all platforms to a separate category
|
||||
errors_by_platform["all"] = []
|
||||
for e in errors_by_platform["Linux"].copy():
|
||||
if e in errors_by_platform["Darwin"] and e in errors_by_platform["Windows"]:
|
||||
for platform in "Linux", "Windows", "Darwin":
|
||||
errors_by_platform[platform].remove(e)
|
||||
errors_by_platform["all"].append(e)
|
||||
|
||||
if changed and args.overwrite_file:
|
||||
with open(errors_by_platform_file, "w") as f:
|
||||
json.dump(errors_by_platform, f, indent=4, sort_keys=True)
|
||||
# newline at end of file
|
||||
f.write("\n")
|
||||
|
||||
# True -> 1 -> non-zero exit value -> error
|
||||
return changed
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--overwrite-file",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use this flag to overwrite the current stored results. Either in CI together with a diff check, or to avoid having to manually correct it.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-diagnostics-file",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Use this for debugging, it will dump the output of all three pyright runs by platform into this file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert __name__ == "__main__", "This script should be run standalone"
|
||||
sys.exit(main(args))
|
||||
@@ -0,0 +1,24 @@
|
||||
regular = "hi"
|
||||
|
||||
from .. import _deprecate
|
||||
|
||||
_deprecate.enable_attribute_deprecations(__name__)
|
||||
|
||||
# Make sure that we don't trigger infinite recursion when accessing module
|
||||
# attributes in between calling enable_attribute_deprecations and defining
|
||||
# __deprecated_attributes__:
|
||||
import sys
|
||||
|
||||
this_mod = sys.modules[__name__]
|
||||
assert this_mod.regular == "hi"
|
||||
assert not hasattr(this_mod, "dep1")
|
||||
|
||||
__deprecated_attributes__ = {
|
||||
"dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1),
|
||||
"dep2": _deprecate.DeprecatedAttribute(
|
||||
"value2",
|
||||
"1.2",
|
||||
issue=1,
|
||||
instead="instead-string",
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import NoReturn
|
||||
|
||||
import pytest
|
||||
|
||||
from ..testing import MockClock, trio_test
|
||||
|
||||
RUN_SLOW = True
|
||||
SKIP_OPTIONAL_IMPORTS = False
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
parser.addoption("--run-slow", action="store_true", help="run slow tests")
|
||||
parser.addoption(
|
||||
"--skip-optional-imports",
|
||||
action="store_true",
|
||||
help="skip tests that rely on libraries not required by trio itself",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
global RUN_SLOW
|
||||
RUN_SLOW = config.getoption("--run-slow", default=True)
|
||||
global SKIP_OPTIONAL_IMPORTS
|
||||
SKIP_OPTIONAL_IMPORTS = config.getoption("--skip-optional-imports", default=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clock() -> MockClock:
|
||||
return MockClock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def autojump_clock() -> MockClock:
|
||||
return MockClock(autojump_threshold=0)
|
||||
|
||||
|
||||
# FIXME: split off into a package (or just make part of Trio's public
|
||||
# interface?), with config file to enable? and I guess a mark option too; I
|
||||
# guess it's useful with the class- and file-level marking machinery (where
|
||||
# the raw @trio_test decorator isn't enough).
|
||||
@pytest.hookimpl(tryfirst=True)
|
||||
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> None:
|
||||
if inspect.iscoroutinefunction(pyfuncitem.obj):
|
||||
pyfuncitem.obj = trio_test(pyfuncitem.obj)
|
||||
|
||||
|
||||
def skip_if_optional_else_raise(error: ImportError) -> NoReturn:
|
||||
if SKIP_OPTIONAL_IMPORTS:
|
||||
pytest.skip(error.msg, allow_module_level=True)
|
||||
else: # pragma: no cover
|
||||
raise error
|
||||
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from .. import abc as tabc
|
||||
from ..lowlevel import Task
|
||||
|
||||
|
||||
def test_instrument_implements_hook_methods() -> None:
|
||||
attrs = {
|
||||
"before_run": (),
|
||||
"after_run": (),
|
||||
"task_spawned": (Task,),
|
||||
"task_scheduled": (Task,),
|
||||
"before_task_step": (Task,),
|
||||
"after_task_step": (Task,),
|
||||
"task_exited": (Task,),
|
||||
"before_io_wait": (3.3,),
|
||||
"after_io_wait": (3.3,),
|
||||
}
|
||||
|
||||
mayonnaise = tabc.Instrument()
|
||||
|
||||
for method_name, args in attrs.items():
|
||||
assert hasattr(mayonnaise, method_name)
|
||||
method = getattr(mayonnaise, method_name)
|
||||
assert callable(method)
|
||||
method(*args)
|
||||
|
||||
|
||||
async def test_AsyncResource_defaults() -> None:
|
||||
@attrs.define(slots=False)
|
||||
class MyAR(tabc.AsyncResource):
|
||||
record: list[str] = attrs.Factory(list)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.record.append("ac")
|
||||
|
||||
async with MyAR() as myar:
|
||||
assert isinstance(myar, MyAR)
|
||||
assert myar.record == []
|
||||
|
||||
assert myar.record == ["ac"]
|
||||
|
||||
|
||||
def test_abc_generics() -> None:
|
||||
# Pythons below 3.5.2 had a typing.Generic that would throw
|
||||
# errors when instantiating or subclassing a parameterized
|
||||
# version of a class with any __slots__. This is why RunVar
|
||||
# (which has slots) is not generic. This tests that
|
||||
# the generic ABCs are fine, because while they are slotted
|
||||
# they don't actually define any slots.
|
||||
|
||||
class SlottedChannel(tabc.SendChannel[tabc.Stream]):
|
||||
__slots__ = ("x",)
|
||||
|
||||
def send_nowait(self, value: object) -> None:
|
||||
raise RuntimeError
|
||||
|
||||
async def send(self, value: object) -> None:
|
||||
raise RuntimeError # pragma: no cover
|
||||
|
||||
def clone(self) -> None:
|
||||
raise RuntimeError # pragma: no cover
|
||||
|
||||
async def aclose(self) -> None:
|
||||
pass # pragma: no cover
|
||||
|
||||
channel = SlottedChannel()
|
||||
with pytest.raises(RuntimeError):
|
||||
channel.send_nowait(None)
|
||||
@@ -0,0 +1,413 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio import EndOfChannel, open_memory_channel
|
||||
|
||||
from ..testing import assert_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
|
||||
async def test_channel() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
open_memory_channel(1.0)
|
||||
with pytest.raises(ValueError, match="^max_buffer_size must be >= 0$"):
|
||||
open_memory_channel(-1)
|
||||
|
||||
s, r = open_memory_channel[Union[int, str, None]](2)
|
||||
repr(s) # smoke test
|
||||
repr(r) # smoke test
|
||||
|
||||
s.send_nowait(1)
|
||||
with assert_checkpoints():
|
||||
await s.send(2)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(None)
|
||||
|
||||
with assert_checkpoints():
|
||||
assert await r.receive() == 1
|
||||
assert r.receive_nowait() == 2
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
|
||||
s.send_nowait("last")
|
||||
await s.aclose()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send("too late")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait("too late")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.clone()
|
||||
await s.aclose()
|
||||
|
||||
assert r.receive_nowait() == "last"
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
await r.aclose()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.receive_nowait()
|
||||
await r.aclose()
|
||||
|
||||
|
||||
async def test_553(autojump_clock: trio.abc.Clock) -> None:
|
||||
s, r = open_memory_channel[str](1)
|
||||
with trio.move_on_after(10) as timeout_scope:
|
||||
await r.receive()
|
||||
assert timeout_scope.cancelled_caught
|
||||
await s.send("Test for PR #553")
|
||||
|
||||
|
||||
async def test_channel_multiple_producers() -> None:
|
||||
async def producer(send_channel: trio.MemorySendChannel[int], i: int) -> None:
|
||||
# We close our handle when we're done with it
|
||||
async with send_channel:
|
||||
for j in range(3 * i, 3 * (i + 1)):
|
||||
await send_channel.send(j)
|
||||
|
||||
send_channel, receive_channel = open_memory_channel[int](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
# We hand out clones to all the new producers, and then close the
|
||||
# original.
|
||||
async with send_channel:
|
||||
for i in range(10):
|
||||
nursery.start_soon(producer, send_channel.clone(), i)
|
||||
|
||||
got = [value async for value in receive_channel]
|
||||
|
||||
got.sort()
|
||||
assert got == list(range(30))
|
||||
|
||||
|
||||
async def test_channel_multiple_consumers() -> None:
|
||||
successful_receivers = set()
|
||||
received = []
|
||||
|
||||
async def consumer(receive_channel: trio.MemoryReceiveChannel[int], i: int) -> None:
|
||||
async for value in receive_channel:
|
||||
successful_receivers.add(i)
|
||||
received.append(value)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
send_channel, receive_channel = trio.open_memory_channel[int](1)
|
||||
async with send_channel:
|
||||
for i in range(5):
|
||||
nursery.start_soon(consumer, receive_channel, i)
|
||||
await wait_all_tasks_blocked()
|
||||
for i in range(10):
|
||||
await send_channel.send(i)
|
||||
|
||||
assert successful_receivers == set(range(5))
|
||||
assert len(received) == 10
|
||||
assert set(received) == set(range(10))
|
||||
|
||||
|
||||
async def test_close_basics() -> None:
|
||||
async def send_block(
|
||||
s: trio.MemorySendChannel[None],
|
||||
expect: type[BaseException],
|
||||
) -> None:
|
||||
with pytest.raises(expect):
|
||||
await s.send(None)
|
||||
|
||||
# closing send -> other send gets ClosedResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.ClosedResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
await s.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# and receive gets EndOfChannel
|
||||
with pytest.raises(EndOfChannel):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
|
||||
# closing receive -> send gets BrokenResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.BrokenResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
await r.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# closing receive -> other receive gets ClosedResourceError
|
||||
async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
s2, r2 = open_memory_channel[int](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_block, r2)
|
||||
await wait_all_tasks_blocked()
|
||||
await r2.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r2.receive_nowait()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r2.receive()
|
||||
|
||||
|
||||
async def test_close_sync() -> None:
|
||||
async def send_block(
|
||||
s: trio.MemorySendChannel[None],
|
||||
expect: type[BaseException],
|
||||
) -> None:
|
||||
with pytest.raises(expect):
|
||||
await s.send(None)
|
||||
|
||||
# closing send -> other send gets ClosedResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.ClosedResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
s.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# and receive gets EndOfChannel
|
||||
with pytest.raises(EndOfChannel):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
|
||||
# closing receive -> send gets BrokenResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.BrokenResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
r.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# closing receive -> other receive gets ClosedResourceError
|
||||
async def receive_block(r: trio.MemoryReceiveChannel[None]) -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_block, r)
|
||||
await wait_all_tasks_blocked()
|
||||
r.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
|
||||
async def test_receive_channel_clone_and_close() -> None:
|
||||
s, r = open_memory_channel[None](10)
|
||||
|
||||
r2 = r.clone()
|
||||
r3 = r.clone()
|
||||
|
||||
s.send_nowait(None)
|
||||
await r.aclose()
|
||||
with r2:
|
||||
pass
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.clone()
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r2.clone()
|
||||
|
||||
# Can still send, r3 is still open
|
||||
s.send_nowait(None)
|
||||
|
||||
await r3.aclose()
|
||||
|
||||
# But now the receiver is really closed
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
|
||||
|
||||
async def test_close_multiple_send_handles() -> None:
|
||||
# With multiple send handles, closing one handle only wakes senders on
|
||||
# that handle, but others can continue just fine
|
||||
s1, r = open_memory_channel[str](0)
|
||||
s2 = s1.clone()
|
||||
|
||||
async def send_will_close() -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s1.send("nope")
|
||||
|
||||
async def send_will_succeed() -> None:
|
||||
await s2.send("ok")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_will_close)
|
||||
nursery.start_soon(send_will_succeed)
|
||||
await wait_all_tasks_blocked()
|
||||
await s1.aclose()
|
||||
assert await r.receive() == "ok"
|
||||
|
||||
|
||||
async def test_close_multiple_receive_handles() -> None:
|
||||
# With multiple receive handles, closing one handle only wakes receivers on
|
||||
# that handle, but others can continue just fine
|
||||
s, r1 = open_memory_channel[str](0)
|
||||
r2 = r1.clone()
|
||||
|
||||
async def receive_will_close() -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r1.receive()
|
||||
|
||||
async def receive_will_succeed() -> None:
|
||||
assert await r2.receive() == "ok"
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_will_close)
|
||||
nursery.start_soon(receive_will_succeed)
|
||||
await wait_all_tasks_blocked()
|
||||
await r1.aclose()
|
||||
await s.send("ok")
|
||||
|
||||
|
||||
async def test_inf_capacity() -> None:
|
||||
send, receive = open_memory_channel[int](float("inf"))
|
||||
|
||||
# It's accepted, and we can send all day without blocking
|
||||
with send:
|
||||
for i in range(10):
|
||||
send.send_nowait(i)
|
||||
|
||||
got = [i async for i in receive]
|
||||
assert got == list(range(10))
|
||||
|
||||
|
||||
async def test_statistics() -> None:
|
||||
s, r = open_memory_channel[None](2)
|
||||
|
||||
assert s.statistics() == r.statistics()
|
||||
stats = s.statistics()
|
||||
assert stats.current_buffer_used == 0
|
||||
assert stats.max_buffer_size == 2
|
||||
assert stats.open_send_channels == 1
|
||||
assert stats.open_receive_channels == 1
|
||||
assert stats.tasks_waiting_send == 0
|
||||
assert stats.tasks_waiting_receive == 0
|
||||
|
||||
s.send_nowait(None)
|
||||
assert s.statistics().current_buffer_used == 1
|
||||
|
||||
s2 = s.clone()
|
||||
assert s.statistics().open_send_channels == 2
|
||||
await s.aclose()
|
||||
assert s2.statistics().open_send_channels == 1
|
||||
|
||||
r2 = r.clone()
|
||||
assert s2.statistics().open_receive_channels == 2
|
||||
await r2.aclose()
|
||||
assert s2.statistics().open_receive_channels == 1
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
s2.send_nowait(None) # fill up the buffer
|
||||
assert s.statistics().current_buffer_used == 2
|
||||
nursery.start_soon(s2.send, None)
|
||||
nursery.start_soon(s2.send, None)
|
||||
await wait_all_tasks_blocked()
|
||||
assert s.statistics().tasks_waiting_send == 2
|
||||
nursery.cancel_scope.cancel()
|
||||
assert s.statistics().tasks_waiting_send == 0
|
||||
|
||||
# empty out the buffer again
|
||||
try:
|
||||
while True:
|
||||
r.receive_nowait()
|
||||
except trio.WouldBlock:
|
||||
pass
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(r.receive)
|
||||
await wait_all_tasks_blocked()
|
||||
assert s.statistics().tasks_waiting_receive == 1
|
||||
nursery.cancel_scope.cancel()
|
||||
assert s.statistics().tasks_waiting_receive == 0
|
||||
|
||||
|
||||
async def test_channel_fairness() -> None:
|
||||
# We can remove an item we just sent, and send an item back in after, if
|
||||
# no-one else is waiting.
|
||||
s, r = open_memory_channel[Union[int, None]](1)
|
||||
s.send_nowait(1)
|
||||
assert r.receive_nowait() == 1
|
||||
s.send_nowait(2)
|
||||
assert r.receive_nowait() == 2
|
||||
|
||||
# But if someone else is waiting to receive, then they "own" the item we
|
||||
# send, so we can't receive it (even though we run first):
|
||||
|
||||
result: int | None = None
|
||||
|
||||
async def do_receive(r: trio.MemoryReceiveChannel[int | None]) -> None:
|
||||
nonlocal result
|
||||
result = await r.receive()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(do_receive, r)
|
||||
await wait_all_tasks_blocked()
|
||||
s.send_nowait(2)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
assert result == 2
|
||||
|
||||
# And the analogous situation for send: if we free up a space, we can't
|
||||
# immediately send something in it if someone is already waiting to do
|
||||
# that
|
||||
s, r = open_memory_channel[Union[int, None]](1)
|
||||
s.send_nowait(1)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(None)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(s.send, 2)
|
||||
await wait_all_tasks_blocked()
|
||||
assert r.receive_nowait() == 1
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(3)
|
||||
assert (await r.receive()) == 2
|
||||
|
||||
|
||||
async def test_unbuffered() -> None:
|
||||
s, r = open_memory_channel[int](0)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(1)
|
||||
|
||||
async def do_send(s: trio.MemorySendChannel[int], v: int) -> None:
|
||||
with assert_checkpoints():
|
||||
await s.send(v)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send, s, 1)
|
||||
with assert_checkpoints():
|
||||
assert await r.receive() == 1
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
|
||||
from .. import _core
|
||||
|
||||
trio_testing_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar(
|
||||
"trio_testing_contextvar",
|
||||
)
|
||||
|
||||
|
||||
async def test_contextvars_default() -> None:
|
||||
trio_testing_contextvar.set("main")
|
||||
record: list[str] = []
|
||||
|
||||
async def child() -> None:
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
assert record == ["main"]
|
||||
|
||||
|
||||
async def test_contextvars_set() -> None:
|
||||
trio_testing_contextvar.set("main")
|
||||
record: list[str] = []
|
||||
|
||||
async def child() -> None:
|
||||
trio_testing_contextvar.set("child")
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
value = trio_testing_contextvar.get()
|
||||
assert record == ["child"]
|
||||
assert value == "main"
|
||||
|
||||
|
||||
async def test_contextvars_copy() -> None:
|
||||
trio_testing_contextvar.set("main")
|
||||
context = contextvars.copy_context()
|
||||
trio_testing_contextvar.set("second_main")
|
||||
record: list[str] = []
|
||||
|
||||
async def child() -> None:
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
context.run(nursery.start_soon, child)
|
||||
nursery.start_soon(child)
|
||||
value = trio_testing_contextvar.get()
|
||||
assert set(record) == {"main", "second_main"}
|
||||
assert value == "second_main"
|
||||
@@ -0,0 +1,283 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from .._deprecate import (
|
||||
TrioDeprecationWarning,
|
||||
deprecated,
|
||||
deprecated_alias,
|
||||
warn_deprecated,
|
||||
)
|
||||
from . import module_with_deprecations
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recwarn_always(recwarn: pytest.WarningsRecorder) -> pytest.WarningsRecorder:
|
||||
warnings.simplefilter("always")
|
||||
# ResourceWarnings about unclosed sockets can occur nondeterministically
|
||||
# (during GC) which throws off the tests in this file
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
return recwarn
|
||||
|
||||
|
||||
def _here() -> tuple[str, int]:
|
||||
frame = inspect.currentframe()
|
||||
assert frame is not None
|
||||
assert frame.f_back is not None
|
||||
info = inspect.getframeinfo(frame.f_back)
|
||||
return (info.filename, info.lineno)
|
||||
|
||||
|
||||
def test_warn_deprecated(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
def deprecated_thing() -> None:
|
||||
warn_deprecated("ice", "1.2", issue=1, instead="water")
|
||||
|
||||
deprecated_thing()
|
||||
filename, lineno = _here()
|
||||
assert len(recwarn_always) == 1
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "ice is deprecated" in got.message.args[0]
|
||||
assert "Trio 1.2" in got.message.args[0]
|
||||
assert "water instead" in got.message.args[0]
|
||||
assert "/issues/1" in got.message.args[0]
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno - 1
|
||||
|
||||
|
||||
def test_warn_deprecated_no_instead_or_issue(
|
||||
recwarn_always: pytest.WarningsRecorder,
|
||||
) -> None:
|
||||
# Explicitly no instead or issue
|
||||
warn_deprecated("water", "1.3", issue=None, instead=None)
|
||||
assert len(recwarn_always) == 1
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "water is deprecated" in got.message.args[0]
|
||||
assert "no replacement" in got.message.args[0]
|
||||
assert "Trio 1.3" in got.message.args[0]
|
||||
|
||||
|
||||
def test_warn_deprecated_stacklevel(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
def nested1() -> None:
|
||||
nested2()
|
||||
|
||||
def nested2() -> None:
|
||||
warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3)
|
||||
|
||||
filename, lineno = _here()
|
||||
nested1()
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno + 1
|
||||
|
||||
|
||||
def old() -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def new() -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def test_warn_deprecated_formatting(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
warn_deprecated(old, "1.0", issue=1, instead=new)
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.old is deprecated" in got.message.args[0]
|
||||
assert "test_deprecate.new instead" in got.message.args[0]
|
||||
|
||||
|
||||
@deprecated("1.5", issue=123, instead=new)
|
||||
def deprecated_old() -> int:
|
||||
return 3
|
||||
|
||||
|
||||
def test_deprecated_decorator(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
assert deprecated_old() == 3
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0]
|
||||
assert "1.5" in got.message.args[0]
|
||||
assert "test_deprecate.new" in got.message.args[0]
|
||||
assert "issues/123" in got.message.args[0]
|
||||
|
||||
|
||||
class Foo:
|
||||
@deprecated("1.0", issue=123, instead="crying")
|
||||
def method(self) -> int:
|
||||
return 7
|
||||
|
||||
|
||||
def test_deprecated_decorator_method(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
f = Foo()
|
||||
assert f.method() == 7
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.Foo.method is deprecated" in got.message.args[0]
|
||||
|
||||
|
||||
@deprecated("1.2", thing="the thing", issue=None, instead=None)
|
||||
def deprecated_with_thing() -> int:
|
||||
return 72
|
||||
|
||||
|
||||
def test_deprecated_decorator_with_explicit_thing(
|
||||
recwarn_always: pytest.WarningsRecorder,
|
||||
) -> None:
|
||||
assert deprecated_with_thing() == 72
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "the thing is deprecated" in got.message.args[0]
|
||||
|
||||
|
||||
def new_hotness() -> str:
|
||||
return "new hotness"
|
||||
|
||||
|
||||
old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1)
|
||||
|
||||
|
||||
def test_deprecated_alias(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
assert old_hotness() == "new hotness"
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.old_hotness is deprecated" in got.message.args[0]
|
||||
assert "1.23" in got.message.args[0]
|
||||
assert "test_deprecate.new_hotness instead" in got.message.args[0]
|
||||
assert "issues/1" in got.message.args[0]
|
||||
|
||||
assert isinstance(old_hotness.__doc__, str)
|
||||
assert ".. deprecated:: 1.23" in old_hotness.__doc__
|
||||
assert "test_deprecate.new_hotness instead" in old_hotness.__doc__
|
||||
assert "issues/1>`__" in old_hotness.__doc__
|
||||
|
||||
|
||||
class Alias:
|
||||
def new_hotness_method(self) -> str:
|
||||
return "new hotness method"
|
||||
|
||||
old_hotness_method = deprecated_alias(
|
||||
"Alias.old_hotness_method",
|
||||
new_hotness_method,
|
||||
"3.21",
|
||||
issue=1,
|
||||
)
|
||||
|
||||
|
||||
def test_deprecated_alias_method(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
obj = Alias()
|
||||
assert obj.old_hotness_method() == "new hotness method"
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
msg = got.message.args[0]
|
||||
assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg
|
||||
assert "test_deprecate.Alias.new_hotness_method instead" in msg
|
||||
|
||||
|
||||
@deprecated("2.1", issue=1, instead="hi")
|
||||
def docstring_test1() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=None, instead="hi")
|
||||
def docstring_test2() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=1, instead=None)
|
||||
def docstring_test3() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=None, instead=None)
|
||||
def docstring_test4() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
def test_deprecated_docstring_munging() -> None:
|
||||
assert (
|
||||
docstring_test1.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
Use hi instead.
|
||||
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
assert (
|
||||
docstring_test2.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
Use hi instead.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
assert (
|
||||
docstring_test3.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
assert (
|
||||
docstring_test4.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_module_with_deprecations(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
assert module_with_deprecations.regular == "hi"
|
||||
assert len(recwarn_always) == 0
|
||||
|
||||
filename, lineno = _here()
|
||||
assert module_with_deprecations.dep1 == "value1" # type: ignore[attr-defined]
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno + 1
|
||||
|
||||
assert "module_with_deprecations.dep1" in got.message.args[0]
|
||||
assert "Trio 1.1" in got.message.args[0]
|
||||
assert "/issues/1" in got.message.args[0]
|
||||
assert "value1 instead" in got.message.args[0]
|
||||
|
||||
assert module_with_deprecations.dep2 == "value2" # type: ignore[attr-defined]
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "instead-string instead" in got.message.args[0]
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
module_with_deprecations.asdf # type: ignore[attr-defined] # noqa: B018 # "useless expression"
|
||||
|
||||
|
||||
def test_warning_class() -> None:
|
||||
with pytest.deprecated_call():
|
||||
warn_deprecated("foo", "bar", issue=None, instead=None)
|
||||
|
||||
# essentially the same as the above check
|
||||
with pytest.warns(DeprecationWarning):
|
||||
warn_deprecated("foo", "bar", issue=None, instead=None)
|
||||
|
||||
with pytest.warns(TrioDeprecationWarning):
|
||||
warn_deprecated(
|
||||
"foo",
|
||||
"bar",
|
||||
issue=None,
|
||||
instead=None,
|
||||
use_triodeprecationwarning=True,
|
||||
)
|
||||
+64
@@ -0,0 +1,64 @@
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
|
||||
|
||||
async def test_deprecation_warning_open_nursery() -> None:
|
||||
with pytest.warns(
|
||||
trio.TrioDeprecationWarning,
|
||||
match="strict_exception_groups=False",
|
||||
) as record:
|
||||
async with trio.open_nursery(strict_exception_groups=False):
|
||||
...
|
||||
assert len(record) == 1
|
||||
async with trio.open_nursery(strict_exception_groups=True):
|
||||
...
|
||||
async with trio.open_nursery():
|
||||
...
|
||||
|
||||
|
||||
def test_deprecation_warning_run() -> None:
|
||||
async def foo() -> None: ...
|
||||
|
||||
async def foo_nursery() -> None:
|
||||
# this should not raise a warning, even if it's implied loose
|
||||
async with trio.open_nursery():
|
||||
...
|
||||
|
||||
async def foo_loose_nursery() -> None:
|
||||
# this should raise a warning, even if specifying the parameter is redundant
|
||||
async with trio.open_nursery(strict_exception_groups=False):
|
||||
...
|
||||
|
||||
def helper(fun: Callable[..., Awaitable[None]], num: int) -> None:
|
||||
with pytest.warns(
|
||||
trio.TrioDeprecationWarning,
|
||||
match="strict_exception_groups=False",
|
||||
) as record:
|
||||
trio.run(fun, strict_exception_groups=False)
|
||||
assert len(record) == num
|
||||
|
||||
helper(foo, 1)
|
||||
helper(foo_nursery, 1)
|
||||
helper(foo_loose_nursery, 2)
|
||||
|
||||
|
||||
def test_deprecation_warning_start_guest_run() -> None:
|
||||
# "The simplest possible "host" loop."
|
||||
from .._core._tests.test_guest_mode import trivial_guest_run
|
||||
|
||||
async def trio_return(in_host: object) -> str:
|
||||
await trio.lowlevel.checkpoint()
|
||||
return "ok"
|
||||
|
||||
with pytest.warns(
|
||||
trio.TrioDeprecationWarning,
|
||||
match="strict_exception_groups=False",
|
||||
) as record:
|
||||
trivial_guest_run(
|
||||
trio_return,
|
||||
strict_exception_groups=False,
|
||||
)
|
||||
assert len(record) == 1
|
||||
@@ -0,0 +1,900 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from contextlib import asynccontextmanager
|
||||
from itertools import count
|
||||
from typing import TYPE_CHECKING, NoReturn
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from trio._tests.pytest_plugin import skip_if_optional_else_raise
|
||||
|
||||
try:
|
||||
import trustme
|
||||
from OpenSSL import SSL
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
from trio import DTLSChannel, DTLSEndpoint
|
||||
from trio.testing._fake_net import FakeNet, UDPPacket
|
||||
|
||||
from .._core._tests.tutil import binds_ipv6, gc_collect_harder, slow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
ca = trustme.CA()
|
||||
server_cert = ca.issue_cert("example.com")
|
||||
|
||||
server_ctx = SSL.Context(SSL.DTLS_METHOD)
|
||||
server_cert.configure_cert(server_ctx)
|
||||
|
||||
client_ctx = SSL.Context(SSL.DTLS_METHOD)
|
||||
ca.configure_trust(client_ctx)
|
||||
|
||||
|
||||
parametrize_ipv6 = pytest.mark.parametrize(
|
||||
"ipv6",
|
||||
[False, pytest.param(True, marks=binds_ipv6)],
|
||||
ids=["ipv4", "ipv6"],
|
||||
)
|
||||
|
||||
|
||||
def endpoint(**kwargs: int | bool) -> DTLSEndpoint:
|
||||
ipv6 = kwargs.pop("ipv6", False)
|
||||
family = trio.socket.AF_INET6 if ipv6 else trio.socket.AF_INET
|
||||
sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family)
|
||||
return DTLSEndpoint(sock, **kwargs)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def dtls_echo_server(
|
||||
*,
|
||||
autocancel: bool = True,
|
||||
mtu: int | None = None,
|
||||
ipv6: bool = False,
|
||||
) -> AsyncGenerator[tuple[DTLSEndpoint, tuple[str, int]], None]:
|
||||
with endpoint(ipv6=ipv6) as server:
|
||||
localhost = "::1" if ipv6 else "127.0.0.1"
|
||||
await server.socket.bind((localhost, 0))
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def echo_handler(dtls_channel: DTLSChannel) -> None:
|
||||
print(
|
||||
"echo handler started: "
|
||||
f"server {dtls_channel.endpoint.socket.getsockname()!r} "
|
||||
f"client {dtls_channel.peer_address!r}",
|
||||
)
|
||||
if mtu is not None:
|
||||
dtls_channel.set_ciphertext_mtu(mtu)
|
||||
try:
|
||||
print("server starting do_handshake")
|
||||
await dtls_channel.do_handshake()
|
||||
print("server finished do_handshake")
|
||||
async for packet in dtls_channel:
|
||||
print(f"echoing {packet!r} -> {dtls_channel.peer_address!r}")
|
||||
await dtls_channel.send(packet)
|
||||
except trio.BrokenResourceError: # pragma: no cover
|
||||
print("echo handler channel broken")
|
||||
|
||||
await nursery.start(server.serve, server_ctx, echo_handler)
|
||||
|
||||
yield server, server.socket.getsockname()
|
||||
|
||||
if autocancel:
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@parametrize_ipv6
|
||||
async def test_smoke(ipv6: bool) -> None:
|
||||
async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address):
|
||||
with endpoint(ipv6=ipv6) as client_endpoint:
|
||||
client_channel = client_endpoint.connect(address, client_ctx)
|
||||
with pytest.raises(trio.NeedHandshakeError):
|
||||
client_channel.get_cleartext_mtu()
|
||||
|
||||
await client_channel.do_handshake()
|
||||
await client_channel.send(b"hello")
|
||||
assert await client_channel.receive() == b"hello"
|
||||
await client_channel.send(b"goodbye")
|
||||
assert await client_channel.receive() == b"goodbye"
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^openssl doesn't support sending empty DTLS packets$",
|
||||
):
|
||||
await client_channel.send(b"")
|
||||
|
||||
client_channel.set_ciphertext_mtu(1234)
|
||||
cleartext_mtu_1234 = client_channel.get_cleartext_mtu()
|
||||
client_channel.set_ciphertext_mtu(4321)
|
||||
assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234
|
||||
client_channel.set_ciphertext_mtu(1234)
|
||||
assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234
|
||||
|
||||
|
||||
@slow
|
||||
async def test_handshake_over_terrible_network(
|
||||
autojump_clock: trio.testing.MockClock,
|
||||
) -> None:
|
||||
HANDSHAKES = 100
|
||||
r = random.Random(0)
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
# avoid spurious timeouts on slow machines
|
||||
autojump_clock.autojump_threshold = 0.001
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def route_packet(packet: UDPPacket) -> None:
|
||||
while True:
|
||||
op = r.choices(
|
||||
["deliver", "drop", "dupe", "delay"],
|
||||
weights=[0.7, 0.1, 0.1, 0.1],
|
||||
)[0]
|
||||
print(f"{packet.source} -> {packet.destination}: {op}")
|
||||
if op == "drop":
|
||||
return
|
||||
elif op == "dupe":
|
||||
fn.send_packet(packet)
|
||||
elif op == "delay":
|
||||
await trio.sleep(r.random() * 3)
|
||||
# I wanted to test random packet corruption too, but it turns out
|
||||
# openssl has a bug in the following scenario:
|
||||
#
|
||||
# - client sends ClientHello
|
||||
# - server sends HelloVerifyRequest with cookie -- but cookie is
|
||||
# invalid b/c either the ClientHello or HelloVerifyRequest was
|
||||
# corrupted
|
||||
# - client re-sends ClientHello with invalid cookie
|
||||
# - server replies with new HelloVerifyRequest and correct cookie
|
||||
#
|
||||
# At this point, the client *should* switch to the new, valid
|
||||
# cookie. But OpenSSL doesn't; it stubbornly insists on re-sending
|
||||
# the original, invalid cookie over and over. In theory we could
|
||||
# work around this by detecting cookie changes and starting over
|
||||
# with a whole new SSL object, but (a) it doesn't seem worth it, (b)
|
||||
# when I tried then I ran into another issue where OpenSSL got stuck
|
||||
# in an infinite loop sending alerts over and over, which I didn't
|
||||
# dig into because see (a).
|
||||
#
|
||||
# elif op == "distort":
|
||||
# payload = bytearray(packet.payload)
|
||||
# payload[r.randrange(len(payload))] ^= 1 << r.randrange(8)
|
||||
# packet = attrs.evolve(packet, payload=payload)
|
||||
else:
|
||||
assert op == "deliver"
|
||||
print(
|
||||
f"{packet.source} -> {packet.destination}: delivered"
|
||||
f" {packet.payload.hex()}",
|
||||
)
|
||||
fn.deliver_packet(packet)
|
||||
break
|
||||
|
||||
def route_packet_wrapper(packet: UDPPacket) -> None:
|
||||
try: # noqa: SIM105 # suppressible-exception
|
||||
nursery.start_soon(route_packet, packet)
|
||||
except RuntimeError: # pragma: no cover
|
||||
# We're exiting the nursery, so any remaining packets can just get
|
||||
# dropped
|
||||
pass
|
||||
|
||||
fn.route_packet = route_packet_wrapper # type: ignore[assignment] # TODO: Fix FakeNet typing
|
||||
|
||||
for i in range(HANDSHAKES):
|
||||
print("#" * 80)
|
||||
print("#" * 80)
|
||||
print("#" * 80)
|
||||
with endpoint() as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
print("client starting do_handshake")
|
||||
await client.do_handshake()
|
||||
print("client finished do_handshake")
|
||||
msg = str(i).encode()
|
||||
# Make multiple attempts to send data, because the network might
|
||||
# drop it
|
||||
while True:
|
||||
with trio.move_on_after(10) as cscope:
|
||||
await client.send(msg)
|
||||
assert await client.receive() == msg
|
||||
if not cscope.cancelled_caught:
|
||||
break
|
||||
|
||||
|
||||
async def test_implicit_handshake() -> None:
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
|
||||
# Implicit handshake
|
||||
await client.send(b"xyz")
|
||||
assert await client.receive() == b"xyz"
|
||||
|
||||
|
||||
async def test_full_duplex() -> None:
|
||||
# Tests simultaneous send/receive, and also multiple methods implicitly invoking
|
||||
# do_handshake simultaneously.
|
||||
with endpoint() as server_endpoint, endpoint() as client_endpoint:
|
||||
await server_endpoint.socket.bind(("127.0.0.1", 0))
|
||||
async with trio.open_nursery() as server_nursery:
|
||||
|
||||
async def handler(channel: DTLSChannel) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(channel.send, b"from server")
|
||||
nursery.start_soon(channel.receive)
|
||||
|
||||
await server_nursery.start(server_endpoint.serve, server_ctx, handler)
|
||||
|
||||
client = client_endpoint.connect(
|
||||
server_endpoint.socket.getsockname(),
|
||||
client_ctx,
|
||||
)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(client.send, b"from client")
|
||||
nursery.start_soon(client.receive)
|
||||
|
||||
server_nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_channel_closing() -> None:
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
await client.do_handshake()
|
||||
client.close()
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client.send(b"abc")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client.receive()
|
||||
|
||||
# close is idempotent
|
||||
client.close()
|
||||
# can also aclose
|
||||
await client.aclose()
|
||||
|
||||
|
||||
async def test_serve_exits_cleanly_on_close() -> None:
|
||||
async with dtls_echo_server(autocancel=False) as (server_endpoint, address):
|
||||
server_endpoint.close()
|
||||
# Testing that the nursery exits even without being cancelled
|
||||
# close is idempotent
|
||||
server_endpoint.close()
|
||||
|
||||
|
||||
async def test_client_multiplex() -> None:
|
||||
async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2):
|
||||
with endpoint() as client_endpoint:
|
||||
client1 = client_endpoint.connect(address1, client_ctx)
|
||||
client2 = client_endpoint.connect(address2, client_ctx)
|
||||
|
||||
await client1.send(b"abc")
|
||||
await client2.send(b"xyz")
|
||||
assert await client2.receive() == b"xyz"
|
||||
assert await client1.receive() == b"abc"
|
||||
|
||||
client_endpoint.close()
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client1.send(b"xxx")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client2.receive()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
client_endpoint.connect(address1, client_ctx)
|
||||
|
||||
async def null_handler(_: object) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await nursery.start(client_endpoint.serve, server_ctx, null_handler)
|
||||
|
||||
|
||||
async def test_dtls_over_dgram_only() -> None:
|
||||
with trio.socket.socket() as s:
|
||||
with pytest.raises(ValueError, match="^DTLS requires a SOCK_DGRAM socket$"):
|
||||
DTLSEndpoint(s)
|
||||
|
||||
|
||||
async def test_double_serve() -> None:
|
||||
async def null_handler(_: object) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
with endpoint() as server_endpoint:
|
||||
await server_endpoint.socket.bind(("127.0.0.1", 0))
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||||
with pytest.raises(trio.BusyResourceError):
|
||||
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_connect_to_non_server(autojump_clock: trio.abc.Clock) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
with endpoint() as client1, endpoint() as client2:
|
||||
await client1.socket.bind(("127.0.0.1", 0))
|
||||
# This should just time out
|
||||
with trio.move_on_after(100) as cscope:
|
||||
channel = client2.connect(client1.socket.getsockname(), client_ctx)
|
||||
await channel.do_handshake()
|
||||
assert cscope.cancelled_caught
|
||||
|
||||
|
||||
async def test_incoming_buffer_overflow(autojump_clock: trio.abc.Clock) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
for buffer_size in [10, 20]:
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint:
|
||||
assert client_endpoint.incoming_packets_buffer == buffer_size
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
for i in range(buffer_size + 15):
|
||||
await client.send(str(i).encode())
|
||||
await trio.sleep(1)
|
||||
stats = client.statistics()
|
||||
assert stats.incoming_packets_dropped_in_trio == 15
|
||||
for i in range(buffer_size):
|
||||
assert await client.receive() == str(i).encode()
|
||||
await client.send(b"buffer clear now")
|
||||
assert await client.receive() == b"buffer clear now"
|
||||
|
||||
|
||||
async def test_server_socket_doesnt_crash_on_garbage(
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
from trio._dtls import (
|
||||
ContentType,
|
||||
HandshakeFragment,
|
||||
HandshakeType,
|
||||
ProtocolVersion,
|
||||
Record,
|
||||
encode_handshake_fragment,
|
||||
encode_record,
|
||||
)
|
||||
|
||||
client_hello = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=encode_handshake_fragment(
|
||||
HandshakeFragment(
|
||||
msg_type=HandshakeType.client_hello,
|
||||
msg_len=10,
|
||||
msg_seq=0,
|
||||
frag_offset=0,
|
||||
frag_len=10,
|
||||
frag=bytes(10),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
client_hello_extended = client_hello + b"\x00"
|
||||
client_hello_short = client_hello[:-1]
|
||||
# cuts off in middle of handshake message header
|
||||
client_hello_really_short = client_hello[:14]
|
||||
client_hello_corrupt_record_len = bytearray(client_hello)
|
||||
client_hello_corrupt_record_len[11] = 0xFF
|
||||
|
||||
client_hello_fragmented = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=encode_handshake_fragment(
|
||||
HandshakeFragment(
|
||||
msg_type=HandshakeType.client_hello,
|
||||
msg_len=20,
|
||||
msg_seq=0,
|
||||
frag_offset=0,
|
||||
frag_len=10,
|
||||
frag=bytes(10),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
client_hello_trailing_data_in_record = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=encode_handshake_fragment(
|
||||
HandshakeFragment(
|
||||
msg_type=HandshakeType.client_hello,
|
||||
msg_len=20,
|
||||
msg_seq=0,
|
||||
frag_offset=0,
|
||||
frag_len=10,
|
||||
frag=bytes(10),
|
||||
),
|
||||
)
|
||||
+ b"\x00",
|
||||
),
|
||||
)
|
||||
|
||||
handshake_empty = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=b"",
|
||||
),
|
||||
)
|
||||
|
||||
client_hello_truncated_in_cookie = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=bytes(2 + 32 + 1) + b"\xff",
|
||||
),
|
||||
)
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock:
|
||||
for bad_packet in [
|
||||
b"",
|
||||
b"xyz",
|
||||
client_hello_extended,
|
||||
client_hello_short,
|
||||
client_hello_really_short,
|
||||
client_hello_corrupt_record_len,
|
||||
client_hello_fragmented,
|
||||
client_hello_trailing_data_in_record,
|
||||
handshake_empty,
|
||||
client_hello_truncated_in_cookie,
|
||||
]:
|
||||
await sock.sendto(bad_packet, address)
|
||||
await trio.sleep(1)
|
||||
|
||||
|
||||
async def test_invalid_cookie_rejected(autojump_clock: trio.abc.Clock) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
from trio._dtls import BadPacket, decode_client_hello_untrusted
|
||||
|
||||
with trio.CancelScope() as cscope:
|
||||
# the first 11 bytes of ClientHello aren't protected by the cookie, so only test
|
||||
# corrupting bytes after that.
|
||||
offset_to_corrupt = count(11)
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
try:
|
||||
_, cookie, _ = decode_client_hello_untrusted(packet.payload)
|
||||
except BadPacket:
|
||||
pass
|
||||
else:
|
||||
if len(cookie) != 0:
|
||||
# this is a challenge response packet
|
||||
# let's corrupt the next offset so the handshake should fail
|
||||
payload = bytearray(packet.payload)
|
||||
offset = next(offset_to_corrupt)
|
||||
if offset >= len(payload):
|
||||
# We've tried all offsets. Clamp offset to the end of the
|
||||
# payload, and terminate the test.
|
||||
offset = len(payload) - 1
|
||||
cscope.cancel()
|
||||
payload[offset] ^= 0x01
|
||||
packet = attrs.evolve(packet, payload=payload)
|
||||
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO: Fix FakeNet typing
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
while True:
|
||||
with endpoint() as client:
|
||||
channel = client.connect(address, client_ctx)
|
||||
await channel.do_handshake()
|
||||
assert cscope.cancelled_caught
|
||||
|
||||
|
||||
async def test_client_cancels_handshake_and_starts_new_one(
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
# if a client disappears during the handshake, and then starts a new handshake from
|
||||
# scratch, then the first handler's channel should fail, and a new handler get
|
||||
# started
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
with endpoint() as server, endpoint() as client:
|
||||
await server.socket.bind(("127.0.0.1", 0))
|
||||
async with trio.open_nursery() as nursery:
|
||||
first_time = True
|
||||
|
||||
async def handler(channel: DTLSChannel) -> None:
|
||||
nonlocal first_time
|
||||
if first_time:
|
||||
first_time = False
|
||||
print("handler: first time, cancelling connect")
|
||||
connect_cscope.cancel()
|
||||
await trio.sleep(0.5)
|
||||
print("handler: handshake should fail now")
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await channel.do_handshake()
|
||||
else:
|
||||
print("handler: not first time, sending hello")
|
||||
await channel.send(b"hello")
|
||||
|
||||
await nursery.start(server.serve, server_ctx, handler)
|
||||
|
||||
print("client: starting first connect")
|
||||
with trio.CancelScope() as connect_cscope:
|
||||
channel = client.connect(server.socket.getsockname(), client_ctx)
|
||||
await channel.do_handshake()
|
||||
assert connect_cscope.cancelled_caught
|
||||
|
||||
print("client: starting second connect")
|
||||
channel = client.connect(server.socket.getsockname(), client_ctx)
|
||||
assert await channel.receive() == b"hello"
|
||||
|
||||
# Give handlers a chance to finish
|
||||
await trio.sleep(10)
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_swap_client_server() -> None:
|
||||
with endpoint() as a, endpoint() as b:
|
||||
await a.socket.bind(("127.0.0.1", 0))
|
||||
await b.socket.bind(("127.0.0.1", 0))
|
||||
|
||||
async def echo_handler(channel: DTLSChannel) -> None:
|
||||
async for packet in channel:
|
||||
await channel.send(packet)
|
||||
|
||||
async def crashing_echo_handler(channel: DTLSChannel) -> None:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await echo_handler(channel)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(a.serve, server_ctx, crashing_echo_handler)
|
||||
await nursery.start(b.serve, server_ctx, echo_handler)
|
||||
|
||||
b_to_a = b.connect(a.socket.getsockname(), client_ctx)
|
||||
await b_to_a.send(b"b as client")
|
||||
assert await b_to_a.receive() == b"b as client"
|
||||
|
||||
a_to_b = a.connect(b.socket.getsockname(), client_ctx)
|
||||
await a_to_b.do_handshake()
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await b_to_a.send(b"association broken")
|
||||
await a_to_b.send(b"a as client")
|
||||
assert await a_to_b.receive() == b"a as client"
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@slow
|
||||
async def test_openssl_retransmit_doesnt_break_stuff() -> None:
|
||||
# can't use autojump_clock here, because the point of the test is to wait for
|
||||
# openssl's built-in retransmit timer to expire, which is hard-coded to use
|
||||
# wall-clock time.
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
blackholed = True
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
if blackholed:
|
||||
print("dropped packet", packet)
|
||||
return
|
||||
print("delivered packet", packet)
|
||||
# packets.append(
|
||||
# scapy.all.IP(
|
||||
# src=packet.source.ip.compressed, dst=packet.destination.ip.compressed
|
||||
# )
|
||||
# / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port)
|
||||
# / packet.payload
|
||||
# )
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server() as (server_endpoint, address):
|
||||
with endpoint() as client_endpoint:
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def connecter() -> None:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
await client.do_handshake(initial_retransmit_timeout=1.5)
|
||||
await client.send(b"hi")
|
||||
assert await client.receive() == b"hi"
|
||||
|
||||
nursery.start_soon(connecter)
|
||||
|
||||
# openssl's default timeout is 1 second, so this ensures that it thinks
|
||||
# the timeout has expired
|
||||
await trio.sleep(1.1)
|
||||
# disable blackholing and send a garbage packet to wake up openssl so it
|
||||
# notices the timeout has expired
|
||||
blackholed = False
|
||||
await server_endpoint.socket.sendto(
|
||||
b"xxx",
|
||||
client_endpoint.socket.getsockname(),
|
||||
)
|
||||
# now the client task should finish connecting and exit cleanly
|
||||
|
||||
# scapy.all.wrpcap("/tmp/trace.pcap", packets)
|
||||
|
||||
|
||||
async def test_initial_retransmit_timeout_configuration(
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
blackholed = True
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
nonlocal blackholed
|
||||
if blackholed:
|
||||
blackholed = False
|
||||
else:
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
for t in [1, 2, 4]:
|
||||
with endpoint() as client:
|
||||
before = trio.current_time()
|
||||
blackholed = True
|
||||
channel = client.connect(address, client_ctx)
|
||||
await channel.do_handshake(initial_retransmit_timeout=t)
|
||||
after = trio.current_time()
|
||||
assert after - before == t
|
||||
|
||||
|
||||
async def test_explicit_tiny_mtu_is_respected() -> None:
|
||||
# ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to
|
||||
# be larger than that. (300 is still smaller than any real network though.)
|
||||
MTU = 300
|
||||
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
print(f"delivering {packet}")
|
||||
print(f"payload size: {len(packet.payload)}")
|
||||
assert len(packet.payload) <= MTU
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server(mtu=MTU) as (server, address):
|
||||
with endpoint() as client:
|
||||
channel = client.connect(address, client_ctx)
|
||||
channel.set_ciphertext_mtu(MTU)
|
||||
await channel.do_handshake()
|
||||
await channel.send(b"hi")
|
||||
assert await channel.receive() == b"hi"
|
||||
|
||||
|
||||
@parametrize_ipv6
|
||||
async def test_handshake_handles_minimum_network_mtu(
|
||||
ipv6: bool,
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
# Fake network that has the minimum allowable MTU for whatever protocol we're using.
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
mtu = 1280 - 48 if ipv6 else 576 - 28
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
if len(packet.payload) > mtu:
|
||||
print(f"dropping {packet}")
|
||||
else:
|
||||
print(f"delivering {packet}")
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
# See if we can successfully do a handshake -- some of the volleys will get dropped,
|
||||
# and the retransmit logic should detect this and back off the MTU to something
|
||||
# smaller until it succeeds.
|
||||
async with dtls_echo_server(ipv6=ipv6) as (_, address):
|
||||
with endpoint(ipv6=ipv6) as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
# the handshake mtu backoff shouldn't affect the return value from
|
||||
# get_cleartext_mtu, b/c that's under the user's control via
|
||||
# set_ciphertext_mtu
|
||||
client.set_ciphertext_mtu(9999)
|
||||
await client.send(b"xyz")
|
||||
assert await client.receive() == b"xyz"
|
||||
assert client.get_cleartext_mtu() > 9000 # as vegeta said
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
async def test_system_task_cleaned_up_on_gc() -> None:
|
||||
before_tasks = trio.lowlevel.current_statistics().tasks_living
|
||||
|
||||
# We put this into a sub-function so that everything automatically becomes garbage
|
||||
# when the frame exits. For some reason just doing 'del e' wasn't enough on pypy
|
||||
# with coverage enabled -- I think we were hitting this bug:
|
||||
# https://foss.heptapod.net/pypy/pypy/-/issues/3656
|
||||
async def start_and_forget_endpoint() -> int:
|
||||
e = endpoint()
|
||||
|
||||
# This connection/handshake attempt can't succeed. The only purpose is to force
|
||||
# the endpoint to set up a receive loop.
|
||||
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
|
||||
await s.bind(("127.0.0.1", 0))
|
||||
c = e.connect(s.getsockname(), client_ctx)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(c.do_handshake)
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
during_tasks = trio.lowlevel.current_statistics().tasks_living
|
||||
return during_tasks
|
||||
|
||||
with pytest.warns(ResourceWarning):
|
||||
during_tasks = await start_and_forget_endpoint()
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
gc_collect_harder()
|
||||
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
after_tasks = trio.lowlevel.current_statistics().tasks_living
|
||||
assert before_tasks < during_tasks
|
||||
assert before_tasks == after_tasks
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
async def test_gc_before_system_task_starts() -> None:
|
||||
e = endpoint()
|
||||
|
||||
with pytest.warns(ResourceWarning):
|
||||
del e
|
||||
gc_collect_harder()
|
||||
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
async def test_gc_as_packet_received() -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
e = endpoint()
|
||||
await e.socket.bind(("127.0.0.1", 0))
|
||||
e._ensure_receive_loop()
|
||||
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
|
||||
await s.sendto(b"xxx", e.socket.getsockname())
|
||||
# At this point, the endpoint's receive loop has been marked runnable because it
|
||||
# just received a packet; closing the endpoint socket won't interrupt that. But by
|
||||
# the time it wakes up to process the packet, the endpoint will be gone.
|
||||
with pytest.warns(ResourceWarning):
|
||||
del e
|
||||
gc_collect_harder()
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
def test_gc_after_trio_exits() -> None:
|
||||
async def main() -> DTLSEndpoint:
|
||||
# We use fakenet just to make sure no real sockets can leak out of the test
|
||||
# case - on pypy somehow the socket was outliving the gc_collect_harder call
|
||||
# below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode
|
||||
# when called after trio exits, it doesn't need a real socket.
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
return endpoint()
|
||||
|
||||
e = trio.run(main)
|
||||
with pytest.warns(ResourceWarning):
|
||||
del e
|
||||
gc_collect_harder()
|
||||
|
||||
|
||||
async def test_already_closed_socket_doesnt_crash() -> None:
|
||||
with endpoint() as e:
|
||||
# We close the socket before checkpointing, so the socket will already be closed
|
||||
# when the system task starts up
|
||||
e.socket.close()
|
||||
# Now give it a chance to start up, and hopefully not crash
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
|
||||
async def test_socket_closed_while_processing_clienthello(
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
# Check what happens if the socket is discovered to be closed when sending a
|
||||
# HelloVerifyRequest, since that has its own sending logic
|
||||
async with dtls_echo_server() as (server, address):
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
fn.deliver_packet(packet)
|
||||
server.socket.close()
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
with endpoint() as client_endpoint:
|
||||
with trio.move_on_after(10):
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
await client.do_handshake()
|
||||
|
||||
|
||||
async def test_association_replaced_while_handshake_running(
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
pass
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
c1 = client_endpoint.connect(address, client_ctx)
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def doomed_handshake() -> None:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await c1.do_handshake()
|
||||
|
||||
nursery.start_soon(doomed_handshake)
|
||||
|
||||
await trio.sleep(10)
|
||||
|
||||
client_endpoint.connect(address, client_ctx)
|
||||
|
||||
|
||||
async def test_association_replaced_before_handshake_starts() -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
# This test shouldn't send any packets
|
||||
def route_packet(packet: UDPPacket) -> NoReturn: # pragma: no cover
|
||||
raise AssertionError()
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
c1 = client_endpoint.connect(address, client_ctx)
|
||||
client_endpoint.connect(address, client_ctx)
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await c1.do_handshake()
|
||||
|
||||
|
||||
async def test_send_to_closed_local_port() -> None:
|
||||
# On Windows, sending a UDP packet to a closed local port can cause a weird
|
||||
# ECONNRESET error later, inside the receive task. Make sure we're handling it
|
||||
# properly.
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(1, 10):
|
||||
channel = client_endpoint.connect(("127.0.0.1", i), client_ctx)
|
||||
nursery.start_soon(channel.do_handshake)
|
||||
channel = client_endpoint.connect(address, client_ctx)
|
||||
await channel.send(b"xxx")
|
||||
assert await channel.receive() == b"xxx"
|
||||
nursery.cancel_scope.cancel()
|
||||
@@ -0,0 +1,574 @@
|
||||
from __future__ import annotations # isort: split
|
||||
|
||||
import __future__ # Regular import, not special!
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import socket as stdlib_socket
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path, PurePath
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
from trio._tests.pytest_plugin import skip_if_optional_else_raise
|
||||
|
||||
from .. import _core, _util
|
||||
from .._core._tests.tutil import slow
|
||||
from .pytest_plugin import RUN_SLOW
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator
|
||||
|
||||
mypy_cache_updated = False
|
||||
|
||||
|
||||
try: # If installed, check both versions of this class.
|
||||
from typing_extensions import Protocol as Protocol_ext
|
||||
except ImportError: # pragma: no cover
|
||||
Protocol_ext = Protocol # type: ignore[assignment]
|
||||
|
||||
|
||||
def _ensure_mypy_cache_updated() -> None:
|
||||
# This pollutes the `empty` dir. Should this be changed?
|
||||
try:
|
||||
from mypy.api import run
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
global mypy_cache_updated
|
||||
if not mypy_cache_updated:
|
||||
# mypy cache was *probably* already updated by the other tests,
|
||||
# but `pytest -k ...` might run just this test on its own
|
||||
result = run(
|
||||
[
|
||||
"--config-file=",
|
||||
"--cache-dir=./.mypy_cache",
|
||||
"--no-error-summary",
|
||||
"-c",
|
||||
"import trio",
|
||||
],
|
||||
)
|
||||
assert not result[1] # stderr
|
||||
assert not result[0] # stdout
|
||||
mypy_cache_updated = True
|
||||
|
||||
|
||||
def test_core_is_properly_reexported() -> None:
|
||||
# Each export from _core should be re-exported by exactly one of these
|
||||
# three modules:
|
||||
sources = [trio, trio.lowlevel, trio.testing]
|
||||
for symbol in dir(_core):
|
||||
if symbol.startswith("_"):
|
||||
continue
|
||||
found = 0
|
||||
for source in sources:
|
||||
if symbol in dir(source) and getattr(source, symbol) is getattr(
|
||||
_core,
|
||||
symbol,
|
||||
):
|
||||
found += 1
|
||||
print(symbol, found)
|
||||
assert found == 1
|
||||
|
||||
|
||||
def class_is_final(cls: type) -> bool:
|
||||
"""Check if a class cannot be subclassed."""
|
||||
try:
|
||||
# new_class() handles metaclasses properly, type(...) does not.
|
||||
types.new_class("SubclassTester", (cls,))
|
||||
except TypeError:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def iter_modules(
|
||||
module: types.ModuleType,
|
||||
only_public: bool,
|
||||
) -> Iterator[types.ModuleType]:
|
||||
yield module
|
||||
for name, class_ in module.__dict__.items():
|
||||
if name.startswith("_") and only_public:
|
||||
continue
|
||||
if not isinstance(class_, ModuleType):
|
||||
continue
|
||||
if not class_.__name__.startswith(module.__name__): # pragma: no cover
|
||||
continue
|
||||
if class_ is module: # pragma: no cover
|
||||
continue
|
||||
yield from iter_modules(class_, only_public)
|
||||
|
||||
|
||||
PUBLIC_MODULES = list(iter_modules(trio, only_public=True))
|
||||
ALL_MODULES = list(iter_modules(trio, only_public=False))
|
||||
PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES]
|
||||
|
||||
|
||||
# It doesn't make sense for downstream redistributors to run this test, since
|
||||
# they might be using a newer version of Python with additional symbols which
|
||||
# won't be reflected in trio.socket, and this shouldn't cause downstream test
|
||||
# runs to start failing.
|
||||
@pytest.mark.redistributors_should_skip()
|
||||
# Static analysis tools often have trouble with alpha releases, where Python's
|
||||
# internals are in flux, grammar may not have settled down, etc.
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info.releaselevel == "alpha",
|
||||
reason="skip static introspection tools on Python dev/alpha releases",
|
||||
)
|
||||
@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES)
|
||||
@pytest.mark.parametrize("tool", ["pylint", "jedi", "mypy", "pyright_verifytypes"])
|
||||
@pytest.mark.filterwarnings(
|
||||
# https://github.com/pypa/setuptools/issues/3274
|
||||
"ignore:module 'sre_constants' is deprecated:DeprecationWarning",
|
||||
)
|
||||
def test_static_tool_sees_all_symbols(tool: str, modname: str, tmp_path: Path) -> None:
|
||||
module = importlib.import_module(modname)
|
||||
|
||||
def no_underscores(symbols: Iterable[str]) -> set[str]:
|
||||
return {symbol for symbol in symbols if not symbol.startswith("_")}
|
||||
|
||||
runtime_names = no_underscores(dir(module))
|
||||
|
||||
# ignore deprecated module `tests` being invisible
|
||||
if modname == "trio":
|
||||
runtime_names.discard("tests")
|
||||
|
||||
# Ignore any __future__ feature objects, if imported under that name.
|
||||
for name in __future__.all_feature_names:
|
||||
if getattr(module, name, None) is getattr(__future__, name):
|
||||
runtime_names.remove(name)
|
||||
|
||||
if tool == "pylint":
|
||||
try:
|
||||
from pylint.lint import PyLinter
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
linter = PyLinter()
|
||||
assert module.__file__ is not None
|
||||
ast = linter.get_ast(module.__file__, modname)
|
||||
static_names = no_underscores(ast) # type: ignore[arg-type]
|
||||
elif tool == "jedi":
|
||||
if sys.implementation.name != "cpython":
|
||||
pytest.skip("jedi does not support pypy")
|
||||
|
||||
try:
|
||||
import jedi
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
# Simulate typing "import trio; trio.<TAB>"
|
||||
script = jedi.Script(f"import {modname}; {modname}.")
|
||||
completions = script.complete()
|
||||
static_names = no_underscores(c.name for c in completions)
|
||||
elif tool == "mypy":
|
||||
if not RUN_SLOW: # pragma: no cover
|
||||
pytest.skip("use --run-slow to check against mypy")
|
||||
|
||||
cache = Path.cwd() / ".mypy_cache"
|
||||
|
||||
_ensure_mypy_cache_updated()
|
||||
|
||||
trio_cache = next(cache.glob("*/trio"))
|
||||
_, modname = (modname + ".").split(".", 1)
|
||||
modname = modname[:-1]
|
||||
mod_cache = trio_cache / modname if modname else trio_cache
|
||||
if mod_cache.is_dir(): # pragma: no coverage
|
||||
mod_cache = mod_cache / "__init__.data.json"
|
||||
else:
|
||||
mod_cache = trio_cache / (modname + ".data.json")
|
||||
|
||||
assert mod_cache.exists()
|
||||
assert mod_cache.is_file()
|
||||
with mod_cache.open() as cache_file:
|
||||
cache_json = json.loads(cache_file.read())
|
||||
static_names = no_underscores(
|
||||
key
|
||||
for key, value in cache_json["names"].items()
|
||||
if not key.startswith(".") and value["kind"] == "Gdef"
|
||||
)
|
||||
elif tool == "pyright_verifytypes":
|
||||
if not RUN_SLOW: # pragma: no cover
|
||||
pytest.skip("use --run-slow to check against pyright")
|
||||
|
||||
try:
|
||||
import pyright # noqa: F401
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
import subprocess
|
||||
|
||||
res = subprocess.run(
|
||||
["pyright", f"--verifytypes={modname}", "--outputjson"],
|
||||
capture_output=True,
|
||||
)
|
||||
current_result = json.loads(res.stdout)
|
||||
|
||||
static_names = {
|
||||
x["name"][len(modname) + 1 :]
|
||||
for x in current_result["typeCompleteness"]["symbols"]
|
||||
if x["name"].startswith(modname)
|
||||
}
|
||||
else: # pragma: no cover
|
||||
raise AssertionError()
|
||||
|
||||
# It's expected that the static set will contain more names than the
|
||||
# runtime set:
|
||||
# - static tools are sometimes sloppy and include deleted names
|
||||
# - some symbols are platform-specific at runtime, but always show up in
|
||||
# static analysis (e.g. in trio.socket or trio.lowlevel)
|
||||
# So we check that the runtime names are a subset of the static names.
|
||||
missing_names = runtime_names - static_names
|
||||
|
||||
# ignore warnings about deprecated module tests
|
||||
missing_names -= {"tests"}
|
||||
|
||||
if missing_names: # pragma: no cover
|
||||
print(f"{tool} can't see the following names in {modname}:")
|
||||
print()
|
||||
for name in sorted(missing_names):
|
||||
print(f" {name}")
|
||||
raise AssertionError()
|
||||
|
||||
|
||||
# this could be sped up by only invoking mypy once per module, or even once for all
|
||||
# modules, instead of once per class.
|
||||
@slow
|
||||
# see comment on test_static_tool_sees_all_symbols
|
||||
@pytest.mark.redistributors_should_skip()
|
||||
# Static analysis tools often have trouble with alpha releases, where Python's
|
||||
# internals are in flux, grammar may not have settled down, etc.
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info.releaselevel == "alpha",
|
||||
reason="skip static introspection tools on Python dev/alpha releases",
|
||||
)
|
||||
@pytest.mark.parametrize("module_name", PUBLIC_MODULE_NAMES)
|
||||
@pytest.mark.parametrize("tool", ["jedi", "mypy"])
|
||||
def test_static_tool_sees_class_members(
|
||||
tool: str,
|
||||
module_name: str,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)]
|
||||
|
||||
# ignore hidden, but not dunder, symbols
|
||||
def no_hidden(symbols: Iterable[str]) -> set[str]:
|
||||
return {
|
||||
symbol
|
||||
for symbol in symbols
|
||||
if (not symbol.startswith("_")) or symbol.startswith("__")
|
||||
}
|
||||
|
||||
if tool == "jedi" and sys.implementation.name != "cpython":
|
||||
pytest.skip("jedi does not support pypy")
|
||||
|
||||
if tool == "mypy":
|
||||
cache = Path.cwd() / ".mypy_cache"
|
||||
|
||||
_ensure_mypy_cache_updated()
|
||||
|
||||
trio_cache = next(cache.glob("*/trio"))
|
||||
modname = module_name
|
||||
_, modname = (modname + ".").split(".", 1)
|
||||
modname = modname[:-1]
|
||||
mod_cache = trio_cache / modname if modname else trio_cache
|
||||
if mod_cache.is_dir():
|
||||
mod_cache = mod_cache / "__init__.data.json"
|
||||
else:
|
||||
mod_cache = trio_cache / (modname + ".data.json")
|
||||
|
||||
assert mod_cache.exists()
|
||||
assert mod_cache.is_file()
|
||||
with mod_cache.open() as cache_file:
|
||||
cache_json = json.loads(cache_file.read())
|
||||
|
||||
# skip a bunch of file-system activity (probably can un-memoize?)
|
||||
@functools.lru_cache
|
||||
def lookup_symbol(symbol: str) -> dict[str, str]:
|
||||
topname, *modname, name = symbol.split(".")
|
||||
version = next(cache.glob("3.*/"))
|
||||
mod_cache = version / topname
|
||||
if not mod_cache.is_dir():
|
||||
mod_cache = version / (topname + ".data.json")
|
||||
|
||||
if modname:
|
||||
for piece in modname[:-1]:
|
||||
mod_cache /= piece
|
||||
next_cache = mod_cache / modname[-1]
|
||||
if next_cache.is_dir(): # pragma: no coverage
|
||||
mod_cache = next_cache / "__init__.data.json"
|
||||
else:
|
||||
mod_cache = mod_cache / (modname[-1] + ".data.json")
|
||||
elif mod_cache.is_dir():
|
||||
mod_cache /= "__init__.data.json"
|
||||
with mod_cache.open() as f:
|
||||
return json.loads(f.read())["names"][name] # type: ignore[no-any-return]
|
||||
|
||||
errors: dict[str, object] = {}
|
||||
for class_name, class_ in module.__dict__.items():
|
||||
if not isinstance(class_, type):
|
||||
continue
|
||||
if module_name == "trio.socket" and class_name in dir(stdlib_socket):
|
||||
continue
|
||||
|
||||
# ignore class that does dirty tricks
|
||||
if class_ is trio.testing.RaisesGroup:
|
||||
continue
|
||||
|
||||
# dir() and inspect.getmembers doesn't display properties from the metaclass
|
||||
# also ignore some dunder methods that tend to differ but are of no consequence
|
||||
ignore_names = set(dir(type(class_))) | {
|
||||
"__annotations__",
|
||||
"__attrs_attrs__",
|
||||
"__attrs_own_setattr__",
|
||||
"__callable_proto_members_only__",
|
||||
"__class_getitem__",
|
||||
"__final__",
|
||||
"__getstate__",
|
||||
"__match_args__",
|
||||
"__order__",
|
||||
"__orig_bases__",
|
||||
"__parameters__",
|
||||
"__protocol_attrs__",
|
||||
"__setstate__",
|
||||
"__slots__",
|
||||
"__weakref__",
|
||||
# ignore errors about dunders inherited from stdlib that tools might
|
||||
# not see
|
||||
"__copy__",
|
||||
"__deepcopy__",
|
||||
}
|
||||
|
||||
if type(class_) is type:
|
||||
# C extension classes don't have these dunders, but Python classes do
|
||||
ignore_names.add("__firstlineno__")
|
||||
ignore_names.add("__static_attributes__")
|
||||
|
||||
# pypy seems to have some additional dunders that differ
|
||||
if sys.implementation.name == "pypy":
|
||||
ignore_names |= {
|
||||
"__basicsize__",
|
||||
"__dictoffset__",
|
||||
"__itemsize__",
|
||||
"__sizeof__",
|
||||
"__weakrefoffset__",
|
||||
"__unicode__",
|
||||
}
|
||||
|
||||
# inspect.getmembers sees `name` and `value` in Enums, otherwise
|
||||
# it behaves the same way as `dir`
|
||||
# runtime_names = no_underscores(dir(class_))
|
||||
runtime_names = (
|
||||
no_hidden(x[0] for x in inspect.getmembers(class_)) - ignore_names
|
||||
)
|
||||
|
||||
if tool == "jedi":
|
||||
try:
|
||||
import jedi
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
script = jedi.Script(
|
||||
f"from {module_name} import {class_name}; {class_name}.",
|
||||
)
|
||||
completions = script.complete()
|
||||
static_names = no_hidden(c.name for c in completions) - ignore_names
|
||||
|
||||
elif tool == "mypy":
|
||||
# load the cached type information
|
||||
cached_type_info = cache_json["names"][class_name]
|
||||
if "node" not in cached_type_info:
|
||||
cached_type_info = lookup_symbol(cached_type_info["cross_ref"])
|
||||
|
||||
assert "node" in cached_type_info
|
||||
node = cached_type_info["node"]
|
||||
static_names = no_hidden(k for k in node["names"] if not k.startswith("."))
|
||||
for symbol in node["mro"][1:]:
|
||||
node = lookup_symbol(symbol)["node"]
|
||||
static_names |= no_hidden(
|
||||
k for k in node["names"] if not k.startswith(".")
|
||||
)
|
||||
static_names -= ignore_names
|
||||
|
||||
else: # pragma: no cover
|
||||
raise AssertionError("unknown tool")
|
||||
|
||||
missing = runtime_names - static_names
|
||||
extra = static_names - runtime_names
|
||||
|
||||
# using .remove() instead of .delete() to get an error in case they start not
|
||||
# being missing
|
||||
|
||||
if (
|
||||
tool == "jedi"
|
||||
and BaseException in class_.__mro__
|
||||
and sys.version_info >= (3, 11)
|
||||
):
|
||||
missing.remove("add_note")
|
||||
|
||||
if (
|
||||
tool == "mypy"
|
||||
and BaseException in class_.__mro__
|
||||
and sys.version_info >= (3, 11)
|
||||
):
|
||||
extra.remove("__notes__")
|
||||
|
||||
if tool == "mypy" and attrs.has(class_):
|
||||
# e.g. __trio__core__run_CancelScope_AttrsAttributes__
|
||||
before = len(extra)
|
||||
extra = {e for e in extra if not e.endswith("AttrsAttributes__")}
|
||||
assert len(extra) == before - 1
|
||||
|
||||
# mypy does not see these attributes in Enum subclasses
|
||||
if (
|
||||
tool == "mypy"
|
||||
and enum.Enum in class_.__mro__
|
||||
and sys.version_info >= (3, 12)
|
||||
):
|
||||
# Another attribute, in 3.12+ only.
|
||||
extra.remove("__signature__")
|
||||
|
||||
# TODO: this *should* be visible via `dir`!!
|
||||
if tool == "mypy" and class_ == trio.Nursery:
|
||||
extra.remove("cancel_scope")
|
||||
|
||||
# These are (mostly? solely?) *runtime* attributes, often set in
|
||||
# __init__, which doesn't show up with dir() or inspect.getmembers,
|
||||
# but we get them in the way we query mypy & jedi
|
||||
EXTRAS = {
|
||||
trio.DTLSChannel: {"peer_address", "endpoint"},
|
||||
trio.DTLSEndpoint: {"socket", "incoming_packets_buffer"},
|
||||
trio.Process: {"args", "pid", "stderr", "stdin", "stdio", "stdout"},
|
||||
trio.SSLListener: {"transport_listener"},
|
||||
trio.SSLStream: {"transport_stream"},
|
||||
trio.SocketListener: {"socket"},
|
||||
trio.SocketStream: {"socket"},
|
||||
trio.testing.MemoryReceiveStream: {"close_hook", "receive_some_hook"},
|
||||
trio.testing.MemorySendStream: {
|
||||
"close_hook",
|
||||
"send_all_hook",
|
||||
"wait_send_all_might_not_block_hook",
|
||||
},
|
||||
trio.testing.Matcher: {
|
||||
"exception_type",
|
||||
"match",
|
||||
"check",
|
||||
},
|
||||
}
|
||||
if tool == "mypy" and class_ in EXTRAS:
|
||||
before = len(extra)
|
||||
extra -= EXTRAS[class_]
|
||||
assert len(extra) == before - len(EXTRAS[class_])
|
||||
|
||||
# TODO: why is this? Is it a problem?
|
||||
# see https://github.com/python-trio/trio/pull/2631#discussion_r1185615916
|
||||
if class_ == trio.StapledStream:
|
||||
extra.remove("receive_stream")
|
||||
extra.remove("send_stream")
|
||||
|
||||
# I have not researched why these are missing, should maybe create an issue
|
||||
# upstream with jedi
|
||||
if tool == "jedi" and sys.version_info >= (3, 11):
|
||||
if class_ in (
|
||||
trio.DTLSChannel,
|
||||
trio.MemoryReceiveChannel,
|
||||
trio.MemorySendChannel,
|
||||
trio.SSLListener,
|
||||
trio.SocketListener,
|
||||
):
|
||||
missing.remove("__aenter__")
|
||||
missing.remove("__aexit__")
|
||||
if class_ in (trio.DTLSChannel, trio.MemoryReceiveChannel):
|
||||
missing.remove("__aiter__")
|
||||
missing.remove("__anext__")
|
||||
|
||||
if class_ in (trio.Path, trio.WindowsPath, trio.PosixPath):
|
||||
# These are from inherited subclasses.
|
||||
missing -= PurePath.__dict__.keys()
|
||||
# These are unix-only.
|
||||
if tool == "mypy" and sys.platform == "win32":
|
||||
missing -= {"owner", "is_mount", "group"}
|
||||
if tool == "jedi" and sys.platform == "win32":
|
||||
extra -= {"owner", "is_mount", "group"}
|
||||
|
||||
# not sure why jedi in particular ignores this (static?) method in 3.13
|
||||
# (especially given the method is from 3.12....)
|
||||
if (
|
||||
tool == "jedi"
|
||||
and sys.version_info >= (3, 13)
|
||||
and class_ in (trio.Path, trio.WindowsPath, trio.PosixPath)
|
||||
):
|
||||
missing.remove("with_segments")
|
||||
|
||||
if missing or extra: # pragma: no cover
|
||||
errors[f"{module_name}.{class_name}"] = {
|
||||
"missing": missing,
|
||||
"extra": extra,
|
||||
}
|
||||
|
||||
# `assert not errors` will not print the full content of errors, even with
|
||||
# `--verbose`, so we manually print it
|
||||
if errors: # pragma: no cover
|
||||
from pprint import pprint
|
||||
|
||||
print(f"\n{tool} can't see the following symbols in {module_name}:")
|
||||
pprint(errors)
|
||||
assert not errors
|
||||
|
||||
|
||||
def test_nopublic_is_final() -> None:
|
||||
"""Check all NoPublicConstructor classes are also @final."""
|
||||
assert class_is_final(_util.NoPublicConstructor) # This is itself final.
|
||||
|
||||
for module in ALL_MODULES:
|
||||
for class_ in module.__dict__.values():
|
||||
if isinstance(class_, _util.NoPublicConstructor):
|
||||
assert class_is_final(class_)
|
||||
|
||||
|
||||
def test_classes_are_final() -> None:
|
||||
# Sanity checks.
|
||||
assert not class_is_final(object)
|
||||
assert class_is_final(bool)
|
||||
|
||||
for module in PUBLIC_MODULES:
|
||||
for name, class_ in module.__dict__.items():
|
||||
if not isinstance(class_, type):
|
||||
continue
|
||||
# Deprecated classes are exported with a leading underscore
|
||||
if name.startswith("_"): # pragma: no cover
|
||||
continue
|
||||
|
||||
# Abstract classes can be subclassed, because that's the whole
|
||||
# point of ABCs
|
||||
if inspect.isabstract(class_):
|
||||
continue
|
||||
# Same with protocols, but only direct children.
|
||||
if Protocol in class_.__bases__ or Protocol_ext in class_.__bases__:
|
||||
continue
|
||||
# Exceptions are allowed to be subclassed, because exception
|
||||
# subclassing isn't used to inherit behavior.
|
||||
if issubclass(class_, BaseException):
|
||||
continue
|
||||
# These are classes that are conceptually abstract, but
|
||||
# inspect.isabstract returns False for boring reasons.
|
||||
if class_ is trio.abc.Instrument or class_ is trio.socket.SocketType:
|
||||
continue
|
||||
# ... insert other special cases here ...
|
||||
|
||||
# The `Path` class needs to support inheritance to allow `WindowsPath` and `PosixPath`.
|
||||
if class_ is trio.Path:
|
||||
continue
|
||||
# don't care about the *Statistics classes
|
||||
if name.endswith("Statistics"):
|
||||
continue
|
||||
|
||||
assert class_is_final(class_)
|
||||
@@ -0,0 +1,313 @@
|
||||
import errno
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.testing._fake_net import FakeNet
|
||||
|
||||
# ENOTCONN gives different messages on different platforms
|
||||
if sys.platform == "linux":
|
||||
ENOTCONN_MSG = r"^\[Errno 107\] (Transport endpoint is|Socket) not connected$"
|
||||
elif sys.platform == "darwin":
|
||||
ENOTCONN_MSG = r"^\[Errno 57\] Socket is not connected$"
|
||||
else:
|
||||
ENOTCONN_MSG = r"^\[Errno 10057\] Unknown error$"
|
||||
|
||||
|
||||
def fn() -> FakeNet:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
return fn
|
||||
|
||||
|
||||
async def test_basic_udp() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
ip, port = s1.getsockname()
|
||||
assert ip == "127.0.0.1"
|
||||
assert port != 0
|
||||
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^\[\w+ \d+\] Invalid argument$",
|
||||
) as exc: # Cannot rebind.
|
||||
await s1.bind(("192.0.2.1", 0))
|
||||
assert exc.value.errno == errno.EINVAL
|
||||
|
||||
# Cannot bind multiple sockets to the same address
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^\[\w+ \d+\] (Address (already )?in use|Unknown error)$",
|
||||
) as exc:
|
||||
await s2.bind(("127.0.0.1", port))
|
||||
assert exc.value.errno == errno.EADDRINUSE
|
||||
|
||||
await s2.sendto(b"xyz", s1.getsockname())
|
||||
data, addr = await s1.recvfrom(10)
|
||||
assert data == b"xyz"
|
||||
assert addr == s2.getsockname()
|
||||
await s1.sendto(b"abc", s2.getsockname())
|
||||
data, addr = await s2.recvfrom(10)
|
||||
assert data == b"abc"
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
|
||||
async def test_msg_trunc() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
await s2.sendto(b"xyz", s1.getsockname())
|
||||
data, addr = await s1.recvfrom(10)
|
||||
|
||||
|
||||
async def test_recv_methods() -> None:
|
||||
"""Test all recv methods for codecov"""
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
|
||||
# receiving on an unbound socket is a bad idea (I think?)
|
||||
with pytest.raises(NotImplementedError, match="code will most likely hang"):
|
||||
await s2.recv(10)
|
||||
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
ip, port = s1.getsockname()
|
||||
assert ip == "127.0.0.1"
|
||||
assert port != 0
|
||||
|
||||
# recvfrom
|
||||
await s2.sendto(b"abc", s1.getsockname())
|
||||
data, addr = await s1.recvfrom(10)
|
||||
assert data == b"abc"
|
||||
assert addr == s2.getsockname()
|
||||
|
||||
# recv
|
||||
await s1.sendto(b"def", s2.getsockname())
|
||||
data = await s2.recv(10)
|
||||
assert data == b"def"
|
||||
|
||||
# recvfrom_into
|
||||
assert await s1.sendto(b"ghi", s2.getsockname()) == 3
|
||||
buf = bytearray(10)
|
||||
|
||||
with pytest.raises(NotImplementedError, match="^partial recvfrom_into$"):
|
||||
(nbytes, addr) = await s2.recvfrom_into(buf, nbytes=2)
|
||||
|
||||
(nbytes, addr) = await s2.recvfrom_into(buf)
|
||||
assert nbytes == 3
|
||||
assert buf == b"ghi" + b"\x00" * 7
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
# recv_into
|
||||
assert await s1.sendto(b"jkl", s2.getsockname()) == 3
|
||||
buf2 = bytearray(10)
|
||||
nbytes = await s2.recv_into(buf2)
|
||||
assert nbytes == 3
|
||||
assert buf2 == b"jkl" + b"\x00" * 7
|
||||
|
||||
if sys.platform == "linux" and sys.implementation.name == "cpython":
|
||||
flags: int = socket.MSG_MORE
|
||||
else:
|
||||
flags = 1
|
||||
|
||||
# Send seems explicitly non-functional
|
||||
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
|
||||
await s2.send(b"mno")
|
||||
assert exc.value.errno == errno.ENOTCONN
|
||||
with pytest.raises(NotImplementedError, match="^FakeNet send flags must be 0, not"):
|
||||
await s2.send(b"mno", flags)
|
||||
|
||||
# sendto errors
|
||||
# it's successfully used earlier
|
||||
with pytest.raises(NotImplementedError, match="^FakeNet send flags must be 0, not"):
|
||||
await s2.sendto(b"mno", flags, s1.getsockname())
|
||||
with pytest.raises(TypeError, match="wrong number of arguments$"):
|
||||
await s2.sendto(b"mno", flags, s1.getsockname(), "extra arg") # type: ignore[call-overload]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="functions not in socket on windows",
|
||||
)
|
||||
async def test_nonwindows_functionality() -> None:
|
||||
# mypy doesn't support a good way of aborting typechecking on different platforms
|
||||
if sys.platform != "win32": # pragma: no branch
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
await s2.bind(("127.0.0.1", 0))
|
||||
|
||||
# sendmsg
|
||||
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
|
||||
await s2.sendmsg([b"mno"])
|
||||
assert exc.value.errno == errno.ENOTCONN
|
||||
|
||||
assert await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) == 3
|
||||
(data, ancdata, msg_flags, addr) = await s2.recvmsg(10)
|
||||
assert data == b"jkl"
|
||||
assert ancdata == []
|
||||
assert msg_flags == 0
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
# TODO: recvmsg
|
||||
|
||||
# recvmsg_into
|
||||
assert await s1.sendto(b"xyzw", s2.getsockname()) == 4
|
||||
buf1 = bytearray(2)
|
||||
buf2 = bytearray(3)
|
||||
ret = await s2.recvmsg_into([buf1, buf2])
|
||||
(nbytes, ancdata, msg_flags, addr) = ret
|
||||
assert nbytes == 4
|
||||
assert buf1 == b"xy"
|
||||
assert buf2 == b"zw" + b"\x00"
|
||||
assert ancdata == []
|
||||
assert msg_flags == 0
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
# recvmsg_into with MSG_TRUNC set
|
||||
assert await s1.sendto(b"xyzwv", s2.getsockname()) == 5
|
||||
buf1 = bytearray(2)
|
||||
ret = await s2.recvmsg_into([buf1])
|
||||
(nbytes, ancdata, msg_flags, addr) = ret
|
||||
assert nbytes == 2
|
||||
assert buf1 == b"xy"
|
||||
assert ancdata == []
|
||||
assert msg_flags == socket.MSG_TRUNC
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match="^'FakeSocket' object has no attribute 'share'$",
|
||||
):
|
||||
await s1.share(0) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "win32",
|
||||
reason="windows-specific fakesocket testing",
|
||||
)
|
||||
async def test_windows_functionality() -> None:
|
||||
# mypy doesn't support a good way of aborting typechecking on different platforms
|
||||
if sys.platform == "win32": # pragma: no branch
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match="^'FakeSocket' object has no attribute 'sendmsg'$",
|
||||
):
|
||||
await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) # type: ignore[attr-defined]
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match="^'FakeSocket' object has no attribute 'recvmsg'$",
|
||||
):
|
||||
s2.recvmsg(0) # type: ignore[attr-defined]
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match="^'FakeSocket' object has no attribute 'recvmsg_into'$",
|
||||
):
|
||||
s2.recvmsg_into([]) # type: ignore[attr-defined]
|
||||
with pytest.raises(NotImplementedError):
|
||||
s1.share(0)
|
||||
|
||||
|
||||
async def test_basic_tcp() -> None:
|
||||
fn()
|
||||
with pytest.raises(NotImplementedError):
|
||||
trio.socket.socket()
|
||||
|
||||
|
||||
async def test_not_implemented_functions() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
|
||||
# getsockopt
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^FakeNet doesn't implement getsockopt\(\d, \d\)$",
|
||||
):
|
||||
s1.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
|
||||
|
||||
# setsockopt
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="^FakeNet always has IPV6_V6ONLY=True$",
|
||||
):
|
||||
s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$",
|
||||
):
|
||||
s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True)
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$",
|
||||
):
|
||||
s1.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
# set_inheritable
|
||||
s1.set_inheritable(False)
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="^FakeNet can't make inheritable sockets$",
|
||||
):
|
||||
s1.set_inheritable(True)
|
||||
|
||||
# get_inheritable
|
||||
assert not s1.get_inheritable()
|
||||
|
||||
|
||||
async def test_getpeername() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
|
||||
s1.getpeername()
|
||||
assert exc.value.errno == errno.ENOTCONN
|
||||
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="^This method seems to assume that self._binding has a remote UDPEndpoint$",
|
||||
):
|
||||
s1.getpeername()
|
||||
|
||||
|
||||
async def test_init() -> None:
|
||||
fn()
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match=re.escape(
|
||||
f"FakeNet doesn't (yet) support type={trio.socket.SOCK_STREAM}",
|
||||
),
|
||||
):
|
||||
s1 = trio.socket.socket()
|
||||
|
||||
# getsockname on unbound ipv4 socket
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
assert s1.getsockname() == ("0.0.0.0", 0)
|
||||
|
||||
# getsockname on bound ipv4 socket
|
||||
await s1.bind(("0.0.0.0", 0))
|
||||
ip, port = s1.getsockname()
|
||||
assert ip == "127.0.0.1"
|
||||
assert port != 0
|
||||
|
||||
# getsockname on unbound ipv6 socket
|
||||
s2 = trio.socket.socket(family=socket.AF_INET6, type=socket.SOCK_DGRAM)
|
||||
assert s2.getsockname() == ("::", 0)
|
||||
|
||||
# getsockname on bound ipv6 socket
|
||||
await s2.bind(("::", 0))
|
||||
ip, port, *_ = s2.getsockname()
|
||||
assert ip == "::1"
|
||||
assert port != 0
|
||||
assert _ == [0, 0]
|
||||
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest import mock
|
||||
from unittest.mock import sentinel
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio import _core, _file_io
|
||||
from trio._file_io import _FILE_ASYNC_METHODS, _FILE_SYNC_ATTRS, AsyncIOWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pathlib
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def path(tmp_path: pathlib.Path) -> str:
|
||||
return os.fspath(tmp_path / "test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wrapped() -> mock.Mock:
|
||||
return mock.Mock(spec_set=io.StringIO)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_file(wrapped: mock.Mock) -> AsyncIOWrapper[mock.Mock]:
|
||||
return trio.wrap_file(wrapped)
|
||||
|
||||
|
||||
def test_wrap_invalid() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
trio.wrap_file("")
|
||||
|
||||
|
||||
def test_wrap_non_iobase() -> None:
|
||||
class FakeFile:
|
||||
def close(self) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
def write(self) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
wrapped = FakeFile()
|
||||
assert not isinstance(wrapped, io.IOBase)
|
||||
|
||||
async_file = trio.wrap_file(wrapped)
|
||||
assert isinstance(async_file, AsyncIOWrapper)
|
||||
|
||||
del FakeFile.write
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
trio.wrap_file(FakeFile())
|
||||
|
||||
|
||||
def test_wrapped_property(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
assert async_file.wrapped is wrapped
|
||||
|
||||
|
||||
def test_dir_matches_wrapped(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS)
|
||||
|
||||
# all supported attrs in wrapped should be available in async_file
|
||||
assert all(attr in dir(async_file) for attr in attrs if attr in dir(wrapped))
|
||||
# all supported attrs not in wrapped should not be available in async_file
|
||||
assert not any(
|
||||
attr in dir(async_file) for attr in attrs if attr not in dir(wrapped)
|
||||
)
|
||||
|
||||
|
||||
def test_unsupported_not_forwarded() -> None:
|
||||
class FakeFile(io.RawIOBase):
|
||||
def unsupported_attr(self) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
async_file = trio.wrap_file(FakeFile())
|
||||
|
||||
assert hasattr(async_file.wrapped, "unsupported_attr")
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
# B018 "useless expression"
|
||||
async_file.unsupported_attr # type: ignore[attr-defined] # noqa: B018
|
||||
|
||||
|
||||
def test_type_stubs_match_lists() -> None:
|
||||
"""Check the manual stubs match the list of wrapped methods."""
|
||||
# Fetch the module's source code.
|
||||
assert _file_io.__spec__ is not None
|
||||
loader = _file_io.__spec__.loader
|
||||
assert isinstance(loader, importlib.abc.SourceLoader)
|
||||
source = io.StringIO(loader.get_source("trio._file_io"))
|
||||
|
||||
# Find the class, then find the TYPE_CHECKING block.
|
||||
for line in source:
|
||||
if "class AsyncIOWrapper" in line:
|
||||
break
|
||||
else: # pragma: no cover - should always find this
|
||||
pytest.fail("No class definition line?")
|
||||
|
||||
for line in source:
|
||||
if "if TYPE_CHECKING" in line:
|
||||
break
|
||||
else: # pragma: no cover - should always find this
|
||||
pytest.fail("No TYPE CHECKING line?")
|
||||
|
||||
# Now we should be at the type checking block.
|
||||
found: list[tuple[str, str]] = []
|
||||
for line in source: # pragma: no branch - expected to break early
|
||||
if line.strip() and not line.startswith(" " * 8):
|
||||
break # Dedented out of the if TYPE_CHECKING block.
|
||||
match = re.match(r"\s*(async )?def ([a-zA-Z0-9_]+)\(", line)
|
||||
if match is not None:
|
||||
kind = "async" if match.group(1) is not None else "sync"
|
||||
found.append((match.group(2), kind))
|
||||
|
||||
# Compare two lists so that we can easily see duplicates, and see what is different overall.
|
||||
expected = [(fname, "async") for fname in _FILE_ASYNC_METHODS]
|
||||
expected += [(fname, "sync") for fname in _FILE_SYNC_ATTRS]
|
||||
# Ignore order, error if duplicates are present.
|
||||
found.sort()
|
||||
expected.sort()
|
||||
assert found == expected
|
||||
|
||||
|
||||
def test_sync_attrs_forwarded(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for attr_name in _FILE_SYNC_ATTRS:
|
||||
if attr_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
assert getattr(async_file, attr_name) is getattr(wrapped, attr_name)
|
||||
|
||||
|
||||
def test_sync_attrs_match_wrapper(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for attr_name in _FILE_SYNC_ATTRS:
|
||||
if attr_name in dir(async_file):
|
||||
continue
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(async_file, attr_name)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(wrapped, attr_name)
|
||||
|
||||
|
||||
def test_async_methods_generated_once(async_file: AsyncIOWrapper[mock.Mock]) -> None:
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
assert getattr(async_file, meth_name) is getattr(async_file, meth_name)
|
||||
|
||||
|
||||
# I gave up on typing this one
|
||||
def test_async_methods_signature(async_file: AsyncIOWrapper[mock.Mock]) -> None:
|
||||
# use read as a representative of all async methods
|
||||
assert async_file.read.__name__ == "read"
|
||||
assert async_file.read.__qualname__ == "AsyncIOWrapper.read"
|
||||
|
||||
assert async_file.read.__doc__ is not None
|
||||
assert "io.StringIO.read" in async_file.read.__doc__
|
||||
|
||||
|
||||
async def test_async_methods_wrap(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
meth = getattr(async_file, meth_name)
|
||||
wrapped_meth = getattr(wrapped, meth_name)
|
||||
|
||||
value = await meth(sentinel.argument, keyword=sentinel.keyword)
|
||||
|
||||
wrapped_meth.assert_called_once_with(
|
||||
sentinel.argument,
|
||||
keyword=sentinel.keyword,
|
||||
)
|
||||
assert value == wrapped_meth()
|
||||
|
||||
wrapped.reset_mock()
|
||||
|
||||
|
||||
async def test_async_methods_match_wrapper(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name in dir(async_file):
|
||||
continue
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(async_file, meth_name)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(wrapped, meth_name)
|
||||
|
||||
|
||||
async def test_open(path: pathlib.Path) -> None:
|
||||
f = await trio.open_file(path, "w")
|
||||
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
|
||||
await f.aclose()
|
||||
|
||||
|
||||
async def test_open_context_manager(path: pathlib.Path) -> None:
|
||||
async with await trio.open_file(path, "w") as f:
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
assert not f.closed
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
async def test_async_iter() -> None:
|
||||
async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar"))
|
||||
expected = list(async_file.wrapped)
|
||||
async_file.wrapped.seek(0)
|
||||
|
||||
result = [line async for line in async_file]
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
async def test_aclose_cancelled(path: pathlib.Path) -> None:
|
||||
with _core.CancelScope() as cscope:
|
||||
f = await trio.open_file(path, "w")
|
||||
cscope.cancel()
|
||||
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await f.write("a")
|
||||
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await f.aclose()
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
async def test_detach_rewraps_asynciobase(tmp_path: pathlib.Path) -> None:
|
||||
tmp_file = tmp_path / "filename"
|
||||
tmp_file.touch()
|
||||
# flake8-async does not like opening files in async mode
|
||||
with open(tmp_file, mode="rb", buffering=0) as raw: # noqa: ASYNC230
|
||||
buffered = io.BufferedReader(raw)
|
||||
|
||||
async_file = trio.wrap_file(buffered)
|
||||
|
||||
detached = await async_file.detach()
|
||||
|
||||
assert isinstance(detached, AsyncIOWrapper)
|
||||
assert detached.wrapped is raw
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NoReturn
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from .._highlevel_generic import StapledStream
|
||||
from ..abc import ReceiveStream, SendStream
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class RecordSendStream(SendStream):
|
||||
record: list[str | tuple[str, object]] = attrs.Factory(list)
|
||||
|
||||
async def send_all(self, data: object) -> None:
|
||||
self.record.append(("send_all", data))
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
self.record.append("wait_send_all_might_not_block")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.record.append("aclose")
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class RecordReceiveStream(ReceiveStream):
|
||||
record: list[str | tuple[str, int | None]] = attrs.Factory(list)
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes:
|
||||
self.record.append(("receive_some", max_bytes))
|
||||
return b""
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.record.append("aclose")
|
||||
|
||||
|
||||
async def test_StapledStream() -> None:
|
||||
send_stream = RecordSendStream()
|
||||
receive_stream = RecordReceiveStream()
|
||||
stapled = StapledStream(send_stream, receive_stream)
|
||||
|
||||
assert stapled.send_stream is send_stream
|
||||
assert stapled.receive_stream is receive_stream
|
||||
|
||||
await stapled.send_all(b"foo")
|
||||
await stapled.wait_send_all_might_not_block()
|
||||
assert send_stream.record == [
|
||||
("send_all", b"foo"),
|
||||
"wait_send_all_might_not_block",
|
||||
]
|
||||
send_stream.record.clear()
|
||||
|
||||
await stapled.send_eof()
|
||||
assert send_stream.record == ["aclose"]
|
||||
send_stream.record.clear()
|
||||
|
||||
async def fake_send_eof() -> None:
|
||||
send_stream.record.append("send_eof")
|
||||
|
||||
send_stream.send_eof = fake_send_eof # type: ignore[attr-defined]
|
||||
await stapled.send_eof()
|
||||
assert send_stream.record == ["send_eof"]
|
||||
|
||||
send_stream.record.clear()
|
||||
assert receive_stream.record == []
|
||||
|
||||
await stapled.receive_some(1234)
|
||||
assert receive_stream.record == [("receive_some", 1234)]
|
||||
assert send_stream.record == []
|
||||
receive_stream.record.clear()
|
||||
|
||||
await stapled.aclose()
|
||||
assert receive_stream.record == ["aclose"]
|
||||
assert send_stream.record == ["aclose"]
|
||||
|
||||
|
||||
async def test_StapledStream_with_erroring_close() -> None:
|
||||
# Make sure that if one of the aclose methods errors out, then the other
|
||||
# one still gets called.
|
||||
class BrokenSendStream(RecordSendStream):
|
||||
async def aclose(self) -> NoReturn:
|
||||
await super().aclose()
|
||||
raise ValueError("send error")
|
||||
|
||||
class BrokenReceiveStream(RecordReceiveStream):
|
||||
async def aclose(self) -> NoReturn:
|
||||
await super().aclose()
|
||||
raise ValueError("recv error")
|
||||
|
||||
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
|
||||
|
||||
with pytest.raises(ValueError, match="^(send|recv) error$") as excinfo:
|
||||
await stapled.aclose()
|
||||
assert isinstance(excinfo.value.__context__, ValueError)
|
||||
|
||||
assert stapled.send_stream.record == ["aclose"]
|
||||
assert stapled.receive_stream.record == ["aclose"]
|
||||
@@ -0,0 +1,410 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import socket as stdlib_socket
|
||||
import sys
|
||||
from socket import AddressFamily, SocketKind
|
||||
from typing import TYPE_CHECKING, Any, Sequence, overload
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio import (
|
||||
SocketListener,
|
||||
open_tcp_listeners,
|
||||
open_tcp_stream,
|
||||
serve_tcp,
|
||||
)
|
||||
from trio.abc import HostnameResolver, SendStream, SocketFactory
|
||||
from trio.testing import open_stream_to_socket_listener
|
||||
|
||||
from .. import socket as tsocket
|
||||
from .._core._tests.tutil import binds_ipv6
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Buffer
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_basic() -> None:
|
||||
listeners = await open_tcp_listeners(0)
|
||||
assert isinstance(listeners, list)
|
||||
for obj in listeners:
|
||||
assert isinstance(obj, SocketListener)
|
||||
# Binds to wildcard address by default
|
||||
assert obj.socket.family in [tsocket.AF_INET, tsocket.AF_INET6]
|
||||
assert obj.socket.getsockname()[0] in ["0.0.0.0", "::"]
|
||||
|
||||
listener = listeners[0]
|
||||
# Make sure the backlog is at least 2
|
||||
c1 = await open_stream_to_socket_listener(listener)
|
||||
c2 = await open_stream_to_socket_listener(listener)
|
||||
|
||||
s1 = await listener.accept()
|
||||
s2 = await listener.accept()
|
||||
|
||||
# Note that we don't know which client stream is connected to which server
|
||||
# stream
|
||||
await s1.send_all(b"x")
|
||||
await s2.send_all(b"x")
|
||||
assert await c1.receive_some(1) == b"x"
|
||||
assert await c2.receive_some(1) == b"x"
|
||||
|
||||
for resource in [c1, c2, s1, s2, *listeners]:
|
||||
await resource.aclose()
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_specific_port_specific_host() -> None:
|
||||
# Pick a port
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("127.0.0.1", 0))
|
||||
host, port = sock.getsockname()
|
||||
sock.close()
|
||||
|
||||
(listener,) = await open_tcp_listeners(port, host=host)
|
||||
async with listener:
|
||||
assert listener.socket.getsockname() == (host, port)
|
||||
|
||||
|
||||
@binds_ipv6
|
||||
async def test_open_tcp_listeners_ipv6_v6only() -> None:
|
||||
# Check IPV6_V6ONLY is working properly
|
||||
(ipv6_listener,) = await open_tcp_listeners(0, host="::1")
|
||||
async with ipv6_listener:
|
||||
_, port, *_ = ipv6_listener.socket.getsockname()
|
||||
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"(Error|all attempts to) connect(ing)* to (\(')*127\.0\.0\.1(', |:)\d+(\): Connection refused| failed)$",
|
||||
):
|
||||
await open_tcp_stream("127.0.0.1", port)
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_rebind() -> None:
|
||||
(l1,) = await open_tcp_listeners(0, host="127.0.0.1")
|
||||
sockaddr1 = l1.socket.getsockname()
|
||||
|
||||
# Plain old rebinding while it's still there should fail, even if we have
|
||||
# SO_REUSEADDR set
|
||||
with stdlib_socket.socket() as probe:
|
||||
probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match="(Address (already )?in use|An attempt was made to access a socket in a way forbidden by its access permissions)$",
|
||||
):
|
||||
probe.bind(sockaddr1)
|
||||
|
||||
# Now use the first listener to set up some connections in various states,
|
||||
# and make sure that they don't create any obstacle to rebinding a second
|
||||
# listener after the first one is closed.
|
||||
c_established = await open_stream_to_socket_listener(l1)
|
||||
s_established = await l1.accept()
|
||||
|
||||
c_time_wait = await open_stream_to_socket_listener(l1)
|
||||
s_time_wait = await l1.accept()
|
||||
# Server-initiated close leaves socket in TIME_WAIT
|
||||
await s_time_wait.aclose()
|
||||
|
||||
await l1.aclose()
|
||||
(l2,) = await open_tcp_listeners(sockaddr1[1], host="127.0.0.1")
|
||||
sockaddr2 = l2.socket.getsockname()
|
||||
|
||||
assert sockaddr1 == sockaddr2
|
||||
assert s_established.socket.getsockname() == sockaddr2
|
||||
assert c_time_wait.socket.getpeername() == sockaddr2
|
||||
|
||||
for resource in [
|
||||
l1,
|
||||
l2,
|
||||
c_established,
|
||||
s_established,
|
||||
c_time_wait,
|
||||
s_time_wait,
|
||||
]:
|
||||
await resource.aclose()
|
||||
|
||||
|
||||
class FakeOSError(OSError):
|
||||
pass
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class FakeSocket(tsocket.SocketType):
|
||||
_family: AddressFamily = attrs.field(converter=AddressFamily)
|
||||
_type: SocketKind = attrs.field(converter=SocketKind)
|
||||
_proto: int
|
||||
|
||||
closed: bool = False
|
||||
poison_listen: bool = False
|
||||
backlog: int | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> SocketKind:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def family(self) -> AddressFamily:
|
||||
return self._family
|
||||
|
||||
@property
|
||||
def proto(self) -> int: # pragma: no cover
|
||||
return self._proto
|
||||
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int) -> int: ...
|
||||
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ...
|
||||
|
||||
def getsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
buflen: int | None = None,
|
||||
) -> int | bytes:
|
||||
if (level, optname) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN):
|
||||
return True
|
||||
raise AssertionError() # pragma: no cover
|
||||
|
||||
@overload
|
||||
def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ...
|
||||
|
||||
@overload
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: None,
|
||||
optlen: int,
|
||||
) -> None: ...
|
||||
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: int | Buffer | None,
|
||||
optlen: int | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def bind(self, address: Any) -> None:
|
||||
pass
|
||||
|
||||
def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None:
|
||||
assert self.backlog is None
|
||||
assert backlog is not None
|
||||
self.backlog = backlog
|
||||
if self.poison_listen:
|
||||
raise FakeOSError("whoops")
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class FakeSocketFactory(SocketFactory):
|
||||
poison_after: int
|
||||
sockets: list[tsocket.SocketType] = attrs.Factory(list)
|
||||
raise_on_family: dict[AddressFamily, int] = attrs.Factory(dict) # family => errno
|
||||
|
||||
def socket(
|
||||
self,
|
||||
family: AddressFamily | int | None = None,
|
||||
type_: SocketKind | int | None = None,
|
||||
proto: int = 0,
|
||||
) -> tsocket.SocketType:
|
||||
assert family is not None
|
||||
assert type_ is not None
|
||||
if isinstance(family, int) and not isinstance(family, AddressFamily):
|
||||
family = AddressFamily(family) # pragma: no cover
|
||||
if family in self.raise_on_family:
|
||||
raise OSError(self.raise_on_family[family], "nope")
|
||||
sock = FakeSocket(family, type_, proto)
|
||||
self.poison_after -= 1
|
||||
if self.poison_after == 0:
|
||||
sock.poison_listen = True
|
||||
self.sockets.append(sock)
|
||||
return sock
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class FakeHostnameResolver(HostnameResolver):
|
||||
family_addr_pairs: Sequence[tuple[AddressFamily, str]]
|
||||
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = 0,
|
||||
type: int = 0,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
) -> list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
assert isinstance(port, int)
|
||||
return [
|
||||
(family, tsocket.SOCK_STREAM, 0, "", (addr, port))
|
||||
for family, addr in self.family_addr_pairs
|
||||
]
|
||||
|
||||
async def getnameinfo(
|
||||
self,
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
flags: int,
|
||||
) -> tuple[str, str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None:
|
||||
# If we were trying to bind to multiple hosts and one of them failed, they
|
||||
# call get cleaned up before returning
|
||||
fsf = FakeSocketFactory(3)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver(
|
||||
[
|
||||
(tsocket.AF_INET, "1.1.1.1"),
|
||||
(tsocket.AF_INET, "2.2.2.2"),
|
||||
(tsocket.AF_INET, "3.3.3.3"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(FakeOSError):
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
|
||||
assert len(fsf.sockets) == 3
|
||||
for sock in fsf.sockets:
|
||||
# property only exists on FakeSocket
|
||||
assert sock.closed # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_port_checking() -> None:
|
||||
for host in ["127.0.0.1", None]:
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners(None, host=host) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners(b"80", host=host) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners("http", host=host) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def test_serve_tcp() -> None:
|
||||
async def handler(stream: SendStream) -> None:
|
||||
await stream.send_all(b"x")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# nursery.start is incorrectly typed, awaiting #2773
|
||||
listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0)
|
||||
stream = await open_stream_to_socket_listener(listeners[0])
|
||||
async with stream:
|
||||
assert await stream.receive_some(1) == b"x"
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"try_families",
|
||||
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"fail_families",
|
||||
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
|
||||
)
|
||||
async def test_open_tcp_listeners_some_address_families_unavailable(
|
||||
try_families: set[AddressFamily],
|
||||
fail_families: set[AddressFamily],
|
||||
) -> None:
|
||||
fsf = FakeSocketFactory(
|
||||
10,
|
||||
raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families},
|
||||
)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver([(family, "foo") for family in try_families]),
|
||||
)
|
||||
|
||||
should_succeed = try_families - fail_families
|
||||
|
||||
if not should_succeed:
|
||||
with pytest.raises(OSError, match="This system doesn't support") as exc_info:
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
|
||||
# open_listeners always creates an exceptiongroup with the
|
||||
# unsupported address families, regardless of the value of
|
||||
# strict_exception_groups or number of unsupported families.
|
||||
assert isinstance(exc_info.value.__cause__, BaseExceptionGroup)
|
||||
for subexc in exc_info.value.__cause__.exceptions:
|
||||
assert "nope" in str(subexc)
|
||||
else:
|
||||
listeners = await open_tcp_listeners(80)
|
||||
for listener in listeners:
|
||||
should_succeed.remove(listener.socket.family)
|
||||
assert not should_succeed
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None:
|
||||
fsf = FakeSocketFactory(
|
||||
10,
|
||||
raise_on_family={
|
||||
tsocket.AF_INET: errno.EAFNOSUPPORT,
|
||||
tsocket.AF_INET6: errno.EINVAL,
|
||||
},
|
||||
)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")]),
|
||||
)
|
||||
|
||||
with pytest.raises(OSError, match="nope") as exc_info:
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
assert exc_info.value.errno == errno.EINVAL
|
||||
assert exc_info.value.__cause__ is None
|
||||
assert "nope" in str(exc_info.value)
|
||||
|
||||
|
||||
# We used to have an elaborate test that opened a real TCP listening socket
|
||||
# and then tried to measure its backlog by making connections to it. And most
|
||||
# of the time, it worked. But no matter what we tried, it was always fragile,
|
||||
# because it had to do things like use timeouts to guess when the listening
|
||||
# queue was full, sometimes the CI hosts go into SYN-cookie mode (where there
|
||||
# effectively is no backlog), sometimes the host might not be enough resources
|
||||
# to give us the full requested backlog... it was a mess. So now we just check
|
||||
# that the backlog argument is passed through correctly.
|
||||
async def test_open_tcp_listeners_backlog() -> None:
|
||||
fsf = FakeSocketFactory(99)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
for given, expected in [
|
||||
(None, 0xFFFF),
|
||||
(99999999, 0xFFFF),
|
||||
(10, 10),
|
||||
(1, 1),
|
||||
]:
|
||||
listeners = await open_tcp_listeners(0, backlog=given)
|
||||
assert listeners
|
||||
for listener in listeners:
|
||||
# `backlog` only exists on FakeSocket
|
||||
assert listener.socket.backlog == expected # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_backlog_float_error() -> None:
|
||||
fsf = FakeSocketFactory(99)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
for should_fail in (0.0, 2.18, 3.14, 9.75):
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=f"backlog must be an int or None, not {should_fail!r}",
|
||||
):
|
||||
await open_tcp_listeners(0, backlog=should_fail) # type: ignore[arg-type]
|
||||
@@ -0,0 +1,686 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import sys
|
||||
from socket import AddressFamily, SocketKind
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio._highlevel_open_tcp_stream import (
|
||||
close_all,
|
||||
format_host_port,
|
||||
open_tcp_stream,
|
||||
reorder_for_rfc_6555_section_5_4,
|
||||
)
|
||||
from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType
|
||||
from trio.testing import Matcher, RaisesGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trio.testing import MockClock
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
|
||||
def test_close_all() -> None:
|
||||
class CloseMe(SocketType):
|
||||
closed = False
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
class CloseKiller(SocketType):
|
||||
def close(self) -> None:
|
||||
raise OSError("os error text")
|
||||
|
||||
c: CloseMe = CloseMe()
|
||||
with close_all() as to_close:
|
||||
to_close.add(c)
|
||||
assert c.closed
|
||||
|
||||
c = CloseMe()
|
||||
with pytest.raises(RuntimeError):
|
||||
with close_all() as to_close:
|
||||
to_close.add(c)
|
||||
raise RuntimeError
|
||||
assert c.closed
|
||||
|
||||
c = CloseMe()
|
||||
with pytest.raises(OSError, match="os error text"):
|
||||
with close_all() as to_close:
|
||||
to_close.add(CloseKiller())
|
||||
to_close.add(c)
|
||||
assert c.closed
|
||||
|
||||
|
||||
def test_reorder_for_rfc_6555_section_5_4() -> None:
|
||||
def fake4(
|
||||
i: int,
|
||||
) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
|
||||
return (
|
||||
AF_INET,
|
||||
SOCK_STREAM,
|
||||
IPPROTO_TCP,
|
||||
"",
|
||||
(f"10.0.0.{i}", 80),
|
||||
)
|
||||
|
||||
def fake6(
|
||||
i: int,
|
||||
) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
|
||||
return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", (f"::{i}", 80))
|
||||
|
||||
for fake in fake4, fake6:
|
||||
# No effect on homogeneous lists
|
||||
targets = [fake(0), fake(1), fake(2)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake(0), fake(1), fake(2)]
|
||||
|
||||
# Single item lists also OK
|
||||
targets = [fake(0)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake(0)]
|
||||
|
||||
# If the list starts out with different families in positions 0 and 1,
|
||||
# then it's left alone
|
||||
orig = [fake4(0), fake6(0), fake4(1), fake6(1)]
|
||||
targets = list(orig)
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == orig
|
||||
|
||||
# If not, it's reordered
|
||||
targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)]
|
||||
|
||||
|
||||
def test_format_host_port() -> None:
|
||||
assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80"
|
||||
assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80"
|
||||
assert format_host_port("example.com", 443) == "example.com:443"
|
||||
assert format_host_port(b"example.com", 443) == "example.com:443"
|
||||
assert format_host_port("::1", "http") == "[::1]:http"
|
||||
assert format_host_port(b"::1", "http") == "[::1]:http"
|
||||
|
||||
|
||||
# Make sure we can connect to localhost using real kernel sockets
|
||||
async def test_open_tcp_stream_real_socket_smoketest() -> None:
|
||||
listen_sock = trio.socket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
_, listen_port = listen_sock.getsockname()
|
||||
listen_sock.listen(1)
|
||||
client_stream = await open_tcp_stream("127.0.0.1", listen_port)
|
||||
server_sock, _ = await listen_sock.accept()
|
||||
await client_stream.send_all(b"x")
|
||||
assert await server_sock.recv(1) == b"x"
|
||||
await client_stream.aclose()
|
||||
server_sock.close()
|
||||
|
||||
listen_sock.close()
|
||||
|
||||
|
||||
async def test_open_tcp_stream_input_validation() -> None:
|
||||
with pytest.raises(ValueError, match="^host must be str or bytes, not None$"):
|
||||
await open_tcp_stream(None, 80) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_stream("127.0.0.1", b"80") # type: ignore[arg-type]
|
||||
|
||||
|
||||
def can_bind_127_0_0_2() -> bool:
|
||||
with socket.socket() as s:
|
||||
try:
|
||||
s.bind(("127.0.0.2", 0))
|
||||
except OSError:
|
||||
return False
|
||||
# s.getsockname() is typed as returning Any
|
||||
return s.getsockname()[0] == "127.0.0.2" # type: ignore[no-any-return]
|
||||
|
||||
|
||||
async def test_local_address_real() -> None:
|
||||
with trio.socket.socket() as listener:
|
||||
await listener.bind(("127.0.0.1", 0))
|
||||
listener.listen()
|
||||
|
||||
# It's hard to test local_address properly, because you need multiple
|
||||
# local addresses that you can bind to. Fortunately, on most Linux
|
||||
# systems, you can bind to any 127.*.*.* address, and they all go
|
||||
# through the loopback interface. So we can use a non-standard
|
||||
# loopback address. On other systems, the only address we know for
|
||||
# certain we have is 127.0.0.1, so we can't really test local_address=
|
||||
# properly -- passing local_address=127.0.0.1 is indistinguishable
|
||||
# from not passing local_address= at all. But, we can still do a smoke
|
||||
# test to make sure the local_address= code doesn't crash.
|
||||
local_address = "127.0.0.2" if can_bind_127_0_0_2() else "127.0.0.1"
|
||||
|
||||
async with await open_tcp_stream(
|
||||
*listener.getsockname(),
|
||||
local_address=local_address,
|
||||
) as client_stream:
|
||||
assert client_stream.socket.getsockname()[0] == local_address
|
||||
if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"):
|
||||
assert client_stream.socket.getsockopt(
|
||||
trio.socket.IPPROTO_IP,
|
||||
trio.socket.IP_BIND_ADDRESS_NO_PORT,
|
||||
)
|
||||
server_sock, remote_addr = await listener.accept()
|
||||
await client_stream.aclose()
|
||||
server_sock.close()
|
||||
# accept returns tuple[SocketType, object], due to typeshed returning `Any`
|
||||
assert remote_addr[0] == local_address
|
||||
|
||||
# Trying to connect to an ipv4 address with the ipv6 wildcard
|
||||
# local_address should fail
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^all attempts to connect* to *127\.0\.0\.\d:\d+ failed$",
|
||||
):
|
||||
await open_tcp_stream(*listener.getsockname(), local_address="::")
|
||||
|
||||
# But the ipv4 wildcard address should work
|
||||
async with await open_tcp_stream(
|
||||
*listener.getsockname(),
|
||||
local_address="0.0.0.0",
|
||||
) as client_stream:
|
||||
server_sock, remote_addr = await listener.accept()
|
||||
server_sock.close()
|
||||
assert remote_addr == client_stream.socket.getsockname()
|
||||
|
||||
|
||||
# Now, thorough tests using fake sockets
|
||||
|
||||
|
||||
@attrs.define(eq=False, slots=False)
|
||||
class FakeSocket(trio.socket.SocketType):
|
||||
scenario: Scenario
|
||||
_family: AddressFamily
|
||||
_type: SocketKind
|
||||
_proto: int
|
||||
|
||||
ip: str | int | None = None
|
||||
port: str | int | None = None
|
||||
succeeded: bool = False
|
||||
closed: bool = False
|
||||
failing: bool = False
|
||||
|
||||
@property
|
||||
def type(self) -> SocketKind:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def family(self) -> AddressFamily: # pragma: no cover
|
||||
return self._family
|
||||
|
||||
@property
|
||||
def proto(self) -> int: # pragma: no cover
|
||||
return self._proto
|
||||
|
||||
async def connect(self, sockaddr: tuple[str | int, str | int | None]) -> None:
|
||||
self.ip = sockaddr[0]
|
||||
self.port = sockaddr[1]
|
||||
assert self.ip not in self.scenario.sockets
|
||||
self.scenario.sockets[self.ip] = self
|
||||
self.scenario.connect_times[self.ip] = trio.current_time()
|
||||
delay, result = self.scenario.ip_dict[self.ip]
|
||||
await trio.sleep(delay)
|
||||
if result == "error":
|
||||
raise OSError("sorry")
|
||||
if result == "postconnect_fail":
|
||||
self.failing = True
|
||||
self.succeeded = True
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
# called when SocketStream is constructed
|
||||
def setsockopt(self, *args: object, **kwargs: object) -> None:
|
||||
if self.failing:
|
||||
# raise something that isn't OSError as SocketStream
|
||||
# ignores those
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver):
|
||||
def __init__(
|
||||
self,
|
||||
port: int,
|
||||
ip_list: Sequence[tuple[str, float, str]],
|
||||
supported_families: set[AddressFamily],
|
||||
) -> None:
|
||||
# ip_list have to be unique
|
||||
ip_order = [ip for (ip, _, _) in ip_list]
|
||||
assert len(set(ip_order)) == len(ip_list)
|
||||
ip_dict: dict[str | int, tuple[float, str]] = {}
|
||||
for ip, delay, result in ip_list:
|
||||
assert delay >= 0
|
||||
assert result in ["error", "success", "postconnect_fail"]
|
||||
ip_dict[ip] = (delay, result)
|
||||
|
||||
self.port = port
|
||||
self.ip_order = ip_order
|
||||
self.ip_dict = ip_dict
|
||||
self.supported_families = supported_families
|
||||
self.socket_count = 0
|
||||
self.sockets: dict[str | int, FakeSocket] = {}
|
||||
self.connect_times: dict[str | int, float] = {}
|
||||
|
||||
def socket(
|
||||
self,
|
||||
family: AddressFamily | int | None = None,
|
||||
type_: SocketKind | int | None = None,
|
||||
proto: int | None = None,
|
||||
) -> SocketType:
|
||||
assert isinstance(family, AddressFamily)
|
||||
assert isinstance(type_, SocketKind)
|
||||
assert proto is not None
|
||||
if family not in self.supported_families:
|
||||
raise OSError("pretending not to support this family")
|
||||
self.socket_count += 1
|
||||
return FakeSocket(self, family, type_, proto)
|
||||
|
||||
def _ip_to_gai_entry(self, ip: str) -> tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int, int, int] | tuple[str, int],
|
||||
]:
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int]
|
||||
if ":" in ip:
|
||||
family = trio.socket.AF_INET6
|
||||
sockaddr = (ip, self.port, 0, 0)
|
||||
else:
|
||||
family = trio.socket.AF_INET
|
||||
sockaddr = (ip, self.port)
|
||||
return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
|
||||
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = -1,
|
||||
type: int = -1,
|
||||
proto: int = -1,
|
||||
flags: int = -1,
|
||||
) -> list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int, int, int] | tuple[str, int],
|
||||
]
|
||||
]:
|
||||
assert host == b"test.example.com"
|
||||
assert port == self.port
|
||||
assert family == trio.socket.AF_UNSPEC
|
||||
assert type == trio.socket.SOCK_STREAM
|
||||
assert proto == 0
|
||||
assert flags == 0
|
||||
return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
|
||||
|
||||
async def getnameinfo(
|
||||
self,
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
flags: int,
|
||||
) -> tuple[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def check(self, succeeded: SocketType | None) -> None:
|
||||
# sockets only go into self.sockets when connect is called; make sure
|
||||
# all the sockets that were created did in fact go in there.
|
||||
assert self.socket_count == len(self.sockets)
|
||||
|
||||
for ip, socket_ in self.sockets.items():
|
||||
assert ip in self.ip_dict
|
||||
if socket_ is not succeeded:
|
||||
assert socket_.closed
|
||||
assert socket_.port == self.port
|
||||
|
||||
|
||||
async def run_scenario(
|
||||
# The port to connect to
|
||||
port: int,
|
||||
# A list of
|
||||
# (ip, delay, result)
|
||||
# tuples, where delay is in seconds and result is "success" or "error"
|
||||
# The ip's will be returned from getaddrinfo in this order, and then
|
||||
# connect() calls to them will have the given result.
|
||||
ip_list: Sequence[tuple[str, float, str]],
|
||||
*,
|
||||
# If False, AF_INET4/6 sockets error out on creation, before connect is
|
||||
# even called.
|
||||
ipv4_supported: bool = True,
|
||||
ipv6_supported: bool = True,
|
||||
# Normally, we return (winning_sock, scenario object)
|
||||
# If this is True, we require there to be an exception, and return
|
||||
# (exception, scenario object)
|
||||
expect_error: tuple[type[BaseException], ...] | type[BaseException] = (),
|
||||
**kwargs: Any,
|
||||
) -> tuple[SocketType, Scenario] | tuple[BaseException, Scenario]:
|
||||
supported_families = set()
|
||||
if ipv4_supported:
|
||||
supported_families.add(trio.socket.AF_INET)
|
||||
if ipv6_supported:
|
||||
supported_families.add(trio.socket.AF_INET6)
|
||||
scenario = Scenario(port, ip_list, supported_families)
|
||||
trio.socket.set_custom_hostname_resolver(scenario)
|
||||
trio.socket.set_custom_socket_factory(scenario)
|
||||
|
||||
try:
|
||||
stream = await open_tcp_stream("test.example.com", port, **kwargs)
|
||||
assert expect_error == ()
|
||||
scenario.check(stream.socket)
|
||||
return (stream.socket, scenario)
|
||||
except AssertionError: # pragma: no cover
|
||||
raise
|
||||
except expect_error as exc:
|
||||
scenario.check(None)
|
||||
return (exc, scenario)
|
||||
|
||||
|
||||
async def test_one_host_quick_success(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "1.2.3.4"
|
||||
assert trio.current_time() == 0.123
|
||||
|
||||
|
||||
async def test_one_host_slow_success(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")])
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "1.2.3.4"
|
||||
assert trio.current_time() == 100
|
||||
|
||||
|
||||
async def test_one_host_quick_fail(autojump_clock: MockClock) -> None:
|
||||
exc, scenario = await run_scenario(
|
||||
82,
|
||||
[("1.2.3.4", 0.123, "error")],
|
||||
expect_error=OSError,
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
assert trio.current_time() == 0.123
|
||||
|
||||
|
||||
async def test_one_host_slow_fail(autojump_clock: MockClock) -> None:
|
||||
exc, scenario = await run_scenario(
|
||||
83,
|
||||
[("1.2.3.4", 100, "error")],
|
||||
expect_error=OSError,
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
assert trio.current_time() == 100
|
||||
|
||||
|
||||
async def test_one_host_failed_after_connect(autojump_clock: MockClock) -> None:
|
||||
exc, scenario = await run_scenario(
|
||||
83,
|
||||
[("1.2.3.4", 1, "postconnect_fail")],
|
||||
expect_error=KeyboardInterrupt,
|
||||
)
|
||||
assert isinstance(exc, KeyboardInterrupt)
|
||||
|
||||
|
||||
# With the default 0.250 second delay, the third attempt will win
|
||||
async def test_basic_fallthrough(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "3.3.3.3"
|
||||
# current time is default time + default time + connection time
|
||||
assert trio.current_time() == (0.250 + 0.250 + 0.2)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.250,
|
||||
"3.3.3.3": 0.500,
|
||||
}
|
||||
|
||||
|
||||
async def test_early_success(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 0.1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "2.2.2.2"
|
||||
assert trio.current_time() == (0.250 + 0.1)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.250,
|
||||
# 3.3.3.3 was never even started
|
||||
}
|
||||
|
||||
|
||||
# With a 0.450 second delay, the first attempt will win
|
||||
async def test_custom_delay(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=0.450,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "1.1.1.1"
|
||||
assert trio.current_time() == 1
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.450,
|
||||
"3.3.3.3": 0.900,
|
||||
}
|
||||
|
||||
|
||||
async def test_none_default(autojump_clock: MockClock) -> None:
|
||||
"""Copy of test_basic_fallthrough, but specifying the delay =None"""
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=None,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "3.3.3.3"
|
||||
# current time is default time + default time + connection time
|
||||
assert trio.current_time() == (0.250 + 0.250 + 0.2)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.250,
|
||||
"3.3.3.3": 0.500,
|
||||
}
|
||||
|
||||
|
||||
async def test_custom_errors_expedite(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.1, "error"),
|
||||
("2.2.2.2", 0.2, "error"),
|
||||
("3.3.3.3", 10, "success"),
|
||||
# .25 is the default timeout
|
||||
("4.4.4.4", 0.25, "success"),
|
||||
],
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "4.4.4.4"
|
||||
assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.1,
|
||||
"3.3.3.3": 0.1 + 0.2,
|
||||
"4.4.4.4": 0.1 + 0.2 + 0.25,
|
||||
}
|
||||
|
||||
|
||||
async def test_all_fail(autojump_clock: MockClock) -> None:
|
||||
exc, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.1, "error"),
|
||||
("2.2.2.2", 0.2, "error"),
|
||||
("3.3.3.3", 10, "error"),
|
||||
("4.4.4.4", 0.250, "error"),
|
||||
],
|
||||
expect_error=OSError,
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
|
||||
subexceptions = (Matcher(OSError, match="^sorry$"),) * 4
|
||||
assert RaisesGroup(
|
||||
*subexceptions,
|
||||
match="all attempts to connect to test.example.com:80 failed",
|
||||
).matches(exc.__cause__)
|
||||
|
||||
assert trio.current_time() == (0.1 + 0.2 + 10)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.1,
|
||||
"3.3.3.3": 0.1 + 0.2,
|
||||
"4.4.4.4": 0.1 + 0.2 + 0.25,
|
||||
}
|
||||
|
||||
|
||||
async def test_multi_success(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.5, "error"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("3.3.3.3", 10 - 1, "success"),
|
||||
("4.4.4.4", 10 - 2, "success"),
|
||||
("5.5.5.5", 0.5, "error"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
assert not scenario.sockets["1.1.1.1"].succeeded
|
||||
assert (
|
||||
scenario.sockets["2.2.2.2"].succeeded
|
||||
or scenario.sockets["3.3.3.3"].succeeded
|
||||
or scenario.sockets["4.4.4.4"].succeeded
|
||||
)
|
||||
assert not scenario.sockets["5.5.5.5"].succeeded
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"]
|
||||
assert trio.current_time() == (0.5 + 10)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.5,
|
||||
"3.3.3.3": 1.5,
|
||||
"4.4.4.4": 2.5,
|
||||
"5.5.5.5": 3.5,
|
||||
}
|
||||
|
||||
|
||||
async def test_does_reorder(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 10, "error"),
|
||||
# This would win if we tried it first...
|
||||
("2.2.2.2", 1, "success"),
|
||||
# But in fact we try this first, because of section 5.4
|
||||
("::3", 0.5, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "::3"
|
||||
assert trio.current_time() == 1 + 0.5
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"::3": 1,
|
||||
}
|
||||
|
||||
|
||||
async def test_handles_no_ipv4(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
# Here the ipv6 addresses fail at socket creation time, so the connect
|
||||
# configuration doesn't matter
|
||||
[
|
||||
("::1", 10, "success"),
|
||||
("2.2.2.2", 0, "success"),
|
||||
("::3", 0.1, "success"),
|
||||
("4.4.4.4", 0, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
ipv4_supported=False,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "::3"
|
||||
assert trio.current_time() == 1 + 0.1
|
||||
assert scenario.connect_times == {
|
||||
"::1": 0,
|
||||
"::3": 1.0,
|
||||
}
|
||||
|
||||
|
||||
async def test_handles_no_ipv6(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
# Here the ipv6 addresses fail at socket creation time, so the connect
|
||||
# configuration doesn't matter
|
||||
[
|
||||
("::1", 0, "success"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("::3", 0, "success"),
|
||||
("4.4.4.4", 0.1, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
ipv6_supported=False,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "4.4.4.4"
|
||||
assert trio.current_time() == 1 + 0.1
|
||||
assert scenario.connect_times == {
|
||||
"2.2.2.2": 0,
|
||||
"4.4.4.4": 1.0,
|
||||
}
|
||||
|
||||
|
||||
async def test_no_hosts(autojump_clock: MockClock) -> None:
|
||||
exc, scenario = await run_scenario(80, [], expect_error=OSError)
|
||||
assert "no results found" in str(exc)
|
||||
|
||||
|
||||
async def test_cancel(autojump_clock: MockClock) -> None:
|
||||
with trio.move_on_after(5) as cancel_scope:
|
||||
exc, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 10, "success"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("3.3.3.3", 10, "success"),
|
||||
("4.4.4.4", 10, "success"),
|
||||
],
|
||||
expect_error=BaseExceptionGroup,
|
||||
)
|
||||
assert isinstance(exc, BaseException)
|
||||
# What comes out should be 1 or more Cancelled errors that all belong
|
||||
# to this cancel_scope; this is the easiest way to check that
|
||||
raise exc
|
||||
assert cancel_scope.cancelled_caught
|
||||
|
||||
assert trio.current_time() == 5
|
||||
|
||||
# This should have been called already, but just to make sure, since the
|
||||
# exception-handling logic in run_scenario is a bit complicated and the
|
||||
# main thing we care about here is that all the sockets were cleaned up.
|
||||
scenario.check(succeeded=None)
|
||||
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from trio import Path, open_unix_socket
|
||||
from trio._highlevel_open_unix_stream import close_on_error
|
||||
|
||||
assert not TYPE_CHECKING or sys.platform != "win32"
|
||||
|
||||
skip_if_not_unix = pytest.mark.skipif(
|
||||
not hasattr(socket, "AF_UNIX"),
|
||||
reason="Needs unix socket support",
|
||||
)
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
def test_close_on_error() -> None:
|
||||
class CloseMe:
|
||||
closed = False
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
with close_on_error(CloseMe()) as c:
|
||||
pass
|
||||
assert not c.closed
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
with close_on_error(CloseMe()) as c:
|
||||
raise RuntimeError
|
||||
assert c.closed
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
@pytest.mark.parametrize("filename", [4, 4.5])
|
||||
async def test_open_with_bad_filename_type(filename: float) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
await open_unix_socket(filename) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
async def test_open_bad_socket() -> None:
|
||||
# mktemp is marked as insecure, but that's okay, we don't want the file to
|
||||
# exist
|
||||
name = tempfile.mktemp()
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await open_unix_socket(name)
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
async def test_open_unix_socket() -> None:
|
||||
for name_type in [Path, str]:
|
||||
name = tempfile.mktemp()
|
||||
serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
with serv_sock:
|
||||
serv_sock.bind(name)
|
||||
try:
|
||||
serv_sock.listen(1)
|
||||
|
||||
# The actual function we're testing
|
||||
unix_socket = await open_unix_socket(name_type(name))
|
||||
|
||||
async with unix_socket:
|
||||
client, _ = serv_sock.accept()
|
||||
with client:
|
||||
await unix_socket.send_all(b"test")
|
||||
assert client.recv(2048) == b"test"
|
||||
|
||||
client.sendall(b"response")
|
||||
received = await unix_socket.receive_some(2048)
|
||||
assert received == b"response"
|
||||
finally:
|
||||
os.unlink(name)
|
||||
|
||||
|
||||
@pytest.mark.skipif(hasattr(socket, "AF_UNIX"), reason="Test for non-unix platforms")
|
||||
async def test_error_on_no_unix() -> None:
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="^Unix sockets are not supported on this platform$",
|
||||
):
|
||||
await open_unix_socket("")
|
||||
@@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, NoReturn
|
||||
|
||||
import attrs
|
||||
|
||||
import trio
|
||||
from trio import Nursery, StapledStream, TaskStatus
|
||||
from trio.testing import (
|
||||
Matcher,
|
||||
MemoryReceiveStream,
|
||||
MemorySendStream,
|
||||
MockClock,
|
||||
RaisesGroup,
|
||||
memory_stream_pair,
|
||||
wait_all_tasks_blocked,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pytest
|
||||
|
||||
from trio._channel import MemoryReceiveChannel, MemorySendChannel
|
||||
from trio.abc import Stream
|
||||
|
||||
# types are somewhat tentative - I just bruteforced them until I got something that didn't
|
||||
# give errors
|
||||
StapledMemoryStream = StapledStream[MemorySendStream, MemoryReceiveStream]
|
||||
|
||||
|
||||
@attrs.define(eq=False, slots=False)
|
||||
class MemoryListener(trio.abc.Listener[StapledMemoryStream]):
|
||||
closed: bool = False
|
||||
accepted_streams: list[trio.abc.Stream] = attrs.Factory(list)
|
||||
queued_streams: tuple[
|
||||
MemorySendChannel[StapledMemoryStream],
|
||||
MemoryReceiveChannel[StapledMemoryStream],
|
||||
] = attrs.Factory(lambda: trio.open_memory_channel[StapledMemoryStream](1))
|
||||
accept_hook: Callable[[], Awaitable[object]] | None = None
|
||||
|
||||
async def connect(self) -> StapledMemoryStream:
|
||||
assert not self.closed
|
||||
client, server = memory_stream_pair()
|
||||
await self.queued_streams[0].send(server)
|
||||
return client
|
||||
|
||||
async def accept(self) -> StapledMemoryStream:
|
||||
await trio.lowlevel.checkpoint()
|
||||
assert not self.closed
|
||||
if self.accept_hook is not None:
|
||||
await self.accept_hook()
|
||||
stream = await self.queued_streams[1].receive()
|
||||
self.accepted_streams.append(stream)
|
||||
return stream
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.closed = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
async def test_serve_listeners_basic() -> None:
|
||||
listeners = [MemoryListener(), MemoryListener()]
|
||||
|
||||
record = []
|
||||
|
||||
def close_hook() -> None:
|
||||
# Make sure this is a forceful close
|
||||
assert trio.current_effective_deadline() == float("-inf")
|
||||
record.append("closed")
|
||||
|
||||
async def handler(stream: StapledMemoryStream) -> None:
|
||||
await stream.send_all(b"123")
|
||||
assert await stream.receive_some(10) == b"456"
|
||||
stream.send_stream.close_hook = close_hook
|
||||
stream.receive_stream.close_hook = close_hook
|
||||
|
||||
async def client(listener: MemoryListener) -> None:
|
||||
s = await listener.connect()
|
||||
assert await s.receive_some(10) == b"123"
|
||||
await s.send_all(b"456")
|
||||
|
||||
async def do_tests(parent_nursery: Nursery) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for listener in listeners:
|
||||
for _ in range(3):
|
||||
nursery.start_soon(client, listener)
|
||||
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# verifies that all 6 streams x 2 directions each were closed ok
|
||||
assert len(record) == 12
|
||||
|
||||
parent_nursery.cancel_scope.cancel()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
l2: list[MemoryListener] = await nursery.start(
|
||||
trio.serve_listeners,
|
||||
handler,
|
||||
listeners,
|
||||
)
|
||||
assert l2 == listeners
|
||||
# This is just split into another function because gh-136 isn't
|
||||
# implemented yet
|
||||
nursery.start_soon(do_tests, nursery)
|
||||
|
||||
for listener in listeners:
|
||||
assert listener.closed
|
||||
|
||||
|
||||
async def test_serve_listeners_accept_unrecognized_error() -> None:
|
||||
for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]:
|
||||
listener = MemoryListener()
|
||||
|
||||
async def raise_error() -> NoReturn:
|
||||
raise error # noqa: B023 # Set from loop
|
||||
|
||||
def check_error(e: BaseException) -> bool:
|
||||
return e is error # noqa: B023
|
||||
|
||||
listener.accept_hook = raise_error
|
||||
|
||||
with RaisesGroup(Matcher(check=check_error)):
|
||||
await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def test_serve_listeners_accept_capacity_error(
|
||||
autojump_clock: MockClock,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
listener = MemoryListener()
|
||||
|
||||
async def raise_EMFILE() -> NoReturn:
|
||||
raise OSError(errno.EMFILE, "out of file descriptors")
|
||||
|
||||
listener.accept_hook = raise_EMFILE
|
||||
|
||||
# It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900
|
||||
# = 10 times total
|
||||
with trio.move_on_after(0.950):
|
||||
await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
|
||||
|
||||
assert len(caplog.records) == 10
|
||||
for record in caplog.records:
|
||||
assert "retrying" in record.msg
|
||||
assert record.exc_info is not None
|
||||
assert isinstance(record.exc_info[1], OSError)
|
||||
assert record.exc_info[1].errno == errno.EMFILE
|
||||
|
||||
|
||||
async def test_serve_listeners_connection_nursery(autojump_clock: MockClock) -> None:
|
||||
listener = MemoryListener()
|
||||
|
||||
async def handler(stream: Stream) -> None:
|
||||
await trio.sleep(1)
|
||||
|
||||
class Done(Exception):
|
||||
pass
|
||||
|
||||
async def connection_watcher(
|
||||
*,
|
||||
task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED,
|
||||
) -> NoReturn:
|
||||
async with trio.open_nursery() as nursery:
|
||||
task_status.started(nursery)
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(nursery.child_tasks) == 10
|
||||
raise Done
|
||||
|
||||
# the exception is wrapped twice because we open two nested nurseries
|
||||
with RaisesGroup(RaisesGroup(Done)):
|
||||
async with trio.open_nursery() as nursery:
|
||||
handler_nursery: trio.Nursery = await nursery.start(connection_watcher)
|
||||
await nursery.start(
|
||||
partial(
|
||||
trio.serve_listeners,
|
||||
handler,
|
||||
[listener],
|
||||
handler_nursery=handler_nursery,
|
||||
),
|
||||
)
|
||||
for _ in range(10):
|
||||
nursery.start_soon(listener.connect)
|
||||
@@ -0,0 +1,330 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import socket as stdlib_socket
|
||||
import sys
|
||||
from typing import Sequence
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import _core, socket as tsocket
|
||||
from .._highlevel_socket import *
|
||||
from ..testing import (
|
||||
assert_checkpoints,
|
||||
check_half_closeable_stream,
|
||||
wait_all_tasks_blocked,
|
||||
)
|
||||
from .test_socket import setsockopt_tests
|
||||
|
||||
|
||||
async def test_SocketStream_basics() -> None:
|
||||
# stdlib socket bad (even if connected)
|
||||
stdlib_a, stdlib_b = stdlib_socket.socketpair()
|
||||
with stdlib_a, stdlib_b:
|
||||
with pytest.raises(TypeError):
|
||||
SocketStream(stdlib_a) # type: ignore[arg-type]
|
||||
|
||||
# DGRAM socket bad
|
||||
with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^SocketStream requires a SOCK_STREAM socket$",
|
||||
):
|
||||
# TODO: does not raise an error?
|
||||
SocketStream(sock)
|
||||
|
||||
a, b = tsocket.socketpair()
|
||||
with a, b:
|
||||
s = SocketStream(a)
|
||||
assert s.socket is a
|
||||
|
||||
# Use a real, connected socket to test socket options, because
|
||||
# socketpair() might give us a unix socket that doesn't support any of
|
||||
# these options
|
||||
with tsocket.socket() as listen_sock:
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(1)
|
||||
with tsocket.socket() as client_sock:
|
||||
await client_sock.connect(listen_sock.getsockname())
|
||||
|
||||
s = SocketStream(client_sock)
|
||||
|
||||
# TCP_NODELAY enabled by default
|
||||
assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
||||
# We can disable it though
|
||||
s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
|
||||
assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
||||
|
||||
res = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
|
||||
assert isinstance(res, bytes)
|
||||
|
||||
setsockopt_tests(s)
|
||||
|
||||
|
||||
async def test_SocketStream_send_all() -> None:
|
||||
BIG = 10000000
|
||||
|
||||
a_sock, b_sock = tsocket.socketpair()
|
||||
with a_sock, b_sock:
|
||||
a = SocketStream(a_sock)
|
||||
b = SocketStream(b_sock)
|
||||
|
||||
# Check a send_all that has to be split into multiple parts (on most
|
||||
# platforms... on Windows every send() either succeeds or fails as a
|
||||
# whole)
|
||||
async def sender() -> None:
|
||||
data = bytearray(BIG)
|
||||
await a.send_all(data)
|
||||
# send_all uses memoryviews internally, which temporarily "lock"
|
||||
# the object they view. If it doesn't clean them up properly, then
|
||||
# some bytearray operations might raise an error afterwards, which
|
||||
# would be a pretty weird and annoying side-effect to spring on
|
||||
# users. So test that this doesn't happen, by forcing the
|
||||
# bytearray's underlying buffer to be realloc'ed:
|
||||
data += bytes(BIG)
|
||||
# (Note: the above line of code doesn't do a very good job at
|
||||
# testing anything, because:
|
||||
# - on CPython, the refcount GC generally cleans up memoryviews
|
||||
# for us even if we're sloppy.
|
||||
# - on PyPy3, at least as of 5.7.0, the memoryview code and the
|
||||
# bytearray code conspire so that resizing never fails – if
|
||||
# resizing forces the bytearray's internal buffer to move, then
|
||||
# all memoryview references are automagically updated (!!).
|
||||
# See:
|
||||
# https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
|
||||
# But I'm leaving the test here in hopes that if this ever changes
|
||||
# and we break our implementation of send_all, then we'll get some
|
||||
# early warning...)
|
||||
|
||||
async def receiver() -> None:
|
||||
# Make sure the sender fills up the kernel buffers and blocks
|
||||
await wait_all_tasks_blocked()
|
||||
nbytes = 0
|
||||
while nbytes < BIG:
|
||||
nbytes += len(await b.receive_some(BIG))
|
||||
assert nbytes == BIG
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(receiver)
|
||||
|
||||
# We know that we received BIG bytes of NULs so far. Make sure that
|
||||
# was all the data in there.
|
||||
await a.send_all(b"e")
|
||||
assert await b.receive_some(10) == b"e"
|
||||
await a.send_eof()
|
||||
assert await b.receive_some(10) == b""
|
||||
|
||||
|
||||
async def fill_stream(s: SocketStream) -> None:
|
||||
async def sender() -> None:
|
||||
while True:
|
||||
await s.send_all(b"x" * 10000)
|
||||
|
||||
async def waiter(nursery: _core.Nursery) -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(waiter, nursery)
|
||||
|
||||
|
||||
async def test_SocketStream_generic() -> None:
|
||||
async def stream_maker() -> tuple[SocketStream, SocketStream]:
|
||||
left, right = tsocket.socketpair()
|
||||
return SocketStream(left), SocketStream(right)
|
||||
|
||||
async def clogged_stream_maker() -> tuple[SocketStream, SocketStream]:
|
||||
left, right = await stream_maker()
|
||||
await fill_stream(left)
|
||||
await fill_stream(right)
|
||||
return left, right
|
||||
|
||||
await check_half_closeable_stream(stream_maker, clogged_stream_maker)
|
||||
|
||||
|
||||
async def test_SocketListener() -> None:
|
||||
# Not a Trio socket
|
||||
with stdlib_socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
s.listen(10)
|
||||
with pytest.raises(TypeError):
|
||||
SocketListener(s) # type: ignore[arg-type]
|
||||
|
||||
# Not a SOCK_STREAM
|
||||
with tsocket.socket(type=tsocket.SOCK_DGRAM) as s:
|
||||
await s.bind(("127.0.0.1", 0))
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^SocketListener requires a SOCK_STREAM socket$",
|
||||
) as excinfo:
|
||||
SocketListener(s)
|
||||
excinfo.match(r".*SOCK_STREAM")
|
||||
|
||||
# Didn't call .listen()
|
||||
# macOS has no way to check for this, so skip testing it there.
|
||||
if sys.platform != "darwin":
|
||||
with tsocket.socket() as s:
|
||||
await s.bind(("127.0.0.1", 0))
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^SocketListener requires a listening socket$",
|
||||
) as excinfo:
|
||||
SocketListener(s)
|
||||
excinfo.match(r".*listen")
|
||||
|
||||
listen_sock = tsocket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(10)
|
||||
listener = SocketListener(listen_sock)
|
||||
|
||||
assert listener.socket is listen_sock
|
||||
|
||||
client_sock = tsocket.socket()
|
||||
await client_sock.connect(listen_sock.getsockname())
|
||||
with assert_checkpoints():
|
||||
server_stream = await listener.accept()
|
||||
assert isinstance(server_stream, SocketStream)
|
||||
assert server_stream.socket.getsockname() == listen_sock.getsockname()
|
||||
assert server_stream.socket.getpeername() == client_sock.getsockname()
|
||||
|
||||
with assert_checkpoints():
|
||||
await listener.aclose()
|
||||
|
||||
with assert_checkpoints():
|
||||
await listener.aclose()
|
||||
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await listener.accept()
|
||||
|
||||
client_sock.close()
|
||||
await server_stream.aclose()
|
||||
|
||||
|
||||
async def test_SocketListener_socket_closed_underfoot() -> None:
|
||||
listen_sock = tsocket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(10)
|
||||
listener = SocketListener(listen_sock)
|
||||
|
||||
# Close the socket, not the listener
|
||||
listen_sock.close()
|
||||
|
||||
# SocketListener gives correct error
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await listener.accept()
|
||||
|
||||
|
||||
async def test_SocketListener_accept_errors() -> None:
|
||||
class FakeSocket(tsocket.SocketType):
|
||||
def __init__(self, events: Sequence[SocketType | BaseException]) -> None:
|
||||
self._events = iter(events)
|
||||
|
||||
type = tsocket.SOCK_STREAM
|
||||
|
||||
# Fool the check for SO_ACCEPTCONN in SocketListener.__init__
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int) -> int: ...
|
||||
|
||||
@overload
|
||||
def getsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
buflen: int,
|
||||
) -> bytes: ...
|
||||
|
||||
def getsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
buflen: int | None = None,
|
||||
) -> int | bytes:
|
||||
return True
|
||||
|
||||
@overload
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: int | Buffer,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def setsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: None,
|
||||
optlen: int,
|
||||
) -> None: ...
|
||||
|
||||
def setsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: int | Buffer | None,
|
||||
optlen: int | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def accept(self) -> tuple[SocketType, object]:
|
||||
await _core.checkpoint()
|
||||
event = next(self._events)
|
||||
if isinstance(event, BaseException):
|
||||
raise event
|
||||
else:
|
||||
return event, None
|
||||
|
||||
fake_server_sock = FakeSocket([])
|
||||
|
||||
fake_listen_sock = FakeSocket(
|
||||
[
|
||||
OSError(errno.ECONNABORTED, "Connection aborted"),
|
||||
OSError(errno.EPERM, "Permission denied"),
|
||||
OSError(errno.EPROTO, "Bad protocol"),
|
||||
fake_server_sock,
|
||||
OSError(errno.EMFILE, "Out of file descriptors"),
|
||||
OSError(errno.EFAULT, "attempt to write to read-only memory"),
|
||||
OSError(errno.ENOBUFS, "out of buffers"),
|
||||
fake_server_sock,
|
||||
],
|
||||
)
|
||||
|
||||
listener = SocketListener(fake_listen_sock)
|
||||
|
||||
with assert_checkpoints():
|
||||
stream = await listener.accept()
|
||||
assert stream.socket is fake_server_sock
|
||||
|
||||
for code, match in {
|
||||
errno.EMFILE: r"\[\w+ \d+\] Out of file descriptors$",
|
||||
errno.EFAULT: r"\[\w+ \d+\] attempt to write to read-only memory$",
|
||||
errno.ENOBUFS: r"\[\w+ \d+\] out of buffers$",
|
||||
}.items():
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(OSError, match=match) as excinfo:
|
||||
await listener.accept()
|
||||
assert excinfo.value.errno == code
|
||||
|
||||
with assert_checkpoints():
|
||||
stream = await listener.accept()
|
||||
assert stream.socket is fake_server_sock
|
||||
|
||||
|
||||
async def test_socket_stream_works_when_peer_has_already_closed() -> None:
|
||||
sock_a, sock_b = tsocket.socketpair()
|
||||
with sock_a, sock_b:
|
||||
await sock_b.send(b"x")
|
||||
sock_b.close()
|
||||
stream = SocketStream(sock_a)
|
||||
assert await stream.receive_some(1) == b"x"
|
||||
assert await stream.receive_some(1) == b""
|
||||
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, NoReturn
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM
|
||||
|
||||
from .._highlevel_ssl_helpers import (
|
||||
open_ssl_over_tcp_listeners,
|
||||
open_ssl_over_tcp_stream,
|
||||
serve_ssl_over_tcp,
|
||||
)
|
||||
|
||||
# using noqa because linters don't understand how pytest fixtures work.
|
||||
from .test_ssl import SERVER_CTX, client_ctx # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from socket import AddressFamily, SocketKind
|
||||
from ssl import SSLContext
|
||||
|
||||
from trio.abc import Stream
|
||||
|
||||
from .._highlevel_socket import SocketListener
|
||||
from .._ssl import SSLListener
|
||||
|
||||
|
||||
async def echo_handler(stream: Stream) -> None:
|
||||
async with stream:
|
||||
try:
|
||||
while True:
|
||||
data = await stream.receive_some(10000)
|
||||
if not data:
|
||||
break
|
||||
await stream.send_all(data)
|
||||
except trio.BrokenResourceError:
|
||||
pass
|
||||
|
||||
|
||||
# Resolver that always returns the given sockaddr, no matter what host/port
|
||||
# you ask for.
|
||||
@attrs.define(slots=False)
|
||||
class FakeHostnameResolver(trio.abc.HostnameResolver):
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int]
|
||||
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = 0,
|
||||
type: int = 0,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
) -> list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]
|
||||
|
||||
async def getnameinfo(self, *args: Any) -> NoReturn: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
|
||||
# using noqa because linters don't understand how pytest fixtures work.
|
||||
async def test_open_ssl_over_tcp_stream_and_everything_else(
|
||||
client_ctx: SSLContext, # noqa: F811 # linters doesn't understand fixture
|
||||
) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# TODO: this function wraps an SSLListener around a SocketListener, this is illegal
|
||||
# according to current type hints, and probably for good reason. But there should
|
||||
# maybe be a different wrapper class/function that could be used instead?
|
||||
res: list[SSLListener[SocketListener]] = ( # type: ignore[type-var]
|
||||
await nursery.start(
|
||||
partial(
|
||||
serve_ssl_over_tcp,
|
||||
echo_handler,
|
||||
0,
|
||||
SERVER_CTX,
|
||||
host="127.0.0.1",
|
||||
),
|
||||
)
|
||||
)
|
||||
(listener,) = res
|
||||
async with listener:
|
||||
# listener.transport_listener is of type Listener[Stream]
|
||||
tp_listener: SocketListener = listener.transport_listener # type: ignore[assignment]
|
||||
|
||||
sockaddr = tp_listener.socket.getsockname()
|
||||
hostname_resolver = FakeHostnameResolver(sockaddr)
|
||||
trio.socket.set_custom_hostname_resolver(hostname_resolver)
|
||||
|
||||
# We don't have the right trust set up
|
||||
# (checks that ssl_context=None is doing some validation)
|
||||
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
|
||||
async with stream:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await stream.do_handshake()
|
||||
|
||||
# We have the trust but not the hostname
|
||||
# (checks custom ssl_context + hostname checking)
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"xyzzy.example.org",
|
||||
80,
|
||||
ssl_context=client_ctx,
|
||||
)
|
||||
async with stream:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await stream.do_handshake()
|
||||
|
||||
# This one should work!
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"trio-test-1.example.org",
|
||||
80,
|
||||
ssl_context=client_ctx,
|
||||
)
|
||||
async with stream:
|
||||
assert isinstance(stream, trio.SSLStream)
|
||||
assert stream.server_hostname == "trio-test-1.example.org"
|
||||
await stream.send_all(b"x")
|
||||
assert await stream.receive_some(1) == b"x"
|
||||
|
||||
# Check https_compatible settings are being passed through
|
||||
assert not stream._https_compatible
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"trio-test-1.example.org",
|
||||
80,
|
||||
ssl_context=client_ctx,
|
||||
https_compatible=True,
|
||||
# also, smoke test happy_eyeballs_delay
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
async with stream:
|
||||
assert stream._https_compatible
|
||||
|
||||
# Stop the echo server
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_open_ssl_over_tcp_listeners() -> None:
|
||||
(listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
|
||||
async with listener:
|
||||
assert isinstance(listener, trio.SSLListener)
|
||||
tl = listener.transport_listener
|
||||
assert isinstance(tl, trio.SocketListener)
|
||||
assert tl.socket.getsockname()[0] == "127.0.0.1"
|
||||
|
||||
assert not listener._https_compatible
|
||||
|
||||
(listener,) = await open_ssl_over_tcp_listeners(
|
||||
0,
|
||||
SERVER_CTX,
|
||||
host="127.0.0.1",
|
||||
https_compatible=True,
|
||||
)
|
||||
async with listener:
|
||||
assert listener._https_compatible
|
||||
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
from typing import TYPE_CHECKING, Type, Union
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio._file_io import AsyncIOWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def path(tmp_path: pathlib.Path) -> trio.Path:
|
||||
return trio.Path(tmp_path / "test")
|
||||
|
||||
|
||||
def method_pair(
|
||||
path: str,
|
||||
method_name: str,
|
||||
) -> tuple[Callable[[], object], Callable[[], Awaitable[object]]]:
|
||||
sync_path = pathlib.Path(path)
|
||||
async_path = trio.Path(path)
|
||||
return getattr(sync_path, method_name), getattr(async_path, method_name)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name == "nt", reason="OS is not posix")
|
||||
async def test_instantiate_posix() -> None:
|
||||
assert isinstance(trio.Path(), trio.PosixPath)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name != "nt", reason="OS is not Windows")
|
||||
async def test_instantiate_windows() -> None:
|
||||
assert isinstance(trio.Path(), trio.WindowsPath)
|
||||
|
||||
|
||||
async def test_open_is_async_context_manager(path: trio.Path) -> None:
|
||||
async with await path.open("w") as f:
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
async def test_magic() -> None:
|
||||
path = trio.Path("test")
|
||||
|
||||
assert str(path) == "test"
|
||||
assert bytes(path) == b"test"
|
||||
|
||||
|
||||
EitherPathType = Union[Type[trio.Path], Type[pathlib.Path]]
|
||||
PathOrStrType = Union[EitherPathType, Type[str]]
|
||||
cls_pairs: list[tuple[EitherPathType, EitherPathType]] = [
|
||||
(trio.Path, pathlib.Path),
|
||||
(pathlib.Path, trio.Path),
|
||||
(trio.Path, trio.Path),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs)
|
||||
async def test_cmp_magic(cls_a: EitherPathType, cls_b: EitherPathType) -> None:
|
||||
a, b = cls_a(""), cls_b("")
|
||||
assert a == b
|
||||
assert not a != b # noqa: SIM202 # negate-not-equal-op
|
||||
|
||||
a, b = cls_a("a"), cls_b("b")
|
||||
assert a < b
|
||||
assert b > a
|
||||
|
||||
# this is intentionally testing equivalence with none, due to the
|
||||
# other=sentinel logic in _forward_magic
|
||||
assert not a == None # noqa
|
||||
assert not b == None # noqa
|
||||
|
||||
|
||||
# upstream python3.8 bug: we should also test (pathlib.Path, trio.Path), but
|
||||
# __*div__ does not properly raise NotImplementedError like the other comparison
|
||||
# magic, so trio.Path's implementation does not get dispatched
|
||||
cls_pairs_str: list[tuple[PathOrStrType, PathOrStrType]] = [
|
||||
(trio.Path, pathlib.Path),
|
||||
(trio.Path, trio.Path),
|
||||
(trio.Path, str),
|
||||
(str, trio.Path),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs_str)
|
||||
async def test_div_magic(cls_a: PathOrStrType, cls_b: PathOrStrType) -> None:
|
||||
a, b = cls_a("a"), cls_b("b")
|
||||
|
||||
result = a / b # type: ignore[operator]
|
||||
# Type checkers think str / str could happen. Check each combo manually in type_tests/.
|
||||
assert isinstance(result, trio.Path)
|
||||
assert str(result) == os.path.join("a", "b")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("cls_a", "cls_b"),
|
||||
[(trio.Path, pathlib.Path), (trio.Path, trio.Path)],
|
||||
)
|
||||
@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"])
|
||||
async def test_hash_magic(
|
||||
cls_a: EitherPathType,
|
||||
cls_b: EitherPathType,
|
||||
path: str,
|
||||
) -> None:
|
||||
a, b = cls_a(path), cls_b(path)
|
||||
assert hash(a) == hash(b)
|
||||
|
||||
|
||||
async def test_forwarded_properties(path: trio.Path) -> None:
|
||||
# use `name` as a representative of forwarded properties
|
||||
|
||||
assert "name" in dir(path)
|
||||
assert path.name == "test"
|
||||
|
||||
|
||||
async def test_async_method_signature(path: trio.Path) -> None:
|
||||
# use `resolve` as a representative of wrapped methods
|
||||
|
||||
assert path.resolve.__name__ == "resolve"
|
||||
assert path.resolve.__qualname__ == "Path.resolve"
|
||||
|
||||
assert path.resolve.__doc__ is not None
|
||||
assert path.resolve.__qualname__ in path.resolve.__doc__
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["is_dir", "is_file"])
|
||||
async def test_compare_async_stat_methods(method_name: str) -> None:
|
||||
method, async_method = method_pair(".", method_name)
|
||||
|
||||
result = method()
|
||||
async_result = await async_method()
|
||||
|
||||
assert result == async_result
|
||||
|
||||
|
||||
async def test_invalid_name_not_wrapped(path: trio.Path) -> None:
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(path, "invalid_fake_attr") # noqa: B009 # "get-attr-with-constant"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["absolute", "resolve"])
|
||||
async def test_async_methods_rewrap(method_name: str) -> None:
|
||||
method, async_method = method_pair(".", method_name)
|
||||
|
||||
result = method()
|
||||
async_result = await async_method()
|
||||
|
||||
assert isinstance(async_result, trio.Path)
|
||||
assert str(result) == str(async_result)
|
||||
|
||||
|
||||
async def test_forward_methods_rewrap(path: trio.Path, tmp_path: pathlib.Path) -> None:
|
||||
with_name = path.with_name("foo")
|
||||
with_suffix = path.with_suffix(".py")
|
||||
|
||||
assert isinstance(with_name, trio.Path)
|
||||
assert with_name == tmp_path / "foo"
|
||||
assert isinstance(with_suffix, trio.Path)
|
||||
assert with_suffix == tmp_path / "test.py"
|
||||
|
||||
|
||||
async def test_forward_properties_rewrap(path: trio.Path) -> None:
|
||||
assert isinstance(path.parent, trio.Path)
|
||||
|
||||
|
||||
async def test_forward_methods_without_rewrap(path: trio.Path) -> None:
|
||||
path = await path.parent.resolve()
|
||||
|
||||
assert path.as_uri().startswith("file:///")
|
||||
|
||||
|
||||
async def test_repr() -> None:
|
||||
path = trio.Path(".")
|
||||
|
||||
assert repr(path) == "trio.Path('.')"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath])
|
||||
async def test_path_wraps_path(
|
||||
path: trio.Path,
|
||||
meth: Callable[[trio.Path, trio.Path], object],
|
||||
) -> None:
|
||||
wrapped = await path.absolute()
|
||||
result = meth(path, wrapped)
|
||||
if result is None:
|
||||
result = path
|
||||
|
||||
assert wrapped == result
|
||||
|
||||
|
||||
async def test_path_nonpath() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
trio.Path(1) # type: ignore
|
||||
|
||||
|
||||
async def test_open_file_can_open_path(path: trio.Path) -> None:
|
||||
async with await trio.open_file(path, "w") as f:
|
||||
assert f.name == os.fspath(path)
|
||||
|
||||
|
||||
async def test_globmethods(path: trio.Path) -> None:
|
||||
# Populate a directory tree
|
||||
await path.mkdir()
|
||||
await (path / "foo").mkdir()
|
||||
await (path / "foo" / "_bar.txt").write_bytes(b"")
|
||||
await (path / "bar.txt").write_bytes(b"")
|
||||
await (path / "bar.dat").write_bytes(b"")
|
||||
|
||||
# Path.glob
|
||||
for _pattern, _results in {
|
||||
"*.txt": {"bar.txt"},
|
||||
"**/*.txt": {"_bar.txt", "bar.txt"},
|
||||
}.items():
|
||||
entries = set()
|
||||
for entry in await path.glob(_pattern):
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == _results
|
||||
|
||||
# Path.rglob
|
||||
entries = set()
|
||||
for entry in await path.rglob("*.txt"):
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == {"_bar.txt", "bar.txt"}
|
||||
|
||||
|
||||
async def test_iterdir(path: trio.Path) -> None:
|
||||
# Populate a directory
|
||||
await path.mkdir()
|
||||
await (path / "foo").mkdir()
|
||||
await (path / "bar.txt").write_bytes(b"")
|
||||
|
||||
entries = set()
|
||||
for entry in await path.iterdir():
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == {"bar.txt", "foo"}
|
||||
|
||||
|
||||
async def test_classmethods() -> None:
|
||||
assert isinstance(await trio.Path.home(), trio.Path)
|
||||
|
||||
# pathlib.Path has only two classmethods
|
||||
assert str(await trio.Path.home()) == os.path.expanduser("~")
|
||||
assert str(await trio.Path.cwd()) == os.getcwd()
|
||||
|
||||
# Wrapped method has docstring
|
||||
assert trio.Path.home.__doc__
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wrapper",
|
||||
[
|
||||
trio._path._wraps_async,
|
||||
trio._path._wrap_method,
|
||||
trio._path._wrap_method_path,
|
||||
trio._path._wrap_method_path_iterable,
|
||||
],
|
||||
)
|
||||
def test_wrapping_without_docstrings(
|
||||
wrapper: Callable[[Callable[[], None]], Callable[[], None]],
|
||||
) -> None:
|
||||
@wrapper
|
||||
def func_without_docstring() -> None: ... # pragma: no cover
|
||||
|
||||
assert func_without_docstring.__doc__ is None
|
||||
@@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Protocol
|
||||
|
||||
import pytest
|
||||
|
||||
import trio._repl
|
||||
|
||||
|
||||
class RawInput(Protocol):
|
||||
def __call__(self, prompt: str = "") -> str: ...
|
||||
|
||||
|
||||
def build_raw_input(cmds: list[str]) -> RawInput:
|
||||
"""
|
||||
Pass in a list of strings.
|
||||
Returns a callable that returns each string, each time its called
|
||||
When there are not more strings to return, raise EOFError
|
||||
"""
|
||||
cmds_iter = iter(cmds)
|
||||
prompts = []
|
||||
|
||||
def _raw_helper(prompt: str = "") -> str:
|
||||
prompts.append(prompt)
|
||||
try:
|
||||
return next(cmds_iter)
|
||||
except StopIteration:
|
||||
raise EOFError from None
|
||||
|
||||
return _raw_helper
|
||||
|
||||
|
||||
def test_build_raw_input() -> None:
|
||||
"""Quick test of our helper function."""
|
||||
raw_input = build_raw_input(["cmd1"])
|
||||
assert raw_input() == "cmd1"
|
||||
with pytest.raises(EOFError):
|
||||
raw_input()
|
||||
|
||||
|
||||
# In 3.10 or later, types.FunctionType (used internally) will automatically
|
||||
# attach __builtins__ to the function objects. However we need to explicitly
|
||||
# include it for 3.8 & 3.9
|
||||
def build_locals() -> dict[str, object]:
|
||||
return {"__builtins__": __builtins__}
|
||||
|
||||
|
||||
async def test_basic_interaction(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Run some basic commands through the interpreter while capturing stdout.
|
||||
Ensure that the interpreted prints the expected results.
|
||||
"""
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# evaluate simple expression and recall the value
|
||||
"x = 1",
|
||||
"print(f'{x=}')",
|
||||
# Literal gets printed
|
||||
"'hello'",
|
||||
# define and call sync function
|
||||
"def func():",
|
||||
" print(x + 1)",
|
||||
"",
|
||||
"func()",
|
||||
# define and call async function
|
||||
"async def afunc():",
|
||||
" return 4",
|
||||
"",
|
||||
"await afunc()",
|
||||
# import works
|
||||
"import sys",
|
||||
"sys.stdout.write('hello stdout\\n')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert out.splitlines() == ["x=1", "'hello'", "2", "4", "hello stdout", "13"]
|
||||
|
||||
|
||||
async def test_system_exits_quit_interpreter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"raise SystemExit",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
with pytest.raises(SystemExit):
|
||||
await trio._repl.run_repl(console)
|
||||
|
||||
|
||||
async def test_KI_interrupts(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"from trio._util import signal_raise",
|
||||
"import signal, trio, trio.lowlevel",
|
||||
"async def f():",
|
||||
" trio.lowlevel.spawn_system_task("
|
||||
" trio.to_thread.run_sync,"
|
||||
" signal_raise,signal.SIGINT,"
|
||||
" )", # just awaiting this kills the test runner?!
|
||||
" await trio.sleep_forever()",
|
||||
" print('should not see this')",
|
||||
"",
|
||||
"await f()",
|
||||
"print('AFTER KeyboardInterrupt')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "KeyboardInterrupt" in err
|
||||
assert "should" not in out
|
||||
assert "AFTER KeyboardInterrupt" in out
|
||||
|
||||
|
||||
async def test_system_exits_in_exc_group(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"import sys",
|
||||
"if sys.version_info < (3, 11):",
|
||||
" from exceptiongroup import BaseExceptionGroup",
|
||||
"",
|
||||
"raise BaseExceptionGroup('', [RuntimeError(), SystemExit()])",
|
||||
"print('AFTER BaseExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
# assert that raise SystemExit in an exception group
|
||||
# doesn't quit
|
||||
assert "AFTER BaseExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_system_exits_in_nested_exc_group(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"import sys",
|
||||
"if sys.version_info < (3, 11):",
|
||||
" from exceptiongroup import BaseExceptionGroup",
|
||||
"",
|
||||
"raise BaseExceptionGroup(",
|
||||
" '', [BaseExceptionGroup('', [RuntimeError(), SystemExit()])])",
|
||||
"print('AFTER BaseExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
# assert that raise SystemExit in an exception group
|
||||
# doesn't quit
|
||||
assert "AFTER BaseExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_base_exception_captured(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# The statement after raise should still get executed
|
||||
"raise BaseException",
|
||||
"print('AFTER BaseException')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "_threads.py" not in err
|
||||
assert "_repl.py" not in err
|
||||
assert "AFTER BaseException" in out
|
||||
|
||||
|
||||
async def test_exc_group_captured(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# The statement after raise should still get executed
|
||||
"raise ExceptionGroup('', [KeyError()])",
|
||||
"print('AFTER ExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "AFTER ExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_base_exception_capture_from_coroutine(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"async def async_func_raises_base_exception():",
|
||||
" raise BaseException",
|
||||
"",
|
||||
# This will raise, but the statement after should still
|
||||
# be executed
|
||||
"await async_func_raises_base_exception()",
|
||||
"print('AFTER BaseException')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "_threads.py" not in err
|
||||
assert "_repl.py" not in err
|
||||
assert "AFTER BaseException" in out
|
||||
|
||||
|
||||
def test_main_entrypoint() -> None:
|
||||
"""
|
||||
Basic smoke test when running via the package __main__ entrypoint.
|
||||
"""
|
||||
repl = subprocess.run([sys.executable, "-m", "trio"], input=b"exit()")
|
||||
assert repl.returncode == 0
|
||||
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import trio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pytest
|
||||
|
||||
|
||||
async def scheduler_trace() -> tuple[tuple[str, int], ...]:
|
||||
"""Returns a scheduler-dependent value we can use to check determinism."""
|
||||
trace = []
|
||||
|
||||
async def tracer(name: str) -> None:
|
||||
for i in range(50):
|
||||
trace.append((name, i))
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(5):
|
||||
nursery.start_soon(tracer, str(i))
|
||||
|
||||
return tuple(trace)
|
||||
|
||||
|
||||
def test_the_trio_scheduler_is_not_deterministic() -> None:
|
||||
# At least, not yet. See https://github.com/python-trio/trio/issues/32
|
||||
traces = [trio.run(scheduler_trace) for _ in range(10)]
|
||||
assert len(set(traces)) == len(traces)
|
||||
|
||||
|
||||
def test_the_trio_scheduler_is_deterministic_if_seeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True)
|
||||
traces = []
|
||||
for _ in range(10):
|
||||
state = trio._core._run._r.getstate()
|
||||
try:
|
||||
trio._core._run._r.seed(0)
|
||||
traces.append(trio.run(scheduler_trace))
|
||||
finally:
|
||||
trio._core._run._r.setstate(state)
|
||||
|
||||
assert len(traces) == 10
|
||||
assert len(set(traces)) == 1
|
||||
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import signal
|
||||
from typing import TYPE_CHECKING, NoReturn
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.testing import RaisesGroup
|
||||
|
||||
from .. import _core
|
||||
from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver
|
||||
from .._util import signal_raise
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FrameType
|
||||
|
||||
|
||||
async def test_open_signal_receiver() -> None:
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
# Raise it a few times, to exercise signal coalescing, both at the
|
||||
# call_soon level and at the SignalQueue level
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGILL)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
signal_raise(signal.SIGILL)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
async for signum in receiver: # pragma: no branch
|
||||
assert signum == signal.SIGILL
|
||||
break
|
||||
assert get_pending_signal_count(receiver) == 0
|
||||
signal_raise(signal.SIGILL)
|
||||
async for signum in receiver: # pragma: no branch
|
||||
assert signum == signal.SIGILL
|
||||
break
|
||||
assert get_pending_signal_count(receiver) == 0
|
||||
with pytest.raises(RuntimeError):
|
||||
await receiver.__anext__()
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None:
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="(signal number out of range|invalid signal value)$",
|
||||
):
|
||||
with open_signal_receiver(signal.SIGILL, 1234567):
|
||||
pass # pragma: no cover
|
||||
# Still restored even if we errored out
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
async def test_open_signal_receiver_empty_fail() -> None:
|
||||
with pytest.raises(TypeError, match="No signals were provided"):
|
||||
with open_signal_receiver():
|
||||
pass
|
||||
|
||||
|
||||
async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None:
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGILL):
|
||||
pass
|
||||
# Still restored correctly
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
async def test_catch_signals_wrong_thread() -> None:
|
||||
async def naughty() -> None:
|
||||
with open_signal_receiver(signal.SIGINT):
|
||||
pass # pragma: no cover
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await trio.to_thread.run_sync(trio.run, naughty)
|
||||
|
||||
|
||||
async def test_open_signal_receiver_conflict() -> None:
|
||||
with RaisesGroup(trio.BusyResourceError):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver.__anext__)
|
||||
nursery.start_soon(receiver.__anext__)
|
||||
|
||||
|
||||
# Blocks until all previous calls to run_sync_soon(idempotent=True) have been
|
||||
# processed.
|
||||
async def wait_run_sync_soon_idempotent_queue_barrier() -> None:
|
||||
ev = trio.Event()
|
||||
token = _core.current_trio_token()
|
||||
token.run_sync_soon(ev.set, idempotent=True)
|
||||
await ev.wait()
|
||||
|
||||
|
||||
async def test_open_signal_receiver_no_starvation() -> None:
|
||||
# Set up a situation where there are always 2 pending signals available to
|
||||
# report, and make sure that instead of getting the same signal reported
|
||||
# over and over, it alternates between reporting both of them.
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
try:
|
||||
print(signal.getsignal(signal.SIGILL))
|
||||
previous = None
|
||||
for _ in range(10):
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
if previous is None:
|
||||
previous = await receiver.__anext__()
|
||||
else:
|
||||
got = await receiver.__anext__()
|
||||
assert got in [signal.SIGILL, signal.SIGFPE]
|
||||
assert got != previous
|
||||
previous = got
|
||||
# Clear out the last signal so that it doesn't get redelivered
|
||||
while get_pending_signal_count(receiver) != 0:
|
||||
await receiver.__anext__()
|
||||
except BaseException: # pragma: no cover
|
||||
# If there's an unhandled exception above, then exiting the
|
||||
# open_signal_receiver block might cause the signal to be
|
||||
# redelivered and give us a core dump instead of a traceback...
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def test_catch_signals_race_condition_on_exit() -> None:
|
||||
delivered_directly: set[int] = set()
|
||||
|
||||
def direct_handler(signo: int, frame: FrameType | None) -> None:
|
||||
delivered_directly.add(signo)
|
||||
|
||||
print(1)
|
||||
# Test the version where the call_soon *doesn't* have a chance to run
|
||||
# before we exit the with block:
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
|
||||
delivered_directly.clear()
|
||||
|
||||
print(2)
|
||||
# Test the version where the call_soon *does* have a chance to run before
|
||||
# we exit the with block:
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert get_pending_signal_count(receiver) == 2
|
||||
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
|
||||
delivered_directly.clear()
|
||||
|
||||
# Again, but with a SIG_IGN signal:
|
||||
|
||||
print(3)
|
||||
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
# test passes if the process reaches this point without dying
|
||||
|
||||
print(4)
|
||||
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert get_pending_signal_count(receiver) == 1
|
||||
# test passes if the process reaches this point without dying
|
||||
|
||||
# Check exception chaining if there are multiple exception-raising
|
||||
# handlers
|
||||
def raise_handler(signum: int, frame: FrameType | None) -> NoReturn:
|
||||
raise RuntimeError(signum)
|
||||
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler):
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert get_pending_signal_count(receiver) == 2
|
||||
exc = excinfo.value
|
||||
signums = {exc.args[0]}
|
||||
assert isinstance(exc.__context__, RuntimeError)
|
||||
signums.add(exc.__context__.args[0])
|
||||
assert signums == {signal.SIGILL, signal.SIGFPE}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,696 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path as SyncPath
|
||||
from signal import Signals
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
NoReturn,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.testing import Matcher, RaisesGroup
|
||||
|
||||
from .. import (
|
||||
Event,
|
||||
Process,
|
||||
_core,
|
||||
fail_after,
|
||||
move_on_after,
|
||||
run_process,
|
||||
sleep,
|
||||
sleep_forever,
|
||||
)
|
||||
from .._core._tests.tutil import skip_if_fbsd_pipes_broken, slow
|
||||
from ..lowlevel import open_process
|
||||
from ..testing import MockClock, assert_no_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FrameType
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .._abc import ReceiveStream
|
||||
|
||||
if sys.platform == "win32":
|
||||
SignalType: TypeAlias = None
|
||||
else:
|
||||
SignalType: TypeAlias = Signals
|
||||
|
||||
SIGKILL: SignalType
|
||||
SIGTERM: SignalType
|
||||
SIGUSR1: SignalType
|
||||
|
||||
posix = os.name == "posix"
|
||||
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
|
||||
from signal import SIGKILL, SIGTERM, SIGUSR1
|
||||
else:
|
||||
SIGKILL, SIGTERM, SIGUSR1 = None, None, None
|
||||
|
||||
|
||||
# Since Windows has very few command-line utilities generally available,
|
||||
# all of our subprocesses are Python processes running short bits of
|
||||
# (mostly) cross-platform code.
|
||||
def python(code: str) -> list[str]:
|
||||
return [sys.executable, "-u", "-c", "import sys; " + code]
|
||||
|
||||
|
||||
EXIT_TRUE = python("sys.exit(0)")
|
||||
EXIT_FALSE = python("sys.exit(1)")
|
||||
CAT = python("sys.stdout.buffer.write(sys.stdin.buffer.read())")
|
||||
|
||||
if posix:
|
||||
|
||||
def SLEEP(seconds: int) -> list[str]:
|
||||
return ["sleep", str(seconds)]
|
||||
|
||||
else:
|
||||
|
||||
def SLEEP(seconds: int) -> list[str]:
|
||||
return python(f"import time; time.sleep({seconds})")
|
||||
|
||||
|
||||
def got_signal(proc: Process, sig: SignalType) -> bool:
|
||||
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
|
||||
return proc.returncode == -sig
|
||||
else:
|
||||
return proc.returncode != 0
|
||||
|
||||
|
||||
@asynccontextmanager # type: ignore[misc] # Any in decorator
|
||||
async def open_process_then_kill(*args: Any, **kwargs: Any) -> AsyncIterator[Process]:
|
||||
proc = await open_process(*args, **kwargs)
|
||||
try:
|
||||
yield proc
|
||||
finally:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
|
||||
|
||||
@asynccontextmanager # type: ignore[misc] # Any in decorator
|
||||
async def run_process_in_nursery(*args: Any, **kwargs: Any) -> AsyncIterator[Process]:
|
||||
async with _core.open_nursery() as nursery:
|
||||
kwargs.setdefault("check", False)
|
||||
proc: Process = await nursery.start(partial(run_process, *args, **kwargs))
|
||||
yield proc
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
background_process_param = pytest.mark.parametrize(
|
||||
"background_process",
|
||||
[open_process_then_kill, run_process_in_nursery],
|
||||
ids=["open_process", "run_process in nursery"],
|
||||
)
|
||||
|
||||
BackgroundProcessType: TypeAlias = Callable[..., AsyncContextManager[Process]]
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_basic(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(EXIT_TRUE) as proc:
|
||||
await proc.wait()
|
||||
assert isinstance(proc, Process)
|
||||
assert proc._pidfd is None
|
||||
assert proc.returncode == 0
|
||||
assert repr(proc) == f"<trio.Process {EXIT_TRUE}: exited with status 0>"
|
||||
|
||||
async with background_process(EXIT_FALSE) as proc:
|
||||
await proc.wait()
|
||||
assert proc.returncode == 1
|
||||
assert repr(proc) == "<trio.Process {!r}: {}>".format(
|
||||
EXIT_FALSE,
|
||||
"exited with status 1",
|
||||
)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_auto_update_returncode(
|
||||
background_process: BackgroundProcessType,
|
||||
) -> None:
|
||||
async with background_process(SLEEP(9999)) as p:
|
||||
assert p.returncode is None
|
||||
assert "running" in repr(p)
|
||||
p.kill()
|
||||
p._proc.wait()
|
||||
assert p.returncode is not None
|
||||
assert "exited" in repr(p)
|
||||
assert p._pidfd is None
|
||||
assert p.returncode is not None
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_multi_wait(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(SLEEP(10)) as proc:
|
||||
# Check that wait (including multi-wait) tolerates being cancelled
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# Now try waiting for real
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
proc.kill()
|
||||
|
||||
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR = python(
|
||||
"data = sys.stdin.buffer.read(); "
|
||||
"sys.stdout.buffer.write(data); "
|
||||
"sys.stderr.buffer.write(data[::-1])",
|
||||
)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_pipes(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as proc:
|
||||
msg = b"the quick brown fox jumps over the lazy dog"
|
||||
|
||||
async def feed_input() -> None:
|
||||
assert proc.stdin is not None
|
||||
await proc.stdin.send_all(msg)
|
||||
await proc.stdin.aclose()
|
||||
|
||||
async def check_output(stream: ReceiveStream, expected: bytes) -> None:
|
||||
seen = bytearray()
|
||||
async for chunk in stream:
|
||||
seen += chunk
|
||||
assert seen == expected
|
||||
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is not None
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
# fail eventually if something is broken
|
||||
nursery.cancel_scope.deadline = _core.current_time() + 30.0
|
||||
nursery.start_soon(feed_input)
|
||||
nursery.start_soon(check_output, proc.stdout, msg)
|
||||
nursery.start_soon(check_output, proc.stderr, msg[::-1])
|
||||
|
||||
assert not nursery.cancel_scope.cancelled_caught
|
||||
assert await proc.wait() == 0
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_interactive(background_process: BackgroundProcessType) -> None:
|
||||
# Test some back-and-forth with a subprocess. This one works like so:
|
||||
# in: 32\n
|
||||
# out: 0000...0000\n (32 zeroes)
|
||||
# err: 1111...1111\n (64 ones)
|
||||
# in: 10\n
|
||||
# out: 2222222222\n (10 twos)
|
||||
# err: 3333....3333\n (20 threes)
|
||||
# in: EOF
|
||||
# out: EOF
|
||||
# err: EOF
|
||||
|
||||
async with background_process(
|
||||
python(
|
||||
"idx = 0\n"
|
||||
"while True:\n"
|
||||
" line = sys.stdin.readline()\n"
|
||||
" if line == '': break\n"
|
||||
" request = int(line.strip())\n"
|
||||
" print(str(idx * 2) * request)\n"
|
||||
" print(str(idx * 2 + 1) * request * 2, file=sys.stderr)\n"
|
||||
" idx += 1\n",
|
||||
),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as proc:
|
||||
newline = b"\n" if posix else b"\r\n"
|
||||
|
||||
async def expect(idx: int, request: int) -> None:
|
||||
async with _core.open_nursery() as nursery:
|
||||
|
||||
async def drain_one(
|
||||
stream: ReceiveStream,
|
||||
count: int,
|
||||
digit: int,
|
||||
) -> None:
|
||||
while count > 0:
|
||||
result = await stream.receive_some(count)
|
||||
assert result == (f"{digit}".encode() * len(result))
|
||||
count -= len(result)
|
||||
assert count == 0
|
||||
assert await stream.receive_some(len(newline)) == newline
|
||||
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is not None
|
||||
nursery.start_soon(drain_one, proc.stdout, request, idx * 2)
|
||||
nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1)
|
||||
|
||||
assert proc.stdin is not None
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is not None
|
||||
with fail_after(5):
|
||||
await proc.stdin.send_all(b"12")
|
||||
await sleep(0.1)
|
||||
await proc.stdin.send_all(b"345" + newline)
|
||||
await expect(0, 12345)
|
||||
await proc.stdin.send_all(b"100" + newline + b"200" + newline)
|
||||
await expect(1, 100)
|
||||
await expect(2, 200)
|
||||
await proc.stdin.send_all(b"0" + newline)
|
||||
await expect(3, 0)
|
||||
await proc.stdin.send_all(b"999999")
|
||||
with move_on_after(0.1) as scope:
|
||||
await expect(4, 0)
|
||||
assert scope.cancelled_caught
|
||||
await proc.stdin.send_all(newline)
|
||||
await expect(4, 999999)
|
||||
await proc.stdin.aclose()
|
||||
assert await proc.stdout.receive_some(1) == b""
|
||||
assert await proc.stderr.receive_some(1) == b""
|
||||
await proc.wait()
|
||||
|
||||
assert proc.returncode == 0
|
||||
|
||||
|
||||
async def test_run() -> None:
|
||||
data = bytes(random.randint(0, 255) for _ in range(2**18))
|
||||
|
||||
result = await run_process(
|
||||
CAT,
|
||||
stdin=data,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
)
|
||||
assert result.args == CAT
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == data
|
||||
assert result.stderr == b""
|
||||
|
||||
result = await run_process(CAT, capture_stdout=True)
|
||||
assert result.args == CAT
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == b""
|
||||
assert result.stderr is None
|
||||
|
||||
result = await run_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=data,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
)
|
||||
assert result.args == COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == data
|
||||
assert result.stderr == data[::-1]
|
||||
|
||||
# invalid combinations
|
||||
with pytest.raises(UnicodeError):
|
||||
await run_process(CAT, stdin="oh no, it's text")
|
||||
|
||||
pipe_stdout_error = r"^stdout=subprocess\.PIPE is only valid with nursery\.start, since that's the only way to access the pipe(; use nursery\.start or pass the data you want to write directly)*$"
|
||||
with pytest.raises(ValueError, match=pipe_stdout_error):
|
||||
await run_process(CAT, stdin=subprocess.PIPE)
|
||||
with pytest.raises(ValueError, match=pipe_stdout_error):
|
||||
await run_process(CAT, stdout=subprocess.PIPE)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=pipe_stdout_error.replace("stdout", "stderr", 1),
|
||||
):
|
||||
await run_process(CAT, stderr=subprocess.PIPE)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^can't specify both stdout and capture_stdout$",
|
||||
):
|
||||
await run_process(CAT, capture_stdout=True, stdout=subprocess.DEVNULL)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^can't specify both stderr and capture_stderr$",
|
||||
):
|
||||
await run_process(CAT, capture_stderr=True, stderr=None)
|
||||
|
||||
|
||||
async def test_run_check() -> None:
|
||||
cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)")
|
||||
with pytest.raises(subprocess.CalledProcessError) as excinfo:
|
||||
await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True)
|
||||
assert excinfo.value.cmd == cmd
|
||||
assert excinfo.value.returncode == 1
|
||||
assert excinfo.value.stderr == b"test\n"
|
||||
assert excinfo.value.stdout is None
|
||||
|
||||
result = await run_process(
|
||||
cmd,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
check=False,
|
||||
)
|
||||
assert result.args == cmd
|
||||
assert result.stdout == b""
|
||||
assert result.stderr == b"test\n"
|
||||
assert result.returncode == 1
|
||||
|
||||
|
||||
@skip_if_fbsd_pipes_broken
|
||||
async def test_run_with_broken_pipe() -> None:
|
||||
result = await run_process(
|
||||
[sys.executable, "-c", "import sys; sys.stdin.close()"],
|
||||
stdin=b"x" * 131072,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
assert result.stdout is result.stderr is None
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_stderr_stdout(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as proc:
|
||||
assert proc.stdio is not None
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is None
|
||||
await proc.stdio.send_all(b"1234")
|
||||
await proc.stdio.send_eof()
|
||||
|
||||
output = []
|
||||
while True:
|
||||
chunk = await proc.stdio.receive_some(16)
|
||||
if chunk == b"":
|
||||
break
|
||||
output.append(chunk)
|
||||
assert b"".join(output) == b"12344321"
|
||||
assert proc.returncode == 0
|
||||
|
||||
# equivalent test with run_process()
|
||||
result = await run_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=b"1234",
|
||||
capture_stdout=True,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == b"12344321"
|
||||
assert result.stderr is None
|
||||
|
||||
# this one hits the branch where stderr=STDOUT but stdout
|
||||
# is not redirected
|
||||
async with background_process(
|
||||
CAT,
|
||||
stdin=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as proc:
|
||||
assert proc.stdout is None
|
||||
assert proc.stderr is None
|
||||
await proc.stdin.aclose()
|
||||
await proc.wait()
|
||||
assert proc.returncode == 0
|
||||
|
||||
if posix:
|
||||
try:
|
||||
r, w = os.pipe()
|
||||
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=w,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as proc:
|
||||
os.close(w)
|
||||
assert proc.stdio is None
|
||||
assert proc.stdout is None
|
||||
assert proc.stderr is None
|
||||
await proc.stdin.send_all(b"1234")
|
||||
await proc.stdin.aclose()
|
||||
assert await proc.wait() == 0
|
||||
assert os.read(r, 4096) == b"12344321"
|
||||
assert os.read(r, 4096) == b""
|
||||
finally:
|
||||
os.close(r)
|
||||
|
||||
|
||||
async def test_errors() -> None:
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
# call-overload on unix, call-arg on windows
|
||||
await open_process(["ls"], encoding="utf-8") # type: ignore
|
||||
assert "unbuffered byte streams" in str(excinfo.value)
|
||||
assert "the 'encoding' option is not supported" in str(excinfo.value)
|
||||
|
||||
if posix:
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
await open_process(["ls"], shell=True)
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
await open_process("ls", shell=False)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_signals(background_process: BackgroundProcessType) -> None:
|
||||
async def test_one_signal(
|
||||
send_it: Callable[[Process], None],
|
||||
signum: signal.Signals | None,
|
||||
) -> None:
|
||||
with move_on_after(1.0) as scope:
|
||||
async with background_process(SLEEP(3600)) as proc:
|
||||
send_it(proc)
|
||||
await proc.wait()
|
||||
assert not scope.cancelled_caught
|
||||
if posix:
|
||||
assert signum is not None
|
||||
assert proc.returncode == -signum
|
||||
else:
|
||||
assert proc.returncode != 0
|
||||
|
||||
await test_one_signal(Process.kill, SIGKILL)
|
||||
await test_one_signal(Process.terminate, SIGTERM)
|
||||
# Test that we can send arbitrary signals.
|
||||
#
|
||||
# We used to use SIGINT here, but it turns out that the Python interpreter
|
||||
# has race conditions that can cause it to explode in weird ways if it
|
||||
# tries to handle SIGINT during startup. SIGUSR1's default disposition is
|
||||
# to terminate the target process, and Python doesn't try to do anything
|
||||
# clever to handle it.
|
||||
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
|
||||
await test_one_signal(lambda proc: proc.send_signal(SIGUSR1), SIGUSR1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not posix, reason="POSIX specific")
|
||||
@background_process_param
|
||||
async def test_wait_reapable_fails(background_process: BackgroundProcessType) -> None:
|
||||
if TYPE_CHECKING and sys.platform == "win32":
|
||||
return
|
||||
old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
|
||||
try:
|
||||
# With SIGCHLD disabled, the wait() syscall will wait for the
|
||||
# process to exit but then fail with ECHILD. Make sure we
|
||||
# support this case as the stdlib subprocess module does.
|
||||
async with background_process(SLEEP(3600)) as proc:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
proc.kill()
|
||||
nursery.cancel_scope.deadline = _core.current_time() + 1.0
|
||||
assert not nursery.cancel_scope.cancelled_caught
|
||||
assert proc.returncode == 0 # exit status unknowable, so...
|
||||
finally:
|
||||
signal.signal(signal.SIGCHLD, old_sigchld)
|
||||
|
||||
|
||||
@slow
|
||||
def test_waitid_eintr() -> None:
|
||||
# This only matters on PyPy (where we're coding EINTR handling
|
||||
# ourselves) but the test works on all waitid platforms.
|
||||
from .._subprocess_platform import wait_child_exiting
|
||||
|
||||
if TYPE_CHECKING and (sys.platform == "win32" or sys.platform == "darwin"):
|
||||
return
|
||||
|
||||
if not wait_child_exiting.__module__.endswith("waitid"):
|
||||
pytest.skip("waitid only")
|
||||
|
||||
# despite the TYPE_CHECKING early return silencing warnings about signal.SIGALRM etc
|
||||
# this import is still checked on win32&darwin and raises [attr-defined].
|
||||
# Linux doesn't raise [attr-defined] though, so we need [unused-ignore]
|
||||
from .._subprocess_platform.waitid import ( # type: ignore[attr-defined, unused-ignore]
|
||||
sync_wait_reapable,
|
||||
)
|
||||
|
||||
got_alarm = False
|
||||
sleeper = subprocess.Popen(["sleep", "3600"])
|
||||
|
||||
def on_alarm(sig: int, frame: FrameType | None) -> None:
|
||||
nonlocal got_alarm
|
||||
got_alarm = True
|
||||
sleeper.kill()
|
||||
|
||||
old_sigalrm = signal.signal(signal.SIGALRM, on_alarm)
|
||||
try:
|
||||
signal.alarm(1)
|
||||
sync_wait_reapable(sleeper.pid)
|
||||
assert sleeper.wait(timeout=1) == -9
|
||||
finally:
|
||||
if sleeper.returncode is None: # pragma: no cover
|
||||
# We only get here if something fails in the above;
|
||||
# if the test passes, wait() will reap the process
|
||||
sleeper.kill()
|
||||
sleeper.wait()
|
||||
signal.signal(signal.SIGALRM, old_sigalrm)
|
||||
|
||||
|
||||
async def test_custom_deliver_cancel() -> None:
|
||||
custom_deliver_cancel_called = False
|
||||
|
||||
async def custom_deliver_cancel(proc: Process) -> None:
|
||||
nonlocal custom_deliver_cancel_called
|
||||
custom_deliver_cancel_called = True
|
||||
proc.terminate()
|
||||
# Make sure this does get cancelled when the process exits, and that
|
||||
# the process really exited.
|
||||
try:
|
||||
await sleep_forever()
|
||||
finally:
|
||||
assert proc.returncode is not None
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel),
|
||||
)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
assert custom_deliver_cancel_called
|
||||
|
||||
|
||||
def test_bad_deliver_cancel() -> None:
|
||||
async def custom_deliver_cancel(proc: Process) -> None:
|
||||
proc.terminate()
|
||||
raise ValueError("foo")
|
||||
|
||||
async def do_stuff() -> None:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel),
|
||||
)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# double wrap from our nursery + the internal nursery
|
||||
with RaisesGroup(RaisesGroup(Matcher(ValueError, "^foo$"))):
|
||||
_core.run(do_stuff, strict_exception_groups=True)
|
||||
|
||||
|
||||
async def test_warn_on_failed_cancel_terminate(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
original_terminate = Process.terminate
|
||||
|
||||
def broken_terminate(self: Process) -> NoReturn:
|
||||
original_terminate(self)
|
||||
raise OSError("whoops")
|
||||
|
||||
monkeypatch.setattr(Process, "terminate", broken_terminate)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match=".*whoops.*"):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(run_process, SLEEP(9999))
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not posix, reason="posix only")
|
||||
async def test_warn_on_cancel_SIGKILL_escalation(
|
||||
autojump_clock: MockClock,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(Process, "terminate", lambda *args: None)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(run_process, SLEEP(9999))
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
# the background_process_param exercises a lot of run_process cases, but it uses
|
||||
# check=False, so lets have a test that uses check=True as well
|
||||
async def test_run_process_background_fail() -> None:
|
||||
with RaisesGroup(subprocess.CalledProcessError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
proc: Process = await nursery.start(run_process, EXIT_FALSE)
|
||||
assert proc.returncode == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not SyncPath("/dev/fd").exists(),
|
||||
reason="requires a way to iterate through open files",
|
||||
)
|
||||
async def test_for_leaking_fds() -> None:
|
||||
gc.collect() # address possible flakiness on PyPy
|
||||
|
||||
starting_fds = set(SyncPath("/dev/fd").iterdir())
|
||||
await run_process(EXIT_TRUE)
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
|
||||
|
||||
with pytest.raises(subprocess.CalledProcessError):
|
||||
await run_process(EXIT_FALSE)
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
|
||||
|
||||
with pytest.raises(PermissionError):
|
||||
await run_process(["/dev/fd/0"])
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
|
||||
|
||||
|
||||
async def test_run_process_internal_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# There's probably less extreme ways of triggering errors inside the nursery
|
||||
# in run_process.
|
||||
async def very_broken_open(*args: object, **kwargs: object) -> str:
|
||||
return "oops"
|
||||
|
||||
monkeypatch.setattr(trio._subprocess, "open_process", very_broken_open)
|
||||
with RaisesGroup(AttributeError, AttributeError):
|
||||
await run_process(EXIT_TRUE, capture_stdout=True)
|
||||
|
||||
|
||||
# regression test for #2209
|
||||
async def test_subprocess_pidfd_unnotified() -> None:
|
||||
noticed_exit = None
|
||||
|
||||
async def wait_and_tell(proc: Process) -> None:
|
||||
nonlocal noticed_exit
|
||||
noticed_exit = Event()
|
||||
await proc.wait()
|
||||
noticed_exit.set()
|
||||
|
||||
proc = await open_process(SLEEP(9999))
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(wait_and_tell, proc)
|
||||
await wait_all_tasks_blocked()
|
||||
assert isinstance(noticed_exit, Event)
|
||||
proc.terminate()
|
||||
# without giving trio a chance to do so,
|
||||
with assert_no_checkpoints():
|
||||
# wait until the process has actually exited;
|
||||
proc._proc.wait()
|
||||
# force a call to poll (that closes the pidfd on linux)
|
||||
proc.poll()
|
||||
with move_on_after(5):
|
||||
# Some platforms use threads to wait for exit, so it might take a bit
|
||||
# for everything to notice
|
||||
await noticed_exit.wait()
|
||||
assert noticed_exit.is_set(), "child task wasn't woken after poll, DEADLOCK"
|
||||
@@ -0,0 +1,655 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Callable, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from trio.testing import Matcher, RaisesGroup
|
||||
|
||||
from .. import _core
|
||||
from .._core._parking_lot import GLOBAL_PARKING_LOT_BREAKER
|
||||
from .._sync import *
|
||||
from .._timeouts import sleep_forever
|
||||
from ..testing import assert_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
|
||||
async def test_Event() -> None:
|
||||
e = Event()
|
||||
assert not e.is_set()
|
||||
assert e.statistics().tasks_waiting == 0
|
||||
|
||||
e.set()
|
||||
assert e.is_set()
|
||||
with assert_checkpoints():
|
||||
await e.wait()
|
||||
|
||||
e = Event()
|
||||
|
||||
record = []
|
||||
|
||||
async def child() -> None:
|
||||
record.append("sleeping")
|
||||
await e.wait()
|
||||
record.append("woken")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
nursery.start_soon(child)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["sleeping", "sleeping"]
|
||||
assert e.statistics().tasks_waiting == 2
|
||||
e.set()
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["sleeping", "sleeping", "woken", "woken"]
|
||||
|
||||
|
||||
async def test_CapacityLimiter() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
CapacityLimiter(1.0)
|
||||
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
|
||||
CapacityLimiter(-1)
|
||||
c = CapacityLimiter(2)
|
||||
repr(c) # smoke test
|
||||
assert c.total_tokens == 2
|
||||
assert c.borrowed_tokens == 0
|
||||
assert c.available_tokens == 2
|
||||
with pytest.raises(RuntimeError):
|
||||
c.release()
|
||||
assert c.borrowed_tokens == 0
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
assert c.available_tokens == 1
|
||||
|
||||
stats = c.statistics()
|
||||
assert stats.borrowed_tokens == 1
|
||||
assert stats.total_tokens == 2
|
||||
assert stats.borrowers == [_core.current_task()]
|
||||
assert stats.tasks_waiting == 0
|
||||
|
||||
# Can't re-acquire when we already have it
|
||||
with pytest.raises(RuntimeError):
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
with pytest.raises(RuntimeError):
|
||||
await c.acquire()
|
||||
assert c.borrowed_tokens == 1
|
||||
|
||||
# We can acquire on behalf of someone else though
|
||||
with assert_checkpoints():
|
||||
await c.acquire_on_behalf_of("someone")
|
||||
|
||||
# But then we've run out of capacity
|
||||
assert c.borrowed_tokens == 2
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
c.acquire_on_behalf_of_nowait("third party")
|
||||
|
||||
assert set(c.statistics().borrowers) == {_core.current_task(), "someone"}
|
||||
|
||||
# Until we release one
|
||||
c.release_on_behalf_of(_core.current_task())
|
||||
assert c.statistics().borrowers == ["someone"]
|
||||
|
||||
c.release_on_behalf_of("someone")
|
||||
assert c.borrowed_tokens == 0
|
||||
with assert_checkpoints():
|
||||
async with c:
|
||||
assert c.borrowed_tokens == 1
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
await c.acquire_on_behalf_of("value 1")
|
||||
await c.acquire_on_behalf_of("value 2")
|
||||
nursery.start_soon(c.acquire_on_behalf_of, "value 3")
|
||||
await wait_all_tasks_blocked()
|
||||
assert c.borrowed_tokens == 2
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.release_on_behalf_of("value 2")
|
||||
# Fairness:
|
||||
assert c.borrowed_tokens == 2
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
c.acquire_nowait()
|
||||
|
||||
c.release_on_behalf_of("value 3")
|
||||
c.release_on_behalf_of("value 1")
|
||||
|
||||
|
||||
async def test_CapacityLimiter_inf() -> None:
|
||||
from math import inf
|
||||
|
||||
c = CapacityLimiter(inf)
|
||||
repr(c) # smoke test
|
||||
assert c.total_tokens == inf
|
||||
assert c.borrowed_tokens == 0
|
||||
assert c.available_tokens == inf
|
||||
with pytest.raises(RuntimeError):
|
||||
c.release()
|
||||
assert c.borrowed_tokens == 0
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
assert c.available_tokens == inf
|
||||
|
||||
|
||||
async def test_CapacityLimiter_change_total_tokens() -> None:
|
||||
c = CapacityLimiter(2)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
c.total_tokens = 1.0
|
||||
|
||||
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
|
||||
c.total_tokens = 0
|
||||
|
||||
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
|
||||
c.total_tokens = -10
|
||||
|
||||
assert c.total_tokens == 2
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(5):
|
||||
nursery.start_soon(c.acquire_on_behalf_of, i)
|
||||
await wait_all_tasks_blocked()
|
||||
assert set(c.statistics().borrowers) == {0, 1}
|
||||
assert c.statistics().tasks_waiting == 3
|
||||
c.total_tokens += 2
|
||||
assert set(c.statistics().borrowers) == {0, 1, 2, 3}
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.total_tokens -= 3
|
||||
assert c.borrowed_tokens == 4
|
||||
assert c.total_tokens == 1
|
||||
c.release_on_behalf_of(0)
|
||||
c.release_on_behalf_of(1)
|
||||
c.release_on_behalf_of(2)
|
||||
assert set(c.statistics().borrowers) == {3}
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.release_on_behalf_of(3)
|
||||
assert set(c.statistics().borrowers) == {4}
|
||||
assert c.statistics().tasks_waiting == 0
|
||||
|
||||
|
||||
# regression test for issue #548
|
||||
async def test_CapacityLimiter_memleak_548() -> None:
|
||||
limiter = CapacityLimiter(total_tokens=1)
|
||||
await limiter.acquire()
|
||||
|
||||
async with _core.open_nursery() as n:
|
||||
n.start_soon(limiter.acquire)
|
||||
await wait_all_tasks_blocked() # give it a chance to run the task
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
# if this is 1, the acquire call (despite being killed) is still there in the task, and will
|
||||
# leak memory all the while the limiter is active
|
||||
assert len(limiter._pending_borrowers) == 0
|
||||
|
||||
|
||||
async def test_Semaphore() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
Semaphore(1.0) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="^initial value must be >= 0$"):
|
||||
Semaphore(-1)
|
||||
s = Semaphore(1)
|
||||
repr(s) # smoke test
|
||||
assert s.value == 1
|
||||
assert s.max_value is None
|
||||
s.release()
|
||||
assert s.value == 2
|
||||
assert s.statistics().tasks_waiting == 0
|
||||
s.acquire_nowait()
|
||||
assert s.value == 1
|
||||
with assert_checkpoints():
|
||||
await s.acquire()
|
||||
assert s.value == 0
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
s.acquire_nowait()
|
||||
|
||||
s.release()
|
||||
assert s.value == 1
|
||||
with assert_checkpoints():
|
||||
async with s:
|
||||
assert s.value == 0
|
||||
assert s.value == 1
|
||||
s.acquire_nowait()
|
||||
|
||||
record = []
|
||||
|
||||
async def do_acquire(s: Semaphore) -> None:
|
||||
record.append("started")
|
||||
await s.acquire()
|
||||
record.append("finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_acquire, s)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["started"]
|
||||
assert s.value == 0
|
||||
s.release()
|
||||
# Fairness:
|
||||
assert s.value == 0
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
s.acquire_nowait()
|
||||
assert record == ["started", "finished"]
|
||||
|
||||
|
||||
async def test_Semaphore_bounded() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
Semaphore(1, max_value=1.0) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="^max_values must be >= initial_value$"):
|
||||
Semaphore(2, max_value=1)
|
||||
bs = Semaphore(1, max_value=1)
|
||||
assert bs.max_value == 1
|
||||
repr(bs) # smoke test
|
||||
with pytest.raises(ValueError, match="^semaphore released too many times$"):
|
||||
bs.release()
|
||||
assert bs.value == 1
|
||||
bs.acquire_nowait()
|
||||
assert bs.value == 0
|
||||
bs.release()
|
||||
assert bs.value == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
|
||||
async def test_Lock_and_StrictFIFOLock(
|
||||
lockcls: type[Lock | StrictFIFOLock],
|
||||
) -> None:
|
||||
l = lockcls() # noqa
|
||||
assert not l.locked()
|
||||
|
||||
# make sure locks can be weakref'ed (gh-331)
|
||||
r = weakref.ref(l)
|
||||
assert r() is l
|
||||
|
||||
repr(l) # smoke test
|
||||
# make sure repr uses the right name for subclasses
|
||||
assert lockcls.__name__ in repr(l)
|
||||
with assert_checkpoints():
|
||||
async with l:
|
||||
assert l.locked()
|
||||
repr(l) # smoke test (repr branches on locked/unlocked)
|
||||
assert not l.locked()
|
||||
l.acquire_nowait()
|
||||
assert l.locked()
|
||||
l.release()
|
||||
assert not l.locked()
|
||||
with assert_checkpoints():
|
||||
await l.acquire()
|
||||
assert l.locked()
|
||||
l.release()
|
||||
assert not l.locked()
|
||||
|
||||
l.acquire_nowait()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Error out if we already own the lock
|
||||
l.acquire_nowait()
|
||||
l.release()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Error out if we don't own the lock
|
||||
l.release()
|
||||
|
||||
holder_task = None
|
||||
|
||||
async def holder() -> None:
|
||||
nonlocal holder_task
|
||||
holder_task = _core.current_task()
|
||||
async with l:
|
||||
await sleep_forever()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
assert not l.locked()
|
||||
nursery.start_soon(holder)
|
||||
await wait_all_tasks_blocked()
|
||||
assert l.locked()
|
||||
# WouldBlock if someone else holds the lock
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
l.acquire_nowait()
|
||||
# Can't release a lock someone else holds
|
||||
with pytest.raises(RuntimeError):
|
||||
l.release()
|
||||
|
||||
statistics = l.statistics()
|
||||
print(statistics)
|
||||
assert statistics.locked
|
||||
assert statistics.owner is holder_task
|
||||
assert statistics.tasks_waiting == 0
|
||||
|
||||
nursery.start_soon(holder)
|
||||
await wait_all_tasks_blocked()
|
||||
statistics = l.statistics()
|
||||
print(statistics)
|
||||
assert statistics.tasks_waiting == 1
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
statistics = l.statistics()
|
||||
assert not statistics.locked
|
||||
assert statistics.owner is None
|
||||
assert statistics.tasks_waiting == 0
|
||||
|
||||
|
||||
async def test_Condition() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
Condition(Semaphore(1)) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
Condition(StrictFIFOLock) # type: ignore[arg-type]
|
||||
l = Lock() # noqa
|
||||
c = Condition(l)
|
||||
assert not l.locked()
|
||||
assert not c.locked()
|
||||
with assert_checkpoints():
|
||||
await c.acquire()
|
||||
assert l.locked()
|
||||
assert c.locked()
|
||||
|
||||
c = Condition()
|
||||
assert not c.locked()
|
||||
c.acquire_nowait()
|
||||
assert c.locked()
|
||||
with pytest.raises(RuntimeError):
|
||||
c.acquire_nowait()
|
||||
c.release()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't wait without holding the lock
|
||||
await c.wait()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't notify without holding the lock
|
||||
c.notify()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't notify without holding the lock
|
||||
c.notify_all()
|
||||
|
||||
finished_waiters = set()
|
||||
|
||||
async def waiter(i: int) -> None:
|
||||
async with c:
|
||||
await c.wait()
|
||||
finished_waiters.add(i)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(3):
|
||||
nursery.start_soon(waiter, i)
|
||||
await wait_all_tasks_blocked()
|
||||
async with c:
|
||||
c.notify()
|
||||
assert c.locked()
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0}
|
||||
async with c:
|
||||
c.notify_all()
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0, 1, 2}
|
||||
|
||||
finished_waiters = set()
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(3):
|
||||
nursery.start_soon(waiter, i)
|
||||
await wait_all_tasks_blocked()
|
||||
async with c:
|
||||
c.notify(2)
|
||||
statistics = c.statistics()
|
||||
print(statistics)
|
||||
assert statistics.tasks_waiting == 1
|
||||
assert statistics.lock_statistics.tasks_waiting == 2
|
||||
# exiting the context manager hands off the lock to the first task
|
||||
assert c.statistics().lock_statistics.tasks_waiting == 1
|
||||
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0, 1}
|
||||
|
||||
async with c:
|
||||
c.notify_all()
|
||||
|
||||
# After being cancelled still hold the lock (!)
|
||||
# (Note that c.__aexit__ checks that we hold the lock as well)
|
||||
with _core.CancelScope() as scope:
|
||||
async with c:
|
||||
scope.cancel()
|
||||
try:
|
||||
await c.wait()
|
||||
finally:
|
||||
assert c.locked()
|
||||
|
||||
|
||||
from .._channel import open_memory_channel
|
||||
from .._sync import AsyncContextManagerMixin
|
||||
|
||||
# Three ways of implementing a Lock in terms of a channel. Used to let us put
|
||||
# the channel through the generic lock tests.
|
||||
|
||||
|
||||
class ChannelLock1(AsyncContextManagerMixin):
|
||||
def __init__(self, capacity: int) -> None:
|
||||
self.s, self.r = open_memory_channel[None](capacity)
|
||||
for _ in range(capacity - 1):
|
||||
self.s.send_nowait(None)
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
self.s.send_nowait(None)
|
||||
|
||||
async def acquire(self) -> None:
|
||||
await self.s.send(None)
|
||||
|
||||
def release(self) -> None:
|
||||
self.r.receive_nowait()
|
||||
|
||||
|
||||
class ChannelLock2(AsyncContextManagerMixin):
|
||||
def __init__(self) -> None:
|
||||
self.s, self.r = open_memory_channel[None](10)
|
||||
self.s.send_nowait(None)
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
self.r.receive_nowait()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
await self.r.receive()
|
||||
|
||||
def release(self) -> None:
|
||||
self.s.send_nowait(None)
|
||||
|
||||
|
||||
class ChannelLock3(AsyncContextManagerMixin):
|
||||
def __init__(self) -> None:
|
||||
self.s, self.r = open_memory_channel[None](0)
|
||||
# self.acquired is true when one task acquires the lock and
|
||||
# only becomes false when it's released and no tasks are
|
||||
# waiting to acquire.
|
||||
self.acquired = False
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
assert not self.acquired
|
||||
self.acquired = True
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if self.acquired:
|
||||
await self.s.send(None)
|
||||
else:
|
||||
self.acquired = True
|
||||
await _core.checkpoint()
|
||||
|
||||
def release(self) -> None:
|
||||
try:
|
||||
self.r.receive_nowait()
|
||||
except _core.WouldBlock:
|
||||
assert self.acquired
|
||||
self.acquired = False
|
||||
|
||||
|
||||
lock_factories = [
|
||||
lambda: CapacityLimiter(1),
|
||||
lambda: Semaphore(1),
|
||||
Lock,
|
||||
StrictFIFOLock,
|
||||
lambda: ChannelLock1(10),
|
||||
lambda: ChannelLock1(1),
|
||||
ChannelLock2,
|
||||
ChannelLock3,
|
||||
]
|
||||
lock_factory_names = [
|
||||
"CapacityLimiter(1)",
|
||||
"Semaphore(1)",
|
||||
"Lock",
|
||||
"StrictFIFOLock",
|
||||
"ChannelLock1(10)",
|
||||
"ChannelLock1(1)",
|
||||
"ChannelLock2",
|
||||
"ChannelLock3",
|
||||
]
|
||||
|
||||
generic_lock_test = pytest.mark.parametrize(
|
||||
"lock_factory",
|
||||
lock_factories,
|
||||
ids=lock_factory_names,
|
||||
)
|
||||
|
||||
LockLike: TypeAlias = Union[
|
||||
CapacityLimiter,
|
||||
Semaphore,
|
||||
Lock,
|
||||
StrictFIFOLock,
|
||||
ChannelLock1,
|
||||
ChannelLock2,
|
||||
ChannelLock3,
|
||||
]
|
||||
LockFactory: TypeAlias = Callable[[], LockLike]
|
||||
|
||||
|
||||
# Spawn a bunch of workers that take a lock and then yield; make sure that
|
||||
# only one worker is ever in the critical section at a time.
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_exclusion(lock_factory: LockFactory) -> None:
|
||||
LOOPS = 10
|
||||
WORKERS = 5
|
||||
in_critical_section = False
|
||||
acquires = 0
|
||||
|
||||
async def worker(lock_like: LockLike) -> None:
|
||||
nonlocal in_critical_section, acquires
|
||||
for _ in range(LOOPS):
|
||||
async with lock_like:
|
||||
acquires += 1
|
||||
assert not in_critical_section
|
||||
in_critical_section = True
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
assert in_critical_section
|
||||
in_critical_section = False
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
lock_like = lock_factory()
|
||||
for _ in range(WORKERS):
|
||||
nursery.start_soon(worker, lock_like)
|
||||
assert not in_critical_section
|
||||
assert acquires == LOOPS * WORKERS
|
||||
|
||||
|
||||
# Several workers queue on the same lock; make sure they each get it, in
|
||||
# order.
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_fifo_fairness(lock_factory: LockFactory) -> None:
|
||||
initial_order = []
|
||||
record = []
|
||||
LOOPS = 5
|
||||
|
||||
async def loopy(name: int, lock_like: LockLike) -> None:
|
||||
# Record the order each task was initially scheduled in
|
||||
initial_order.append(name)
|
||||
for _ in range(LOOPS):
|
||||
async with lock_like:
|
||||
record.append(name)
|
||||
|
||||
lock_like = lock_factory()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(loopy, 1, lock_like)
|
||||
nursery.start_soon(loopy, 2, lock_like)
|
||||
nursery.start_soon(loopy, 3, lock_like)
|
||||
# The first three could be in any order due to scheduling randomness,
|
||||
# but after that they should repeat in the same order
|
||||
for i in range(LOOPS):
|
||||
assert record[3 * i : 3 * (i + 1)] == initial_order
|
||||
|
||||
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_acquire_nowait_blocks_acquire(
|
||||
lock_factory: LockFactory,
|
||||
) -> None:
|
||||
lock_like = lock_factory()
|
||||
|
||||
record = []
|
||||
|
||||
async def lock_taker() -> None:
|
||||
record.append("started")
|
||||
async with lock_like:
|
||||
pass
|
||||
record.append("finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
lock_like.acquire_nowait()
|
||||
nursery.start_soon(lock_taker)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["started"]
|
||||
lock_like.release()
|
||||
|
||||
|
||||
async def test_lock_acquire_unowned_lock() -> None:
|
||||
"""Test that trying to acquire a lock whose owner has exited raises an error.
|
||||
see https://github.com/python-trio/trio/issues/3035
|
||||
"""
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
lock = trio.Lock()
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(lock.acquire)
|
||||
owner_str = re.escape(str(lock._lot.broken_by[0]))
|
||||
with pytest.raises(
|
||||
trio.BrokenResourceError,
|
||||
match=f"^Owner of this lock exited without releasing: {owner_str}$",
|
||||
):
|
||||
await lock.acquire()
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
|
||||
|
||||
async def test_lock_multiple_acquire() -> None:
|
||||
"""Test for error if awaiting on a lock whose owner exits without releasing.
|
||||
see https://github.com/python-trio/trio/issues/3035"""
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
lock = trio.Lock()
|
||||
with RaisesGroup(
|
||||
Matcher(
|
||||
trio.BrokenResourceError,
|
||||
match="^Owner of this lock exited without releasing: ",
|
||||
),
|
||||
):
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(lock.acquire)
|
||||
nursery.start_soon(lock.acquire)
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
|
||||
|
||||
async def test_lock_handover() -> None:
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
child_task: Task | None = None
|
||||
lock = trio.Lock()
|
||||
|
||||
# this task acquires the lock
|
||||
lock.acquire_nowait()
|
||||
assert GLOBAL_PARKING_LOT_BREAKER == {
|
||||
_core.current_task(): [
|
||||
lock._lot,
|
||||
],
|
||||
}
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(lock.acquire)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# hand over the lock to the child task
|
||||
lock.release()
|
||||
|
||||
# check values, and get the identifier out of the dict for later check
|
||||
assert len(GLOBAL_PARKING_LOT_BREAKER) == 1
|
||||
child_task = next(iter(GLOBAL_PARKING_LOT_BREAKER))
|
||||
assert GLOBAL_PARKING_LOT_BREAKER[child_task] == [lock._lot]
|
||||
|
||||
assert lock._lot.broken_by == [child_task]
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
@@ -0,0 +1,684 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# XX this should get broken up, like testing.py did
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from trio.testing import RaisesGroup
|
||||
|
||||
from .. import _core, sleep, socket as tsocket
|
||||
from .._core._tests.tutil import can_bind_ipv6
|
||||
from .._highlevel_generic import StapledStream, aclose_forcefully
|
||||
from .._highlevel_socket import SocketListener
|
||||
from ..testing import *
|
||||
from ..testing._check_streams import _assert_raises
|
||||
from ..testing._memory_streams import _UnboundedByteQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trio import Nursery
|
||||
from trio.abc import ReceiveStream, SendStream
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked() -> None:
|
||||
record = []
|
||||
|
||||
async def busy_bee() -> None:
|
||||
for _ in range(10):
|
||||
await _core.checkpoint()
|
||||
record.append("busy bee exhausted")
|
||||
|
||||
async def waiting_for_bee_to_leave() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
record.append("quiet at last!")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(busy_bee)
|
||||
nursery.start_soon(waiting_for_bee_to_leave)
|
||||
nursery.start_soon(waiting_for_bee_to_leave)
|
||||
|
||||
# check cancellation
|
||||
record = []
|
||||
|
||||
async def cancelled_while_waiting() -> None:
|
||||
try:
|
||||
await wait_all_tasks_blocked()
|
||||
except _core.Cancelled:
|
||||
record.append("ok")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(cancelled_while_waiting)
|
||||
nursery.cancel_scope.cancel()
|
||||
assert record == ["ok"]
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked_with_timeouts(mock_clock: MockClock) -> None:
|
||||
record = []
|
||||
|
||||
async def timeout_task() -> None:
|
||||
record.append("tt start")
|
||||
await sleep(5)
|
||||
record.append("tt finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(timeout_task)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["tt start"]
|
||||
mock_clock.jump(10)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["tt start", "tt finished"]
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked_with_cushion() -> None:
|
||||
record = []
|
||||
|
||||
async def blink() -> None:
|
||||
record.append("blink start")
|
||||
await sleep(0.01)
|
||||
await sleep(0.01)
|
||||
await sleep(0.01)
|
||||
record.append("blink end")
|
||||
|
||||
async def wait_no_cushion() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
record.append("wait_no_cushion end")
|
||||
|
||||
async def wait_small_cushion() -> None:
|
||||
await wait_all_tasks_blocked(0.02)
|
||||
record.append("wait_small_cushion end")
|
||||
|
||||
async def wait_big_cushion() -> None:
|
||||
await wait_all_tasks_blocked(0.03)
|
||||
record.append("wait_big_cushion end")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(blink)
|
||||
nursery.start_soon(wait_no_cushion)
|
||||
nursery.start_soon(wait_small_cushion)
|
||||
nursery.start_soon(wait_small_cushion)
|
||||
nursery.start_soon(wait_big_cushion)
|
||||
|
||||
assert record == [
|
||||
"blink start",
|
||||
"wait_no_cushion end",
|
||||
"blink end",
|
||||
"wait_small_cushion end",
|
||||
"wait_small_cushion end",
|
||||
"wait_big_cushion end",
|
||||
]
|
||||
|
||||
|
||||
################################################################
|
||||
|
||||
|
||||
async def test_assert_checkpoints(recwarn: pytest.WarningsRecorder) -> None:
|
||||
with assert_checkpoints():
|
||||
await _core.checkpoint()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_checkpoints():
|
||||
1 + 1 # noqa: B018 # "useless expression"
|
||||
|
||||
# partial yield cases
|
||||
# if you have a schedule point but not a cancel point, or vice-versa, then
|
||||
# that's not a checkpoint.
|
||||
for partial_yield in [
|
||||
_core.checkpoint_if_cancelled,
|
||||
_core.cancel_shielded_checkpoint,
|
||||
]:
|
||||
print(partial_yield)
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_checkpoints():
|
||||
await partial_yield()
|
||||
|
||||
# But both together count as a checkpoint
|
||||
with assert_checkpoints():
|
||||
await _core.checkpoint_if_cancelled()
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
|
||||
|
||||
async def test_assert_no_checkpoints(recwarn: pytest.WarningsRecorder) -> None:
|
||||
with assert_no_checkpoints():
|
||||
1 + 1 # noqa: B018 # "useless expression"
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await _core.checkpoint()
|
||||
|
||||
# partial yield cases
|
||||
# if you have a schedule point but not a cancel point, or vice-versa, then
|
||||
# that doesn't make *either* version of assert_{no_,}yields happy.
|
||||
for partial_yield in [
|
||||
_core.checkpoint_if_cancelled,
|
||||
_core.cancel_shielded_checkpoint,
|
||||
]:
|
||||
print(partial_yield)
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await partial_yield()
|
||||
|
||||
# And both together also count as a checkpoint
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await _core.checkpoint_if_cancelled()
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
|
||||
|
||||
################################################################
|
||||
|
||||
|
||||
async def test_Sequencer() -> None:
|
||||
record = []
|
||||
|
||||
def t(val: object) -> None:
|
||||
print(val)
|
||||
record.append(val)
|
||||
|
||||
async def f1(seq: Sequencer) -> None:
|
||||
async with seq(1):
|
||||
t(("f1", 1))
|
||||
async with seq(3):
|
||||
t(("f1", 3))
|
||||
async with seq(4):
|
||||
t(("f1", 4))
|
||||
|
||||
async def f2(seq: Sequencer) -> None:
|
||||
async with seq(0):
|
||||
t(("f2", 0))
|
||||
async with seq(2):
|
||||
t(("f2", 2))
|
||||
|
||||
seq = Sequencer()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(f1, seq)
|
||||
nursery.start_soon(f2, seq)
|
||||
async with seq(5):
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)]
|
||||
|
||||
seq = Sequencer()
|
||||
# Catches us if we try to reuse a sequence point:
|
||||
async with seq(0):
|
||||
pass
|
||||
with pytest.raises(RuntimeError):
|
||||
async with seq(0):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
async def test_Sequencer_cancel() -> None:
|
||||
# Killing a blocked task makes everything blow up
|
||||
record = []
|
||||
seq = Sequencer()
|
||||
|
||||
async def child(i: int) -> None:
|
||||
with _core.CancelScope() as scope:
|
||||
if i == 1:
|
||||
scope.cancel()
|
||||
try:
|
||||
async with seq(i):
|
||||
pass # pragma: no cover
|
||||
except RuntimeError:
|
||||
record.append(f"seq({i}) RuntimeError")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child, 1)
|
||||
nursery.start_soon(child, 2)
|
||||
async with seq(0):
|
||||
pass # pragma: no cover
|
||||
|
||||
assert record == ["seq(1) RuntimeError", "seq(2) RuntimeError"]
|
||||
|
||||
# Late arrivals also get errors
|
||||
with pytest.raises(RuntimeError):
|
||||
async with seq(3):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
################################################################
|
||||
async def test__assert_raises() -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
with _assert_raises(RuntimeError):
|
||||
1 + 1 # noqa: B018 # "useless expression"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
with _assert_raises(RuntimeError):
|
||||
"foo" + 1 # type: ignore[operator] # noqa: B018 # "useless expression"
|
||||
|
||||
with _assert_raises(RuntimeError):
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
# This is a private implementation detail, but it's complex enough to be worth
|
||||
# testing directly
|
||||
async def test__UnboundeByteQueue() -> None:
|
||||
ubq = _UnboundedByteQueue()
|
||||
|
||||
ubq.put(b"123")
|
||||
ubq.put(b"456")
|
||||
assert ubq.get_nowait(1) == b"1"
|
||||
assert ubq.get_nowait(10) == b"23456"
|
||||
ubq.put(b"789")
|
||||
assert ubq.get_nowait() == b"789"
|
||||
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
ubq.get_nowait(10)
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
ubq.get_nowait()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
ubq.put("string") # type: ignore[arg-type]
|
||||
|
||||
ubq.put(b"abc")
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get(10) == b"abc"
|
||||
ubq.put(b"def")
|
||||
ubq.put(b"ghi")
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get(1) == b"d"
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get() == b"efghi"
|
||||
|
||||
async def putter(data: bytes) -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
ubq.put(data)
|
||||
|
||||
async def getter(expect: bytes) -> None:
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get() == expect
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"xyz")
|
||||
nursery.start_soon(putter, b"xyz")
|
||||
|
||||
# Two gets at the same time -> BusyResourceError
|
||||
with RaisesGroup(_core.BusyResourceError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"asdf")
|
||||
nursery.start_soon(getter, b"asdf")
|
||||
|
||||
# Closing
|
||||
|
||||
ubq.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
ubq.put(b"---")
|
||||
|
||||
assert ubq.get_nowait(10) == b""
|
||||
assert ubq.get_nowait() == b""
|
||||
assert await ubq.get(10) == b""
|
||||
assert await ubq.get() == b""
|
||||
|
||||
# close is idempotent
|
||||
ubq.close()
|
||||
|
||||
# close wakes up blocked getters
|
||||
ubq2 = _UnboundedByteQueue()
|
||||
|
||||
async def closer() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
ubq2.close()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"")
|
||||
nursery.start_soon(closer)
|
||||
|
||||
|
||||
async def test_MemorySendStream() -> None:
|
||||
mss = MemorySendStream()
|
||||
|
||||
async def do_send_all(data: bytes) -> None:
|
||||
with assert_checkpoints():
|
||||
await mss.send_all(data)
|
||||
|
||||
await do_send_all(b"123")
|
||||
assert mss.get_data_nowait(1) == b"1"
|
||||
assert mss.get_data_nowait() == b"23"
|
||||
|
||||
with assert_checkpoints():
|
||||
await mss.wait_send_all_might_not_block()
|
||||
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
mss.get_data_nowait()
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
mss.get_data_nowait(10)
|
||||
|
||||
await do_send_all(b"456")
|
||||
with assert_checkpoints():
|
||||
assert await mss.get_data() == b"456"
|
||||
|
||||
# Call send_all twice at once; one should get BusyResourceError and one
|
||||
# should succeed. But we can't let the error propagate, because it might
|
||||
# cause the other to be cancelled before it can finish doing its thing,
|
||||
# and we don't know which one will get the error.
|
||||
resource_busy_count = 0
|
||||
|
||||
async def do_send_all_count_resourcebusy() -> None:
|
||||
nonlocal resource_busy_count
|
||||
try:
|
||||
await do_send_all(b"xxx")
|
||||
except _core.BusyResourceError:
|
||||
resource_busy_count += 1
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all_count_resourcebusy)
|
||||
nursery.start_soon(do_send_all_count_resourcebusy)
|
||||
|
||||
assert resource_busy_count == 1
|
||||
|
||||
with assert_checkpoints():
|
||||
await mss.aclose()
|
||||
|
||||
assert await mss.get_data() == b"xxx"
|
||||
assert await mss.get_data() == b""
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"---")
|
||||
|
||||
# hooks
|
||||
|
||||
assert mss.send_all_hook is None
|
||||
assert mss.wait_send_all_might_not_block_hook is None
|
||||
assert mss.close_hook is None
|
||||
|
||||
record = []
|
||||
|
||||
async def send_all_hook() -> None:
|
||||
# hook runs after send_all does its work (can pull data out)
|
||||
assert mss2.get_data_nowait() == b"abc"
|
||||
record.append("send_all_hook")
|
||||
|
||||
async def wait_send_all_might_not_block_hook() -> None:
|
||||
record.append("wait_send_all_might_not_block_hook")
|
||||
|
||||
def close_hook() -> None:
|
||||
record.append("close_hook")
|
||||
|
||||
mss2 = MemorySendStream(
|
||||
send_all_hook,
|
||||
wait_send_all_might_not_block_hook,
|
||||
close_hook,
|
||||
)
|
||||
|
||||
assert mss2.send_all_hook is send_all_hook
|
||||
assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook
|
||||
assert mss2.close_hook is close_hook
|
||||
|
||||
await mss2.send_all(b"abc")
|
||||
await mss2.wait_send_all_might_not_block()
|
||||
await aclose_forcefully(mss2)
|
||||
mss2.close()
|
||||
|
||||
assert record == [
|
||||
"send_all_hook",
|
||||
"wait_send_all_might_not_block_hook",
|
||||
"close_hook",
|
||||
"close_hook",
|
||||
]
|
||||
|
||||
|
||||
async def test_MemoryReceiveStream() -> None:
|
||||
mrs = MemoryReceiveStream()
|
||||
|
||||
async def do_receive_some(max_bytes: int | None) -> bytes:
|
||||
with assert_checkpoints():
|
||||
return await mrs.receive_some(max_bytes)
|
||||
|
||||
mrs.put_data(b"abc")
|
||||
assert await do_receive_some(1) == b"a"
|
||||
assert await do_receive_some(10) == b"bc"
|
||||
mrs.put_data(b"abc")
|
||||
assert await do_receive_some(None) == b"abc"
|
||||
|
||||
with RaisesGroup(_core.BusyResourceError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_receive_some, 10)
|
||||
nursery.start_soon(do_receive_some, 10)
|
||||
|
||||
assert mrs.receive_some_hook is None
|
||||
|
||||
mrs.put_data(b"def")
|
||||
mrs.put_eof()
|
||||
mrs.put_eof()
|
||||
|
||||
assert await do_receive_some(10) == b"def"
|
||||
assert await do_receive_some(10) == b""
|
||||
assert await do_receive_some(10) == b""
|
||||
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
mrs.put_data(b"---")
|
||||
|
||||
async def receive_some_hook() -> None:
|
||||
mrs2.put_data(b"xxx")
|
||||
|
||||
record = []
|
||||
|
||||
def close_hook() -> None:
|
||||
record.append("closed")
|
||||
|
||||
mrs2 = MemoryReceiveStream(receive_some_hook, close_hook)
|
||||
assert mrs2.receive_some_hook is receive_some_hook
|
||||
assert mrs2.close_hook is close_hook
|
||||
|
||||
mrs2.put_data(b"yyy")
|
||||
assert await mrs2.receive_some(10) == b"yyyxxx"
|
||||
assert await mrs2.receive_some(10) == b"xxx"
|
||||
assert await mrs2.receive_some(10) == b"xxx"
|
||||
|
||||
mrs2.put_data(b"zzz")
|
||||
mrs2.receive_some_hook = None
|
||||
assert await mrs2.receive_some(10) == b"zzz"
|
||||
|
||||
mrs2.put_data(b"lost on close")
|
||||
with assert_checkpoints():
|
||||
await mrs2.aclose()
|
||||
assert record == ["closed"]
|
||||
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await mrs2.receive_some(10)
|
||||
|
||||
|
||||
async def test_MemoryRecvStream_closing() -> None:
|
||||
mrs = MemoryReceiveStream()
|
||||
# close with no pending data
|
||||
mrs.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
assert await mrs.receive_some(10) == b""
|
||||
# repeated closes ok
|
||||
mrs.close()
|
||||
# put_data now fails
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
mrs.put_data(b"123")
|
||||
|
||||
mrs2 = MemoryReceiveStream()
|
||||
# close with pending data
|
||||
mrs2.put_data(b"xyz")
|
||||
mrs2.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await mrs2.receive_some(10)
|
||||
|
||||
|
||||
async def test_memory_stream_pump() -> None:
|
||||
mss = MemorySendStream()
|
||||
mrs = MemoryReceiveStream()
|
||||
|
||||
# no-op if no data present
|
||||
memory_stream_pump(mss, mrs)
|
||||
|
||||
await mss.send_all(b"123")
|
||||
memory_stream_pump(mss, mrs)
|
||||
assert await mrs.receive_some(10) == b"123"
|
||||
|
||||
await mss.send_all(b"456")
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert await mrs.receive_some(10) == b"4"
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert not memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert await mrs.receive_some(10) == b"56"
|
||||
|
||||
mss.close()
|
||||
memory_stream_pump(mss, mrs)
|
||||
assert await mrs.receive_some(10) == b""
|
||||
|
||||
|
||||
async def test_memory_stream_one_way_pair() -> None:
|
||||
s, r = memory_stream_one_way_pair()
|
||||
assert s.send_all_hook is not None
|
||||
assert s.wait_send_all_might_not_block_hook is None
|
||||
assert s.close_hook is not None
|
||||
assert r.receive_some_hook is None
|
||||
await s.send_all(b"123")
|
||||
assert await r.receive_some(10) == b"123"
|
||||
|
||||
async def receiver(expected: bytes) -> None:
|
||||
assert await r.receive_some(10) == expected
|
||||
|
||||
# This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"abc")
|
||||
await wait_all_tasks_blocked()
|
||||
await s.send_all(b"abc")
|
||||
|
||||
# And this fails if we don't pump from close_hook
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"")
|
||||
await wait_all_tasks_blocked()
|
||||
await s.aclose()
|
||||
|
||||
s, r = memory_stream_one_way_pair()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"")
|
||||
await wait_all_tasks_blocked()
|
||||
s.close()
|
||||
|
||||
s, r = memory_stream_one_way_pair()
|
||||
|
||||
old = s.send_all_hook
|
||||
s.send_all_hook = None
|
||||
await s.send_all(b"456")
|
||||
|
||||
async def cancel_after_idle(nursery: Nursery) -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async def check_for_cancel() -> None:
|
||||
with pytest.raises(_core.Cancelled):
|
||||
# This should block forever... or until cancelled. Even though we
|
||||
# sent some data on the send stream.
|
||||
await r.receive_some(10)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(cancel_after_idle, nursery)
|
||||
nursery.start_soon(check_for_cancel)
|
||||
|
||||
s.send_all_hook = old
|
||||
await s.send_all(b"789")
|
||||
assert await r.receive_some(10) == b"456789"
|
||||
|
||||
|
||||
async def test_memory_stream_pair() -> None:
|
||||
a, b = memory_stream_pair()
|
||||
await a.send_all(b"123")
|
||||
await b.send_all(b"abc")
|
||||
assert await b.receive_some(10) == b"123"
|
||||
assert await a.receive_some(10) == b"abc"
|
||||
|
||||
await a.send_eof()
|
||||
assert await b.receive_some(10) == b""
|
||||
|
||||
async def sender() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
await b.send_all(b"xyz")
|
||||
|
||||
async def receiver() -> None:
|
||||
assert await a.receive_some(10) == b"xyz"
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver)
|
||||
nursery.start_soon(sender)
|
||||
|
||||
|
||||
async def test_memory_streams_with_generic_tests() -> None:
|
||||
async def one_way_stream_maker() -> tuple[MemorySendStream, MemoryReceiveStream]:
|
||||
return memory_stream_one_way_pair()
|
||||
|
||||
await check_one_way_stream(one_way_stream_maker, None)
|
||||
|
||||
async def half_closeable_stream_maker() -> tuple[
|
||||
StapledStream[MemorySendStream, MemoryReceiveStream],
|
||||
StapledStream[MemorySendStream, MemoryReceiveStream],
|
||||
]:
|
||||
return memory_stream_pair()
|
||||
|
||||
await check_half_closeable_stream(half_closeable_stream_maker, None)
|
||||
|
||||
|
||||
async def test_lockstep_streams_with_generic_tests() -> None:
|
||||
async def one_way_stream_maker() -> tuple[SendStream, ReceiveStream]:
|
||||
return lockstep_stream_one_way_pair()
|
||||
|
||||
await check_one_way_stream(one_way_stream_maker, one_way_stream_maker)
|
||||
|
||||
async def two_way_stream_maker() -> tuple[
|
||||
StapledStream[SendStream, ReceiveStream],
|
||||
StapledStream[SendStream, ReceiveStream],
|
||||
]:
|
||||
return lockstep_stream_pair()
|
||||
|
||||
await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
|
||||
|
||||
|
||||
async def test_open_stream_to_socket_listener() -> None:
|
||||
async def check(listener: SocketListener) -> None:
|
||||
async with listener:
|
||||
client_stream = await open_stream_to_socket_listener(listener)
|
||||
async with client_stream:
|
||||
server_stream = await listener.accept()
|
||||
async with server_stream:
|
||||
await client_stream.send_all(b"x")
|
||||
assert await server_stream.receive_some(1) == b"x"
|
||||
|
||||
# Listener bound to localhost
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("127.0.0.1", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
# Listener bound to IPv4 wildcard (needs special handling)
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("0.0.0.0", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
# true on all CI systems
|
||||
if can_bind_ipv6: # pragma: no branch
|
||||
# Listener bound to IPv6 wildcard (needs special handling)
|
||||
sock = tsocket.socket(family=tsocket.AF_INET6)
|
||||
await sock.bind(("::", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
if hasattr(tsocket, "AF_UNIX"):
|
||||
# Listener bound to Unix-domain socket
|
||||
sock = tsocket.socket(family=tsocket.AF_UNIX)
|
||||
# can't use pytest's tmpdir; if we try then macOS says "OSError:
|
||||
# AF_UNIX path too long"
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = f"{tmpdir}/sock"
|
||||
await sock.bind(path)
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
|
||||
def test_trio_test() -> None:
|
||||
async def busy_kitchen(
|
||||
*,
|
||||
mock_clock: object,
|
||||
autojump_clock: object,
|
||||
) -> None: ... # pragma: no cover
|
||||
|
||||
with pytest.raises(ValueError, match="^too many clocks spoil the broth!$"):
|
||||
trio_test(busy_kitchen)(
|
||||
mock_clock=MockClock(),
|
||||
autojump_clock=MockClock(autojump_threshold=0),
|
||||
)
|
||||
@@ -0,0 +1,374 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.testing import Matcher, RaisesGroup
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import ExceptionGroup
|
||||
|
||||
|
||||
def wrap_escape(s: str) -> str:
|
||||
return "^" + re.escape(s) + "$"
|
||||
|
||||
|
||||
def test_raises_group() -> None:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=wrap_escape(
|
||||
f'Invalid argument "{TypeError()!r}" must be exception type, Matcher, or RaisesGroup.',
|
||||
),
|
||||
):
|
||||
RaisesGroup(TypeError())
|
||||
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("foo", (ValueError(),))
|
||||
|
||||
with RaisesGroup(SyntaxError):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("foo", (SyntaxError(),))
|
||||
|
||||
# multiple exceptions
|
||||
with RaisesGroup(ValueError, SyntaxError):
|
||||
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
|
||||
|
||||
# order doesn't matter
|
||||
with RaisesGroup(SyntaxError, ValueError):
|
||||
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
|
||||
|
||||
# nested exceptions
|
||||
with RaisesGroup(RaisesGroup(ValueError)):
|
||||
raise ExceptionGroup("foo", (ExceptionGroup("bar", (ValueError(),)),))
|
||||
|
||||
with RaisesGroup(
|
||||
SyntaxError,
|
||||
RaisesGroup(ValueError),
|
||||
RaisesGroup(RuntimeError),
|
||||
):
|
||||
raise ExceptionGroup(
|
||||
"foo",
|
||||
(
|
||||
SyntaxError(),
|
||||
ExceptionGroup("bar", (ValueError(),)),
|
||||
ExceptionGroup("", (RuntimeError(),)),
|
||||
),
|
||||
)
|
||||
|
||||
# will error if there's excess exceptions
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("", (ValueError(), ValueError()))
|
||||
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("", (RuntimeError(), ValueError()))
|
||||
|
||||
# will error if there's missing exceptions
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError, ValueError):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError, SyntaxError):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
|
||||
|
||||
def test_flatten_subgroups() -> None:
|
||||
# loose semantics, as with expect*
|
||||
with RaisesGroup(ValueError, flatten_subgroups=True):
|
||||
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
|
||||
|
||||
with RaisesGroup(ValueError, TypeError, flatten_subgroups=True):
|
||||
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(), TypeError())),))
|
||||
with RaisesGroup(ValueError, TypeError, flatten_subgroups=True):
|
||||
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()]), TypeError()])
|
||||
|
||||
# mixed loose is possible if you want it to be at least N deep
|
||||
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
|
||||
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
|
||||
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
|
||||
raise ExceptionGroup(
|
||||
"",
|
||||
(ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),),
|
||||
)
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
|
||||
# but not the other way around
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^You cannot specify a nested structure inside a RaisesGroup with",
|
||||
):
|
||||
RaisesGroup(RaisesGroup(ValueError), flatten_subgroups=True) # type: ignore[call-overload]
|
||||
|
||||
|
||||
def test_catch_unwrapped_exceptions() -> None:
|
||||
# Catches lone exceptions with strict=False
|
||||
# just as except* would
|
||||
with RaisesGroup(ValueError, allow_unwrapped=True):
|
||||
raise ValueError
|
||||
|
||||
# expecting multiple unwrapped exceptions is not possible
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^You cannot specify multiple exceptions with",
|
||||
):
|
||||
RaisesGroup(SyntaxError, ValueError, allow_unwrapped=True) # type: ignore[call-overload]
|
||||
# if users want one of several exception types they need to use a Matcher
|
||||
# (which the error message suggests)
|
||||
with RaisesGroup(
|
||||
Matcher(check=lambda e: isinstance(e, (SyntaxError, ValueError))),
|
||||
allow_unwrapped=True,
|
||||
):
|
||||
raise ValueError
|
||||
|
||||
# Unwrapped nested `RaisesGroup` is likely a user error, so we raise an error.
|
||||
with pytest.raises(ValueError, match="has no effect when expecting"):
|
||||
RaisesGroup(RaisesGroup(ValueError), allow_unwrapped=True) # type: ignore[call-overload]
|
||||
|
||||
# But it *can* be used to check for nesting level +- 1 if they move it to
|
||||
# the nested RaisesGroup. Users should probably use `Matcher`s instead though.
|
||||
with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)):
|
||||
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
|
||||
with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)):
|
||||
raise ExceptionGroup("", [ValueError()])
|
||||
|
||||
# with allow_unwrapped=False (default) it will not be caught
|
||||
with pytest.raises(ValueError, match="^value error text$"):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ValueError("value error text")
|
||||
|
||||
# allow_unwrapped on it's own won't match against nested groups
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError, allow_unwrapped=True):
|
||||
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
|
||||
|
||||
# for that you need both allow_unwrapped and flatten_subgroups
|
||||
with RaisesGroup(ValueError, allow_unwrapped=True, flatten_subgroups=True):
|
||||
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
|
||||
|
||||
# code coverage
|
||||
with pytest.raises(TypeError):
|
||||
with RaisesGroup(ValueError, allow_unwrapped=True):
|
||||
raise TypeError
|
||||
|
||||
|
||||
def test_match() -> None:
|
||||
# supports match string
|
||||
with RaisesGroup(ValueError, match="bar"):
|
||||
raise ExceptionGroup("bar", (ValueError(),))
|
||||
|
||||
# now also works with ^$
|
||||
with RaisesGroup(ValueError, match="^bar$"):
|
||||
raise ExceptionGroup("bar", (ValueError(),))
|
||||
|
||||
# it also includes notes
|
||||
with RaisesGroup(ValueError, match="my note"):
|
||||
e = ExceptionGroup("bar", (ValueError(),))
|
||||
e.add_note("my note")
|
||||
raise e
|
||||
|
||||
# and technically you can match it all with ^$
|
||||
# but you're probably better off using a Matcher at that point
|
||||
with RaisesGroup(ValueError, match="^bar\nmy note$"):
|
||||
e = ExceptionGroup("bar", (ValueError(),))
|
||||
e.add_note("my note")
|
||||
raise e
|
||||
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError, match="foo"):
|
||||
raise ExceptionGroup("bar", (ValueError(),))
|
||||
|
||||
|
||||
def test_check() -> None:
|
||||
exc = ExceptionGroup("", (ValueError(),))
|
||||
with RaisesGroup(ValueError, check=lambda x: x is exc):
|
||||
raise exc
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError, check=lambda x: x is exc):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
|
||||
|
||||
def test_unwrapped_match_check() -> None:
|
||||
def my_check(e: object) -> bool: # pragma: no cover
|
||||
return True
|
||||
|
||||
msg = (
|
||||
"`allow_unwrapped=True` bypasses the `match` and `check` parameters"
|
||||
" if the exception is unwrapped. If you intended to match/check the"
|
||||
" exception you should use a `Matcher` object. If you want to match/check"
|
||||
" the exceptiongroup when the exception *is* wrapped you need to"
|
||||
" do e.g. `if isinstance(exc.value, ExceptionGroup):"
|
||||
" assert RaisesGroup(...).matches(exc.value)` afterwards."
|
||||
)
|
||||
with pytest.raises(ValueError, match=re.escape(msg)):
|
||||
RaisesGroup(ValueError, allow_unwrapped=True, match="foo") # type: ignore[call-overload]
|
||||
with pytest.raises(ValueError, match=re.escape(msg)):
|
||||
RaisesGroup(ValueError, allow_unwrapped=True, check=my_check) # type: ignore[call-overload]
|
||||
|
||||
# Users should instead use a Matcher
|
||||
rg = RaisesGroup(Matcher(ValueError, match="^foo$"), allow_unwrapped=True)
|
||||
with rg:
|
||||
raise ValueError("foo")
|
||||
with rg:
|
||||
raise ExceptionGroup("", [ValueError("foo")])
|
||||
|
||||
# or if they wanted to match/check the group, do a conditional `.matches()`
|
||||
with RaisesGroup(ValueError, allow_unwrapped=True) as exc:
|
||||
raise ExceptionGroup("bar", [ValueError("foo")])
|
||||
if isinstance(exc.value, ExceptionGroup): # pragma: no branch
|
||||
assert RaisesGroup(ValueError, match="bar").matches(exc.value)
|
||||
|
||||
|
||||
def test_RaisesGroup_matches() -> None:
|
||||
rg = RaisesGroup(ValueError)
|
||||
assert not rg.matches(None)
|
||||
assert not rg.matches(ValueError())
|
||||
assert rg.matches(ExceptionGroup("", (ValueError(),)))
|
||||
|
||||
|
||||
def test_message() -> None:
|
||||
def check_message(message: str, body: RaisesGroup[Any]) -> None:
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match=f"^DID NOT RAISE any exception, expected {re.escape(message)}$",
|
||||
):
|
||||
with body:
|
||||
...
|
||||
|
||||
# basic
|
||||
check_message("ExceptionGroup(ValueError)", RaisesGroup(ValueError))
|
||||
# multiple exceptions
|
||||
check_message(
|
||||
"ExceptionGroup(ValueError, ValueError)",
|
||||
RaisesGroup(ValueError, ValueError),
|
||||
)
|
||||
# nested
|
||||
check_message(
|
||||
"ExceptionGroup(ExceptionGroup(ValueError))",
|
||||
RaisesGroup(RaisesGroup(ValueError)),
|
||||
)
|
||||
|
||||
# Matcher
|
||||
check_message(
|
||||
"ExceptionGroup(Matcher(ValueError, match='my_str'))",
|
||||
RaisesGroup(Matcher(ValueError, "my_str")),
|
||||
)
|
||||
check_message(
|
||||
"ExceptionGroup(Matcher(match='my_str'))",
|
||||
RaisesGroup(Matcher(match="my_str")),
|
||||
)
|
||||
|
||||
# BaseExceptionGroup
|
||||
check_message(
|
||||
"BaseExceptionGroup(KeyboardInterrupt)",
|
||||
RaisesGroup(KeyboardInterrupt),
|
||||
)
|
||||
# BaseExceptionGroup with type inside Matcher
|
||||
check_message(
|
||||
"BaseExceptionGroup(Matcher(KeyboardInterrupt))",
|
||||
RaisesGroup(Matcher(KeyboardInterrupt)),
|
||||
)
|
||||
# Base-ness transfers to parent containers
|
||||
check_message(
|
||||
"BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt))",
|
||||
RaisesGroup(RaisesGroup(KeyboardInterrupt)),
|
||||
)
|
||||
# but not to child containers
|
||||
check_message(
|
||||
"BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt), ExceptionGroup(ValueError))",
|
||||
RaisesGroup(RaisesGroup(KeyboardInterrupt), RaisesGroup(ValueError)),
|
||||
)
|
||||
|
||||
|
||||
def test_matcher() -> None:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^You must specify at least one parameter to match on.$",
|
||||
):
|
||||
Matcher() # type: ignore[call-overload]
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"^exception_type {re.escape(repr(object))} must be a subclass of BaseException$",
|
||||
):
|
||||
Matcher(object) # type: ignore[type-var]
|
||||
|
||||
with RaisesGroup(Matcher(ValueError)):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(Matcher(TypeError)):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
|
||||
|
||||
def test_matcher_match() -> None:
|
||||
with RaisesGroup(Matcher(ValueError, "foo")):
|
||||
raise ExceptionGroup("", (ValueError("foo"),))
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(Matcher(ValueError, "foo")):
|
||||
raise ExceptionGroup("", (ValueError("bar"),))
|
||||
|
||||
# Can be used without specifying the type
|
||||
with RaisesGroup(Matcher(match="foo")):
|
||||
raise ExceptionGroup("", (ValueError("foo"),))
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(Matcher(match="foo")):
|
||||
raise ExceptionGroup("", (ValueError("bar"),))
|
||||
|
||||
# check ^$
|
||||
with RaisesGroup(Matcher(ValueError, match="^bar$")):
|
||||
raise ExceptionGroup("", [ValueError("bar")])
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(Matcher(ValueError, match="^bar$")):
|
||||
raise ExceptionGroup("", [ValueError("barr")])
|
||||
|
||||
|
||||
def test_Matcher_check() -> None:
|
||||
def check_oserror_and_errno_is_5(e: BaseException) -> bool:
|
||||
return isinstance(e, OSError) and e.errno == 5
|
||||
|
||||
with RaisesGroup(Matcher(check=check_oserror_and_errno_is_5)):
|
||||
raise ExceptionGroup("", (OSError(5, ""),))
|
||||
|
||||
# specifying exception_type narrows the parameter type to the callable
|
||||
def check_errno_is_5(e: OSError) -> bool:
|
||||
return e.errno == 5
|
||||
|
||||
with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
|
||||
raise ExceptionGroup("", (OSError(5, ""),))
|
||||
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
|
||||
raise ExceptionGroup("", (OSError(6, ""),))
|
||||
|
||||
|
||||
def test_matcher_tostring() -> None:
|
||||
assert str(Matcher(ValueError)) == "Matcher(ValueError)"
|
||||
assert str(Matcher(match="[a-z]")) == "Matcher(match='[a-z]')"
|
||||
pattern_no_flags = re.compile("noflag", 0)
|
||||
assert str(Matcher(match=pattern_no_flags)) == "Matcher(match='noflag')"
|
||||
pattern_flags = re.compile("noflag", re.IGNORECASE)
|
||||
assert str(Matcher(match=pattern_flags)) == f"Matcher(match={pattern_flags!r})"
|
||||
assert (
|
||||
str(Matcher(ValueError, match="re", check=bool))
|
||||
== f"Matcher(ValueError, match='re', check={bool!r})"
|
||||
)
|
||||
|
||||
|
||||
def test__ExceptionInfo(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
trio.testing._raises_group,
|
||||
"ExceptionInfo",
|
||||
trio.testing._raises_group._ExceptionInfo,
|
||||
)
|
||||
with trio.testing.RaisesGroup(ValueError) as excinfo:
|
||||
raise ExceptionGroup("", (ValueError("hello"),))
|
||||
assert excinfo.type is ExceptionGroup
|
||||
assert excinfo.value.exceptions[0].args == ("hello",)
|
||||
assert isinstance(excinfo.tb, TracebackType)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,272 @@
|
||||
import time
|
||||
from typing import Awaitable, Callable, Protocol, TypeVar
|
||||
|
||||
import outcome
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
|
||||
from .. import _core
|
||||
from .._core._tests.tutil import slow
|
||||
from .._timeouts import (
|
||||
TooSlowError,
|
||||
fail_after,
|
||||
fail_at,
|
||||
move_on_after,
|
||||
move_on_at,
|
||||
sleep,
|
||||
sleep_forever,
|
||||
sleep_until,
|
||||
)
|
||||
from ..testing import assert_checkpoints
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def check_takes_about(f: Callable[[], Awaitable[T]], expected_dur: float) -> T:
|
||||
start = time.perf_counter()
|
||||
result = await outcome.acapture(f)
|
||||
dur = time.perf_counter() - start
|
||||
print(dur / expected_dur)
|
||||
# 1.5 is an arbitrary fudge factor because there's always some delay
|
||||
# between when we become eligible to wake up and when we actually do. We
|
||||
# used to sleep for 0.05, and regularly observed overruns of 1.6x on
|
||||
# Appveyor, and then started seeing overruns of 2.3x on Travis's macOS, so
|
||||
# now we bumped up the sleep to 1 second, marked the tests as slow, and
|
||||
# hopefully now the proportional error will be less huge.
|
||||
#
|
||||
# We also also for durations that are a hair shorter than expected. For
|
||||
# example, here's a run on Windows where a 1.0 second sleep was measured
|
||||
# to take 0.9999999999999858 seconds:
|
||||
# https://ci.appveyor.com/project/njsmith/trio/build/1.0.768/job/3lbdyxl63q3h9s21
|
||||
# I believe that what happened here is that Windows's low clock resolution
|
||||
# meant that our calls to time.monotonic() returned exactly the same
|
||||
# values as the calls inside the actual run loop, but the two subtractions
|
||||
# returned slightly different values because the run loop's clock adds a
|
||||
# random floating point offset to both times, which should cancel out, but
|
||||
# lol floating point we got slightly different rounding errors. (That
|
||||
# value above is exactly 128 ULPs below 1.0, which would make sense if it
|
||||
# started as a 1 ULP error at a different dynamic range.)
|
||||
assert (1 - 1e-8) <= (dur / expected_dur) < 1.5
|
||||
|
||||
return result.unwrap()
|
||||
|
||||
|
||||
# How long to (attempt to) sleep for when testing. Smaller numbers make the
|
||||
# test suite go faster.
|
||||
TARGET = 1.0
|
||||
|
||||
|
||||
@slow
|
||||
async def test_sleep() -> None:
|
||||
async def sleep_1() -> None:
|
||||
await sleep_until(_core.current_time() + TARGET)
|
||||
|
||||
await check_takes_about(sleep_1, TARGET)
|
||||
|
||||
async def sleep_2() -> None:
|
||||
await sleep(TARGET)
|
||||
|
||||
await check_takes_about(sleep_2, TARGET)
|
||||
|
||||
with assert_checkpoints():
|
||||
await sleep(0)
|
||||
# This also serves as a test of the trivial move_on_at
|
||||
with move_on_at(_core.current_time()):
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await sleep(0)
|
||||
|
||||
|
||||
@slow
|
||||
async def test_move_on_after() -> None:
|
||||
async def sleep_3() -> None:
|
||||
with move_on_after(TARGET):
|
||||
await sleep(100)
|
||||
|
||||
await check_takes_about(sleep_3, TARGET)
|
||||
|
||||
|
||||
async def test_cannot_wake_sleep_forever() -> None:
|
||||
# Test an error occurs if you manually wake sleep_forever().
|
||||
task = trio.lowlevel.current_task()
|
||||
|
||||
async def wake_task() -> None:
|
||||
await trio.lowlevel.checkpoint()
|
||||
trio.lowlevel.reschedule(task, outcome.Value(None))
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(wake_task)
|
||||
with pytest.raises(RuntimeError):
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
class TimeoutScope(Protocol):
|
||||
def __call__(self, seconds: float, *, shield: bool) -> trio.CancelScope: ...
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scope", [move_on_after, fail_after])
|
||||
async def test_context_shields_from_outer(scope: TimeoutScope) -> None:
|
||||
with _core.CancelScope() as outer, scope(TARGET, shield=True) as inner:
|
||||
outer.cancel()
|
||||
try:
|
||||
await trio.lowlevel.checkpoint()
|
||||
except trio.Cancelled:
|
||||
pytest.fail("shield didn't work")
|
||||
inner.shield = False
|
||||
with pytest.raises(trio.Cancelled):
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
@slow
|
||||
async def test_move_on_after_moves_on_even_if_shielded() -> None:
|
||||
async def task() -> None:
|
||||
with _core.CancelScope() as outer, move_on_after(TARGET, shield=True):
|
||||
outer.cancel()
|
||||
# The outer scope is cancelled, but this task is protected by the
|
||||
# shield, so it manages to get to sleep until deadline is met
|
||||
await sleep_forever()
|
||||
|
||||
await check_takes_about(task, TARGET)
|
||||
|
||||
|
||||
@slow
|
||||
async def test_fail_after_fails_even_if_shielded() -> None:
|
||||
async def task() -> None:
|
||||
with pytest.raises(TooSlowError), _core.CancelScope() as outer, fail_after(
|
||||
TARGET,
|
||||
shield=True,
|
||||
):
|
||||
outer.cancel()
|
||||
# The outer scope is cancelled, but this task is protected by the
|
||||
# shield, so it manages to get to sleep until deadline is met
|
||||
await sleep_forever()
|
||||
|
||||
await check_takes_about(task, TARGET)
|
||||
|
||||
|
||||
@slow
|
||||
async def test_fail() -> None:
|
||||
async def sleep_4() -> None:
|
||||
with fail_at(_core.current_time() + TARGET):
|
||||
await sleep(100)
|
||||
|
||||
with pytest.raises(TooSlowError):
|
||||
await check_takes_about(sleep_4, TARGET)
|
||||
|
||||
with fail_at(_core.current_time() + 100):
|
||||
await sleep(0)
|
||||
|
||||
async def sleep_5() -> None:
|
||||
with fail_after(TARGET):
|
||||
await sleep(100)
|
||||
|
||||
with pytest.raises(TooSlowError):
|
||||
await check_takes_about(sleep_5, TARGET)
|
||||
|
||||
with fail_after(100):
|
||||
await sleep(0)
|
||||
|
||||
|
||||
async def test_timeouts_raise_value_error() -> None:
|
||||
# deadlines are allowed to be negative, but not delays.
|
||||
# neither delays nor deadlines are allowed to be NaN
|
||||
|
||||
nan = float("nan")
|
||||
|
||||
for fun, val in (
|
||||
(sleep, -1),
|
||||
(sleep, nan),
|
||||
(sleep_until, nan),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^(deadline|`seconds`) must (not )*be (non-negative|NaN)$",
|
||||
):
|
||||
await fun(val)
|
||||
|
||||
for cm, val in (
|
||||
(fail_after, -1),
|
||||
(fail_after, nan),
|
||||
(fail_at, nan),
|
||||
(move_on_after, -1),
|
||||
(move_on_after, nan),
|
||||
(move_on_at, nan),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^(deadline|`seconds`) must (not )*be (non-negative|NaN)$",
|
||||
):
|
||||
with cm(val):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
async def test_timeout_deadline_on_entry(mock_clock: _core.MockClock) -> None:
|
||||
rcs = move_on_after(5)
|
||||
assert rcs.relative_deadline == 5
|
||||
|
||||
mock_clock.jump(3)
|
||||
start = _core.current_time()
|
||||
with rcs as cs:
|
||||
assert cs.is_relative is None
|
||||
|
||||
# This would previously be start+2
|
||||
assert cs.deadline == start + 5
|
||||
assert cs.relative_deadline == 5
|
||||
|
||||
cs.deadline = start + 3
|
||||
assert cs.deadline == start + 3
|
||||
assert cs.relative_deadline == 3
|
||||
|
||||
cs.relative_deadline = 4
|
||||
assert cs.deadline == start + 4
|
||||
assert cs.relative_deadline == 4
|
||||
|
||||
rcs = move_on_after(5)
|
||||
assert rcs.shield is False
|
||||
rcs.shield = True
|
||||
assert rcs.shield is True
|
||||
|
||||
mock_clock.jump(3)
|
||||
start = _core.current_time()
|
||||
with rcs as cs:
|
||||
assert cs.deadline == start + 5
|
||||
|
||||
assert rcs is cs
|
||||
|
||||
|
||||
async def test_invalid_access_unentered(mock_clock: _core.MockClock) -> None:
|
||||
cs = move_on_after(5)
|
||||
mock_clock.jump(3)
|
||||
start = _core.current_time()
|
||||
|
||||
match_str = "^unentered relative cancel scope does not have an absolute deadline"
|
||||
with pytest.warns(DeprecationWarning, match=match_str):
|
||||
assert cs.deadline == start + 5
|
||||
mock_clock.jump(1)
|
||||
# this is hella sketchy, but they *have* been warned
|
||||
with pytest.warns(DeprecationWarning, match=match_str):
|
||||
assert cs.deadline == start + 6
|
||||
|
||||
with pytest.warns(DeprecationWarning, match=match_str):
|
||||
cs.deadline = 7
|
||||
# now transformed into absolute
|
||||
assert cs.deadline == 7
|
||||
assert not cs.is_relative
|
||||
|
||||
cs = move_on_at(5)
|
||||
|
||||
match_str = (
|
||||
"^unentered non-relative cancel scope does not have a relative deadline$"
|
||||
)
|
||||
with pytest.raises(RuntimeError, match=match_str):
|
||||
assert cs.relative_deadline
|
||||
with pytest.raises(RuntimeError, match=match_str):
|
||||
cs.relative_deadline = 7
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="not implemented")
|
||||
async def test_fail_access_before_entering() -> None: # pragma: no cover
|
||||
my_fail_at = fail_at(5)
|
||||
assert my_fail_at.deadline # type: ignore[attr-defined]
|
||||
my_fail_after = fail_after(5)
|
||||
assert my_fail_after.relative_deadline # type: ignore[attr-defined]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user