Updated script that can be controled by Nodejs web app

This commit is contained in:
mac OS
2024-11-25 12:24:18 +07:00
parent c440eda1f4
commit 8b0ab2bd3a
8662 changed files with 1803808 additions and 34 deletions
@@ -0,0 +1,138 @@
"""Trio - A friendly Python library for async concurrency and I/O
"""
from __future__ import annotations
from typing import TYPE_CHECKING
# General layout:
#
# trio/_core/... is the self-contained core library. It does various
# shenanigans to export a consistent "core API", but parts of the core API are
# too low-level to be recommended for regular use.
#
# trio/*.py define a set of more usable tools on top of this. They import from
# trio._core and from each other.
#
# This file pulls together the friendly public API, by re-exporting the more
# innocuous bits of the _core API + the higher-level tools from trio/*.py.
#
# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625)
#
# must be imported early to avoid circular import
from ._core import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED # isort: split
# Submodules imported by default
from . import abc, from_thread, lowlevel, socket, to_thread
from ._channel import (
MemoryChannelStatistics as MemoryChannelStatistics,
MemoryReceiveChannel as MemoryReceiveChannel,
MemorySendChannel as MemorySendChannel,
open_memory_channel as open_memory_channel,
)
from ._core import (
BrokenResourceError as BrokenResourceError,
BusyResourceError as BusyResourceError,
Cancelled as Cancelled,
CancelScope as CancelScope,
ClosedResourceError as ClosedResourceError,
EndOfChannel as EndOfChannel,
Nursery as Nursery,
RunFinishedError as RunFinishedError,
TaskStatus as TaskStatus,
TrioInternalError as TrioInternalError,
WouldBlock as WouldBlock,
current_effective_deadline as current_effective_deadline,
current_time as current_time,
open_nursery as open_nursery,
run as run,
)
from ._deprecate import TrioDeprecationWarning as TrioDeprecationWarning
from ._dtls import (
DTLSChannel as DTLSChannel,
DTLSChannelStatistics as DTLSChannelStatistics,
DTLSEndpoint as DTLSEndpoint,
)
from ._file_io import open_file as open_file, wrap_file as wrap_file
from ._highlevel_generic import (
StapledStream as StapledStream,
aclose_forcefully as aclose_forcefully,
)
from ._highlevel_open_tcp_listeners import (
open_tcp_listeners as open_tcp_listeners,
serve_tcp as serve_tcp,
)
from ._highlevel_open_tcp_stream import open_tcp_stream as open_tcp_stream
from ._highlevel_open_unix_stream import open_unix_socket as open_unix_socket
from ._highlevel_serve_listeners import serve_listeners as serve_listeners
from ._highlevel_socket import (
SocketListener as SocketListener,
SocketStream as SocketStream,
)
from ._highlevel_ssl_helpers import (
open_ssl_over_tcp_listeners as open_ssl_over_tcp_listeners,
open_ssl_over_tcp_stream as open_ssl_over_tcp_stream,
serve_ssl_over_tcp as serve_ssl_over_tcp,
)
from ._path import Path as Path, PosixPath as PosixPath, WindowsPath as WindowsPath
from ._signals import open_signal_receiver as open_signal_receiver
from ._ssl import (
NeedHandshakeError as NeedHandshakeError,
SSLListener as SSLListener,
SSLStream as SSLStream,
)
from ._subprocess import Process as Process, run_process as run_process
from ._sync import (
CapacityLimiter as CapacityLimiter,
CapacityLimiterStatistics as CapacityLimiterStatistics,
Condition as Condition,
ConditionStatistics as ConditionStatistics,
Event as Event,
EventStatistics as EventStatistics,
Lock as Lock,
LockStatistics as LockStatistics,
Semaphore as Semaphore,
StrictFIFOLock as StrictFIFOLock,
)
from ._timeouts import (
TooSlowError as TooSlowError,
fail_after as fail_after,
fail_at as fail_at,
move_on_after as move_on_after,
move_on_at as move_on_at,
sleep as sleep,
sleep_forever as sleep_forever,
sleep_until as sleep_until,
)
# pyright explicitly does not care about `__version__`
# see https://github.com/microsoft/pyright/blob/main/docs/typed-libraries.md#type-completeness
from ._version import __version__
# Not imported by default, but mentioned here so static analysis tools like
# pylint will know that it exists.
if TYPE_CHECKING:
from . import testing
from . import _deprecate as _deprecate
_deprecate.enable_attribute_deprecations(__name__)
__deprecated_attributes__: dict[str, _deprecate.DeprecatedAttribute] = {}
# Having the public path in .__module__ attributes is important for:
# - exception names in printed tracebacks
# - sphinx :show-inheritance:
# - deprecation warnings
# - pickle
# - probably other stuff
from ._util import fixup_module_metadata
fixup_module_metadata(__name__, globals())
fixup_module_metadata(lowlevel.__name__, lowlevel.__dict__)
fixup_module_metadata(socket.__name__, socket.__dict__)
fixup_module_metadata(abc.__name__, abc.__dict__)
fixup_module_metadata(from_thread.__name__, from_thread.__dict__)
fixup_module_metadata(to_thread.__name__, to_thread.__dict__)
del fixup_module_metadata
del TYPE_CHECKING
@@ -0,0 +1,3 @@
from trio._repl import main
main(locals())
+716
View File
@@ -0,0 +1,716 @@
from __future__ import annotations
import socket
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generic, TypeVar
import trio
if TYPE_CHECKING:
from types import TracebackType
from typing_extensions import Self
# both of these introduce circular imports if outside a TYPE_CHECKING guard
from ._socket import SocketType
from .lowlevel import Task
class Clock(ABC):
"""The interface for custom run loop clocks."""
__slots__ = ()
@abstractmethod
def start_clock(self) -> None:
"""Do any setup this clock might need.
Called at the beginning of the run.
"""
@abstractmethod
def current_time(self) -> float:
"""Return the current time, according to this clock.
This is used to implement functions like :func:`trio.current_time` and
:func:`trio.move_on_after`.
Returns:
float: The current time.
"""
@abstractmethod
def deadline_to_sleep_time(self, deadline: float) -> float:
"""Compute the real time until the given deadline.
This is called before we enter a system-specific wait function like
:func:`select.select`, to get the timeout to pass.
For a clock using wall-time, this should be something like::
return deadline - self.current_time()
but of course it may be different if you're implementing some kind of
virtual clock.
Args:
deadline (float): The absolute time of the next deadline,
according to this clock.
Returns:
float: The number of real seconds to sleep until the given
deadline. May be :data:`math.inf`.
"""
class Instrument(ABC): # noqa: B024 # conceptually is ABC
"""The interface for run loop instrumentation.
Instruments don't have to inherit from this abstract base class, and all
of these methods are optional. This class serves mostly as documentation.
"""
__slots__ = ()
def before_run(self) -> None:
"""Called at the beginning of :func:`trio.run`."""
return
def after_run(self) -> None:
"""Called just before :func:`trio.run` returns."""
return
def task_spawned(self, task: Task) -> None:
"""Called when the given task is created.
Args:
task (trio.lowlevel.Task): The new task.
"""
return
def task_scheduled(self, task: Task) -> None:
"""Called when the given task becomes runnable.
It may still be some time before it actually runs, if there are other
runnable tasks ahead of it.
Args:
task (trio.lowlevel.Task): The task that became runnable.
"""
return
def before_task_step(self, task: Task) -> None:
"""Called immediately before we resume running the given task.
Args:
task (trio.lowlevel.Task): The task that is about to run.
"""
return
def after_task_step(self, task: Task) -> None:
"""Called when we return to the main run loop after a task has yielded.
Args:
task (trio.lowlevel.Task): The task that just ran.
"""
return
def task_exited(self, task: Task) -> None:
"""Called when the given task exits.
Args:
task (trio.lowlevel.Task): The finished task.
"""
return
def before_io_wait(self, timeout: float) -> None:
"""Called before blocking to wait for I/O readiness.
Args:
timeout (float): The number of seconds we are willing to wait.
"""
return
def after_io_wait(self, timeout: float) -> None:
"""Called after handling pending I/O.
Args:
timeout (float): The number of seconds we were willing to
wait. This much time may or may not have elapsed, depending on
whether any I/O was ready.
"""
return
class HostnameResolver(ABC):
"""If you have a custom hostname resolver, then implementing
:class:`HostnameResolver` allows you to register this to be used by Trio.
See :func:`trio.socket.set_custom_hostname_resolver`.
"""
__slots__ = ()
@abstractmethod
async def getaddrinfo(
self,
host: bytes | None,
port: bytes | str | int | None,
family: int = 0,
type: int = 0,
proto: int = 0,
flags: int = 0,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
"""A custom implementation of :func:`~trio.socket.getaddrinfo`.
Called by :func:`trio.socket.getaddrinfo`.
If ``host`` is given as a numeric IP address, then
:func:`~trio.socket.getaddrinfo` may handle the request itself rather
than calling this method.
Any required IDNA encoding is handled before calling this function;
your implementation can assume that it will never see U-labels like
``"café.com"``, and only needs to handle A-labels like
``b"xn--caf-dma.com"``.
"""
@abstractmethod
async def getnameinfo(
self,
sockaddr: tuple[str, int] | tuple[str, int, int, int],
flags: int,
) -> tuple[str, str]:
"""A custom implementation of :func:`~trio.socket.getnameinfo`.
Called by :func:`trio.socket.getnameinfo`.
"""
class SocketFactory(ABC):
"""If you write a custom class implementing the Trio socket interface,
then you can use a :class:`SocketFactory` to get Trio to use it.
See :func:`trio.socket.set_custom_socket_factory`.
"""
__slots__ = ()
@abstractmethod
def socket(
self,
family: socket.AddressFamily | int = socket.AF_INET,
type: socket.SocketKind | int = socket.SOCK_STREAM,
proto: int = 0,
) -> SocketType:
"""Create and return a socket object.
Your socket object must inherit from :class:`trio.socket.SocketType`,
which is an empty class whose only purpose is to "mark" which classes
should be considered valid Trio sockets.
Called by :func:`trio.socket.socket`.
Note that unlike :func:`trio.socket.socket`, this does not take a
``fileno=`` argument. If a ``fileno=`` is specified, then
:func:`trio.socket.socket` returns a regular Trio socket object
instead of calling this method.
"""
class AsyncResource(ABC):
"""A standard interface for resources that needs to be cleaned up, and
where that cleanup may require blocking operations.
This class distinguishes between "graceful" closes, which may perform I/O
and thus block, and a "forceful" close, which cannot. For example, cleanly
shutting down a TLS-encrypted connection requires sending a "goodbye"
message; but if a peer has become non-responsive, then sending this
message might block forever, so we may want to just drop the connection
instead. Therefore the :meth:`aclose` method is unusual in that it
should always close the connection (or at least make its best attempt)
*even if it fails*; failure indicates a failure to achieve grace, not a
failure to close the connection.
Objects that implement this interface can be used as async context
managers, i.e., you can write::
async with create_resource() as some_async_resource:
...
Entering the context manager is synchronous (not a checkpoint); exiting it
calls :meth:`aclose`. The default implementations of
``__aenter__`` and ``__aexit__`` should be adequate for all subclasses.
"""
__slots__ = ()
@abstractmethod
async def aclose(self) -> None:
"""Close this resource, possibly blocking.
IMPORTANT: This method may block in order to perform a "graceful"
shutdown. But, if this fails, then it still *must* close any
underlying resources before returning. An error from this method
indicates a failure to achieve grace, *not* a failure to close the
connection.
For example, suppose we call :meth:`aclose` on a TLS-encrypted
connection. This requires sending a "goodbye" message; but if the peer
has become non-responsive, then our attempt to send this message might
block forever, and eventually time out and be cancelled. In this case
the :meth:`aclose` method on :class:`~trio.SSLStream` will
immediately close the underlying transport stream using
:func:`trio.aclose_forcefully` before raising :exc:`~trio.Cancelled`.
If the resource is already closed, then this method should silently
succeed.
Once this method completes, any other pending or future operations on
this resource should generally raise :exc:`~trio.ClosedResourceError`,
unless there's a good reason to do otherwise.
See also: :func:`trio.aclose_forcefully`.
"""
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
await self.aclose()
class SendStream(AsyncResource):
"""A standard interface for sending data on a byte stream.
The underlying stream may be unidirectional, or bidirectional. If it's
bidirectional, then you probably want to also implement
:class:`ReceiveStream`, which makes your object a :class:`Stream`.
:class:`SendStream` objects also implement the :class:`AsyncResource`
interface, so they can be closed by calling :meth:`~AsyncResource.aclose`
or using an ``async with`` block.
If you want to send Python objects rather than raw bytes, see
:class:`SendChannel`.
"""
__slots__ = ()
@abstractmethod
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
"""Sends the given data through the stream, blocking if necessary.
Args:
data (bytes, bytearray, or memoryview): The data to send.
Raises:
trio.BusyResourceError: if another task is already executing a
:meth:`send_all`, :meth:`wait_send_all_might_not_block`, or
:meth:`HalfCloseableStream.send_eof` on this stream.
trio.BrokenResourceError: if something has gone wrong, and the stream
is broken.
trio.ClosedResourceError: if you previously closed this stream
object, or if another task closes this stream object while
:meth:`send_all` is running.
Most low-level operations in Trio provide a guarantee: if they raise
:exc:`trio.Cancelled`, this means that they had no effect, so the
system remains in a known state. This is **not true** for
:meth:`send_all`. If this operation raises :exc:`trio.Cancelled` (or
any other exception for that matter), then it may have sent some, all,
or none of the requested data, and there is no way to know which.
"""
@abstractmethod
async def wait_send_all_might_not_block(self) -> None:
"""Block until it's possible that :meth:`send_all` might not block.
This method may return early: it's possible that after it returns,
:meth:`send_all` will still block. (In the worst case, if no better
implementation is available, then it might always return immediately
without blocking. It's nice to do better than that when possible,
though.)
This method **must not** return *late*: if it's possible for
:meth:`send_all` to complete without blocking, then it must
return. When implementing it, err on the side of returning early.
Raises:
trio.BusyResourceError: if another task is already executing a
:meth:`send_all`, :meth:`wait_send_all_might_not_block`, or
:meth:`HalfCloseableStream.send_eof` on this stream.
trio.BrokenResourceError: if something has gone wrong, and the stream
is broken.
trio.ClosedResourceError: if you previously closed this stream
object, or if another task closes this stream object while
:meth:`wait_send_all_might_not_block` is running.
Note:
This method is intended to aid in implementing protocols that want
to delay choosing which data to send until the last moment. E.g.,
suppose you're working on an implementation of a remote display server
like `VNC
<https://en.wikipedia.org/wiki/Virtual_Network_Computing>`__, and
the network connection is currently backed up so that if you call
:meth:`send_all` now then it will sit for 0.5 seconds before actually
sending anything. In this case it doesn't make sense to take a
screenshot, then wait 0.5 seconds, and then send it, because the
screen will keep changing while you wait; it's better to wait 0.5
seconds, then take the screenshot, and then send it, because this
way the data you deliver will be more
up-to-date. Using :meth:`wait_send_all_might_not_block` makes it
possible to implement the better strategy.
If you use this method, you might also want to read up on
``TCP_NOTSENT_LOWAT``.
Further reading:
* `Prioritization Only Works When There's Pending Data to Prioritize
<https://insouciant.org/tech/prioritization-only-works-when-theres-pending-data-to-prioritize/>`__
* WWDC 2015: Your App and Next Generation Networks: `slides
<http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1>`__,
`video and transcript
<https://developer.apple.com/videos/play/wwdc2015/719/>`__
"""
class ReceiveStream(AsyncResource):
"""A standard interface for receiving data on a byte stream.
The underlying stream may be unidirectional, or bidirectional. If it's
bidirectional, then you probably want to also implement
:class:`SendStream`, which makes your object a :class:`Stream`.
:class:`ReceiveStream` objects also implement the :class:`AsyncResource`
interface, so they can be closed by calling :meth:`~AsyncResource.aclose`
or using an ``async with`` block.
If you want to receive Python objects rather than raw bytes, see
:class:`ReceiveChannel`.
`ReceiveStream` objects can be used in ``async for`` loops. Each iteration
will produce an arbitrary sized chunk of bytes, like calling
`receive_some` with no arguments. Every chunk will contain at least one
byte, and the loop automatically exits when reaching end-of-file.
"""
__slots__ = ()
@abstractmethod
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
"""Wait until there is data available on this stream, and then return
some of it.
A return value of ``b""`` (an empty bytestring) indicates that the
stream has reached end-of-file. Implementations should be careful that
they return ``b""`` if, and only if, the stream has reached
end-of-file!
Args:
max_bytes (int): The maximum number of bytes to return. Must be
greater than zero. Optional; if omitted, then the stream object
is free to pick a reasonable default.
Returns:
bytes or bytearray: The data received.
Raises:
trio.BusyResourceError: if two tasks attempt to call
:meth:`receive_some` on the same stream at the same time.
trio.BrokenResourceError: if something has gone wrong, and the stream
is broken.
trio.ClosedResourceError: if you previously closed this stream
object, or if another task closes this stream object while
:meth:`receive_some` is running.
"""
def __aiter__(self) -> Self:
return self
async def __anext__(self) -> bytes | bytearray:
data = await self.receive_some()
if not data:
raise StopAsyncIteration
return data
class Stream(SendStream, ReceiveStream):
"""A standard interface for interacting with bidirectional byte streams.
A :class:`Stream` is an object that implements both the
:class:`SendStream` and :class:`ReceiveStream` interfaces.
If implementing this interface, you should consider whether you can go one
step further and implement :class:`HalfCloseableStream`.
"""
__slots__ = ()
class HalfCloseableStream(Stream):
"""This interface extends :class:`Stream` to also allow closing the send
part of the stream without closing the receive part.
"""
__slots__ = ()
@abstractmethod
async def send_eof(self) -> None:
"""Send an end-of-file indication on this stream, if possible.
The difference between :meth:`send_eof` and
:meth:`~AsyncResource.aclose` is that :meth:`send_eof` is a
*unidirectional* end-of-file indication. After you call this method,
you shouldn't try sending any more data on this stream, and your
remote peer should receive an end-of-file indication (eventually,
after receiving all the data you sent before that). But, they may
continue to send data to you, and you can continue to receive it by
calling :meth:`~ReceiveStream.receive_some`. You can think of it as
calling :meth:`~AsyncResource.aclose` on just the
:class:`SendStream` "half" of the stream object (and in fact that's
literally how :class:`trio.StapledStream` implements it).
Examples:
* On a socket, this corresponds to ``shutdown(..., SHUT_WR)`` (`man
page <https://linux.die.net/man/2/shutdown>`__).
* The SSH protocol provides the ability to multiplex bidirectional
"channels" on top of a single encrypted connection. A Trio
implementation of SSH could expose these channels as
:class:`HalfCloseableStream` objects, and calling :meth:`send_eof`
would send an ``SSH_MSG_CHANNEL_EOF`` request (see `RFC 4254 §5.3
<https://tools.ietf.org/html/rfc4254#section-5.3>`__).
* On an SSL/TLS-encrypted connection, the protocol doesn't provide any
way to do a unidirectional shutdown without closing the connection
entirely, so :class:`~trio.SSLStream` implements
:class:`Stream`, not :class:`HalfCloseableStream`.
If an EOF has already been sent, then this method should silently
succeed.
Raises:
trio.BusyResourceError: if another task is already executing a
:meth:`~SendStream.send_all`,
:meth:`~SendStream.wait_send_all_might_not_block`, or
:meth:`send_eof` on this stream.
trio.BrokenResourceError: if something has gone wrong, and the stream
is broken.
trio.ClosedResourceError: if you previously closed this stream
object, or if another task closes this stream object while
:meth:`send_eof` is running.
"""
# A regular invariant generic type
T = TypeVar("T")
# The type of object produced by a ReceiveChannel (covariant because
# ReceiveChannel[Derived] can be passed to someone expecting
# ReceiveChannel[Base])
ReceiveType = TypeVar("ReceiveType", covariant=True)
# The type of object accepted by a SendChannel (contravariant because
# SendChannel[Base] can be passed to someone expecting
# SendChannel[Derived])
SendType = TypeVar("SendType", contravariant=True)
# The type of object produced by a Listener (covariant plus must be
# an AsyncResource)
T_resource = TypeVar("T_resource", bound=AsyncResource, covariant=True)
class Listener(AsyncResource, Generic[T_resource]):
"""A standard interface for listening for incoming connections.
:class:`Listener` objects also implement the :class:`AsyncResource`
interface, so they can be closed by calling :meth:`~AsyncResource.aclose`
or using an ``async with`` block.
"""
__slots__ = ()
@abstractmethod
async def accept(self) -> T_resource:
"""Wait until an incoming connection arrives, and then return it.
Returns:
AsyncResource: An object representing the incoming connection. In
practice this is generally some kind of :class:`Stream`,
but in principle you could also define a :class:`Listener` that
returned, say, channel objects.
Raises:
trio.BusyResourceError: if two tasks attempt to call
:meth:`accept` on the same listener at the same time.
trio.ClosedResourceError: if you previously closed this listener
object, or if another task closes this listener object while
:meth:`accept` is running.
Listeners don't generally raise :exc:`~trio.BrokenResourceError`,
because for listeners there is no general condition of "the
network/remote peer broke the connection" that can be handled in a
generic way, like there is for streams. Other errors *can* occur and
be raised from :meth:`accept` for example, if you run out of file
descriptors then you might get an :class:`OSError` with its errno set
to ``EMFILE``.
"""
class SendChannel(AsyncResource, Generic[SendType]):
"""A standard interface for sending Python objects to some receiver.
`SendChannel` objects also implement the `AsyncResource` interface, so
they can be closed by calling `~AsyncResource.aclose` or using an ``async
with`` block.
If you want to send raw bytes rather than Python objects, see
`SendStream`.
"""
__slots__ = ()
@abstractmethod
async def send(self, value: SendType) -> None:
"""Attempt to send an object through the channel, blocking if necessary.
Args:
value (object): The object to send.
Raises:
trio.BrokenResourceError: if something has gone wrong, and the
channel is broken. For example, you may get this if the receiver
has already been closed.
trio.ClosedResourceError: if you previously closed this
:class:`SendChannel` object, or if another task closes it while
:meth:`send` is running.
trio.BusyResourceError: some channels allow multiple tasks to call
`send` at the same time, but others don't. If you try to call
`send` simultaneously from multiple tasks on a channel that
doesn't support it, then you can get `~trio.BusyResourceError`.
"""
class ReceiveChannel(AsyncResource, Generic[ReceiveType]):
"""A standard interface for receiving Python objects from some sender.
You can iterate over a :class:`ReceiveChannel` using an ``async for``
loop::
async for value in receive_channel:
...
This is equivalent to calling :meth:`receive` repeatedly. The loop exits
without error when `receive` raises `~trio.EndOfChannel`.
`ReceiveChannel` objects also implement the `AsyncResource` interface, so
they can be closed by calling `~AsyncResource.aclose` or using an ``async
with`` block.
If you want to receive raw bytes rather than Python objects, see
`ReceiveStream`.
"""
__slots__ = ()
@abstractmethod
async def receive(self) -> ReceiveType:
"""Attempt to receive an incoming object, blocking if necessary.
Returns:
object: Whatever object was received.
Raises:
trio.EndOfChannel: if the sender has been closed cleanly, and no
more objects are coming. This is not an error condition.
trio.ClosedResourceError: if you previously closed this
:class:`ReceiveChannel` object.
trio.BrokenResourceError: if something has gone wrong, and the
channel is broken.
trio.BusyResourceError: some channels allow multiple tasks to call
`receive` at the same time, but others don't. If you try to call
`receive` simultaneously from multiple tasks on a channel that
doesn't support it, then you can get `~trio.BusyResourceError`.
"""
def __aiter__(self) -> Self:
return self
async def __anext__(self) -> ReceiveType:
try:
return await self.receive()
except trio.EndOfChannel:
raise StopAsyncIteration from None
# these are necessary for Sphinx's :show-inheritance: with type args.
# (this should be removed if possible)
# see: https://github.com/python/cpython/issues/123250
SendChannel.__module__ = SendChannel.__module__.replace("_abc", "abc")
ReceiveChannel.__module__ = ReceiveChannel.__module__.replace("_abc", "abc")
Listener.__module__ = Listener.__module__.replace("_abc", "abc")
class Channel(SendChannel[T], ReceiveChannel[T]):
"""A standard interface for interacting with bidirectional channels.
A `Channel` is an object that implements both the `SendChannel` and
`ReceiveChannel` interfaces, so you can both send and receive objects.
"""
__slots__ = ()
# see above
Channel.__module__ = Channel.__module__.replace("_abc", "abc")
@@ -0,0 +1,444 @@
from __future__ import annotations
from collections import OrderedDict, deque
from math import inf
from typing import (
TYPE_CHECKING,
Generic,
Tuple, # only needed for typechecking on <3.9
)
import attrs
from outcome import Error, Value
import trio
from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T
from ._core import Abort, RaiseCancelT, Task, enable_ki_protection
from ._util import NoPublicConstructor, final, generic_function
if TYPE_CHECKING:
from types import TracebackType
from typing_extensions import Self
def _open_memory_channel(
max_buffer_size: int | float, # noqa: PYI041
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
"""Open a channel for passing objects between tasks within a process.
Memory channels are lightweight, cheap to allocate, and entirely
in-memory. They don't involve any operating-system resources, or any kind
of serialization. They just pass Python objects directly between tasks
(with a possible stop in an internal buffer along the way).
Channel objects can be closed by calling `~trio.abc.AsyncResource.aclose`
or using ``async with``. They are *not* automatically closed when garbage
collected. Closing memory channels isn't mandatory, but it is generally a
good idea, because it helps avoid situations where tasks get stuck waiting
on a channel when there's no-one on the other side. See
:ref:`channel-shutdown` for details.
Memory channel operations are all atomic with respect to
cancellation, either `~trio.abc.ReceiveChannel.receive` will
successfully return an object, or it will raise :exc:`Cancelled`
while leaving the channel unchanged.
Args:
max_buffer_size (int or math.inf): The maximum number of items that can
be buffered in the channel before :meth:`~trio.abc.SendChannel.send`
blocks. Choosing a sensible value here is important to ensure that
backpressure is communicated promptly and avoid unnecessary latency;
see :ref:`channel-buffering` for more details. If in doubt, use 0.
Returns:
A pair ``(send_channel, receive_channel)``. If you have
trouble remembering which order these go in, remember: data
flows from left → right.
In addition to the standard channel methods, all memory channel objects
provide a ``statistics()`` method, which returns an object with the
following fields:
* ``current_buffer_used``: The number of items currently stored in the
channel buffer.
* ``max_buffer_size``: The maximum number of items allowed in the buffer,
as passed to :func:`open_memory_channel`.
* ``open_send_channels``: The number of open
:class:`MemorySendChannel` endpoints pointing to this channel.
Initially 1, but can be increased by
:meth:`MemorySendChannel.clone`.
* ``open_receive_channels``: Likewise, but for open
:class:`MemoryReceiveChannel` endpoints.
* ``tasks_waiting_send``: The number of tasks blocked in ``send`` on this
channel (summing over all clones).
* ``tasks_waiting_receive``: The number of tasks blocked in ``receive`` on
this channel (summing over all clones).
"""
if max_buffer_size != inf and not isinstance(max_buffer_size, int):
raise TypeError("max_buffer_size must be an integer or math.inf")
if max_buffer_size < 0:
raise ValueError("max_buffer_size must be >= 0")
state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size)
return (
MemorySendChannel[T]._create(state),
MemoryReceiveChannel[T]._create(state),
)
# This workaround requires python3.9+, once older python versions are not supported
# or there's a better way of achieving type-checking on a generic factory function,
# it could replace the normal function header
if TYPE_CHECKING:
# written as a class so you can say open_memory_channel[int](5)
# Need to use Tuple instead of tuple due to CI check running on 3.8
class open_memory_channel(Tuple["MemorySendChannel[T]", "MemoryReceiveChannel[T]"]):
def __new__( # type: ignore[misc] # "must return a subtype"
cls,
max_buffer_size: int | float, # noqa: PYI041
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
return _open_memory_channel(max_buffer_size)
def __init__(self, max_buffer_size: int | float): # noqa: PYI041
...
else:
# apply the generic_function decorator to make open_memory_channel indexable
# so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime
open_memory_channel = generic_function(_open_memory_channel)
@attrs.frozen
class MemoryChannelStatistics:
current_buffer_used: int
max_buffer_size: int | float
open_send_channels: int
open_receive_channels: int
tasks_waiting_send: int
tasks_waiting_receive: int
@attrs.define
class MemoryChannelState(Generic[T]):
max_buffer_size: int | float
data: deque[T] = attrs.Factory(deque)
# Counts of open endpoints using this state
open_send_channels: int = 0
open_receive_channels: int = 0
# {task: value}
send_tasks: OrderedDict[Task, T] = attrs.Factory(OrderedDict)
# {task: None}
receive_tasks: OrderedDict[Task, None] = attrs.Factory(OrderedDict)
def statistics(self) -> MemoryChannelStatistics:
return MemoryChannelStatistics(
current_buffer_used=len(self.data),
max_buffer_size=self.max_buffer_size,
open_send_channels=self.open_send_channels,
open_receive_channels=self.open_receive_channels,
tasks_waiting_send=len(self.send_tasks),
tasks_waiting_receive=len(self.receive_tasks),
)
@final
@attrs.define(eq=False, repr=False, slots=False)
class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[SendType]
_closed: bool = False
# This is just the tasks waiting on *this* object. As compared to
# self._state.send_tasks, which includes tasks from this object and
# all clones.
_tasks: set[Task] = attrs.Factory(set)
def __attrs_post_init__(self) -> None:
self._state.open_send_channels += 1
def __repr__(self) -> str:
return f"<send channel at {id(self):#x}, using buffer at {id(self._state):#x}>"
def statistics(self) -> MemoryChannelStatistics:
"""Returns a `MemoryChannelStatistics` for the memory channel this is
associated with."""
# XX should we also report statistics specific to this object?
return self._state.statistics()
@enable_ki_protection
def send_nowait(self, value: SendType) -> None:
"""Like `~trio.abc.SendChannel.send`, but if the channel's buffer is
full, raises `WouldBlock` instead of blocking.
"""
if self._closed:
raise trio.ClosedResourceError
if self._state.open_receive_channels == 0:
raise trio.BrokenResourceError
if self._state.receive_tasks:
assert not self._state.data
task, _ = self._state.receive_tasks.popitem(last=False)
task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task, Value(value))
elif len(self._state.data) < self._state.max_buffer_size:
self._state.data.append(value)
else:
raise trio.WouldBlock
@enable_ki_protection
async def send(self, value: SendType) -> None:
"""See `SendChannel.send <trio.abc.SendChannel.send>`.
Memory channels allow multiple tasks to call `send` at the same time.
"""
await trio.lowlevel.checkpoint_if_cancelled()
try:
self.send_nowait(value)
except trio.WouldBlock:
pass
else:
await trio.lowlevel.cancel_shielded_checkpoint()
return
task = trio.lowlevel.current_task()
self._tasks.add(task)
self._state.send_tasks[task] = value
task.custom_sleep_data = self
def abort_fn(_: RaiseCancelT) -> Abort:
self._tasks.remove(task)
del self._state.send_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED
await trio.lowlevel.wait_task_rescheduled(abort_fn)
# Return type must be stringified or use a TypeVar
@enable_ki_protection
def clone(self) -> MemorySendChannel[SendType]:
"""Clone this send channel object.
This returns a new `MemorySendChannel` object, which acts as a
duplicate of the original: sending on the new object does exactly the
same thing as sending on the old object. (If you're familiar with
`os.dup`, then this is a similar idea.)
However, closing one of the objects does not close the other, and
receivers don't get `EndOfChannel` until *all* clones have been
closed.
This is useful for communication patterns that involve multiple
producers all sending objects to the same destination. If you give
each producer its own clone of the `MemorySendChannel`, and then make
sure to close each `MemorySendChannel` when it's finished, receivers
will automatically get notified when all producers are finished. See
:ref:`channel-mpmc` for examples.
Raises:
trio.ClosedResourceError: if you already closed this
`MemorySendChannel` object.
"""
if self._closed:
raise trio.ClosedResourceError
return MemorySendChannel._create(self._state)
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.close()
@enable_ki_protection
def close(self) -> None:
"""Close this send channel object synchronously.
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Memory channels can also be closed synchronously. This has the same
effect on the channel and other tasks using it, but `close` is not a
trio checkpoint. This simplifies cleaning up in cancelled tasks.
Using ``with send_channel:`` will close the channel object on leaving
the with block.
"""
if self._closed:
return
self._closed = True
for task in self._tasks:
trio.lowlevel.reschedule(task, Error(trio.ClosedResourceError()))
del self._state.send_tasks[task]
self._tasks.clear()
self._state.open_send_channels -= 1
if self._state.open_send_channels == 0:
assert not self._state.send_tasks
for task in self._state.receive_tasks:
task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task, Error(trio.EndOfChannel()))
self._state.receive_tasks.clear()
@enable_ki_protection
async def aclose(self) -> None:
"""Close this send channel object asynchronously.
See `MemorySendChannel.close`."""
self.close()
await trio.lowlevel.checkpoint()
@final
@attrs.define(eq=False, repr=False, slots=False)
class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[ReceiveType]
_closed: bool = False
_tasks: set[trio._core._run.Task] = attrs.Factory(set)
def __attrs_post_init__(self) -> None:
self._state.open_receive_channels += 1
def statistics(self) -> MemoryChannelStatistics:
"""Returns a `MemoryChannelStatistics` for the memory channel this is
associated with."""
return self._state.statistics()
def __repr__(self) -> str:
return (
f"<receive channel at {id(self):#x}, using buffer at {id(self._state):#x}>"
)
@enable_ki_protection
def receive_nowait(self) -> ReceiveType:
"""Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing
ready to receive, raises `WouldBlock` instead of blocking.
"""
if self._closed:
raise trio.ClosedResourceError
if self._state.send_tasks:
task, value = self._state.send_tasks.popitem(last=False)
task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task)
self._state.data.append(value)
# Fall through
if self._state.data:
return self._state.data.popleft()
if not self._state.open_send_channels:
raise trio.EndOfChannel
raise trio.WouldBlock
@enable_ki_protection
async def receive(self) -> ReceiveType:
"""See `ReceiveChannel.receive <trio.abc.ReceiveChannel.receive>`.
Memory channels allow multiple tasks to call `receive` at the same
time. The first task will get the first item sent, the second task
will get the second item sent, and so on.
"""
await trio.lowlevel.checkpoint_if_cancelled()
try:
value = self.receive_nowait()
except trio.WouldBlock:
pass
else:
await trio.lowlevel.cancel_shielded_checkpoint()
return value
task = trio.lowlevel.current_task()
self._tasks.add(task)
self._state.receive_tasks[task] = None
task.custom_sleep_data = self
def abort_fn(_: RaiseCancelT) -> Abort:
self._tasks.remove(task)
del self._state.receive_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED
# Not strictly guaranteed to return ReceiveType, but will do so unless
# you intentionally reschedule with a bad value.
return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore[no-any-return]
@enable_ki_protection
def clone(self) -> MemoryReceiveChannel[ReceiveType]:
"""Clone this receive channel object.
This returns a new `MemoryReceiveChannel` object, which acts as a
duplicate of the original: receiving on the new object does exactly
the same thing as receiving on the old object.
However, closing one of the objects does not close the other, and the
underlying channel is not closed until all clones are closed. (If
you're familiar with `os.dup`, then this is a similar idea.)
This is useful for communication patterns that involve multiple
consumers all receiving objects from the same underlying channel. See
:ref:`channel-mpmc` for examples.
.. warning:: The clones all share the same underlying channel.
Whenever a clone :meth:`receive`\\s a value, it is removed from the
channel and the other clones do *not* receive that value. If you
want to send multiple copies of the same stream of values to
multiple destinations, like :func:`itertools.tee`, then you need to
find some other solution; this method does *not* do that.
Raises:
trio.ClosedResourceError: if you already closed this
`MemoryReceiveChannel` object.
"""
if self._closed:
raise trio.ClosedResourceError
return MemoryReceiveChannel._create(self._state)
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.close()
@enable_ki_protection
def close(self) -> None:
"""Close this receive channel object synchronously.
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Memory channels can also be closed synchronously. This has the same
effect on the channel and other tasks using it, but `close` is not a
trio checkpoint. This simplifies cleaning up in cancelled tasks.
Using ``with receive_channel:`` will close the channel object on
leaving the with block.
"""
if self._closed:
return
self._closed = True
for task in self._tasks:
trio.lowlevel.reschedule(task, Error(trio.ClosedResourceError()))
del self._state.receive_tasks[task]
self._tasks.clear()
self._state.open_receive_channels -= 1
if self._state.open_receive_channels == 0:
assert not self._state.receive_tasks
for task in self._state.send_tasks:
task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task, Error(trio.BrokenResourceError()))
self._state.send_tasks.clear()
self._state.data.clear()
@enable_ki_protection
async def aclose(self) -> None:
"""Close this receive channel object asynchronously.
See `MemoryReceiveChannel.close`."""
self.close()
await trio.lowlevel.checkpoint()
@@ -0,0 +1,87 @@
"""
This namespace represents the core functionality that has to be built-in
and deal with private internal data structures. Things in this namespace
are publicly available in either trio, trio.lowlevel, or trio.testing.
"""
import sys
from ._entry_queue import TrioToken
from ._exceptions import (
BrokenResourceError,
BusyResourceError,
Cancelled,
ClosedResourceError,
EndOfChannel,
RunFinishedError,
TrioInternalError,
WouldBlock,
)
from ._ki import currently_ki_protected, disable_ki_protection, enable_ki_protection
from ._local import RunVar, RunVarToken
from ._mock_clock import MockClock
from ._parking_lot import (
ParkingLot,
ParkingLotStatistics,
add_parking_lot_breaker,
remove_parking_lot_breaker,
)
# Imports that always exist
from ._run import (
TASK_STATUS_IGNORED,
CancelScope,
Nursery,
RunStatistics,
Task,
TaskStatus,
add_instrument,
checkpoint,
checkpoint_if_cancelled,
current_clock,
current_effective_deadline,
current_root_task,
current_statistics,
current_task,
current_time,
current_trio_token,
notify_closing,
open_nursery,
remove_instrument,
reschedule,
run,
spawn_system_task,
start_guest_run,
wait_all_tasks_blocked,
wait_readable,
wait_writable,
)
from ._thread_cache import start_thread_soon
# Has to come after _run to resolve a circular import
from ._traps import (
Abort,
RaiseCancelT,
cancel_shielded_checkpoint,
permanently_detach_coroutine_object,
reattach_detached_coroutine_object,
temporarily_detach_coroutine_object,
wait_task_rescheduled,
)
from ._unbounded_queue import UnboundedQueue, UnboundedQueueStatistics
# Windows imports
if sys.platform == "win32":
from ._run import (
current_iocp,
monitor_completion_key,
readinto_overlapped,
register_with_iocp,
wait_overlapped,
write_overlapped,
)
# Kqueue imports
elif sys.platform != "linux" and sys.platform != "win32":
from ._run import current_kqueue, monitor_kevent, wait_kevent
del sys # It would be better to import sys as _sys, but mypy does not understand it
@@ -0,0 +1,216 @@
from __future__ import annotations
import logging
import sys
import warnings
import weakref
from typing import TYPE_CHECKING, NoReturn
import attrs
from .. import _core
from .._util import name_asyncgen
from . import _run
# Used to log exceptions in async generator finalizers
ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors")
if TYPE_CHECKING:
from types import AsyncGeneratorType
from typing import Set
_WEAK_ASYNC_GEN_SET = weakref.WeakSet[AsyncGeneratorType[object, NoReturn]]
_ASYNC_GEN_SET = Set[AsyncGeneratorType[object, NoReturn]]
else:
_WEAK_ASYNC_GEN_SET = weakref.WeakSet
_ASYNC_GEN_SET = set
@attrs.define(eq=False)
class AsyncGenerators:
# Async generators are added to this set when first iterated. Any
# left after the main task exits will be closed before trio.run()
# returns. During most of the run, this is a WeakSet so GC works.
# During shutdown, when we're finalizing all the remaining
# asyncgens after the system nursery has been closed, it's a
# regular set so we don't have to deal with GC firing at
# unexpected times.
alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attrs.Factory(_WEAK_ASYNC_GEN_SET)
# This collects async generators that get garbage collected during
# the one-tick window between the system nursery closing and the
# init task starting end-of-run asyncgen finalization.
trailing_needs_finalize: _ASYNC_GEN_SET = attrs.Factory(_ASYNC_GEN_SET)
prev_hooks: sys._asyncgen_hooks = attrs.field(init=False)
def install_hooks(self, runner: _run.Runner) -> None:
def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None:
if hasattr(_run.GLOBAL_RUN_CONTEXT, "task"):
self.alive.add(agen)
else:
# An async generator first iterated outside of a Trio
# task doesn't belong to Trio. Probably we're in guest
# mode and the async generator belongs to our host.
# The locals dictionary is the only good place to
# remember this fact, at least until
# https://bugs.python.org/issue40916 is implemented.
agen.ag_frame.f_locals["@trio_foreign_asyncgen"] = True
if self.prev_hooks.firstiter is not None:
self.prev_hooks.firstiter(agen)
def finalize_in_trio_context(
agen: AsyncGeneratorType[object, NoReturn],
agen_name: str,
) -> None:
try:
runner.spawn_system_task(
self._finalize_one,
agen,
agen_name,
name=f"close asyncgen {agen_name} (abandoned)",
)
except RuntimeError:
# There is a one-tick window where the system nursery
# is closed but the init task hasn't yet made
# self.asyncgens a strong set to disable GC. We seem to
# have hit it.
self.trailing_needs_finalize.add(agen)
def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None:
agen_name = name_asyncgen(agen)
try:
is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen")
except AttributeError: # pragma: no cover
is_ours = True
if is_ours:
runner.entry_queue.run_sync_soon(
finalize_in_trio_context,
agen,
agen_name,
)
# Do this last, because it might raise an exception
# depending on the user's warnings filter. (That
# exception will be printed to the terminal and
# ignored, since we're running in GC context.)
warnings.warn(
f"Async generator {agen_name!r} was garbage collected before it "
"had been exhausted. Surround its use in 'async with "
"aclosing(...):' to ensure that it gets cleaned up as soon as "
"you're done using it.",
ResourceWarning,
stacklevel=2,
source=agen,
)
else:
# Not ours -> forward to the host loop's async generator finalizer
if self.prev_hooks.finalizer is not None:
self.prev_hooks.finalizer(agen)
else:
# Host has no finalizer. Reimplement the default
# Python behavior with no hooks installed: throw in
# GeneratorExit, step once, raise RuntimeError if
# it doesn't exit.
closer = agen.aclose()
try:
# If the next thing is a yield, this will raise RuntimeError
# which we allow to propagate
closer.send(None)
except StopIteration:
pass
else:
# If the next thing is an await, we get here. Give a nicer
# error than the default "async generator ignored GeneratorExit"
raise RuntimeError(
f"Non-Trio async generator {agen_name!r} awaited something "
"during finalization; install a finalization hook to "
"support this, or wrap it in 'async with aclosing(...):'",
)
self.prev_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(firstiter=firstiter, finalizer=finalizer) # type: ignore[arg-type] # Finalizer doesn't use AsyncGeneratorType
async def finalize_remaining(self, runner: _run.Runner) -> None:
# This is called from init after shutting down the system nursery.
# The only tasks running at this point are init and
# the run_sync_soon task, and since the system nursery is closed,
# there's no way for user code to spawn more.
assert _core.current_task() is runner.init_task
assert len(runner.tasks) == 2
# To make async generator finalization easier to reason
# about, we'll shut down asyncgen garbage collection by turning
# the alive WeakSet into a regular set.
self.alive = set(self.alive)
# Process all pending run_sync_soon callbacks, in case one of
# them was an asyncgen finalizer that snuck in under the wire.
runner.entry_queue.run_sync_soon(runner.reschedule, runner.init_task)
await _core.wait_task_rescheduled(
lambda _: _core.Abort.FAILED, # pragma: no cover
)
self.alive.update(self.trailing_needs_finalize)
self.trailing_needs_finalize.clear()
# None of the still-living tasks use async generators, so
# every async generator must be suspended at a yield point --
# there's no one to be doing the iteration. That's good,
# because aclose() only works on an asyncgen that's suspended
# at a yield point. (If it's suspended at an event loop trap,
# because someone is in the middle of iterating it, then you
# get a RuntimeError on 3.8+, and a nasty surprise on earlier
# versions due to https://bugs.python.org/issue32526.)
#
# However, once we start aclose() of one async generator, it
# might start fetching the next value from another, thus
# preventing us from closing that other (at least until
# aclose() of the first one is complete). This constraint
# effectively requires us to finalize the remaining asyncgens
# in arbitrary order, rather than doing all of them at the
# same time. On 3.8+ we could defer any generator with
# ag_running=True to a later batch, but that only catches
# the case where our aclose() starts after the user's
# asend()/etc. If our aclose() starts first, then the
# user's asend()/etc will raise RuntimeError, since they're
# probably not checking ag_running.
#
# It might be possible to allow some parallelized cleanup if
# we can determine that a certain set of asyncgens have no
# interdependencies, using gc.get_referents() and such.
# But just doing one at a time will typically work well enough
# (since each aclose() executes in a cancelled scope) and
# is much easier to reason about.
# It's possible that that cleanup code will itself create
# more async generators, so we iterate repeatedly until
# all are gone.
while self.alive:
batch = self.alive
self.alive = _ASYNC_GEN_SET()
for agen in batch:
await self._finalize_one(agen, name_asyncgen(agen))
def close(self) -> None:
sys.set_asyncgen_hooks(*self.prev_hooks)
async def _finalize_one(
self,
agen: AsyncGeneratorType[object, NoReturn],
name: object,
) -> None:
try:
# This shield ensures that finalize_asyncgen never exits
# with an exception, not even a Cancelled. The inside
# is cancelled so there's no deadlock risk.
with _core.CancelScope(shield=True) as cancel_scope:
cancel_scope.cancel()
await agen.aclose()
except BaseException:
ASYNCGEN_LOGGER.exception(
"Exception ignored during finalization of async generator %r -- "
"surround your use of the generator in 'async with aclosing(...):' "
"to raise exceptions like this in the context where they're generated",
name,
)
@@ -0,0 +1,130 @@
from __future__ import annotations
from types import TracebackType
from typing import Any, ClassVar, cast
################################################################
# concat_tb
################################################################
# We need to compute a new traceback that is the concatenation of two existing
# tracebacks. This requires copying the entries in 'head' and then pointing
# the final tb_next to 'tail'.
#
# NB: 'tail' might be None, which requires some special handling in the ctypes
# version.
#
# The complication here is that Python doesn't actually support copying or
# modifying traceback objects, so we have to get creative...
#
# On CPython, we use ctypes. On PyPy, we use "transparent proxies".
#
# Jinja2 is a useful source of inspiration:
# https://github.com/pallets/jinja/blob/main/src/jinja2/debug.py
try:
import tputil
except ImportError:
# ctypes it is
# How to handle refcounting? I don't want to use ctypes.py_object because
# I don't understand or trust it, and I don't want to use
# ctypes.pythonapi.Py_{Inc,Dec}Ref because we might clash with user code
# that also tries to use them but with different types. So private _ctypes
# APIs it is!
import _ctypes
import ctypes
class CTraceback(ctypes.Structure):
_fields_: ClassVar = [
("PyObject_HEAD", ctypes.c_byte * object().__sizeof__()),
("tb_next", ctypes.c_void_p),
("tb_frame", ctypes.c_void_p),
("tb_lasti", ctypes.c_int),
("tb_lineno", ctypes.c_int),
]
def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType:
# TracebackType has no public constructor, so allocate one the hard way
try:
raise ValueError
except ValueError as exc:
new_tb = exc.__traceback__
assert new_tb is not None
c_new_tb = CTraceback.from_address(id(new_tb))
# At the C level, tb_next either points to the next traceback or is
# NULL. c_void_p and the .tb_next accessor both convert NULL to None,
# but we shouldn't DECREF None just because we assigned to a NULL
# pointer! Here we know that our new traceback has only 1 frame in it,
# so we can assume the tb_next field is NULL.
assert c_new_tb.tb_next is None
# If tb_next is None, then we want to set c_new_tb.tb_next to NULL,
# which it already is, so we're done. Otherwise, we have to actually
# do some work:
if tb_next is not None:
_ctypes.Py_INCREF(tb_next) # type: ignore[attr-defined]
c_new_tb.tb_next = id(tb_next)
assert c_new_tb.tb_frame is not None
_ctypes.Py_INCREF(base_tb.tb_frame) # type: ignore[attr-defined]
old_tb_frame = new_tb.tb_frame
c_new_tb.tb_frame = id(base_tb.tb_frame)
_ctypes.Py_DECREF(old_tb_frame) # type: ignore[attr-defined]
c_new_tb.tb_lasti = base_tb.tb_lasti
c_new_tb.tb_lineno = base_tb.tb_lineno
try:
return new_tb
finally:
# delete references from locals to avoid creating cycles
# see test_cancel_scope_exit_doesnt_create_cyclic_garbage
del new_tb, old_tb_frame
else:
# http://doc.pypy.org/en/latest/objspace-proxies.html
def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType:
# tputil.ProxyOperation is PyPy-only, and there's no way to specify
# cpython/pypy in current type checkers.
def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[no-any-unimported]
# Rationale for pragma: I looked fairly carefully and tried a few
# things, and AFAICT it's not actually possible to get any
# 'opname' that isn't __getattr__ or __getattribute__. So there's
# no missing test we could add, and no value in coverage nagging
# us about adding one.
if (
operation.opname
in {
"__getattribute__",
"__getattr__",
}
and operation.args[0] == "tb_next"
): # pragma: no cover
return tb_next
return operation.delegate() # Delegate is reverting to original behaviour
return cast(
TracebackType,
tputil.make_proxy(controller, type(base_tb), base_tb),
) # Returns proxy to traceback
# this is used for collapsing single-exception ExceptionGroups when using
# `strict_exception_groups=False`. Once that is retired this function and its helper can
# be removed as well.
def concat_tb(
head: TracebackType | None,
tail: TracebackType | None,
) -> TracebackType | None:
# We have to use an iterative algorithm here, because in the worst case
# this might be a RecursionError stack that is by definition too deep to
# process by recursion!
head_tbs = []
pointer = head
while pointer is not None:
head_tbs.append(pointer)
pointer = pointer.tb_next
current_head = tail
for head_tb in reversed(head_tbs):
current_head = copy_tb(head_tb, tb_next=current_head)
return current_head
@@ -0,0 +1,220 @@
from __future__ import annotations
import threading
from collections import deque
from typing import TYPE_CHECKING, Callable, NoReturn, Tuple
import attrs
from .. import _core
from .._util import NoPublicConstructor, final
from ._wakeup_socketpair import WakeupSocketpair
if TYPE_CHECKING:
from typing_extensions import TypeVarTuple, Unpack
PosArgsT = TypeVarTuple("PosArgsT")
Function = Callable[..., object]
Job = Tuple[Function, Tuple[object, ...]]
@attrs.define
class EntryQueue:
# This used to use a queue.Queue. but that was broken, because Queues are
# implemented in Python, and not reentrant -- so it was thread-safe, but
# not signal-safe. deque is implemented in C, so each operation is atomic
# WRT threads (and this is guaranteed in the docs), AND each operation is
# atomic WRT signal delivery (signal handlers can run on either side, but
# not *during* a deque operation). dict makes similar guarantees - and
# it's even ordered!
queue: deque[Job] = attrs.Factory(deque)
idempotent_queue: dict[Job, None] = attrs.Factory(dict)
wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair)
done: bool = False
# Must be a reentrant lock, because it's acquired from signal handlers.
# RLock is signal-safe as of cpython 3.2. NB that this does mean that the
# lock is effectively *disabled* when we enter from signal context. The
# way we use the lock this is OK though, because when
# run_sync_soon is called from a signal it's atomic WRT the
# main thread -- it just might happen at some inconvenient place. But if
# you look at the one place where the main thread holds the lock, it's
# just to make 1 assignment, so that's atomic WRT a signal anyway.
lock: threading.RLock = attrs.Factory(threading.RLock)
async def task(self) -> None:
assert _core.currently_ki_protected()
# RLock has two implementations: a signal-safe version in _thread, and
# and signal-UNsafe version in threading. We need the signal safe
# version. Python 3.2 and later should always use this anyway, but,
# since the symptoms if this goes wrong are just "weird rare
# deadlocks", then let's make a little check.
# See:
# https://bugs.python.org/issue13697#msg237140
assert self.lock.__class__.__module__ == "_thread"
def run_cb(job: Job) -> None:
# We run this with KI protection enabled; it's the callback's
# job to disable it if it wants it disabled. Exceptions are
# treated like system task exceptions (i.e., converted into
# TrioInternalError and cause everything to shut down).
sync_fn, args = job
try:
sync_fn(*args)
except BaseException as exc:
async def kill_everything(exc: BaseException) -> NoReturn:
raise exc
try:
_core.spawn_system_task(kill_everything, exc)
except RuntimeError:
# We're quite late in the shutdown process and the
# system nursery is already closed.
# TODO(2020-06): this is a gross hack and should
# be fixed soon when we address #1607.
parent_nursery = _core.current_task().parent_nursery
if parent_nursery is None:
raise AssertionError(
"Internal error: `parent_nursery` should never be `None`",
) from exc # pragma: no cover
parent_nursery.start_soon(kill_everything, exc)
# This has to be carefully written to be safe in the face of new items
# being queued while we iterate, and to do a bounded amount of work on
# each pass:
def run_all_bounded() -> None:
for _ in range(len(self.queue)):
run_cb(self.queue.popleft())
for job in list(self.idempotent_queue):
del self.idempotent_queue[job]
run_cb(job)
try:
while True:
run_all_bounded()
if not self.queue and not self.idempotent_queue:
await self.wakeup.wait_woken()
else:
await _core.checkpoint()
except _core.Cancelled:
# Keep the work done with this lock held as minimal as possible,
# because it doesn't protect us against concurrent signal delivery
# (see the comment above). Notice that this code would still be
# correct if written like:
# self.done = True
# with self.lock:
# pass
# because all we want is to force run_sync_soon
# to either be completely before or completely after the write to
# done. That's why we don't need the lock to protect
# against signal handlers.
with self.lock:
self.done = True
# No more jobs will be submitted, so just clear out any residual
# ones:
run_all_bounded()
assert not self.queue
assert not self.idempotent_queue
def close(self) -> None:
self.wakeup.close()
def size(self) -> int:
return len(self.queue) + len(self.idempotent_queue)
def run_sync_soon(
self,
sync_fn: Callable[[Unpack[PosArgsT]], object],
*args: Unpack[PosArgsT],
idempotent: bool = False,
) -> None:
with self.lock:
if self.done:
raise _core.RunFinishedError("run() has exited")
# We have to hold the lock all the way through here, because
# otherwise the main thread might exit *while* we're doing these
# calls, and then our queue item might not be processed, or the
# wakeup call might trigger an OSError b/c the IO manager has
# already been shut down.
if idempotent:
self.idempotent_queue[(sync_fn, args)] = None
else:
self.queue.append((sync_fn, args))
self.wakeup.wakeup_thread_and_signal_safe()
@final
@attrs.define(eq=False)
class TrioToken(metaclass=NoPublicConstructor):
"""An opaque object representing a single call to :func:`trio.run`.
It has no public constructor; instead, see :func:`current_trio_token`.
This object has two uses:
1. It lets you re-enter the Trio run loop from external threads or signal
handlers. This is the low-level primitive that :func:`trio.to_thread`
and `trio.from_thread` use to communicate with worker threads, that
`trio.open_signal_receiver` uses to receive notifications about
signals, and so forth.
2. Each call to :func:`trio.run` has exactly one associated
:class:`TrioToken` object, so you can use it to identify a particular
call.
"""
_reentry_queue: EntryQueue
def run_sync_soon(
self,
sync_fn: Callable[[Unpack[PosArgsT]], object],
*args: Unpack[PosArgsT],
idempotent: bool = False,
) -> None:
"""Schedule a call to ``sync_fn(*args)`` to occur in the context of a
Trio task.
This is safe to call from the main thread, from other threads, and
from signal handlers. This is the fundamental primitive used to
re-enter the Trio run loop from outside of it.
The call will happen "soon", but there's no guarantee about exactly
when, and no mechanism provided for finding out when it's happened.
If you need this, you'll have to build your own.
The call is effectively run as part of a system task (see
:func:`~trio.lowlevel.spawn_system_task`). In particular this means
that:
* :exc:`KeyboardInterrupt` protection is *enabled* by default; if
you want ``sync_fn`` to be interruptible by control-C, then you
need to use :func:`~trio.lowlevel.disable_ki_protection`
explicitly.
* If ``sync_fn`` raises an exception, then it's converted into a
:exc:`~trio.TrioInternalError` and *all* tasks are cancelled. You
should be careful that ``sync_fn`` doesn't crash.
All calls with ``idempotent=False`` are processed in strict
first-in first-out order.
If ``idempotent=True``, then ``sync_fn`` and ``args`` must be
hashable, and Trio will make a best-effort attempt to discard any
call submission which is equal to an already-pending call. Trio
will process these in first-in first-out order.
Any ordering guarantees apply separately to ``idempotent=False``
and ``idempotent=True`` calls; there's no rule for how calls in the
different categories are ordered with respect to each other.
:raises trio.RunFinishedError:
if the associated call to :func:`trio.run`
has already exited. (Any call that *doesn't* raise this error
is guaranteed to be fully processed before :func:`trio.run`
exits.)
"""
self._reentry_queue.run_sync_soon(sync_fn, *args, idempotent=idempotent)
@@ -0,0 +1,113 @@
from trio._util import NoPublicConstructor, final
class TrioInternalError(Exception):
"""Raised by :func:`run` if we encounter a bug in Trio, or (possibly) a
misuse of one of the low-level :mod:`trio.lowlevel` APIs.
This should never happen! If you get this error, please file a bug.
Unfortunately, if you get this error it also means that all bets are off
Trio doesn't know what is going on and its normal invariants may be void.
(For example, we might have "lost track" of a task. Or lost track of all
tasks.) Again, though, this shouldn't happen.
"""
class RunFinishedError(RuntimeError):
"""Raised by `trio.from_thread.run` and similar functions if the
corresponding call to :func:`trio.run` has already finished.
"""
class WouldBlock(Exception):
"""Raised by ``X_nowait`` functions if ``X`` would block."""
@final
class Cancelled(BaseException, metaclass=NoPublicConstructor):
"""Raised by blocking calls if the surrounding scope has been cancelled.
You should let this exception propagate, to be caught by the relevant
cancel scope. To remind you of this, it inherits from :exc:`BaseException`
instead of :exc:`Exception`, just like :exc:`KeyboardInterrupt` and
:exc:`SystemExit` do. This means that if you write something like::
try:
...
except Exception:
...
then this *won't* catch a :exc:`Cancelled` exception.
You cannot raise :exc:`Cancelled` yourself. Attempting to do so
will produce a :exc:`TypeError`. Use :meth:`cancel_scope.cancel()
<trio.CancelScope.cancel>` instead.
.. note::
In the US it's also common to see this word spelled "canceled", with
only one "l". This is a `recent
<https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=5&smoothing=3&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__
and `US-specific
<https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=18&smoothing=3&share=&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__
innovation, and even in the US both forms are still commonly used. So
for consistency with the rest of the world and with "cancellation"
(which always has two "l"s), Trio uses the two "l" spelling
everywhere.
"""
def __str__(self) -> str:
return "Cancelled"
class BusyResourceError(Exception):
"""Raised when a task attempts to use a resource that some other task is
already using, and this would lead to bugs and nonsense.
For example, if two tasks try to send data through the same socket at the
same time, Trio will raise :class:`BusyResourceError` instead of letting
the data get scrambled.
"""
class ClosedResourceError(Exception):
"""Raised when attempting to use a resource after it has been closed.
Note that "closed" here means that *your* code closed the resource,
generally by calling a method with a name like ``close`` or ``aclose``, or
by exiting a context manager. If a problem arises elsewhere for example,
because of a network failure, or because a remote peer closed their end of
a connection then that should be indicated by a different exception
class, like :exc:`BrokenResourceError` or an :exc:`OSError` subclass.
"""
class BrokenResourceError(Exception):
"""Raised when an attempt to use a resource fails due to external
circumstances.
For example, you might get this if you try to send data on a stream where
the remote side has already closed the connection.
You *don't* get this error if *you* closed the resource in that case you
get :class:`ClosedResourceError`.
This exception's ``__cause__`` attribute will often contain more
information about the underlying error.
"""
class EndOfChannel(Exception):
"""Raised when trying to receive from a :class:`trio.abc.ReceiveChannel`
that has no more data to receive.
This is analogous to an "end-of-file" condition, but for channels.
"""
@@ -0,0 +1,51 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from __future__ import annotations
import sys
from typing import TYPE_CHECKING
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
from ._instrumentation import Instrument
__all__ = ["add_instrument", "remove_instrument"]
def add_instrument(instrument: Instrument) -> None:
"""Start instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to activate.
If ``instrument`` is already active, does nothing.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def remove_instrument(instrument: Instrument) -> None:
"""Stop instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to de-activate.
Raises:
KeyError: if the instrument is not currently active. This could
occur either because you never added it, or because you added it
and then it raised an unhandled exception and was automatically
deactivated.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument)
except AttributeError:
raise RuntimeError("must be called from async context") from None
@@ -0,0 +1,98 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from __future__ import annotations
import sys
from typing import TYPE_CHECKING
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
from .._file_io import _HasFileNo
assert not TYPE_CHECKING or sys.platform == "linux"
__all__ = ["notify_closing", "wait_readable", "wait_writable"]
async def wait_readable(fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is readable.
On Unix systems, ``fd`` must either be an integer file descriptor,
or else an object with a ``.fileno()`` method which returns an
integer file descriptor. Any kind of file descriptor can be passed,
though the exact semantics will depend on your kernel. For example,
this probably won't do anything useful for on-disk files.
On Windows systems, ``fd`` must either be an integer ``SOCKET``
handle, or else an object with a ``.fileno()`` method which returns
an integer ``SOCKET`` handle. File descriptors aren't supported,
and neither are handles that refer to anything besides a
``SOCKET``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become readable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_writable(fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is writable.
See `wait_readable` for the definition of ``fd``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become writable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def notify_closing(fd: int | _HasFileNo) -> None:
"""Notify waiters of the given object that it will be closed.
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~trio.ClosedResourceError`.
This doesn't actually close the object you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:
1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.
It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None
@@ -0,0 +1,156 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Callable, ContextManager
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
import select
from .. import _core
from .._file_io import _HasFileNo
from ._traps import Abort, RaiseCancelT
assert not TYPE_CHECKING or sys.platform == "darwin"
__all__ = [
"current_kqueue",
"monitor_kevent",
"notify_closing",
"wait_kevent",
"wait_readable",
"wait_writable",
]
def current_kqueue() -> select.kqueue:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def monitor_kevent(
ident: int,
filter: int,
) -> ContextManager[_core.UnboundedQueue[select.kevent]]:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_kevent(
ident: int,
filter: int,
abort_func: Callable[[RaiseCancelT], Abort],
) -> Abort:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(
ident,
filter,
abort_func,
)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_readable(fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is readable.
On Unix systems, ``fd`` must either be an integer file descriptor,
or else an object with a ``.fileno()`` method which returns an
integer file descriptor. Any kind of file descriptor can be passed,
though the exact semantics will depend on your kernel. For example,
this probably won't do anything useful for on-disk files.
On Windows systems, ``fd`` must either be an integer ``SOCKET``
handle, or else an object with a ``.fileno()`` method which returns
an integer ``SOCKET`` handle. File descriptors aren't supported,
and neither are handles that refer to anything besides a
``SOCKET``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become readable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_writable(fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is writable.
See `wait_readable` for the definition of ``fd``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become writable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def notify_closing(fd: int | _HasFileNo) -> None:
"""Notify waiters of the given object that it will be closed.
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~trio.ClosedResourceError`.
This doesn't actually close the object you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:
1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.
It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None
@@ -0,0 +1,209 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, ContextManager
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
from typing_extensions import Buffer
from .._file_io import _HasFileNo
from ._unbounded_queue import UnboundedQueue
from ._windows_cffi import CData, Handle
assert not TYPE_CHECKING or sys.platform == "win32"
__all__ = [
"current_iocp",
"monitor_completion_key",
"notify_closing",
"readinto_overlapped",
"register_with_iocp",
"wait_overlapped",
"wait_readable",
"wait_writable",
"write_overlapped",
]
async def wait_readable(sock: _HasFileNo | int) -> None:
"""Block until the kernel reports that the given object is readable.
On Unix systems, ``sock`` must either be an integer file descriptor,
or else an object with a ``.fileno()`` method which returns an
integer file descriptor. Any kind of file descriptor can be passed,
though the exact semantics will depend on your kernel. For example,
this probably won't do anything useful for on-disk files.
On Windows systems, ``sock`` must either be an integer ``SOCKET``
handle, or else an object with a ``.fileno()`` method which returns
an integer ``SOCKET`` handle. File descriptors aren't supported,
and neither are handles that refer to anything besides a
``SOCKET``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become readable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_writable(sock: _HasFileNo | int) -> None:
"""Block until the kernel reports that the given object is writable.
See `wait_readable` for the definition of ``sock``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become writable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def notify_closing(handle: Handle | int | _HasFileNo) -> None:
"""Notify waiters of the given object that it will be closed.
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~trio.ClosedResourceError`.
This doesn't actually close the object you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:
1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.
It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def register_with_iocp(handle: int | CData) -> None:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_overlapped(handle_: int | CData, lpOverlapped: CData | int) -> object:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped(
handle_,
lpOverlapped,
)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def write_overlapped(
handle: int | CData,
data: Buffer,
file_offset: int = 0,
) -> int:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped(
handle,
data,
file_offset,
)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def readinto_overlapped(
handle: int | CData,
buffer: Buffer,
file_offset: int = 0,
) -> int:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped(
handle,
buffer,
file_offset,
)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def current_iocp() -> int:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def monitor_completion_key() -> ContextManager[tuple[int, UnboundedQueue[object]]]:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key()
except AttributeError:
raise RuntimeError("must be called from async context") from None
@@ -0,0 +1,273 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT, RunStatistics, Task
if TYPE_CHECKING:
import contextvars
from collections.abc import Awaitable, Callable
from outcome import Outcome
from typing_extensions import Unpack
from .._abc import Clock
from ._entry_queue import TrioToken
from ._run import PosArgT
__all__ = [
"current_clock",
"current_root_task",
"current_statistics",
"current_time",
"current_trio_token",
"reschedule",
"spawn_system_task",
"wait_all_tasks_blocked",
]
def current_statistics() -> RunStatistics:
"""Returns ``RunStatistics``, which contains run-loop-level debugging information.
Currently, the following fields are defined:
* ``tasks_living`` (int): The number of tasks that have been spawned
and not yet exited.
* ``tasks_runnable`` (int): The number of tasks that are currently
queued on the run queue (as opposed to blocked waiting for something
to happen).
* ``seconds_to_next_deadline`` (float): The time until the next
pending cancel scope deadline. May be negative if the deadline has
expired but we haven't yet processed cancellations. May be
:data:`~math.inf` if there are no pending deadlines.
* ``run_sync_soon_queue_size`` (int): The number of
unprocessed callbacks queued via
:meth:`trio.lowlevel.TrioToken.run_sync_soon`.
* ``io_statistics`` (object): Some statistics from Trio's I/O
backend. This always has an attribute ``backend`` which is a string
naming which operating-system-specific I/O backend is in use; the
other attributes vary between backends.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_statistics()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def current_time() -> float:
"""Returns the current time according to Trio's internal clock.
Returns:
float: The current time.
Raises:
RuntimeError: if not inside a call to :func:`trio.run`.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_time()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def current_clock() -> Clock:
"""Returns the current :class:`~trio.abc.Clock`."""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_clock()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def current_root_task() -> Task | None:
"""Returns the current root :class:`Task`.
This is the task that is the ultimate parent of all other tasks.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_root_task()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None:
"""Reschedule the given task with the given
:class:`outcome.Outcome`.
See :func:`wait_task_rescheduled` for the gory details.
There must be exactly one call to :func:`reschedule` for every call to
:func:`wait_task_rescheduled`. (And when counting, keep in mind that
returning :data:`Abort.SUCCEEDED` from an abort callback is equivalent
to calling :func:`reschedule` once.)
Args:
task (trio.lowlevel.Task): the task to be rescheduled. Must be blocked
in a call to :func:`wait_task_rescheduled`.
next_send (outcome.Outcome): the value (or error) to return (or
raise) from :func:`wait_task_rescheduled`.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def spawn_system_task(
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
*args: Unpack[PosArgT],
name: object = None,
context: contextvars.Context | None = None,
) -> Task:
"""Spawn a "system" task.
System tasks have a few differences from regular tasks:
* They don't need an explicit nursery; instead they go into the
internal "system nursery".
* If a system task raises an exception, then it's converted into a
:exc:`~trio.TrioInternalError` and *all* tasks are cancelled. If you
write a system task, you should be careful to make sure it doesn't
crash.
* System tasks are automatically cancelled when the main task exits.
* By default, system tasks have :exc:`KeyboardInterrupt` protection
*enabled*. If you want your task to be interruptible by control-C,
then you need to use :func:`disable_ki_protection` explicitly (and
come up with some plan for what to do with a
:exc:`KeyboardInterrupt`, given that system tasks aren't allowed to
raise exceptions).
* System tasks do not inherit context variables from their creator.
Towards the end of a call to :meth:`trio.run`, after the main
task and all system tasks have exited, the system nursery
becomes closed. At this point, new calls to
:func:`spawn_system_task` will raise ``RuntimeError("Nursery
is closed to new arrivals")`` instead of creating a system
task. It's possible to encounter this state either in
a ``finally`` block in an async generator, or in a callback
passed to :meth:`TrioToken.run_sync_soon` at the right moment.
Args:
async_fn: An async callable.
args: Positional arguments for ``async_fn``. If you want to pass
keyword arguments, use :func:`functools.partial`.
name: The name for this task. Only used for debugging/introspection
(e.g. ``repr(task_obj)``). If this isn't a string,
:func:`spawn_system_task` will try to make it one. A common use
case is if you're wrapping a function before spawning a new
task, you might pass the original function as the ``name=`` to
make debugging easier.
context: An optional ``contextvars.Context`` object with context variables
to use for this task. You would normally get a copy of the current
context with ``context = contextvars.copy_context()`` and then you would
pass that ``context`` object here.
Returns:
Task: the newly spawned task
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.spawn_system_task(
async_fn,
*args,
name=name,
context=context,
)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def current_trio_token() -> TrioToken:
"""Retrieve the :class:`TrioToken` for the current call to
:func:`trio.run`.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_trio_token()
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_all_tasks_blocked(cushion: float = 0.0) -> None:
"""Block until there are no runnable tasks.
This is useful in testing code when you want to give other tasks a
chance to "settle down". The calling task is blocked, and doesn't wake
up until all other tasks are also blocked for at least ``cushion``
seconds. (Setting a non-zero ``cushion`` is intended to handle cases
like two tasks talking to each other over a local socket, where we
want to ignore the potential brief moment between a send and receive
when all tasks are blocked.)
Note that ``cushion`` is measured in *real* time, not the Trio clock
time.
If there are multiple tasks blocked in :func:`wait_all_tasks_blocked`,
then the one with the shortest ``cushion`` is the one woken (and
this task becoming unblocked resets the timers for the remaining
tasks). If there are multiple tasks that have exactly the same
``cushion``, then all are woken.
You should also consider :class:`trio.testing.Sequencer`, which
provides a more explicit way to control execution ordering within a
test, and will often produce more readable tests.
Example:
Here's an example of one way to test that Trio's locks are fair: we
take the lock in the parent, start a child, wait for the child to be
blocked waiting for the lock (!), and then check that we can't
release and immediately re-acquire the lock::
async def lock_taker(lock):
await lock.acquire()
lock.release()
async def test_lock_fairness():
lock = trio.Lock()
await lock.acquire()
async with trio.open_nursery() as nursery:
nursery.start_soon(lock_taker, lock)
# child hasn't run yet, we have the lock
assert lock.locked()
assert lock._owner is trio.lowlevel.current_task()
await trio.testing.wait_all_tasks_blocked()
# now the child has run and is blocked on lock.acquire(), we
# still have the lock
assert lock.locked()
assert lock._owner is trio.lowlevel.current_task()
lock.release()
try:
# The child has a prior claim, so we can't have it
lock.acquire_nowait()
except trio.WouldBlock:
assert lock._owner is not trio.lowlevel.current_task()
print("PASS")
else:
print("FAIL")
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion)
except AttributeError:
raise RuntimeError("must be called from async context") from None
@@ -0,0 +1,108 @@
import logging
import types
from typing import Any, Callable, Dict, Sequence, TypeVar
from .._abc import Instrument
# Used to log exceptions in instruments
INSTRUMENT_LOGGER = logging.getLogger("trio.abc.Instrument")
F = TypeVar("F", bound=Callable[..., Any])
# Decorator to mark methods public. This does nothing by itself, but
# trio/_tools/gen_exports.py looks for it.
def _public(fn: F) -> F:
return fn
class Instruments(Dict[str, Dict[Instrument, None]]):
"""A collection of `trio.abc.Instrument` organized by hook.
Instrumentation calls are rather expensive, and we don't want a
rarely-used instrument (like before_run()) to slow down hot
operations (like before_task_step()). Thus, we cache the set of
instruments to be called for each hook, and skip the instrumentation
call if there's nothing currently installed for that hook.
"""
__slots__ = ()
def __init__(self, incoming: Sequence[Instrument]):
self["_all"] = {}
for instrument in incoming:
self.add_instrument(instrument)
@_public
def add_instrument(self, instrument: Instrument) -> None:
"""Start instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to activate.
If ``instrument`` is already active, does nothing.
"""
if instrument in self["_all"]:
return
self["_all"][instrument] = None
try:
for name in dir(instrument):
if name.startswith("_"):
continue
try:
prototype = getattr(Instrument, name)
except AttributeError:
continue
impl = getattr(instrument, name)
if isinstance(impl, types.MethodType) and impl.__func__ is prototype:
# Inherited unchanged from _abc.Instrument
continue
self.setdefault(name, {})[instrument] = None
except:
self.remove_instrument(instrument)
raise
@_public
def remove_instrument(self, instrument: Instrument) -> None:
"""Stop instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to de-activate.
Raises:
KeyError: if the instrument is not currently active. This could
occur either because you never added it, or because you added it
and then it raised an unhandled exception and was automatically
deactivated.
"""
# If instrument isn't present, the KeyError propagates out
self["_all"].pop(instrument)
for hookname, instruments in list(self.items()):
if instrument in instruments:
del instruments[instrument]
if not instruments:
del self[hookname]
def call(self, hookname: str, *args: Any) -> None:
"""Call hookname(*args) on each applicable instrument.
You must first check whether there are any instruments installed for
that hook, e.g.::
if "before_task_step" in instruments:
instruments.call("before_task_step", task)
"""
for instrument in list(self[hookname]):
try:
getattr(instrument, hookname)(*args)
except BaseException:
self.remove_instrument(instrument)
INSTRUMENT_LOGGER.exception(
"Exception raised when calling %r on instrument %r. "
"Instrument has been disabled.",
hookname,
instrument,
)
@@ -0,0 +1,31 @@
from __future__ import annotations
import copy
from typing import TYPE_CHECKING
import outcome
from .. import _core
if TYPE_CHECKING:
from ._io_epoll import EpollWaiters
from ._io_windows import AFDWaiters
# Utility function shared between _io_epoll and _io_windows
def wake_all(waiters: EpollWaiters | AFDWaiters, exc: BaseException) -> None:
try:
current_task = _core.current_task()
except RuntimeError:
current_task = None
raise_at_end = False
for attr_name in ["read_task", "write_task"]:
task = getattr(waiters, attr_name)
if task is not None:
if task is current_task:
raise_at_end = True
else:
_core.reschedule(task, outcome.Error(copy.copy(exc)))
setattr(waiters, attr_name, None)
if raise_at_end:
raise exc
@@ -0,0 +1,387 @@
from __future__ import annotations
import contextlib
import select
import sys
from collections import defaultdict
from typing import TYPE_CHECKING, Literal
import attrs
from .. import _core
from ._io_common import wake_all
from ._run import Task, _public
from ._wakeup_socketpair import WakeupSocketpair
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from .._core import Abort, RaiseCancelT
from .._file_io import _HasFileNo
@attrs.define(eq=False)
class EpollWaiters:
read_task: Task | None = None
write_task: Task | None = None
current_flags: int = 0
assert not TYPE_CHECKING or sys.platform == "linux"
EventResult: TypeAlias = "list[tuple[int, int]]"
@attrs.frozen(eq=False)
class _EpollStatistics:
tasks_waiting_read: int
tasks_waiting_write: int
backend: Literal["epoll"] = attrs.field(init=False, default="epoll")
# Some facts about epoll
# ----------------------
#
# Internally, an epoll object is sort of like a WeakKeyDictionary where the
# keys are tuples of (fd number, file object). When you call epoll_ctl, you
# pass in an fd; that gets converted to an (fd number, file object) tuple by
# looking up the fd in the process's fd table at the time of the call. When an
# event happens on the file object, epoll_wait drops the file object part, and
# just returns the fd number in its event. So from the outside it looks like
# it's keeping a table of fds, but really it's a bit more complicated. This
# has some subtle consequences.
#
# In general, file objects inside the kernel are reference counted. Each entry
# in a process's fd table holds a strong reference to the corresponding file
# object, and most operations that use file objects take a temporary strong
# reference while they're working. So when you call close() on an fd, that
# might or might not cause the file object to be deallocated -- it depends on
# whether there are any other references to that file object. Some common ways
# this can happen:
#
# - after calling dup(), you have two fds in the same process referring to the
# same file object. Even if you close one fd (= remove that entry from the
# fd table), the file object will be kept alive by the other fd.
# - when calling fork(), the child inherits a copy of the parent's fd table,
# so all the file objects get another reference. (But if the fork() is
# followed by exec(), then all of the child's fds that have the CLOEXEC flag
# set will be closed at that point.)
# - most syscalls that work on fds take a strong reference to the underlying
# file object while they're using it. So there's one thread blocked in
# read(fd), and then another thread calls close() on the last fd referring
# to that object, the underlying file won't actually be closed until
# after read() returns.
#
# However, epoll does *not* take a reference to any of the file objects in its
# interest set (that's what makes it similar to a WeakKeyDictionary). File
# objects inside an epoll interest set will be deallocated if all *other*
# references to them are closed. And when that happens, the epoll object will
# automatically deregister that file object and stop reporting events on it.
# So that's quite handy.
#
# But, what happens if we do this?
#
# fd1 = open(...)
# epoll_ctl(EPOLL_CTL_ADD, fd1, ...)
# fd2 = dup(fd1)
# close(fd1)
#
# In this case, the dup() keeps the underlying file object alive, so it
# remains registered in the epoll object's interest set, as the tuple (fd1,
# file object). But, fd1 no longer refers to this file object! You might think
# there was some magic to handle this, but unfortunately no; the consequences
# are totally predictable from what I said above:
#
# If any events occur on the file object, then epoll will report them as
# happening on fd1, even though that doesn't make sense.
#
# Perhaps we would like to deregister fd1 to stop getting nonsensical events.
# But how? When we call epoll_ctl, we have to pass an fd number, which will
# get expanded to an (fd number, file object) tuple. We can't pass fd1,
# because when epoll_ctl tries to look it up, it won't find our file object.
# And we can't pass fd2, because that will get expanded to (fd2, file object),
# which is a different lookup key. In fact, it's *impossible* to de-register
# this fd!
#
# We could even have fd1 get assigned to another file object, and then we can
# have multiple keys registered simultaneously using the same fd number, like:
# (fd1, file object 1), (fd1, file object 2). And if events happen on either
# file object, then epoll will happily report that something happened to
# "fd1".
#
# Now here's what makes this especially nasty: suppose the old file object
# becomes, say, readable. That means that every time we call epoll_wait, it
# will return immediately to tell us that "fd1" is readable. Normally, we
# would handle this by de-registering fd1, waking up the corresponding call to
# wait_readable, then the user will call read() or recv() or something, and
# we're fine. But if this happens on a stale fd where we can't remove the
# registration, then we might get stuck in a state where epoll_wait *always*
# returns immediately, so our event loop becomes unable to sleep, and now our
# program is burning 100% of the CPU doing nothing, with no way out.
#
#
# What does this mean for Trio?
# -----------------------------
#
# Since we don't control the user's code, we have no way to guarantee that we
# don't get stuck with stale fd's in our epoll interest set. For example, a
# user could call wait_readable(fd) in one task, and then while that's
# running, they might close(fd) from another task. In this situation, they're
# *supposed* to call notify_closing(fd) to let us know what's happening, so we
# can interrupt the wait_readable() call and avoid getting into this mess. And
# that's the only thing that can possibly work correctly in all cases. But
# sometimes user code has bugs. So if this does happen, we'd like to degrade
# gracefully, and survive without corrupting Trio's internal state or
# otherwise causing the whole program to explode messily.
#
# Our solution: we always use EPOLLONESHOT. This way, we might get *one*
# spurious event on a stale fd, but then epoll will automatically silence it
# until we explicitly say that we want more events... and if we have a stale
# fd, then we actually can't re-enable it! So we can't get stuck in an
# infinite busy-loop. If there's a stale fd hanging around, then it might
# cause a spurious `BusyResourceError`, or cause one wait_* call to return
# before it should have... but in general, the wait_* functions are allowed to
# have some spurious wakeups; the user code will just attempt the operation,
# get EWOULDBLOCK, and call wait_* again. And the program as a whole will
# survive, any exceptions will propagate, etc.
#
# As a bonus, EPOLLONESHOT also saves us having to explicitly deregister fds
# on the normal wakeup path, so it's a bit more efficient in general.
#
# However, EPOLLONESHOT has a few trade-offs to consider:
#
# First, you can't combine EPOLLONESHOT with EPOLLEXCLUSIVE. This is a bit sad
# in one somewhat rare case: if you have a multi-process server where a group
# of processes all share the same listening socket, then EPOLLEXCLUSIVE can be
# used to avoid "thundering herd" problems when a new connection comes in. But
# this isn't too bad. It's not clear if EPOLLEXCLUSIVE even works for us
# anyway:
#
# https://stackoverflow.com/questions/41582560/how-does-epolls-epollexclusive-mode-interact-with-level-triggering
#
# And it's not clear that EPOLLEXCLUSIVE is a great approach either:
#
# https://blog.cloudflare.com/the-sad-state-of-linux-socket-balancing/
#
# And if we do need to support this, we could always add support through some
# more-specialized API in the future. So this isn't a blocker to using
# EPOLLONESHOT.
#
# Second, EPOLLONESHOT does not actually *deregister* the fd after delivering
# an event (EPOLL_CTL_DEL). Instead, it keeps the fd registered, but
# effectively does an EPOLL_CTL_MOD to set the fd's interest flags to
# all-zeros. So we could still end up with an fd hanging around in the
# interest set for a long time, even if we're not using it.
#
# Fortunately, this isn't a problem, because it's only a weak reference if
# we have a stale fd that's been silenced by EPOLLONESHOT, then it wastes a
# tiny bit of kernel memory remembering this fd that can never be revived, but
# when the underlying file object is eventually closed, that memory will be
# reclaimed. So that's OK.
#
# The other issue is that when someone calls wait_*, using EPOLLONESHOT means
# that if we have ever waited for this fd before, we have to use EPOLL_CTL_MOD
# to re-enable it; but if it's a new fd, we have to use EPOLL_CTL_ADD. How do
# we know which one to use? There's no reasonable way to track which fds are
# currently registered -- remember, we're assuming the user might have gone
# and rearranged their fds without telling us!
#
# Fortunately, this also has a simple solution: if we wait on a socket or
# other fd once, then we'll probably wait on it lots of times. And the epoll
# object itself knows which fds it already has registered. So when an fd comes
# in, we optimistically assume that it's been waited on before, and try doing
# EPOLL_CTL_MOD. And if that fails with an ENOENT error, then we try again
# with EPOLL_CTL_ADD.
#
# So that's why this code is the way it is. And now you know more than you
# wanted to about how epoll works.
@attrs.define(eq=False)
class EpollIOManager:
# Using lambda here because otherwise crash on import with gevent monkey patching
# See https://github.com/python-trio/trio/issues/2848
_epoll: select.epoll = attrs.Factory(lambda: select.epoll())
# {fd: EpollWaiters}
_registered: defaultdict[int, EpollWaiters] = attrs.Factory(
lambda: defaultdict(EpollWaiters),
)
_force_wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair)
_force_wakeup_fd: int | None = None
def __attrs_post_init__(self) -> None:
self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
def statistics(self) -> _EpollStatistics:
tasks_waiting_read = 0
tasks_waiting_write = 0
for waiter in self._registered.values():
if waiter.read_task is not None:
tasks_waiting_read += 1
if waiter.write_task is not None:
tasks_waiting_write += 1
return _EpollStatistics(
tasks_waiting_read=tasks_waiting_read,
tasks_waiting_write=tasks_waiting_write,
)
def close(self) -> None:
self._epoll.close()
self._force_wakeup.close()
def force_wakeup(self) -> None:
self._force_wakeup.wakeup_thread_and_signal_safe()
# Return value must be False-y IFF the timeout expired, NOT if any I/O
# happened or force_wakeup was called. Otherwise it can be anything; gets
# passed straight through to process_events.
def get_events(self, timeout: float) -> EventResult:
# max_events must be > 0 or epoll gets cranky
# accessing self._registered from a thread looks dangerous, but it's
# OK because it doesn't matter if our value is a little bit off.
max_events = max(1, len(self._registered))
return self._epoll.poll(timeout, max_events)
def process_events(self, events: EventResult) -> None:
for fd, flags in events:
if fd == self._force_wakeup_fd:
self._force_wakeup.drain()
continue
waiters = self._registered[fd]
# EPOLLONESHOT always clears the flags when an event is delivered
waiters.current_flags = 0
# Clever hack stolen from selectors.EpollSelector: an event
# with EPOLLHUP or EPOLLERR flags wakes both readers and
# writers.
if flags & ~select.EPOLLIN and waiters.write_task is not None:
_core.reschedule(waiters.write_task)
waiters.write_task = None
if flags & ~select.EPOLLOUT and waiters.read_task is not None:
_core.reschedule(waiters.read_task)
waiters.read_task = None
self._update_registrations(fd)
def _update_registrations(self, fd: int) -> None:
waiters = self._registered[fd]
wanted_flags = 0
if waiters.read_task is not None:
wanted_flags |= select.EPOLLIN
if waiters.write_task is not None:
wanted_flags |= select.EPOLLOUT
if wanted_flags != waiters.current_flags:
try:
try:
# First try EPOLL_CTL_MOD
self._epoll.modify(fd, wanted_flags | select.EPOLLONESHOT)
except OSError:
# If that fails, it might be a new fd; try EPOLL_CTL_ADD
self._epoll.register(fd, wanted_flags | select.EPOLLONESHOT)
waiters.current_flags = wanted_flags
except OSError as exc:
# If everything fails, probably it's a bad fd, e.g. because
# the fd was closed behind our back. In this case we don't
# want to try to unregister the fd, because that will probably
# fail too. Just clear our state and wake everyone up.
del self._registered[fd]
# This could raise (in case we're calling this inside one of
# the to-be-woken tasks), so we have to do it last.
wake_all(waiters, exc)
return
if not wanted_flags:
del self._registered[fd]
async def _epoll_wait(self, fd: int | _HasFileNo, attr_name: str) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
waiters = self._registered[fd]
if getattr(waiters, attr_name) is not None:
raise _core.BusyResourceError(
"another task is already reading / writing this fd",
)
setattr(waiters, attr_name, _core.current_task())
self._update_registrations(fd)
def abort(_: RaiseCancelT) -> Abort:
setattr(waiters, attr_name, None)
self._update_registrations(fd)
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort)
@_public
async def wait_readable(self, fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is readable.
On Unix systems, ``fd`` must either be an integer file descriptor,
or else an object with a ``.fileno()`` method which returns an
integer file descriptor. Any kind of file descriptor can be passed,
though the exact semantics will depend on your kernel. For example,
this probably won't do anything useful for on-disk files.
On Windows systems, ``fd`` must either be an integer ``SOCKET``
handle, or else an object with a ``.fileno()`` method which returns
an integer ``SOCKET`` handle. File descriptors aren't supported,
and neither are handles that refer to anything besides a
``SOCKET``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become readable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
await self._epoll_wait(fd, "read_task")
@_public
async def wait_writable(self, fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is writable.
See `wait_readable` for the definition of ``fd``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become writable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
await self._epoll_wait(fd, "write_task")
@_public
def notify_closing(self, fd: int | _HasFileNo) -> None:
"""Notify waiters of the given object that it will be closed.
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~trio.ClosedResourceError`.
This doesn't actually close the object you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:
1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.
It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.
"""
if not isinstance(fd, int):
fd = fd.fileno()
wake_all(
self._registered[fd],
_core.ClosedResourceError("another task closed this fd"),
)
del self._registered[fd]
with contextlib.suppress(OSError, ValueError):
self._epoll.unregister(fd)
@@ -0,0 +1,292 @@
from __future__ import annotations
import errno
import select
import sys
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Iterator, Literal
import attrs
import outcome
from .. import _core
from ._run import _public
from ._wakeup_socketpair import WakeupSocketpair
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from .._core import Abort, RaiseCancelT, Task, UnboundedQueue
from .._file_io import _HasFileNo
assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")
EventResult: TypeAlias = "list[select.kevent]"
@attrs.frozen(eq=False)
class _KqueueStatistics:
tasks_waiting: int
monitors: int
backend: Literal["kqueue"] = attrs.field(init=False, default="kqueue")
@attrs.define(eq=False)
class KqueueIOManager:
_kqueue: select.kqueue = attrs.Factory(select.kqueue)
# {(ident, filter): Task or UnboundedQueue}
_registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = (
attrs.Factory(dict)
)
_force_wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair)
_force_wakeup_fd: int | None = None
def __attrs_post_init__(self) -> None:
force_wakeup_event = select.kevent(
self._force_wakeup.wakeup_sock,
select.KQ_FILTER_READ,
select.KQ_EV_ADD,
)
self._kqueue.control([force_wakeup_event], 0)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
def statistics(self) -> _KqueueStatistics:
tasks_waiting = 0
monitors = 0
for receiver in self._registered.values():
if type(receiver) is _core.Task:
tasks_waiting += 1
else:
monitors += 1
return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors)
def close(self) -> None:
self._kqueue.close()
self._force_wakeup.close()
def force_wakeup(self) -> None:
self._force_wakeup.wakeup_thread_and_signal_safe()
def get_events(self, timeout: float) -> EventResult:
# max_events must be > 0 or kqueue gets cranky
# and we generally want this to be strictly larger than the actual
# number of events we get, so that we can tell that we've gotten
# all the events in just 1 call.
max_events = len(self._registered) + 1
events = []
while True:
batch = self._kqueue.control([], max_events, timeout)
events += batch
if len(batch) < max_events:
break
else:
timeout = 0
# and loop back to the start
return events
def process_events(self, events: EventResult) -> None:
for event in events:
key = (event.ident, event.filter)
if event.ident == self._force_wakeup_fd:
self._force_wakeup.drain()
continue
receiver = self._registered[key]
if event.flags & select.KQ_EV_ONESHOT:
del self._registered[key]
if isinstance(receiver, _core.Task):
_core.reschedule(receiver, outcome.Value(event))
else:
receiver.put_nowait(event)
# kevent registration is complicated -- e.g. aio submission can
# implicitly perform a EV_ADD, and EVFILT_PROC with NOTE_TRACK will
# automatically register filters for child processes. So our lowlevel
# API is *very* low-level: we expose the kqueue itself for adding
# events or sticking into AIO submission structs, and split waiting
# off into separate methods. It's your responsibility to make sure
# that handle_io never receives an event without a corresponding
# registration! This may be challenging if you want to be careful
# about e.g. KeyboardInterrupt. Possibly this API could be improved to
# be more ergonomic...
@_public
def current_kqueue(self) -> select.kqueue:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
return self._kqueue
@contextmanager
@_public
def monitor_kevent(
self,
ident: int,
filter: int,
) -> Iterator[_core.UnboundedQueue[select.kevent]]:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
"attempt to register multiple listeners for same ident/filter pair",
)
q = _core.UnboundedQueue[select.kevent]()
self._registered[key] = q
try:
yield q
finally:
del self._registered[key]
@_public
async def wait_kevent(
self,
ident: int,
filter: int,
abort_func: Callable[[RaiseCancelT], Abort],
) -> Abort:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
"attempt to register multiple listeners for same ident/filter pair",
)
self._registered[key] = _core.current_task()
def abort(raise_cancel: RaiseCancelT) -> Abort:
r = abort_func(raise_cancel)
if r is _core.Abort.SUCCEEDED:
del self._registered[key]
return r
# wait_task_rescheduled does not have its return type typed
return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return]
async def _wait_common(
self,
fd: int | _HasFileNo,
filter: int,
) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT
event = select.kevent(fd, filter, flags)
self._kqueue.control([event], 0)
def abort(_: RaiseCancelT) -> Abort:
event = select.kevent(fd, filter, select.KQ_EV_DELETE)
try:
self._kqueue.control([event], 0)
except OSError as exc:
# kqueue tracks individual fds (*not* the underlying file
# object, see _io_epoll.py for a long discussion of why this
# distinction matters), and automatically deregisters an event
# if the fd is closed. So if kqueue.control says that it
# doesn't know about this event, then probably it's because
# the fd was closed behind our backs. (Too bad we can't ask it
# to wake us up when this happens, versus discovering it after
# the fact... oh well, you can't have everything.)
#
# FreeBSD reports this using EBADF. macOS uses ENOENT.
if exc.errno in (errno.EBADF, errno.ENOENT): # pragma: no branch
pass
else: # pragma: no cover
# As far as we know, this branch can't happen.
raise
return _core.Abort.SUCCEEDED
await self.wait_kevent(fd, filter, abort)
@_public
async def wait_readable(self, fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is readable.
On Unix systems, ``fd`` must either be an integer file descriptor,
or else an object with a ``.fileno()`` method which returns an
integer file descriptor. Any kind of file descriptor can be passed,
though the exact semantics will depend on your kernel. For example,
this probably won't do anything useful for on-disk files.
On Windows systems, ``fd`` must either be an integer ``SOCKET``
handle, or else an object with a ``.fileno()`` method which returns
an integer ``SOCKET`` handle. File descriptors aren't supported,
and neither are handles that refer to anything besides a
``SOCKET``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become readable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
await self._wait_common(fd, select.KQ_FILTER_READ)
@_public
async def wait_writable(self, fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is writable.
See `wait_readable` for the definition of ``fd``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become writable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
await self._wait_common(fd, select.KQ_FILTER_WRITE)
@_public
def notify_closing(self, fd: int | _HasFileNo) -> None:
"""Notify waiters of the given object that it will be closed.
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~trio.ClosedResourceError`.
This doesn't actually close the object you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:
1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.
It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.
"""
if not isinstance(fd, int):
fd = fd.fileno()
for filter_ in [select.KQ_FILTER_READ, select.KQ_FILTER_WRITE]:
key = (fd, filter_)
receiver = self._registered.get(key)
if receiver is None:
continue
if type(receiver) is _core.Task:
event = select.kevent(fd, filter_, select.KQ_EV_DELETE)
self._kqueue.control([event], 0)
exc = _core.ClosedResourceError("another task closed this fd")
_core.reschedule(receiver, outcome.Error(exc))
del self._registered[key]
else:
# XX this is an interesting example of a case where being able
# to close a queue would be useful...
raise NotImplementedError(
"can't close an fd that monitor_kevent is using",
)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,237 @@
from __future__ import annotations
import inspect
import signal
import sys
from functools import wraps
from typing import TYPE_CHECKING, Final, Protocol, TypeVar
import attrs
from .._util import is_main_thread
CallableT = TypeVar("CallableT", bound="Callable[..., object]")
RetT = TypeVar("RetT")
if TYPE_CHECKING:
import types
from collections.abc import Callable
from typing_extensions import ParamSpec, TypeGuard
ArgsT = ParamSpec("ArgsT")
# In ordinary single-threaded Python code, when you hit control-C, it raises
# an exception and automatically does all the regular unwinding stuff.
#
# In Trio code, we would like hitting control-C to raise an exception and
# automatically do all the regular unwinding stuff. In particular, we would
# like to maintain our invariant that all tasks always run to completion (one
# way or another), by unwinding all of them.
#
# But it's basically impossible to write the core task running code in such a
# way that it can maintain this invariant in the face of KeyboardInterrupt
# exceptions arising at arbitrary bytecode positions. Similarly, if a
# KeyboardInterrupt happened at the wrong moment inside pretty much any of our
# inter-task synchronization or I/O primitives, then the system state could
# get corrupted and prevent our being able to clean up properly.
#
# So, we need a way to defer KeyboardInterrupt processing from these critical
# sections.
#
# Things that don't work:
#
# - Listen for SIGINT and process it in a system task: works fine for
# well-behaved programs that regularly pass through the event loop, but if
# user-code goes into an infinite loop then it can't be interrupted. Which
# is unfortunate, since dealing with infinite loops is what
# KeyboardInterrupt is for!
#
# - Use pthread_sigmask to disable signal delivery during critical section:
# (a) windows has no pthread_sigmask, (b) python threads start with all
# signals unblocked, so if there are any threads around they'll receive the
# signal and then tell the main thread to run the handler, even if the main
# thread has that signal blocked.
#
# - Install a signal handler which checks a global variable to decide whether
# to raise the exception immediately (if we're in a non-critical section),
# or to schedule it on the event loop (if we're in a critical section). The
# problem here is that it's impossible to transition safely out of user code:
#
# with keyboard_interrupt_enabled:
# msg = coro.send(value)
#
# If this raises a KeyboardInterrupt, it might be because the coroutine got
# interrupted and has unwound... or it might be the KeyboardInterrupt
# arrived just *after* 'send' returned, so the coroutine is still running,
# but we just lost the message it sent. (And worse, in our actual task
# runner, the send is hidden inside a utility function etc.)
#
# Solution:
#
# Mark *stack frames* as being interrupt-safe or interrupt-unsafe, and from
# the signal handler check which kind of frame we're currently in when
# deciding whether to raise or schedule the exception.
#
# There are still some cases where this can fail, like if someone hits
# control-C while the process is in the event loop, and then it immediately
# enters an infinite loop in user code. In this case the user has to hit
# control-C a second time. And of course if the user code is written so that
# it doesn't actually exit after a task crashes and everything gets cancelled,
# then there's not much to be done. (Hitting control-C repeatedly might help,
# but in general the solution is to kill the process some other way, just like
# for any Python program that's written to catch and ignore
# KeyboardInterrupt.)
# We use this special string as a unique key into the frame locals dictionary.
# The @ ensures it is not a valid identifier and can't clash with any possible
# real local name. See: https://github.com/python-trio/trio/issues/469
LOCALS_KEY_KI_PROTECTION_ENABLED: Final = "@TRIO_KI_PROTECTION_ENABLED"
# NB: according to the signal.signal docs, 'frame' can be None on entry to
# this function:
def ki_protection_enabled(frame: types.FrameType | None) -> bool:
while frame is not None:
if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals:
return bool(frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED])
if frame.f_code.co_name == "__del__":
return True
frame = frame.f_back
return True
def currently_ki_protected() -> bool:
r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection
enabled.
It's surprisingly easy to think that one's :exc:`KeyboardInterrupt`
protection is enabled when it isn't, or vice-versa. This function tells
you what Trio thinks of the matter, which makes it useful for ``assert``\s
and unit tests.
Returns:
bool: True if protection is enabled, and False otherwise.
"""
return ki_protection_enabled(sys._getframe())
# This is to support the async_generator package necessary for aclosing on <3.10
# functions decorated @async_generator are given this magic property that's a
# reference to the object itself
# see python-trio/async_generator/async_generator/_impl.py
def legacy_isasyncgenfunction(
obj: object,
) -> TypeGuard[Callable[..., types.AsyncGeneratorType[object, object]]]:
return getattr(obj, "_async_gen_function", None) == id(obj)
def _ki_protection_decorator(
enabled: bool,
) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]:
# The "ignore[return-value]" below is because the inspect functions cast away the
# original return type of fn, making it just CoroutineType[Any, Any, Any] etc.
# ignore[misc] is because @wraps() is passed a callable with Any in the return type.
def decorator(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]:
# In some version of Python, isgeneratorfunction returns true for
# coroutine functions, so we have to check for coroutine functions
# first.
if inspect.iscoroutinefunction(fn):
@wraps(fn)
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc]
# See the comment for regular generators below
coro = fn(*args, **kwargs)
assert coro.cr_frame is not None, "Coroutine frame should exist"
coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return coro # type: ignore[return-value]
return wrapper
elif inspect.isgeneratorfunction(fn):
@wraps(fn)
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc]
# It's important that we inject this directly into the
# generator's locals, as opposed to setting it here and then
# doing 'yield from'. The reason is, if a generator is
# throw()n into, then it may magically pop to the top of the
# stack. And @contextmanager generators in particular are a
# case where we often want KI protection, and which are often
# thrown into! See:
# https://bugs.python.org/issue29590
gen = fn(*args, **kwargs)
gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return gen # type: ignore[return-value]
return wrapper
elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn):
@wraps(fn) # type: ignore[arg-type]
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc]
# See the comment for regular generators above
agen = fn(*args, **kwargs)
agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return agen # type: ignore[return-value]
return wrapper
else:
@wraps(fn)
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return fn(*args, **kwargs)
return wrapper
return decorator
# pyright workaround: https://github.com/microsoft/pyright/issues/5866
class KIProtectionSignature(Protocol):
__name__: str
def __call__(self, f: CallableT, /) -> CallableT:
pass
# the following `type: ignore`s are because we use ParamSpec internally, but want to allow overloads
enable_ki_protection: KIProtectionSignature = _ki_protection_decorator(True) # type: ignore[assignment]
enable_ki_protection.__name__ = "enable_ki_protection"
disable_ki_protection: KIProtectionSignature = _ki_protection_decorator(False) # type: ignore[assignment]
disable_ki_protection.__name__ = "disable_ki_protection"
@attrs.define(slots=False)
class KIManager:
handler: Callable[[int, types.FrameType | None], None] | None = None
def install(
self,
deliver_cb: Callable[[], object],
restrict_keyboard_interrupt_to_checkpoints: bool,
) -> None:
assert self.handler is None
if (
not is_main_thread()
or signal.getsignal(signal.SIGINT) != signal.default_int_handler
):
return
def handler(signum: int, frame: types.FrameType | None) -> None:
assert signum == signal.SIGINT
protection_enabled = ki_protection_enabled(frame)
if protection_enabled or restrict_keyboard_interrupt_to_checkpoints:
deliver_cb()
else:
raise KeyboardInterrupt
self.handler = handler
signal.signal(signal.SIGINT, handler)
def close(self) -> None:
if self.handler is not None:
if signal.getsignal(signal.SIGINT) is self.handler:
signal.signal(signal.SIGINT, signal.default_int_handler)
self.handler = None
@@ -0,0 +1,104 @@
from __future__ import annotations
from typing import Generic, TypeVar, cast
# Runvar implementations
import attrs
from .._util import NoPublicConstructor, final
from . import _run
T = TypeVar("T")
@final
class _NoValue: ...
@final
@attrs.define(eq=False)
class RunVarToken(Generic[T], metaclass=NoPublicConstructor):
_var: RunVar[T]
previous_value: T | type[_NoValue] = _NoValue
redeemed: bool = attrs.field(default=False, init=False)
@classmethod
def _empty(cls, var: RunVar[T]) -> RunVarToken[T]:
return cls._create(var)
@final
@attrs.define(eq=False, repr=False)
class RunVar(Generic[T]):
"""The run-local variant of a context variable.
:class:`RunVar` objects are similar to context variable objects,
except that they are shared across a single call to :func:`trio.run`
rather than a single task.
"""
_name: str
_default: T | type[_NoValue] = _NoValue
def get(self, default: T | type[_NoValue] = _NoValue) -> T:
"""Gets the value of this :class:`RunVar` for the current run call."""
try:
return cast(T, _run.GLOBAL_RUN_CONTEXT.runner._locals[self])
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context") from None
except KeyError:
# contextvars consistency
# `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released
if default is not _NoValue:
return default # type: ignore[return-value]
if self._default is not _NoValue:
return self._default # type: ignore[return-value]
raise LookupError(self) from None
def set(self, value: T) -> RunVarToken[T]:
"""Sets the value of this :class:`RunVar` for this current run
call.
"""
try:
old_value = self.get()
except LookupError:
token = RunVarToken._empty(self)
else:
token = RunVarToken[T]._create(self, old_value)
# This can't fail, because if we weren't in Trio context then the
# get() above would have failed.
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value
return token
def reset(self, token: RunVarToken[T]) -> None:
"""Resets the value of this :class:`RunVar` to what it was
previously specified by the token.
"""
if token is None:
raise TypeError("token must not be none")
if token.redeemed:
raise ValueError("token has already been used")
if token._var is not self:
raise ValueError("token is not for us")
previous = token.previous_value
try:
if previous is _NoValue:
_run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self)
else:
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context") from None
token.redeemed = True
def __repr__(self) -> str:
return f"<RunVar name={self._name!r}>"
@@ -0,0 +1,164 @@
import time
from math import inf
from .. import _core
from .._abc import Clock
from .._util import final
from ._run import GLOBAL_RUN_CONTEXT
################################################################
# The glorious MockClock
################################################################
# Prior art:
# https://twistedmatrix.com/documents/current/api/twisted.internet.task.Clock.html
# https://github.com/ztellman/manifold/issues/57
@final
class MockClock(Clock):
"""A user-controllable clock suitable for writing tests.
Args:
rate (float): the initial :attr:`rate`.
autojump_threshold (float): the initial :attr:`autojump_threshold`.
.. attribute:: rate
How many seconds of clock time pass per second of real time. Default is
0.0, i.e. the clock only advances through manuals calls to :meth:`jump`
or when the :attr:`autojump_threshold` is triggered. You can assign to
this attribute to change it.
.. attribute:: autojump_threshold
The clock keeps an eye on the run loop, and if at any point it detects
that all tasks have been blocked for this many real seconds (i.e.,
according to the actual clock, not this clock), then the clock
automatically jumps ahead to the run loop's next scheduled
timeout. Default is :data:`math.inf`, i.e., to never autojump. You can
assign to this attribute to change it.
Basically the idea is that if you have code or tests that use sleeps
and timeouts, you can use this to make it run much faster, totally
automatically. (At least, as long as those sleeps/timeouts are
happening inside Trio; if your test involves talking to external
service and waiting for it to timeout then obviously we can't help you
there.)
You should set this to the smallest value that lets you reliably avoid
"false alarms" where some I/O is in flight (e.g. between two halves of
a socketpair) but the threshold gets triggered and time gets advanced
anyway. This will depend on the details of your tests and test
environment. If you aren't doing any I/O (like in our sleeping example
above) then just set it to zero, and the clock will jump whenever all
tasks are blocked.
.. note:: If you use ``autojump_threshold`` and
`wait_all_tasks_blocked` at the same time, then you might wonder how
they interact, since they both cause things to happen after the run
loop goes idle for some time. The answer is:
`wait_all_tasks_blocked` takes priority. If there's a task blocked
in `wait_all_tasks_blocked`, then the autojump feature treats that
as active task and does *not* jump the clock.
"""
def __init__(self, rate: float = 0.0, autojump_threshold: float = inf):
# when the real clock said 'real_base', the virtual time was
# 'virtual_base', and since then it's advanced at 'rate' virtual
# seconds per real second.
self._real_base = 0.0
self._virtual_base = 0.0
self._rate = 0.0
self._autojump_threshold = 0.0
# kept as an attribute so that our tests can monkeypatch it
self._real_clock = time.perf_counter
# use the property update logic to set initial values
self.rate = rate
self.autojump_threshold = autojump_threshold
def __repr__(self) -> str:
return f"<MockClock, time={self.current_time():.7f}, rate={self._rate} @ {id(self):#x}>"
@property
def rate(self) -> float:
return self._rate
@rate.setter
def rate(self, new_rate: float) -> None:
if new_rate < 0:
raise ValueError("rate must be >= 0")
else:
real = self._real_clock()
virtual = self._real_to_virtual(real)
self._virtual_base = virtual
self._real_base = real
self._rate = float(new_rate)
@property
def autojump_threshold(self) -> float:
return self._autojump_threshold
@autojump_threshold.setter
def autojump_threshold(self, new_autojump_threshold: float) -> None:
self._autojump_threshold = float(new_autojump_threshold)
self._try_resync_autojump_threshold()
# runner.clock_autojump_threshold is an internal API that isn't easily
# usable by custom third-party Clock objects. If you need access to this
# functionality, let us know, and we'll figure out how to make a public
# API. Discussion:
#
# https://github.com/python-trio/trio/issues/1587
def _try_resync_autojump_threshold(self) -> None:
try:
runner = GLOBAL_RUN_CONTEXT.runner
if runner.is_guest:
runner.force_guest_tick_asap()
except AttributeError:
pass
else:
runner.clock_autojump_threshold = self._autojump_threshold
# Invoked by the run loop when runner.clock_autojump_threshold is
# exceeded.
def _autojump(self) -> None:
statistics = _core.current_statistics()
jump = statistics.seconds_to_next_deadline
if 0 < jump < inf:
self.jump(jump)
def _real_to_virtual(self, real: float) -> float:
real_offset = real - self._real_base
virtual_offset = self._rate * real_offset
return self._virtual_base + virtual_offset
def start_clock(self) -> None:
self._try_resync_autojump_threshold()
def current_time(self) -> float:
return self._real_to_virtual(self._real_clock())
def deadline_to_sleep_time(self, deadline: float) -> float:
virtual_timeout = deadline - self.current_time()
if virtual_timeout <= 0:
return 0
elif self._rate > 0:
return virtual_timeout / self._rate
else:
return 999999999
def jump(self, seconds: float) -> None:
"""Manually advance the clock by the given number of seconds.
Args:
seconds (float): the number of seconds to jump the clock forward.
Raises:
ValueError: if you try to pass a negative value for ``seconds``.
"""
if seconds < 0:
raise ValueError("time can't go backwards")
self._virtual_base += seconds
@@ -0,0 +1,317 @@
# ParkingLot provides an abstraction for a fair waitqueue with cancellation
# and requeueing support. Inspiration:
#
# https://webkit.org/blog/6161/locking-in-webkit/
# https://amanieu.github.io/parking_lot/
#
# which were in turn heavily influenced by
#
# http://gee.cs.oswego.edu/dl/papers/aqs.pdf
#
# Compared to these, our use of cooperative scheduling allows some
# simplifications (no need for internal locking). On the other hand, the need
# to support Trio's strong cancellation semantics adds some complications
# (tasks need to know where they're queued so they can cancel). Also, in the
# above work, the ParkingLot is a global structure that holds a collection of
# waitqueues keyed by lock address, and which are opportunistically allocated
# and destroyed as contention arises; this allows the worst-case memory usage
# for all waitqueues to be O(#tasks). Here we allocate a separate wait queue
# for each synchronization object, so we're O(#objects + #tasks). This isn't
# *so* bad since compared to our synchronization objects are heavier than
# theirs and our tasks are lighter, so for us #objects is smaller and #tasks
# is larger.
#
# This is in the core because for two reasons. First, it's used by
# UnboundedQueue, and UnboundedQueue is used for a number of things in the
# core. And second, it's responsible for providing fairness to all of our
# high-level synchronization primitives (locks, queues, etc.). For now with
# our FIFO scheduler this is relatively trivial (it's just a FIFO waitqueue),
# but in the future we ever start support task priorities or fair scheduling
#
# https://github.com/python-trio/trio/issues/32
#
# then all we'll have to do is update this. (Well, full-fledged task
# priorities might also require priority inheritance, which would require more
# work.)
#
# For discussion of data structures to use here, see:
#
# https://github.com/dabeaz/curio/issues/136
#
# (and also the articles above). Currently we use a SortedDict ordered by a
# global monotonic counter that ensures FIFO ordering. The main advantage of
# this is that it's easy to implement :-). An intrusive doubly-linked list
# would also be a natural approach, so long as we only handle FIFO ordering.
#
# XX: should we switch to the shared global ParkingLot approach?
#
# XX: we should probably add support for "parking tokens" to allow for
# task-fair RWlock (basically: when parking a task needs to be able to mark
# itself as a reader or a writer, and then a task-fair wakeup policy is, wake
# the next task, and if it's a reader than keep waking tasks so long as they
# are readers). Without this I think you can implement write-biased or
# read-biased RWlocks (by using two parking lots and drawing from whichever is
# preferred), but not task-fair -- and task-fair plays much more nicely with
# WFQ. (Consider what happens in the two-lot implementation if you're
# write-biased but all the pending writers are blocked at the scheduler level
# by the WFQ logic...)
# ...alternatively, "phase-fair" RWlocks are pretty interesting:
# http://www.cs.unc.edu/~anderson/papers/ecrts09b.pdf
# Useful summary:
# https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/locks/ReadWriteLock.html
#
# XX: if we do add WFQ, then we might have to drop the current feature where
# unpark returns the tasks that were unparked. Rationale: suppose that at the
# time we call unpark, the next task is deprioritized... and then, before it
# becomes runnable, a new task parks which *is* runnable. Ideally we should
# immediately wake the new task, and leave the old task on the queue for
# later. But this means we can't commit to which task we are unparking when
# unpark is called.
#
# See: https://github.com/python-trio/trio/issues/53
from __future__ import annotations
import inspect
import math
from collections import OrderedDict
from typing import TYPE_CHECKING
import attrs
import outcome
from .. import _core
from .._util import final
if TYPE_CHECKING:
from collections.abc import Iterator
from ._run import Task
GLOBAL_PARKING_LOT_BREAKER: dict[Task, list[ParkingLot]] = {}
def add_parking_lot_breaker(task: Task, lot: ParkingLot) -> None:
"""Register a task as a breaker for a lot. See :meth:`ParkingLot.break_lot`.
raises:
trio.BrokenResourceError: if the task has already exited.
"""
if inspect.getcoroutinestate(task.coro) == inspect.CORO_CLOSED:
raise _core._exceptions.BrokenResourceError(
"Attempted to add already exited task as lot breaker.",
)
if task not in GLOBAL_PARKING_LOT_BREAKER:
GLOBAL_PARKING_LOT_BREAKER[task] = [lot]
else:
GLOBAL_PARKING_LOT_BREAKER[task].append(lot)
def remove_parking_lot_breaker(task: Task, lot: ParkingLot) -> None:
"""Deregister a task as a breaker for a lot. See :meth:`ParkingLot.break_lot`"""
try:
GLOBAL_PARKING_LOT_BREAKER[task].remove(lot)
except (KeyError, ValueError):
raise RuntimeError(
"Attempted to remove task as breaker for a lot it is not registered for",
) from None
if not GLOBAL_PARKING_LOT_BREAKER[task]:
del GLOBAL_PARKING_LOT_BREAKER[task]
@attrs.frozen
class ParkingLotStatistics:
"""An object containing debugging information for a ParkingLot.
Currently, the following fields are defined:
* ``tasks_waiting`` (int): The number of tasks blocked on this lot's
:meth:`trio.lowlevel.ParkingLot.park` method.
"""
tasks_waiting: int
@final
@attrs.define(eq=False)
class ParkingLot:
"""A fair wait queue with cancellation and requeueing.
This class encapsulates the tricky parts of implementing a wait
queue. It's useful for implementing higher-level synchronization
primitives like queues and locks.
In addition to the methods below, you can use ``len(parking_lot)`` to get
the number of parked tasks, and ``if parking_lot: ...`` to check whether
there are any parked tasks.
"""
# {task: None}, we just want a deque where we can quickly delete random
# items
_parked: OrderedDict[Task, None] = attrs.field(factory=OrderedDict, init=False)
broken_by: list[Task] = attrs.field(factory=list, init=False)
def __len__(self) -> int:
"""Returns the number of parked tasks."""
return len(self._parked)
def __bool__(self) -> bool:
"""True if there are parked tasks, False otherwise."""
return bool(self._parked)
# XX this currently returns None
# if we ever add the ability to repark while one's resuming place in
# line (for false wakeups), then we could have it return a ticket that
# abstracts the "place in line" concept.
@_core.enable_ki_protection
async def park(self) -> None:
"""Park the current task until woken by a call to :meth:`unpark` or
:meth:`unpark_all`.
Raises:
BrokenResourceError: if attempting to park in a broken lot, or the lot
breaks before we get to unpark.
"""
if self.broken_by:
raise _core.BrokenResourceError(
f"Attempted to park in parking lot broken by {self.broken_by}",
)
task = _core.current_task()
self._parked[task] = None
task.custom_sleep_data = self
def abort_fn(_: _core.RaiseCancelT) -> _core.Abort:
del task.custom_sleep_data._parked[task]
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort_fn)
def _pop_several(self, count: int | float) -> Iterator[Task]: # noqa: PYI041
if isinstance(count, float):
if math.isinf(count):
count = len(self._parked)
else:
raise ValueError("Cannot pop a non-integer number of tasks.")
else:
count = min(count, len(self._parked))
for _ in range(count):
task, _ = self._parked.popitem(last=False)
yield task
@_core.enable_ki_protection
def unpark(self, *, count: int | float = 1) -> list[Task]: # noqa: PYI041
"""Unpark one or more tasks.
This wakes up ``count`` tasks that are blocked in :meth:`park`. If
there are fewer than ``count`` tasks parked, then wakes as many tasks
are available and then returns successfully.
Args:
count (int | math.inf): the number of tasks to unpark.
"""
tasks = list(self._pop_several(count))
for task in tasks:
_core.reschedule(task)
return tasks
def unpark_all(self) -> list[Task]:
"""Unpark all parked tasks."""
return self.unpark(count=len(self))
@_core.enable_ki_protection
def repark(
self,
new_lot: ParkingLot,
*,
count: int | float = 1, # noqa: PYI041
) -> None:
"""Move parked tasks from one :class:`ParkingLot` object to another.
This dequeues ``count`` tasks from one lot, and requeues them on
another, preserving order. For example::
async def parker(lot):
print("sleeping")
await lot.park()
print("woken")
async def main():
lot1 = trio.lowlevel.ParkingLot()
lot2 = trio.lowlevel.ParkingLot()
async with trio.open_nursery() as nursery:
nursery.start_soon(parker, lot1)
await trio.testing.wait_all_tasks_blocked()
assert len(lot1) == 1
assert len(lot2) == 0
lot1.repark(lot2)
assert len(lot1) == 0
assert len(lot2) == 1
# This wakes up the task that was originally parked in lot1
lot2.unpark()
If there are fewer than ``count`` tasks parked, then reparks as many
tasks as are available and then returns successfully.
Args:
new_lot (ParkingLot): the parking lot to move tasks to.
count (int|math.inf): the number of tasks to move.
"""
if not isinstance(new_lot, ParkingLot):
raise TypeError("new_lot must be a ParkingLot")
for task in self._pop_several(count):
new_lot._parked[task] = None
task.custom_sleep_data = new_lot
def repark_all(self, new_lot: ParkingLot) -> None:
"""Move all parked tasks from one :class:`ParkingLot` object to
another.
See :meth:`repark` for details.
"""
return self.repark(new_lot, count=len(self))
def break_lot(self, task: Task | None = None) -> None:
"""Break this lot, with ``task`` noted as the task that broke it.
This causes all parked tasks to raise an error, and any
future tasks attempting to park to error. Unpark & repark become no-ops as the
parking lot is empty.
The error raised contains a reference to the task sent as a parameter. The task
is also saved in the parking lot in the ``broken_by`` attribute.
"""
if task is None:
task = _core.current_task()
# if lot is already broken, just mark this as another breaker and return
if self.broken_by:
self.broken_by.append(task)
return
self.broken_by.append(task)
for parked_task in self._parked:
_core.reschedule(
parked_task,
outcome.Error(
_core.BrokenResourceError(f"Parking lot broken by {task}"),
),
)
self._parked.clear()
def statistics(self) -> ParkingLotStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this lot's
:meth:`park` method.
"""
return ParkingLotStatistics(tasks_waiting=len(self._parked))
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,333 @@
from __future__ import annotations
import contextlib
import sys
import weakref
from math import inf
from typing import TYPE_CHECKING, NoReturn
import pytest
from ... import _core
from .tutil import gc_collect_harder, restore_unraisablehook
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
def test_asyncgen_basics() -> None:
collected = []
async def example(cause: str) -> AsyncGenerator[int, None]:
try:
with contextlib.suppress(GeneratorExit):
yield 42
await _core.checkpoint()
except _core.Cancelled:
assert "exhausted" not in cause
task_name = _core.current_task().name
assert cause in task_name or task_name == "<init>"
assert _core.current_effective_deadline() == -inf
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
collected.append(cause)
else:
assert "async_main" in _core.current_task().name
assert "exhausted" in cause
assert _core.current_effective_deadline() == inf
await _core.checkpoint()
collected.append(cause)
saved = []
async def async_main() -> None:
# GC'ed before exhausted
with pytest.warns(
ResourceWarning,
match="Async generator.*collected before.*exhausted",
):
assert await example("abandoned").asend(None) == 42
gc_collect_harder()
await _core.wait_all_tasks_blocked()
assert collected.pop() == "abandoned"
aiter_ = example("exhausted 1")
try:
assert await aiter_.asend(None) == 42
finally:
await aiter_.aclose()
assert collected.pop() == "exhausted 1"
# Also fine if you exhaust it at point of use
async for val in example("exhausted 2"):
assert val == 42
assert collected.pop() == "exhausted 2"
gc_collect_harder()
# No problems saving the geniter when using either of these patterns
aiter_ = example("exhausted 3")
try:
saved.append(aiter_)
assert await aiter_.asend(None) == 42
finally:
await aiter_.aclose()
assert collected.pop() == "exhausted 3"
# Also fine if you exhaust it at point of use
saved.append(example("exhausted 4"))
async for val in saved[-1]:
assert val == 42
assert collected.pop() == "exhausted 4"
# Leave one referenced-but-unexhausted and make sure it gets cleaned up
saved.append(example("outlived run"))
assert await saved[-1].asend(None) == 42
assert collected == []
_core.run(async_main)
assert collected.pop() == "outlived run"
for agen in saved:
assert agen.ag_frame is None # all should now be exhausted
async def test_asyncgen_throws_during_finalization(
caplog: pytest.LogCaptureFixture,
) -> None:
record = []
async def agen() -> AsyncGenerator[int, None]:
try:
yield 1
finally:
await _core.cancel_shielded_checkpoint()
record.append("crashing")
raise ValueError("oops")
with restore_unraisablehook():
await agen().asend(None)
gc_collect_harder()
await _core.wait_all_tasks_blocked()
assert record == ["crashing"]
# Following type ignore is because typing for LogCaptureFixture is wrong
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info # type: ignore[misc]
assert exc_type is ValueError
assert str(exc_value) == "oops"
assert "during finalization of async generator" in caplog.records[0].message
def test_firstiter_after_closing() -> None:
saved = []
record = []
async def funky_agen() -> AsyncGenerator[int, None]:
try:
yield 1
except GeneratorExit:
record.append("cleanup 1")
raise
try:
yield 2
finally:
record.append("cleanup 2")
await funky_agen().asend(None)
async def async_main() -> None:
aiter_ = funky_agen()
saved.append(aiter_)
assert await aiter_.asend(None) == 1
assert await aiter_.asend(None) == 2
_core.run(async_main)
assert record == ["cleanup 2", "cleanup 1"]
def test_interdependent_asyncgen_cleanup_order() -> None:
saved: list[AsyncGenerator[int, None]] = []
record: list[int | str] = []
async def innermost() -> AsyncGenerator[int, None]:
try:
yield 1
finally:
await _core.cancel_shielded_checkpoint()
record.append("innermost")
async def agen(
label: int,
inner: AsyncGenerator[int, None],
) -> AsyncGenerator[int, None]:
try:
yield await inner.asend(None)
finally:
# Either `inner` has already been cleaned up, or
# we're about to exhaust it. Either way, we wind
# up with `record` containing the labels in
# innermost-to-outermost order.
with pytest.raises(StopAsyncIteration):
await inner.asend(None)
record.append(label)
async def async_main() -> None:
# This makes a chain of 101 interdependent asyncgens:
# agen(99)'s cleanup will iterate agen(98)'s will iterate
# ... agen(0)'s will iterate innermost()'s
ag_chain = innermost()
for idx in range(100):
ag_chain = agen(idx, ag_chain)
saved.append(ag_chain)
assert await ag_chain.asend(None) == 1
assert record == []
_core.run(async_main)
assert record == ["innermost", *range(100)]
@restore_unraisablehook()
def test_last_minute_gc_edge_case() -> None:
saved: list[AsyncGenerator[int, None]] = []
record = []
needs_retry = True
async def agen() -> AsyncGenerator[int, None]:
try:
yield 1
finally:
record.append("cleaned up")
def collect_at_opportune_moment(token: _core._entry_queue.TrioToken) -> None:
runner = _core._run.GLOBAL_RUN_CONTEXT.runner
assert runner.system_nursery is not None
if runner.system_nursery._closed and isinstance(
runner.asyncgens.alive,
weakref.WeakSet,
):
saved.clear()
record.append("final collection")
gc_collect_harder()
record.append("done")
else:
try:
token.run_sync_soon(collect_at_opportune_moment, token)
except _core.RunFinishedError: # pragma: no cover
nonlocal needs_retry
needs_retry = True
async def async_main() -> None:
token = _core.current_trio_token()
token.run_sync_soon(collect_at_opportune_moment, token)
saved.append(agen())
await saved[-1].asend(None)
# Actually running into the edge case requires that the run_sync_soon task
# execute in between the system nursery's closure and the strong-ification
# of runner.asyncgens. There's about a 25% chance that it doesn't
# (if the run_sync_soon task runs before init on one tick and after init
# on the next tick); if we try enough times, we can make the chance of
# failure as small as we want.
for _attempt in range(50):
needs_retry = False
del record[:]
del saved[:]
_core.run(async_main)
if needs_retry: # pragma: no cover
assert record == ["cleaned up"]
else:
assert record == ["final collection", "done", "cleaned up"]
break
else: # pragma: no cover
pytest.fail(
"Didn't manage to hit the trailing_finalizer_asyncgens case "
f"despite trying {_attempt} times",
)
async def step_outside_async_context(aiter_: AsyncGenerator[int, None]) -> None:
# abort_fns run outside of task context, at least if they're
# triggered by a deadline expiry rather than a direct
# cancellation. Thus, an asyncgen first iterated inside one
# will appear non-Trio, and since no other hooks were installed,
# will use the last-ditch fallback handling (that tries to mimic
# CPython's behavior with no hooks).
#
# NB: the strangeness with aiter being an attribute of abort_fn is
# to make it as easy as possible to ensure we don't hang onto a
# reference to aiter inside the guts of the run loop.
def abort_fn(_: _core.RaiseCancelT) -> _core.Abort:
with pytest.raises(StopIteration, match="42"):
abort_fn.aiter.asend(None).send(None) # type: ignore[attr-defined] # Callables don't have attribute "aiter"
del abort_fn.aiter # type: ignore[attr-defined]
return _core.Abort.SUCCEEDED
abort_fn.aiter = aiter_ # type: ignore[attr-defined]
async with _core.open_nursery() as nursery:
nursery.start_soon(_core.wait_task_rescheduled, abort_fn)
await _core.wait_all_tasks_blocked()
nursery.cancel_scope.deadline = _core.current_time()
async def test_fallback_when_no_hook_claims_it(
capsys: pytest.CaptureFixture[str],
) -> None:
async def well_behaved() -> AsyncGenerator[int, None]:
yield 42
async def yields_after_yield() -> AsyncGenerator[int, None]:
with pytest.raises(GeneratorExit):
yield 42
yield 100
async def awaits_after_yield() -> AsyncGenerator[int, None]:
with pytest.raises(GeneratorExit):
yield 42
await _core.cancel_shielded_checkpoint()
with restore_unraisablehook():
await step_outside_async_context(well_behaved())
gc_collect_harder()
assert capsys.readouterr().err == ""
await step_outside_async_context(yields_after_yield())
gc_collect_harder()
assert "ignored GeneratorExit" in capsys.readouterr().err
await step_outside_async_context(awaits_after_yield())
gc_collect_harder()
assert "awaited something during finalization" in capsys.readouterr().err
def test_delegation_to_existing_hooks() -> None:
record = []
def my_firstiter(agen: AsyncGenerator[object, NoReturn]) -> None:
record.append("firstiter " + agen.ag_frame.f_locals["arg"])
def my_finalizer(agen: AsyncGenerator[object, NoReturn]) -> None:
record.append("finalizer " + agen.ag_frame.f_locals["arg"])
async def example(arg: str) -> AsyncGenerator[int, None]:
try:
yield 42
finally:
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
record.append("trio collected " + arg)
async def async_main() -> None:
await step_outside_async_context(example("theirs"))
assert await example("ours").asend(None) == 42
gc_collect_harder()
assert record == ["firstiter theirs", "finalizer theirs"]
record[:] = []
await _core.wait_all_tasks_blocked()
assert record == ["trio collected ours"]
with restore_unraisablehook():
old_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(my_firstiter, my_finalizer)
try:
_core.run(async_main)
finally:
assert sys.get_asyncgen_hooks() == (my_firstiter, my_finalizer)
sys.set_asyncgen_hooks(*old_hooks)
@@ -0,0 +1,102 @@
from __future__ import annotations
import gc
import sys
from traceback import extract_tb
from typing import TYPE_CHECKING, Callable, NoReturn
import pytest
from .._concat_tb import concat_tb
if TYPE_CHECKING:
from types import TracebackType
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
def raiser1() -> NoReturn:
raiser1_2()
def raiser1_2() -> NoReturn:
raiser1_3()
def raiser1_3() -> NoReturn:
raise ValueError("raiser1_string")
def raiser2() -> NoReturn:
raiser2_2()
def raiser2_2() -> NoReturn:
raise KeyError("raiser2_string")
def get_exc(raiser: Callable[[], NoReturn]) -> Exception:
try:
raiser()
except Exception as exc:
return exc
raise AssertionError("raiser should always raise") # pragma: no cover
def get_tb(raiser: Callable[[], NoReturn]) -> TracebackType | None:
return get_exc(raiser).__traceback__
def test_concat_tb() -> None:
tb1 = get_tb(raiser1)
tb2 = get_tb(raiser2)
# These return a list of (filename, lineno, fn name, text) tuples
# https://docs.python.org/3/library/traceback.html#traceback.extract_tb
entries1 = extract_tb(tb1)
entries2 = extract_tb(tb2)
tb12 = concat_tb(tb1, tb2)
assert extract_tb(tb12) == entries1 + entries2
tb21 = concat_tb(tb2, tb1)
assert extract_tb(tb21) == entries2 + entries1
# Check degenerate cases
assert extract_tb(concat_tb(None, tb1)) == entries1
assert extract_tb(concat_tb(tb1, None)) == entries1
assert concat_tb(None, None) is None
# Make sure the original tracebacks didn't get mutated by mistake
assert extract_tb(get_tb(raiser1)) == entries1
assert extract_tb(get_tb(raiser2)) == entries2
# Unclear if this can still fail, removing the `del` from _concat_tb.copy_tb does not seem
# to trigger it (on a platform where the `del` is executed)
@pytest.mark.skipif(
sys.implementation.name != "cpython",
reason="Only makes sense with refcounting GC",
)
def test_ExceptionGroup_catch_doesnt_create_cyclic_garbage() -> None:
# https://github.com/python-trio/trio/pull/2063
gc.collect()
old_flags = gc.get_debug()
def make_multi() -> NoReturn:
raise ExceptionGroup("", [get_exc(raiser1), get_exc(raiser2)])
try:
gc.set_debug(gc.DEBUG_SAVEALL)
with pytest.raises(ExceptionGroup) as excinfo:
# covers ~~MultiErrorCatcher.__exit__ and~~ _concat_tb.copy_tb
# TODO: is the above comment true anymore? as this no longer uses MultiError.catch
raise make_multi()
for exc in excinfo.value.exceptions:
assert isinstance(exc, (ValueError, KeyError))
gc.collect()
assert not gc.garbage
finally:
gc.set_debug(old_flags)
gc.garbage.clear()
@@ -0,0 +1,666 @@
from __future__ import annotations
import asyncio
import contextlib
import contextvars
import queue
import signal
import socket
import sys
import threading
import time
import traceback
import warnings
from functools import partial
from math import inf
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Awaitable,
Callable,
NoReturn,
TypeVar,
)
import pytest
from outcome import Outcome
import trio
import trio.testing
from trio.abc import Instrument
from ..._util import signal_raise
from .tutil import gc_collect_harder, restore_unraisablehook
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from trio._channel import MemorySendChannel
T = TypeVar("T")
InHost: TypeAlias = Callable[[object], None]
# The simplest possible "host" loop.
# Nice features:
# - we can run code "outside" of trio using the schedule function passed to
# our main
# - final result is returned
# - any unhandled exceptions cause an immediate crash
def trivial_guest_run(
trio_fn: Callable[..., Awaitable[T]],
*,
in_host_after_start: Callable[[], None] | None = None,
**start_guest_run_kwargs: Any,
) -> T:
todo: queue.Queue[tuple[str, Outcome[T] | Callable[..., object]]] = queue.Queue()
host_thread = threading.current_thread()
def run_sync_soon_threadsafe(fn: Callable[[], object]) -> None:
nonlocal todo
if host_thread is threading.current_thread(): # pragma: no cover
crash = partial(
pytest.fail,
"run_sync_soon_threadsafe called from host thread",
)
todo.put(("run", crash))
todo.put(("run", fn))
def run_sync_soon_not_threadsafe(fn: Callable[[], object]) -> None:
nonlocal todo
if host_thread is not threading.current_thread(): # pragma: no cover
crash = partial(
pytest.fail,
"run_sync_soon_not_threadsafe called from worker thread",
)
todo.put(("run", crash))
todo.put(("run", fn))
def done_callback(outcome: Outcome[T]) -> None:
nonlocal todo
todo.put(("unwrap", outcome))
trio.lowlevel.start_guest_run(
trio_fn,
run_sync_soon_not_threadsafe,
run_sync_soon_threadsafe=run_sync_soon_threadsafe,
run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe,
done_callback=done_callback,
**start_guest_run_kwargs,
)
if in_host_after_start is not None:
in_host_after_start()
try:
while True:
op, obj = todo.get()
if op == "run":
assert not isinstance(obj, Outcome)
obj()
elif op == "unwrap":
assert isinstance(obj, Outcome)
return obj.unwrap()
else: # pragma: no cover
raise NotImplementedError(f"{op!r} not handled")
finally:
# Make sure that exceptions raised here don't capture these, so that
# if an exception does cause us to abandon a run then the Trio state
# has a chance to be GC'ed and warn about it.
del todo, run_sync_soon_threadsafe, done_callback
def test_guest_trivial() -> None:
async def trio_return(in_host: InHost) -> str:
await trio.lowlevel.checkpoint()
return "ok"
assert trivial_guest_run(trio_return) == "ok"
async def trio_fail(in_host: InHost) -> NoReturn:
raise KeyError("whoopsiedaisy")
with pytest.raises(KeyError, match="whoopsiedaisy"):
trivial_guest_run(trio_fail)
def test_guest_can_do_io() -> None:
async def trio_main(in_host: InHost) -> None:
record = []
a, b = trio.socket.socketpair()
with a, b:
async with trio.open_nursery() as nursery:
async def do_receive() -> None:
record.append(await a.recv(1))
nursery.start_soon(do_receive)
await trio.testing.wait_all_tasks_blocked()
await b.send(b"x")
assert record == [b"x"]
trivial_guest_run(trio_main)
def test_guest_is_initialized_when_start_returns() -> None:
trio_token = None
record = []
async def trio_main(in_host: InHost) -> str:
record.append("main task ran")
await trio.lowlevel.checkpoint()
assert trio.lowlevel.current_trio_token() is trio_token
return "ok"
def after_start() -> None:
# We should get control back before the main task executes any code
assert record == []
nonlocal trio_token
trio_token = trio.lowlevel.current_trio_token()
trio_token.run_sync_soon(record.append, "run_sync_soon cb ran")
@trio.lowlevel.spawn_system_task
async def early_task() -> None:
record.append("system task ran")
await trio.lowlevel.checkpoint()
res = trivial_guest_run(trio_main, in_host_after_start=after_start)
assert res == "ok"
assert set(record) == {"system task ran", "main task ran", "run_sync_soon cb ran"}
class BadClock:
def start_clock(self) -> NoReturn:
raise ValueError("whoops")
def after_start_never_runs() -> None: # pragma: no cover
pytest.fail("shouldn't get here")
# Errors during initialization (which can only be TrioInternalErrors)
# are raised out of start_guest_run, not out of the done_callback
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(
trio_main,
clock=BadClock(),
in_host_after_start=after_start_never_runs,
)
def test_host_can_directly_wake_trio_task() -> None:
async def trio_main(in_host: InHost) -> str:
ev = trio.Event()
in_host(ev.set)
await ev.wait()
return "ok"
assert trivial_guest_run(trio_main) == "ok"
def test_host_altering_deadlines_wakes_trio_up() -> None:
def set_deadline(cscope: trio.CancelScope, new_deadline: float) -> None:
cscope.deadline = new_deadline
async def trio_main(in_host: InHost) -> str:
with trio.CancelScope() as cscope:
in_host(lambda: set_deadline(cscope, -inf))
await trio.sleep_forever()
assert cscope.cancelled_caught
with trio.CancelScope() as cscope:
# also do a change that doesn't affect the next deadline, just to
# exercise that path
in_host(lambda: set_deadline(cscope, 1e6))
in_host(lambda: set_deadline(cscope, -inf))
await trio.sleep(999)
assert cscope.cancelled_caught
return "ok"
assert trivial_guest_run(trio_main) == "ok"
def test_guest_mode_sniffio_integration() -> None:
from sniffio import current_async_library, thread_local as sniffio_library
async def trio_main(in_host: InHost) -> str:
async def synchronize() -> None:
"""Wait for all in_host() calls issued so far to complete."""
evt = trio.Event()
in_host(evt.set)
await evt.wait()
# Host and guest have separate sniffio_library contexts
in_host(partial(setattr, sniffio_library, "name", "nullio"))
await synchronize()
assert current_async_library() == "trio"
record = []
in_host(lambda: record.append(current_async_library()))
await synchronize()
assert record == ["nullio"]
assert current_async_library() == "trio"
return "ok"
try:
assert trivial_guest_run(trio_main) == "ok"
finally:
sniffio_library.name = None
def test_warn_set_wakeup_fd_overwrite() -> None:
assert signal.set_wakeup_fd(-1) == -1
async def trio_main(in_host: InHost) -> str:
return "ok"
a, b = socket.socketpair()
with a, b:
a.setblocking(False)
# Warn if there's already a wakeup fd
signal.set_wakeup_fd(a.fileno())
try:
with pytest.warns(RuntimeWarning, match="signal handling code.*collided"):
assert trivial_guest_run(trio_main) == "ok"
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
signal.set_wakeup_fd(a.fileno())
try:
with pytest.warns(RuntimeWarning, match="signal handling code.*collided"):
assert (
trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=False)
== "ok"
)
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
# Don't warn if there isn't already a wakeup fd
with warnings.catch_warnings():
warnings.simplefilter("error")
assert trivial_guest_run(trio_main) == "ok"
with warnings.catch_warnings():
warnings.simplefilter("error")
assert (
trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=True)
== "ok"
)
# If there's already a wakeup fd, but we've been told to trust it,
# then it's left alone and there's no warning
signal.set_wakeup_fd(a.fileno())
try:
async def trio_check_wakeup_fd_unaltered(in_host: InHost) -> str:
fd = signal.set_wakeup_fd(-1)
assert fd == a.fileno()
signal.set_wakeup_fd(fd)
return "ok"
with warnings.catch_warnings():
warnings.simplefilter("error")
assert (
trivial_guest_run(
trio_check_wakeup_fd_unaltered,
host_uses_signal_set_wakeup_fd=True,
)
== "ok"
)
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked() -> None:
# This is designed to hit the branch in unrolled_run where:
# idle_primed=True
# runner.runq is empty
# events is Truth-y
# ...and confirm that in this case, wait_all_tasks_blocked does not get
# triggered.
def set_deadline(cscope: trio.CancelScope, new_deadline: float) -> None:
print(f"setting deadline {new_deadline}")
cscope.deadline = new_deadline
async def trio_main(in_host: InHost) -> str:
async def sit_in_wait_all_tasks_blocked(watb_cscope: trio.CancelScope) -> None:
with watb_cscope:
# Overall point of this test is that this
# wait_all_tasks_blocked should *not* return normally, but
# only by cancellation.
await trio.testing.wait_all_tasks_blocked(cushion=9999)
raise AssertionError( # pragma: no cover
"wait_all_tasks_blocked should *not* return normally, "
"only by cancellation.",
)
assert watb_cscope.cancelled_caught
async def get_woken_by_host_deadline(watb_cscope: trio.CancelScope) -> None:
with trio.CancelScope() as cscope:
print("scheduling stuff to happen")
# Altering the deadline from the host, to something in the
# future, will cause the run loop to wake up, but then
# discover that there is nothing to do and go back to sleep.
# This should *not* trigger wait_all_tasks_blocked.
#
# So the 'before_io_wait' here will wait until we're blocking
# with the wait_all_tasks_blocked primed, and then schedule a
# deadline change. The critical test is that this should *not*
# wake up 'sit_in_wait_all_tasks_blocked'.
#
# The after we've had a chance to wake up
# 'sit_in_wait_all_tasks_blocked', we want the test to
# actually end. So in after_io_wait we schedule a second host
# call to tear things down.
class InstrumentHelper(Instrument):
def __init__(self) -> None:
self.primed = False
def before_io_wait(self, timeout: float) -> None:
print(f"before_io_wait({timeout})")
if timeout == 9999: # pragma: no branch
assert not self.primed
in_host(lambda: set_deadline(cscope, 1e9))
self.primed = True
def after_io_wait(self, timeout: float) -> None:
if self.primed: # pragma: no branch
print("instrument triggered")
in_host(lambda: cscope.cancel())
trio.lowlevel.remove_instrument(self)
trio.lowlevel.add_instrument(InstrumentHelper())
await trio.sleep_forever()
assert cscope.cancelled_caught
watb_cscope.cancel()
async with trio.open_nursery() as nursery:
watb_cscope = trio.CancelScope()
nursery.start_soon(sit_in_wait_all_tasks_blocked, watb_cscope)
await trio.testing.wait_all_tasks_blocked()
nursery.start_soon(get_woken_by_host_deadline, watb_cscope)
return "ok"
assert trivial_guest_run(trio_main) == "ok"
@restore_unraisablehook()
def test_guest_warns_if_abandoned() -> None:
# This warning is emitted from the garbage collector. So we have to make
# sure that our abandoned run is garbage. The easiest way to do this is to
# put it into a function, so that we're sure all the local state,
# traceback frames, etc. are garbage once it returns.
def do_abandoned_guest_run() -> None:
async def abandoned_main(in_host: InHost) -> None:
in_host(lambda: 1 / 0)
while True:
await trio.lowlevel.checkpoint()
with pytest.raises(ZeroDivisionError):
trivial_guest_run(abandoned_main)
with pytest.warns(RuntimeWarning, match="Trio guest run got abandoned"):
do_abandoned_guest_run()
gc_collect_harder()
# If you have problems some day figuring out what's holding onto a
# reference to the unrolled_run generator and making this test fail,
# then this might be useful to help track it down. (It assumes you
# also hack start_guest_run so that it does 'global W; W =
# weakref(unrolled_run_gen)'.)
#
# import gc
# print(trio._core._run.W)
# targets = [trio._core._run.W()]
# for i in range(15):
# new_targets = []
# for target in targets:
# new_targets += gc.get_referrers(target)
# new_targets.remove(targets)
# print("#####################")
# print(f"depth {i}: {len(new_targets)}")
# print(new_targets)
# targets = new_targets
with pytest.raises(RuntimeError):
trio.current_time()
def aiotrio_run(
trio_fn: Callable[..., Awaitable[T]],
*,
pass_not_threadsafe: bool = True,
**start_guest_run_kwargs: Any,
) -> T:
loop = asyncio.new_event_loop()
async def aio_main() -> T:
trio_done_fut = loop.create_future()
def trio_done_callback(main_outcome: Outcome[object]) -> None:
print(f"trio_fn finished: {main_outcome!r}")
trio_done_fut.set_result(main_outcome)
if pass_not_threadsafe:
start_guest_run_kwargs["run_sync_soon_not_threadsafe"] = loop.call_soon
trio.lowlevel.start_guest_run(
trio_fn,
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
done_callback=trio_done_callback,
**start_guest_run_kwargs,
)
return (await trio_done_fut).unwrap() # type: ignore[no-any-return]
try:
return loop.run_until_complete(aio_main())
finally:
loop.close()
def test_guest_mode_on_asyncio() -> None:
async def trio_main() -> str:
print("trio_main!")
to_trio, from_aio = trio.open_memory_channel[int](float("inf"))
from_trio: asyncio.Queue[int] = asyncio.Queue()
aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio))
# Make sure we have at least one tick where we don't need to go into
# the thread
await trio.lowlevel.checkpoint()
from_trio.put_nowait(0)
async for n in from_aio:
print(f"trio got: {n}")
from_trio.put_nowait(n + 1)
if n >= 10:
aio_task.cancel()
return "trio-main-done"
raise AssertionError("should never be reached") # pragma: no cover
async def aio_pingpong(
from_trio: asyncio.Queue[int],
to_trio: MemorySendChannel[int],
) -> None:
print("aio_pingpong!")
try:
while True:
n = await from_trio.get()
print(f"aio got: {n}")
to_trio.send_nowait(n + 1)
except asyncio.CancelledError:
raise
except: # pragma: no cover
traceback.print_exc()
raise
assert (
aiotrio_run(
trio_main,
# Not all versions of asyncio we test on can actually be trusted,
# but this test doesn't care about signal handling, and it's
# easier to just avoid the warnings.
host_uses_signal_set_wakeup_fd=True,
)
== "trio-main-done"
)
assert (
aiotrio_run(
trio_main,
# Also check that passing only call_soon_threadsafe works, via the
# fallback path where we use it for everything.
pass_not_threadsafe=False,
host_uses_signal_set_wakeup_fd=True,
)
== "trio-main-done"
)
def test_guest_mode_internal_errors(
monkeypatch: pytest.MonkeyPatch,
recwarn: pytest.WarningsRecorder,
) -> None:
with monkeypatch.context() as m:
async def crash_in_run_loop(in_host: InHost) -> None:
m.setattr("trio._core._run.GLOBAL_RUN_CONTEXT.runner.runq", "HI")
await trio.sleep(1)
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_run_loop)
with monkeypatch.context() as m:
async def crash_in_io(in_host: InHost) -> None:
m.setattr("trio._core._run.TheIOManager.get_events", None)
await trio.lowlevel.checkpoint()
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_io)
with monkeypatch.context() as m:
async def crash_in_worker_thread_io(in_host: InHost) -> None:
t = threading.current_thread()
old_get_events = trio._core._run.TheIOManager.get_events
def bad_get_events(*args: Any) -> object:
if threading.current_thread() is not t:
raise ValueError("oh no!")
else:
return old_get_events(*args)
m.setattr("trio._core._run.TheIOManager.get_events", bad_get_events)
await trio.sleep(1)
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_worker_thread_io)
gc_collect_harder()
def test_guest_mode_ki() -> None:
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
# Check SIGINT in Trio func and in host func
async def trio_main(in_host: InHost) -> None:
with pytest.raises(KeyboardInterrupt):
signal_raise(signal.SIGINT)
# Host SIGINT should get injected into Trio
in_host(partial(signal_raise, signal.SIGINT))
await trio.sleep(10)
with pytest.raises(KeyboardInterrupt) as excinfo:
trivial_guest_run(trio_main)
assert excinfo.value.__context__ is None
# Signal handler should be restored properly on exit
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
# Also check chaining in the case where KI is injected after main exits
final_exc = KeyError("whoa")
async def trio_main_raising(in_host: InHost) -> NoReturn:
in_host(partial(signal_raise, signal.SIGINT))
raise final_exc
with pytest.raises(KeyboardInterrupt) as excinfo:
trivial_guest_run(trio_main_raising)
assert excinfo.value.__context__ is final_exc
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
def test_guest_mode_autojump_clock_threshold_changing() -> None:
# This is super obscure and probably no-one will ever notice, but
# technically mutating the MockClock.autojump_threshold from the host
# should wake up the guest, so let's test it.
clock = trio.testing.MockClock()
DURATION = 120
async def trio_main(in_host: InHost) -> None:
assert trio.current_time() == 0
in_host(lambda: setattr(clock, "autojump_threshold", 0))
await trio.sleep(DURATION)
assert trio.current_time() == DURATION
start = time.monotonic()
trivial_guest_run(trio_main, clock=clock)
end = time.monotonic()
# Should be basically instantaneous, but we'll leave a generous buffer to
# account for any CI weirdness
assert end - start < DURATION / 2
@restore_unraisablehook()
def test_guest_mode_asyncgens() -> None:
import sniffio
record = set()
async def agen(label: str) -> AsyncGenerator[int, None]:
assert sniffio.current_async_library() == label
try:
yield 1
finally:
library = sniffio.current_async_library()
with contextlib.suppress(trio.Cancelled):
await sys.modules[library].sleep(0)
record.add((label, library))
async def iterate_in_aio() -> None:
await agen("asyncio").asend(None)
async def trio_main() -> None:
task = asyncio.ensure_future(iterate_in_aio())
done_evt = trio.Event()
task.add_done_callback(lambda _: done_evt.set())
with trio.fail_after(1):
await done_evt.wait()
await agen("trio").asend(None)
gc_collect_harder()
# Ensure we don't pollute the thread-level context if run under
# an asyncio without contextvars support (3.6)
context = contextvars.copy_context()
context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True)
assert record == {("asyncio", "asyncio"), ("trio", "trio")}
@@ -0,0 +1,266 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Container, Iterable, NoReturn
import attrs
import pytest
from ... import _abc, _core
from .tutil import check_sequence_matches
if TYPE_CHECKING:
from ...lowlevel import Task
@attrs.define(eq=False, slots=False)
class TaskRecorder(_abc.Instrument):
record: list[tuple[str, Task | None]] = attrs.Factory(list)
def before_run(self) -> None:
self.record.append(("before_run", None))
def task_scheduled(self, task: Task) -> None:
self.record.append(("schedule", task))
def before_task_step(self, task: Task) -> None:
assert task is _core.current_task()
self.record.append(("before", task))
def after_task_step(self, task: Task) -> None:
assert task is _core.current_task()
self.record.append(("after", task))
def after_run(self) -> None:
self.record.append(("after_run", None))
def filter_tasks(self, tasks: Container[Task]) -> Iterable[tuple[str, Task | None]]:
for item in self.record:
if item[0] in ("schedule", "before", "after") and item[1] in tasks:
yield item
if item[0] in ("before_run", "after_run"):
yield item
def test_instruments(recwarn: object) -> None:
r1 = TaskRecorder()
r2 = TaskRecorder()
r3 = TaskRecorder()
task = None
# We use a child task for this, because the main task does some extra
# bookkeeping stuff that can leak into the instrument results, and we
# don't want to deal with it.
async def task_fn() -> None:
nonlocal task
task = _core.current_task()
for _ in range(4):
await _core.checkpoint()
# replace r2 with r3, to test that we can manipulate them as we go
_core.remove_instrument(r2)
with pytest.raises(KeyError):
_core.remove_instrument(r2)
# add is idempotent
_core.add_instrument(r3)
_core.add_instrument(r3)
for _ in range(1):
await _core.checkpoint()
async def main() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(task_fn)
_core.run(main, instruments=[r1, r2])
# It sleeps 5 times, so it runs 6 times. Note that checkpoint()
# reschedules the task immediately upon yielding, before the
# after_task_step event fires.
expected = (
[("before_run", None), ("schedule", task)]
+ [("before", task), ("schedule", task), ("after", task)] * 5
+ [("before", task), ("after", task), ("after_run", None)]
)
assert r1.record == r2.record + r3.record
assert task is not None
assert list(r1.filter_tasks([task])) == expected
def test_instruments_interleave() -> None:
tasks = {}
async def two_step1() -> None:
tasks["t1"] = _core.current_task()
await _core.checkpoint()
async def two_step2() -> None:
tasks["t2"] = _core.current_task()
await _core.checkpoint()
async def main() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(two_step1)
nursery.start_soon(two_step2)
r = TaskRecorder()
_core.run(main, instruments=[r])
expected = [
("before_run", None),
("schedule", tasks["t1"]),
("schedule", tasks["t2"]),
{
("before", tasks["t1"]),
("schedule", tasks["t1"]),
("after", tasks["t1"]),
("before", tasks["t2"]),
("schedule", tasks["t2"]),
("after", tasks["t2"]),
},
{
("before", tasks["t1"]),
("after", tasks["t1"]),
("before", tasks["t2"]),
("after", tasks["t2"]),
},
("after_run", None),
]
print(list(r.filter_tasks(tasks.values())))
check_sequence_matches(list(r.filter_tasks(tasks.values())), expected)
def test_null_instrument() -> None:
# undefined instrument methods are skipped
class NullInstrument(_abc.Instrument):
def something_unrelated(self) -> None:
pass # pragma: no cover
async def main() -> None:
await _core.checkpoint()
_core.run(main, instruments=[NullInstrument()])
def test_instrument_before_after_run() -> None:
record = []
class BeforeAfterRun(_abc.Instrument):
def before_run(self) -> None:
record.append("before_run")
def after_run(self) -> None:
record.append("after_run")
async def main() -> None:
pass
_core.run(main, instruments=[BeforeAfterRun()])
assert record == ["before_run", "after_run"]
def test_instrument_task_spawn_exit() -> None:
record = []
class SpawnExitRecorder(_abc.Instrument):
def task_spawned(self, task: Task) -> None:
record.append(("spawned", task))
def task_exited(self, task: Task) -> None:
record.append(("exited", task))
async def main() -> Task:
return _core.current_task()
main_task = _core.run(main, instruments=[SpawnExitRecorder()])
assert ("spawned", main_task) in record
assert ("exited", main_task) in record
# This test also tests having a crash before the initial task is even spawned,
# which is very difficult to handle.
def test_instruments_crash(caplog: pytest.LogCaptureFixture) -> None:
record = []
class BrokenInstrument(_abc.Instrument):
def task_scheduled(self, task: Task) -> NoReturn:
record.append("scheduled")
raise ValueError("oops")
def close(self) -> None:
# Shouldn't be called -- tests that the instrument disabling logic
# works right.
record.append("closed") # pragma: no cover
async def main() -> Task:
record.append("main ran")
return _core.current_task()
r = TaskRecorder()
main_task = _core.run(main, instruments=[r, BrokenInstrument()])
assert record == ["scheduled", "main ran"]
# the TaskRecorder kept going throughout, even though the BrokenInstrument
# was disabled
assert ("after", main_task) in r.record
assert ("after_run", None) in r.record
# And we got a log message
assert caplog.records[0].exc_info is not None
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info
assert exc_type is ValueError
assert str(exc_value) == "oops"
assert "Instrument has been disabled" in caplog.records[0].message
def test_instruments_monkeypatch() -> None:
class NullInstrument(_abc.Instrument):
pass
instrument = NullInstrument()
async def main() -> None:
record: list[Task] = []
# Changing the set of hooks implemented by an instrument after
# it's installed doesn't make them start being called right away
instrument.before_task_step = ( # type: ignore[method-assign]
record.append # type: ignore[assignment] # append is pos-only
)
await _core.checkpoint()
await _core.checkpoint()
assert len(record) == 0
# But if we remove and re-add the instrument, the new hooks are
# picked up
_core.remove_instrument(instrument)
_core.add_instrument(instrument)
await _core.checkpoint()
await _core.checkpoint()
assert record.count(_core.current_task()) == 2
_core.remove_instrument(instrument)
await _core.checkpoint()
await _core.checkpoint()
assert record.count(_core.current_task()) == 2
_core.run(main, instruments=[instrument])
def test_instrument_that_raises_on_getattr() -> None:
class EvilInstrument(_abc.Instrument):
def task_exited(self, task: Task) -> NoReturn:
raise AssertionError("this should never happen") # pragma: no cover
@property
def after_run(self) -> NoReturn:
raise ValueError("oops")
async def main() -> None:
with pytest.raises(ValueError, match="^oops$"):
_core.add_instrument(EvilInstrument())
# Make sure the instrument is fully removed from the per-method lists
runner = _core.current_task()._runner
assert "after_run" not in runner.instruments
assert "task_exited" not in runner.instruments
_core.run(main)
@@ -0,0 +1,480 @@
from __future__ import annotations
import random
import socket as stdlib_socket
from contextlib import suppress
from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar
import pytest
import trio
from ... import _core
from ...testing import assert_checkpoints, wait_all_tasks_blocked
# Cross-platform tests for IO handling
if TYPE_CHECKING:
from collections.abc import Generator
from typing_extensions import ParamSpec
ArgsT = ParamSpec("ArgsT")
def fill_socket(sock: stdlib_socket.socket) -> None:
try:
while True:
sock.send(b"x" * 65536)
except BlockingIOError:
pass
def drain_socket(sock: stdlib_socket.socket) -> None:
try:
while True:
sock.recv(65536)
except BlockingIOError:
pass
WaitSocket = Callable[[stdlib_socket.socket], Awaitable[object]]
SocketPair = Tuple[stdlib_socket.socket, stdlib_socket.socket]
RetT = TypeVar("RetT")
@pytest.fixture
def socketpair() -> Generator[SocketPair, None, None]:
pair = stdlib_socket.socketpair()
for sock in pair:
sock.setblocking(False)
yield pair
for sock in pair:
sock.close()
def also_using_fileno(
fn: Callable[[stdlib_socket.socket | int], RetT],
) -> list[Callable[[stdlib_socket.socket], RetT]]:
def fileno_wrapper(fileobj: stdlib_socket.socket) -> RetT:
return fn(fileobj.fileno())
name = f"<{fn.__name__} on fileno>"
fileno_wrapper.__name__ = fileno_wrapper.__qualname__ = name
return [fn, fileno_wrapper]
# Decorators that feed in different settings for wait_readable / wait_writable
# / notify_closing.
# Note that if you use all three decorators on the same test, it will run all
# N**3 *combinations*
read_socket_test = pytest.mark.parametrize(
"wait_readable",
also_using_fileno(trio.lowlevel.wait_readable),
ids=lambda fn: fn.__name__,
)
write_socket_test = pytest.mark.parametrize(
"wait_writable",
also_using_fileno(trio.lowlevel.wait_writable),
ids=lambda fn: fn.__name__,
)
notify_closing_test = pytest.mark.parametrize(
"notify_closing",
also_using_fileno(trio.lowlevel.notify_closing),
ids=lambda fn: fn.__name__,
)
# XX These tests are all a bit dicey because they can't distinguish between
# wait_on_{read,writ}able blocking the way it should, versus blocking
# momentarily and then immediately resuming.
@read_socket_test
@write_socket_test
async def test_wait_basic(
socketpair: SocketPair,
wait_readable: WaitSocket,
wait_writable: WaitSocket,
) -> None:
a, b = socketpair
# They start out writable()
with assert_checkpoints():
await wait_writable(a)
# But readable() blocks until data arrives
record = []
async def block_on_read() -> None:
try:
with assert_checkpoints():
await wait_readable(a)
except _core.Cancelled:
record.append("cancelled")
else:
record.append("readable")
assert a.recv(10) == b"x"
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_read)
await wait_all_tasks_blocked()
assert record == []
b.send(b"x")
fill_socket(a)
# Now writable will block, but readable won't
with assert_checkpoints():
await wait_readable(b)
record = []
async def block_on_write() -> None:
try:
with assert_checkpoints():
await wait_writable(a)
except _core.Cancelled:
record.append("cancelled")
else:
record.append("writable")
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_write)
await wait_all_tasks_blocked()
assert record == []
drain_socket(b)
# check cancellation
record = []
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_read)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == ["cancelled"]
fill_socket(a)
record = []
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_write)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == ["cancelled"]
@read_socket_test
async def test_double_read(socketpair: SocketPair, wait_readable: WaitSocket) -> None:
a, b = socketpair
# You can't have two tasks trying to read from a socket at the same time
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_readable, a)
await wait_all_tasks_blocked()
with pytest.raises(_core.BusyResourceError):
await wait_readable(a)
nursery.cancel_scope.cancel()
@write_socket_test
async def test_double_write(socketpair: SocketPair, wait_writable: WaitSocket) -> None:
a, b = socketpair
# You can't have two tasks trying to write to a socket at the same time
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_writable, a)
await wait_all_tasks_blocked()
with pytest.raises(_core.BusyResourceError):
await wait_writable(a)
nursery.cancel_scope.cancel()
@read_socket_test
@write_socket_test
@notify_closing_test
async def test_interrupted_by_close(
socketpair: SocketPair,
wait_readable: WaitSocket,
wait_writable: WaitSocket,
notify_closing: Callable[[stdlib_socket.socket], object],
) -> None:
a, b = socketpair
async def reader() -> None:
with pytest.raises(_core.ClosedResourceError):
await wait_readable(a)
async def writer() -> None:
with pytest.raises(_core.ClosedResourceError):
await wait_writable(a)
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(reader)
nursery.start_soon(writer)
await wait_all_tasks_blocked()
notify_closing(a)
@read_socket_test
@write_socket_test
async def test_socket_simultaneous_read_write(
socketpair: SocketPair,
wait_readable: WaitSocket,
wait_writable: WaitSocket,
) -> None:
record: list[str] = []
async def r_task(sock: stdlib_socket.socket) -> None:
await wait_readable(sock)
record.append("r_task")
async def w_task(sock: stdlib_socket.socket) -> None:
await wait_writable(sock)
record.append("w_task")
a, b = socketpair
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(r_task, a)
nursery.start_soon(w_task, a)
await wait_all_tasks_blocked()
assert record == []
b.send(b"x")
await wait_all_tasks_blocked()
assert record == ["r_task"]
drain_socket(b)
await wait_all_tasks_blocked()
assert record == ["r_task", "w_task"]
@read_socket_test
@write_socket_test
async def test_socket_actual_streaming(
socketpair: SocketPair,
wait_readable: WaitSocket,
wait_writable: WaitSocket,
) -> None:
a, b = socketpair
# Use a small send buffer on one of the sockets to increase the chance of
# getting partial writes
a.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_SNDBUF, 10000)
N = 1000000 # 1 megabyte
MAX_CHUNK = 65536
results: dict[str, int] = {}
async def sender(sock: stdlib_socket.socket, seed: int, key: str) -> None:
r = random.Random(seed)
sent = 0
while sent < N:
print("sent", sent)
chunk = bytearray(r.randrange(MAX_CHUNK))
while chunk:
with assert_checkpoints():
await wait_writable(sock)
this_chunk_size = sock.send(chunk)
sent += this_chunk_size
del chunk[:this_chunk_size]
sock.shutdown(stdlib_socket.SHUT_WR)
results[key] = sent
async def receiver(sock: stdlib_socket.socket, key: str) -> None:
received = 0
while True:
print("received", received)
with assert_checkpoints():
await wait_readable(sock)
this_chunk_size = len(sock.recv(MAX_CHUNK))
if not this_chunk_size:
break
received += this_chunk_size
results[key] = received
async with _core.open_nursery() as nursery:
nursery.start_soon(sender, a, 0, "send_a")
nursery.start_soon(sender, b, 1, "send_b")
nursery.start_soon(receiver, a, "recv_a")
nursery.start_soon(receiver, b, "recv_b")
assert results["send_a"] == results["recv_b"]
assert results["send_b"] == results["recv_a"]
async def test_notify_closing_on_invalid_object() -> None:
# It should either be a no-op (generally on Unix, where we don't know
# which fds are valid), or an OSError (on Windows, where we currently only
# support sockets, so we have to do some validation to figure out whether
# it's a socket or a regular handle).
got_oserror = False
got_no_error = False
try:
trio.lowlevel.notify_closing(-1)
except OSError:
got_oserror = True
else:
got_no_error = True
assert got_oserror or got_no_error
async def test_wait_on_invalid_object() -> None:
# We definitely want to raise an error everywhere if you pass in an
# invalid fd to wait_*
for wait in [trio.lowlevel.wait_readable, trio.lowlevel.wait_writable]:
with stdlib_socket.socket() as s:
fileno = s.fileno()
# We just closed the socket and don't do anything else in between, so
# we can be confident that the fileno hasn't be reassigned.
with pytest.raises(
OSError,
match=r"^\[\w+ \d+] (Bad file descriptor|An operation was attempted on something that is not a socket)$",
):
await wait(fileno)
async def test_io_manager_statistics() -> None:
def check(*, expected_readers: int, expected_writers: int) -> None:
statistics = _core.current_statistics()
print(statistics)
iostats = statistics.io_statistics
if iostats.backend == "epoll" or iostats.backend == "windows":
assert iostats.tasks_waiting_read == expected_readers
assert iostats.tasks_waiting_write == expected_writers
else:
assert iostats.backend == "kqueue"
assert iostats.tasks_waiting == expected_readers + expected_writers
a1, b1 = stdlib_socket.socketpair()
a2, b2 = stdlib_socket.socketpair()
a3, b3 = stdlib_socket.socketpair()
for sock in [a1, b1, a2, b2, a3, b3]:
sock.setblocking(False)
with a1, b1, a2, b2, a3, b3:
# let the call_soon_task settle down
await wait_all_tasks_blocked()
# 1 for call_soon_task
check(expected_readers=1, expected_writers=0)
# We want:
# - one socket with a writer blocked
# - two sockets with a reader blocked
# - a socket with both blocked
fill_socket(a1)
fill_socket(a3)
async with _core.open_nursery() as nursery:
nursery.start_soon(_core.wait_writable, a1)
nursery.start_soon(_core.wait_readable, a2)
nursery.start_soon(_core.wait_readable, b2)
nursery.start_soon(_core.wait_writable, a3)
nursery.start_soon(_core.wait_readable, a3)
await wait_all_tasks_blocked()
# +1 for call_soon_task
check(expected_readers=3 + 1, expected_writers=2)
nursery.cancel_scope.cancel()
# 1 for call_soon_task
check(expected_readers=1, expected_writers=0)
async def test_can_survive_unnotified_close() -> None:
# An "unnotified" close is when the user closes an fd/socket/handle
# directly, without calling notify_closing first. This should never happen
# -- users should call notify_closing before closing things. But, just in
# case they don't, we would still like to avoid exploding.
#
# Acceptable behaviors:
# - wait_* never return, but can be cancelled cleanly
# - wait_* exit cleanly
# - wait_* raise an OSError
#
# Not acceptable:
# - getting stuck in an uncancellable state
# - TrioInternalError blowing up the whole run
#
# This test exercises some tricky "unnotified close" scenarios, to make
# sure we get the "acceptable" behaviors.
async def allow_OSError(
async_func: Callable[ArgsT, Awaitable[object]],
*args: ArgsT.args,
**kwargs: ArgsT.kwargs,
) -> None:
with suppress(OSError):
await async_func(*args, **kwargs)
with stdlib_socket.socket() as s:
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
await wait_all_tasks_blocked()
s.close()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# We hit different paths on Windows depending on whether we close the last
# handle to the object (which produces a LOCAL_CLOSE notification and
# wakes up wait_readable), or only close one of the handles (which leaves
# wait_readable pending until cancelled).
with stdlib_socket.socket() as s, s.dup() as s2: # noqa: F841
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
await wait_all_tasks_blocked()
s.close()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# A more elaborate case, with two tasks waiting. On windows and epoll,
# the two tasks get muxed together onto a single underlying wait
# operation. So when they're cancelled, there's a brief moment where one
# of the tasks is cancelled but the other isn't, so we try to re-issue the
# underlying wait operation. But here, the handle we were going to use to
# do that has been pulled out from under our feet... so test that we can
# survive this.
a, b = stdlib_socket.socketpair()
with a, b, a.dup() as a2:
a.setblocking(False)
b.setblocking(False)
fill_socket(a)
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
await wait_all_tasks_blocked()
a.close()
nursery.cancel_scope.cancel()
# A similar case, but now the single-task-wakeup happens due to I/O
# arriving, not a cancellation, so the operation gets re-issued from
# handle_io context rather than abort context.
a, b = stdlib_socket.socketpair()
with a, b, a.dup() as a2:
print(f"a={a.fileno()}, b={b.fileno()}, a2={a2.fileno()}")
a.setblocking(False)
b.setblocking(False)
fill_socket(a)
e = trio.Event()
# We want to wait for the kernel to process the wakeup on 'a', if any.
# But depending on the platform, we might not get a wakeup on 'a'. So
# we put one task to sleep waiting on 'a', and we put a second task to
# sleep waiting on 'a2', with the idea that the 'a2' notification will
# definitely arrive, and when it does then we can assume that whatever
# notification was going to arrive for 'a' has also arrived.
async def wait_readable_a2_then_set() -> None:
await trio.lowlevel.wait_readable(a2)
e.set()
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
nursery.start_soon(wait_readable_a2_then_set)
await wait_all_tasks_blocked()
a.close()
b.send(b"x")
# Make sure that the wakeup has been received and everything has
# settled before cancelling the wait_writable.
await e.wait()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
@@ -0,0 +1,517 @@
from __future__ import annotations
import contextlib
import inspect
import signal
import threading
from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator
import outcome
import pytest
from trio.testing import RaisesGroup
try:
from async_generator import async_generator, yield_
except ImportError: # pragma: no cover
async_generator = yield_ = None
from ... import _core
from ..._abc import Instrument
from ..._timeouts import sleep
from ..._util import signal_raise
from ...testing import wait_all_tasks_blocked
if TYPE_CHECKING:
from ..._core import Abort, RaiseCancelT
def ki_self() -> None:
signal_raise(signal.SIGINT)
def test_ki_self() -> None:
with pytest.raises(KeyboardInterrupt):
ki_self()
async def test_ki_enabled() -> None:
# Regular tasks aren't KI-protected
assert not _core.currently_ki_protected()
# Low-level call-soon callbacks are KI-protected
token = _core.current_trio_token()
record = []
def check() -> None:
record.append(_core.currently_ki_protected())
token.run_sync_soon(check)
await wait_all_tasks_blocked()
assert record == [True]
@_core.enable_ki_protection
def protected() -> None:
assert _core.currently_ki_protected()
unprotected()
@_core.disable_ki_protection
def unprotected() -> None:
assert not _core.currently_ki_protected()
protected()
@_core.enable_ki_protection
async def aprotected() -> None:
assert _core.currently_ki_protected()
await aunprotected()
@_core.disable_ki_protection
async def aunprotected() -> None:
assert not _core.currently_ki_protected()
await aprotected()
# make sure that the decorator here overrides the automatic manipulation
# that start_soon() does:
async with _core.open_nursery() as nursery:
nursery.start_soon(aprotected)
nursery.start_soon(aunprotected)
@_core.enable_ki_protection
def gen_protected() -> Iterator[None]:
assert _core.currently_ki_protected()
yield
for _ in gen_protected():
pass
@_core.disable_ki_protection
def gen_unprotected() -> Iterator[None]:
assert not _core.currently_ki_protected()
yield
for _ in gen_unprotected():
pass
# This used to be broken due to
#
# https://bugs.python.org/issue29590
#
# Specifically, after a coroutine is resumed with .throw(), then the stack
# makes it look like the immediate caller is the function that called
# .throw(), not the actual caller. So child() here would have a caller deep in
# the guts of the run loop, and always be protected, even when it shouldn't
# have been. (Solution: we don't use .throw() anymore.)
async def test_ki_enabled_after_yield_briefly() -> None:
@_core.enable_ki_protection
async def protected() -> None:
await child(True)
@_core.disable_ki_protection
async def unprotected() -> None:
await child(False)
async def child(expected: bool) -> None:
import traceback
traceback.print_stack()
assert _core.currently_ki_protected() == expected
await _core.checkpoint()
traceback.print_stack()
assert _core.currently_ki_protected() == expected
await protected()
await unprotected()
# This also used to be broken due to
# https://bugs.python.org/issue29590
async def test_generator_based_context_manager_throw() -> None:
@contextlib.contextmanager
@_core.enable_ki_protection
def protected_manager() -> Iterator[None]:
assert _core.currently_ki_protected()
try:
yield
finally:
assert _core.currently_ki_protected()
with protected_manager():
assert not _core.currently_ki_protected()
with pytest.raises(KeyError):
# This is the one that used to fail
with protected_manager():
raise KeyError
# the async_generator package isn't typed, hence all the type: ignores
@pytest.mark.skipif(async_generator is None, reason="async_generator not installed")
async def test_async_generator_agen_protection() -> None:
@_core.enable_ki_protection
@async_generator # type: ignore[misc] # untyped generator
async def agen_protected1() -> None:
assert _core.currently_ki_protected()
try:
await yield_()
finally:
assert _core.currently_ki_protected()
@_core.disable_ki_protection
@async_generator # type: ignore[misc] # untyped generator
async def agen_unprotected1() -> None:
assert not _core.currently_ki_protected()
try:
await yield_()
finally:
assert not _core.currently_ki_protected()
# Swap the order of the decorators:
@async_generator # type: ignore[misc] # untyped generator
@_core.enable_ki_protection
async def agen_protected2() -> None:
assert _core.currently_ki_protected()
try:
await yield_()
finally:
assert _core.currently_ki_protected()
@async_generator # type: ignore[misc] # untyped generator
@_core.disable_ki_protection
async def agen_unprotected2() -> None:
assert not _core.currently_ki_protected()
try:
await yield_()
finally:
assert not _core.currently_ki_protected()
await _check_agen(agen_protected1)
await _check_agen(agen_protected2)
await _check_agen(agen_unprotected1)
await _check_agen(agen_unprotected2)
async def test_native_agen_protection() -> None:
# Native async generators
@_core.enable_ki_protection
async def agen_protected() -> AsyncIterator[None]:
assert _core.currently_ki_protected()
try:
yield
finally:
assert _core.currently_ki_protected()
@_core.disable_ki_protection
async def agen_unprotected() -> AsyncIterator[None]:
assert not _core.currently_ki_protected()
try:
yield
finally:
assert not _core.currently_ki_protected()
await _check_agen(agen_protected)
await _check_agen(agen_unprotected)
async def _check_agen(agen_fn: Callable[[], AsyncIterator[None]]) -> None:
async for _ in agen_fn():
assert not _core.currently_ki_protected()
# asynccontextmanager insists that the function passed must itself be an
# async gen function, not a wrapper around one
if inspect.isasyncgenfunction(agen_fn):
async with contextlib.asynccontextmanager(agen_fn)():
assert not _core.currently_ki_protected()
# Another case that's tricky due to:
# https://bugs.python.org/issue29590
with pytest.raises(KeyError):
async with contextlib.asynccontextmanager(agen_fn)():
raise KeyError
# Test the case where there's no magic local anywhere in the call stack
def test_ki_disabled_out_of_context() -> None:
assert _core.currently_ki_protected()
def test_ki_disabled_in_del() -> None:
def nestedfunction() -> bool:
return _core.currently_ki_protected()
def __del__() -> None:
assert _core.currently_ki_protected()
assert nestedfunction()
@_core.disable_ki_protection
def outerfunction() -> None:
assert not _core.currently_ki_protected()
assert not nestedfunction()
__del__()
__del__()
outerfunction()
assert nestedfunction()
def test_ki_protection_works() -> None:
async def sleeper(name: str, record: set[str]) -> None:
try:
while True:
await _core.checkpoint()
except _core.Cancelled:
record.add(name + " ok")
async def raiser(name: str, record: set[str]) -> None:
try:
# os.kill runs signal handlers before returning, so we don't need
# to worry that the handler will be delayed
print("killing, protection =", _core.currently_ki_protected())
ki_self()
except KeyboardInterrupt:
print("raised!")
# Make sure we aren't getting cancelled as well as siginted
await _core.checkpoint()
record.add(name + " raise ok")
raise
else:
print("didn't raise!")
# If we didn't raise (b/c protected), then we *should* get
# cancelled at the next opportunity
try:
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
except _core.Cancelled:
record.add(name + " cancel ok")
# simulated control-C during raiser, which is *unprotected*
print("check 1")
record_set: set[str] = set()
async def check_unprotected_kill() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper, "s1", record_set)
nursery.start_soon(sleeper, "s2", record_set)
nursery.start_soon(raiser, "r1", record_set)
# raises inside a nursery, so the KeyboardInterrupt is wrapped in an ExceptionGroup
with RaisesGroup(KeyboardInterrupt):
_core.run(check_unprotected_kill)
assert record_set == {"s1 ok", "s2 ok", "r1 raise ok"}
# simulated control-C during raiser, which is *protected*, so the KI gets
# delivered to the main task instead
print("check 2")
record_set = set()
async def check_protected_kill() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper, "s1", record_set)
nursery.start_soon(sleeper, "s2", record_set)
nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record_set)
# __aexit__ blocks, and then receives the KI
# raises inside a nursery, so the KeyboardInterrupt is wrapped in an ExceptionGroup
with RaisesGroup(KeyboardInterrupt):
_core.run(check_protected_kill)
assert record_set == {"s1 ok", "s2 ok", "r1 cancel ok"}
# kill at last moment still raises (run_sync_soon until it raises an
# error, then kill)
print("check 3")
async def check_kill_during_shutdown() -> None:
token = _core.current_trio_token()
def kill_during_shutdown() -> None:
assert _core.currently_ki_protected()
try:
token.run_sync_soon(kill_during_shutdown)
except _core.RunFinishedError:
# it's too late for regular handling! handle this!
print("kill! kill!")
ki_self()
token.run_sync_soon(kill_during_shutdown)
# no nurseries involved, so the KeyboardInterrupt isn't wrapped
with pytest.raises(KeyboardInterrupt):
_core.run(check_kill_during_shutdown)
# KI arrives very early, before main is even spawned
print("check 4")
class InstrumentOfDeath(Instrument):
def before_run(self) -> None:
ki_self()
async def main_1() -> None:
await _core.checkpoint()
# no nurseries involved, so the KeyboardInterrupt isn't wrapped
with pytest.raises(KeyboardInterrupt):
_core.run(main_1, instruments=[InstrumentOfDeath()])
# checkpoint_if_cancelled notices pending KI
print("check 5")
@_core.enable_ki_protection
async def main_2() -> None:
assert _core.currently_ki_protected()
ki_self()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint_if_cancelled()
_core.run(main_2)
# KI arrives while main task is not abortable, b/c already scheduled
print("check 6")
@_core.enable_ki_protection
async def main_3() -> None:
assert _core.currently_ki_protected()
ki_self()
await _core.cancel_shielded_checkpoint()
await _core.cancel_shielded_checkpoint()
await _core.cancel_shielded_checkpoint()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
_core.run(main_3)
# KI arrives while main task is not abortable, b/c refuses to be aborted
print("check 7")
@_core.enable_ki_protection
async def main_4() -> None:
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
def abort(_: RaiseCancelT) -> Abort:
_core.reschedule(task, outcome.Value(1))
return _core.Abort.FAILED
assert await _core.wait_task_rescheduled(abort) == 1
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
_core.run(main_4)
# KI delivered via slow abort
print("check 8")
@_core.enable_ki_protection
async def main_5() -> None:
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
def abort(raise_cancel: RaiseCancelT) -> Abort:
result = outcome.capture(raise_cancel)
_core.reschedule(task, result)
return _core.Abort.FAILED
with pytest.raises(KeyboardInterrupt):
assert await _core.wait_task_rescheduled(abort)
await _core.checkpoint()
_core.run(main_5)
# KI arrives just before main task exits, so the run_sync_soon machinery
# is still functioning and will accept the callback to deliver the KI, but
# by the time the callback is actually run, main has exited and can't be
# aborted.
print("check 9")
@_core.enable_ki_protection
async def main_6() -> None:
ki_self()
with pytest.raises(KeyboardInterrupt):
_core.run(main_6)
print("check 10")
# KI in unprotected code, with
# restrict_keyboard_interrupt_to_checkpoints=True
record_list = []
async def main_7() -> None:
# We're not KI protected...
assert not _core.currently_ki_protected()
ki_self()
# ...but even after the KI, we keep running uninterrupted...
record_list.append("ok")
# ...until we hit a checkpoint:
with pytest.raises(KeyboardInterrupt):
await sleep(10)
_core.run(main_7, restrict_keyboard_interrupt_to_checkpoints=True)
assert record_list == ["ok"]
record_list = []
# Exact same code raises KI early if we leave off the argument, doesn't
# even reach the record.append call:
with pytest.raises(KeyboardInterrupt):
_core.run(main_7)
assert record_list == []
# KI arrives while main task is inside a cancelled cancellation scope
# the KeyboardInterrupt should take priority
print("check 11")
@_core.enable_ki_protection
async def main_8() -> None:
assert _core.currently_ki_protected()
with _core.CancelScope() as cancel_scope:
cancel_scope.cancel()
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
ki_self()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
_core.run(main_8)
def test_ki_is_good_neighbor() -> None:
# in the unlikely event someone overwrites our signal handler, we leave
# the overwritten one be
try:
orig = signal.getsignal(signal.SIGINT)
def my_handler(signum: object, frame: object) -> None: # pragma: no cover
pass
async def main() -> None:
signal.signal(signal.SIGINT, my_handler)
_core.run(main)
assert signal.getsignal(signal.SIGINT) is my_handler
finally:
signal.signal(signal.SIGINT, orig)
# Regression test for #461
# don't know if _active not being visible is a problem
def test_ki_with_broken_threads() -> None:
thread = threading.main_thread()
# scary!
original = threading._active[thread.ident] # type: ignore[attr-defined]
# put this in a try finally so we don't have a chance of cascading a
# breakage down to everything else
try:
del threading._active[thread.ident] # type: ignore[attr-defined]
@_core.enable_ki_protection
async def inner() -> None:
assert signal.getsignal(signal.SIGINT) != signal.default_int_handler
_core.run(inner)
finally:
threading._active[thread.ident] = original # type: ignore[attr-defined]
@@ -0,0 +1,118 @@
import pytest
from trio import run
from trio.lowlevel import RunVar, RunVarToken
from ... import _core
# scary runvar tests
def test_runvar_smoketest() -> None:
t1 = RunVar[str]("test1")
t2 = RunVar[str]("test2", default="catfish")
assert repr(t1) == "<RunVar name='test1'>"
async def first_check() -> None:
with pytest.raises(LookupError):
t1.get()
t1.set("swordfish")
assert t1.get() == "swordfish"
assert t2.get() == "catfish"
assert t2.get(default="eel") == "eel"
t2.set("goldfish")
assert t2.get() == "goldfish"
assert t2.get(default="tuna") == "goldfish"
async def second_check() -> None:
with pytest.raises(LookupError):
t1.get()
assert t2.get() == "catfish"
run(first_check)
run(second_check)
def test_runvar_resetting() -> None:
t1 = RunVar[str]("test1")
t2 = RunVar[str]("test2", default="dogfish")
t3 = RunVar[str]("test3")
async def reset_check() -> None:
token = t1.set("moonfish")
assert t1.get() == "moonfish"
t1.reset(token)
with pytest.raises(TypeError):
t1.reset(None) # type: ignore[arg-type]
with pytest.raises(LookupError):
t1.get()
token2 = t2.set("catdogfish")
assert t2.get() == "catdogfish"
t2.reset(token2)
assert t2.get() == "dogfish"
with pytest.raises(ValueError, match="^token has already been used$"):
t2.reset(token2)
token3 = t3.set("basculin")
assert t3.get() == "basculin"
with pytest.raises(ValueError, match="^token is not for us$"):
t1.reset(token3)
run(reset_check)
def test_runvar_sync() -> None:
t1 = RunVar[str]("test1")
async def sync_check() -> None:
async def task1() -> None:
t1.set("plaice")
assert t1.get() == "plaice"
async def task2(tok: RunVarToken[str]) -> None:
t1.reset(tok)
with pytest.raises(LookupError):
t1.get()
t1.set("haddock")
async with _core.open_nursery() as n:
token = t1.set("cod")
assert t1.get() == "cod"
n.start_soon(task1)
await _core.wait_all_tasks_blocked()
assert t1.get() == "plaice"
n.start_soon(task2, token)
await _core.wait_all_tasks_blocked()
assert t1.get() == "haddock"
run(sync_check)
def test_accessing_runvar_outside_run_call_fails() -> None:
t1 = RunVar[str]("test1")
with pytest.raises(RuntimeError):
t1.set("asdf")
with pytest.raises(RuntimeError):
t1.get()
async def get_token() -> RunVarToken[str]:
return t1.set("ok")
token = run(get_token)
with pytest.raises(RuntimeError):
t1.reset(token)
@@ -0,0 +1,175 @@
import time
from math import inf
import pytest
from trio import sleep
from ... import _core
from .. import wait_all_tasks_blocked
from .._mock_clock import MockClock
from .tutil import slow
def test_mock_clock() -> None:
REAL_NOW = 123.0
c = MockClock()
c._real_clock = lambda: REAL_NOW
repr(c) # smoke test
assert c.rate == 0
assert c.current_time() == 0
c.jump(1.2)
assert c.current_time() == 1.2
with pytest.raises(ValueError, match="^time can't go backwards$"):
c.jump(-1)
assert c.current_time() == 1.2
assert c.deadline_to_sleep_time(1.1) == 0
assert c.deadline_to_sleep_time(1.2) == 0
assert c.deadline_to_sleep_time(1.3) > 999999
with pytest.raises(ValueError, match="^rate must be >= 0$"):
c.rate = -1
assert c.rate == 0
c.rate = 2
assert c.current_time() == 1.2
REAL_NOW += 1
assert c.current_time() == 3.2
assert c.deadline_to_sleep_time(3.1) == 0
assert c.deadline_to_sleep_time(3.2) == 0
assert c.deadline_to_sleep_time(4.2) == 0.5
c.rate = 0.5
assert c.current_time() == 3.2
assert c.deadline_to_sleep_time(3.1) == 0
assert c.deadline_to_sleep_time(3.2) == 0
assert c.deadline_to_sleep_time(4.2) == 2.0
c.jump(0.8)
assert c.current_time() == 4.0
REAL_NOW += 1
assert c.current_time() == 4.5
c2 = MockClock(rate=3)
assert c2.rate == 3
assert c2.current_time() < 10
async def test_mock_clock_autojump(mock_clock: MockClock) -> None:
assert mock_clock.autojump_threshold == inf
mock_clock.autojump_threshold = 0
assert mock_clock.autojump_threshold == 0
real_start = time.perf_counter()
virtual_start = _core.current_time()
for i in range(10):
print(f"sleeping {10 * i} seconds")
await sleep(10 * i)
print("woke up!")
assert virtual_start + 10 * i == _core.current_time()
virtual_start = _core.current_time()
real_duration = time.perf_counter() - real_start
print(f"Slept {10 * sum(range(10))} seconds in {real_duration} seconds")
assert real_duration < 1
mock_clock.autojump_threshold = 0.02
t = _core.current_time()
# this should wake up before the autojump threshold triggers, so time
# shouldn't change
await wait_all_tasks_blocked()
assert t == _core.current_time()
# this should too
await wait_all_tasks_blocked(0.01)
assert t == _core.current_time()
# set up a situation where the autojump task is blocked for a long long
# time, to make sure that cancel-and-adjust-threshold logic is working
mock_clock.autojump_threshold = 10000
await wait_all_tasks_blocked()
mock_clock.autojump_threshold = 0
# if the above line didn't take affect immediately, then this would be
# bad:
await sleep(100000)
async def test_mock_clock_autojump_interference(mock_clock: MockClock) -> None:
mock_clock.autojump_threshold = 0.02
mock_clock2 = MockClock()
# messing with the autojump threshold of a clock that isn't actually
# installed in the run loop shouldn't do anything.
mock_clock2.autojump_threshold = 0.01
# if the autojump_threshold of 0.01 were in effect, then the next line
# would block forever, as the autojump task kept waking up to try to
# jump the clock.
await wait_all_tasks_blocked(0.015)
# but the 0.02 limit does apply
await sleep(100000)
def test_mock_clock_autojump_preset() -> None:
# Check that we can set the autojump_threshold before the clock is
# actually in use, and it gets picked up
mock_clock = MockClock(autojump_threshold=0.1)
mock_clock.autojump_threshold = 0.01
real_start = time.perf_counter()
_core.run(sleep, 10000, clock=mock_clock)
assert time.perf_counter() - real_start < 1
async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(
mock_clock: MockClock,
) -> None:
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with the default cushion=0.
mock_clock.autojump_threshold = 0
record = []
async def sleeper() -> None:
await sleep(100)
record.append("yawn")
async def waiter() -> None:
await wait_all_tasks_blocked()
record.append("waiter woke")
await sleep(1000)
record.append("waiter done")
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper)
nursery.start_soon(waiter)
assert record == ["waiter woke", "yawn", "waiter done"]
@slow
async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(
mock_clock: MockClock,
) -> None:
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with a non-zero cushion.
mock_clock.autojump_threshold = 0
record = []
async def sleeper() -> None:
await sleep(100)
record.append("yawn")
async def waiter() -> None:
await wait_all_tasks_blocked(1)
record.append("waiter done")
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper)
nursery.start_soon(waiter)
assert record == ["waiter done", "yawn"]
@@ -0,0 +1,384 @@
from __future__ import annotations
import re
from typing import TypeVar
import pytest
import trio
from trio.lowlevel import (
add_parking_lot_breaker,
current_task,
remove_parking_lot_breaker,
)
from trio.testing import Matcher, RaisesGroup
from ... import _core
from ...testing import wait_all_tasks_blocked
from .._parking_lot import ParkingLot
from .tutil import check_sequence_matches
T = TypeVar("T")
async def test_parking_lot_basic() -> None:
record = []
async def waiter(i: int, lot: ParkingLot) -> None:
record.append(f"sleep {i}")
await lot.park()
record.append(f"wake {i}")
async with _core.open_nursery() as nursery:
lot = ParkingLot()
assert not lot
assert len(lot) == 0
assert lot.statistics().tasks_waiting == 0
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
assert len(record) == 3
assert bool(lot)
assert len(lot) == 3
assert lot.statistics().tasks_waiting == 3
lot.unpark_all()
assert lot.statistics().tasks_waiting == 0
await wait_all_tasks_blocked()
assert len(record) == 6
check_sequence_matches(
record,
[{"sleep 0", "sleep 1", "sleep 2"}, {"wake 0", "wake 1", "wake 2"}],
)
async with _core.open_nursery() as nursery:
record = []
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
assert len(record) == 3
for _ in range(3):
lot.unpark()
await wait_all_tasks_blocked()
# 1-by-1 wakeups are strict FIFO
assert record == [
"sleep 0",
"sleep 1",
"sleep 2",
"wake 0",
"wake 1",
"wake 2",
]
# It's legal (but a no-op) to try and unpark while there's nothing parked
lot.unpark()
lot.unpark(count=1)
lot.unpark(count=100)
# Check unpark with count
async with _core.open_nursery() as nursery:
record = []
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
lot.unpark(count=2)
await wait_all_tasks_blocked()
check_sequence_matches(
record,
["sleep 0", "sleep 1", "sleep 2", {"wake 0", "wake 1"}],
)
lot.unpark_all()
with pytest.raises(
ValueError,
match=r"^Cannot pop a non-integer number of tasks\.$",
):
lot.unpark(count=1.5)
async def cancellable_waiter(
name: T,
lot: ParkingLot,
scopes: dict[T, _core.CancelScope],
record: list[str],
) -> None:
with _core.CancelScope() as scope:
scopes[name] = scope
record.append(f"sleep {name}")
try:
await lot.park()
except _core.Cancelled:
record.append(f"cancelled {name}")
else:
record.append(f"wake {name}")
async def test_parking_lot_cancel() -> None:
record: list[str] = []
scopes: dict[int, _core.CancelScope] = {}
async with _core.open_nursery() as nursery:
lot = ParkingLot()
nursery.start_soon(cancellable_waiter, 1, lot, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
scopes[2].cancel()
await wait_all_tasks_blocked()
assert len(record) == 4
lot.unpark_all()
await wait_all_tasks_blocked()
assert len(record) == 6
check_sequence_matches(
record,
["sleep 1", "sleep 2", "sleep 3", "cancelled 2", {"wake 1", "wake 3"}],
)
async def test_parking_lot_repark() -> None:
record: list[str] = []
scopes: dict[int, _core.CancelScope] = {}
lot1 = ParkingLot()
lot2 = ParkingLot()
with pytest.raises(TypeError):
lot1.repark([]) # type: ignore[arg-type]
async with _core.open_nursery() as nursery:
nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot1, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
assert len(lot1) == 3
lot1.repark(lot2)
assert len(lot1) == 2
assert len(lot2) == 1
lot2.unpark_all()
await wait_all_tasks_blocked()
assert len(record) == 4
assert record == ["sleep 1", "sleep 2", "sleep 3", "wake 1"]
lot1.repark_all(lot2)
assert len(lot1) == 0
assert len(lot2) == 2
scopes[2].cancel()
await wait_all_tasks_blocked()
assert len(lot2) == 1
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"cancelled 2",
]
lot2.unpark_all()
await wait_all_tasks_blocked()
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"cancelled 2",
"wake 3",
]
async def test_parking_lot_repark_with_count() -> None:
record: list[str] = []
scopes: dict[int, _core.CancelScope] = {}
lot1 = ParkingLot()
lot2 = ParkingLot()
async with _core.open_nursery() as nursery:
nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot1, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
assert len(lot1) == 3
assert len(lot2) == 0
lot1.repark(lot2, count=2)
assert len(lot1) == 1
assert len(lot2) == 2
while lot2:
lot2.unpark()
await wait_all_tasks_blocked()
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"wake 2",
]
lot1.unpark_all()
async def dummy_task(
task_status: _core.TaskStatus[_core.Task] = trio.TASK_STATUS_IGNORED,
) -> None:
task_status.started(_core.current_task())
await trio.sleep_forever()
async def test_parking_lot_breaker_basic() -> None:
"""Test basic functionality for breaking lots."""
lot = ParkingLot()
task = current_task()
# defaults to current task
lot.break_lot()
assert lot.broken_by == [task]
# breaking the lot again with the same task appends another copy in `broken_by`
lot.break_lot()
assert lot.broken_by == [task, task]
# trying to park in broken lot errors
broken_by_str = re.escape(str([task, task]))
with pytest.raises(
_core.BrokenResourceError,
match=f"^Attempted to park in parking lot broken by {broken_by_str}$",
):
await lot.park()
async def test_parking_lot_break_parking_tasks() -> None:
"""Checks that tasks currently waiting to park raise an error when the breaker exits."""
async def bad_parker(lot: ParkingLot, scope: _core.CancelScope) -> None:
add_parking_lot_breaker(current_task(), lot)
with scope:
await trio.sleep_forever()
lot = ParkingLot()
cs = _core.CancelScope()
# check that parked task errors
with RaisesGroup(
Matcher(_core.BrokenResourceError, match="^Parking lot broken by"),
):
async with _core.open_nursery() as nursery:
nursery.start_soon(bad_parker, lot, cs)
await wait_all_tasks_blocked()
nursery.start_soon(lot.park)
await wait_all_tasks_blocked()
cs.cancel()
async def test_parking_lot_breaker_registration() -> None:
lot = ParkingLot()
task = current_task()
with pytest.raises(
RuntimeError,
match="Attempted to remove task as breaker for a lot it is not registered for",
):
remove_parking_lot_breaker(task, lot)
# check that a task can be registered as breaker for the same lot multiple times
add_parking_lot_breaker(task, lot)
add_parking_lot_breaker(task, lot)
remove_parking_lot_breaker(task, lot)
remove_parking_lot_breaker(task, lot)
with pytest.raises(
RuntimeError,
match="Attempted to remove task as breaker for a lot it is not registered for",
):
remove_parking_lot_breaker(task, lot)
# registering a task as breaker on an already broken lot is fine
lot.break_lot()
child_task = None
async with trio.open_nursery() as nursery:
child_task = await nursery.start(dummy_task)
add_parking_lot_breaker(child_task, lot)
nursery.cancel_scope.cancel()
assert lot.broken_by == [task, child_task]
# manually breaking a lot with an already exited task is fine
lot = ParkingLot()
lot.break_lot(child_task)
assert lot.broken_by == [child_task]
async def test_parking_lot_breaker_rebreak() -> None:
lot = ParkingLot()
task = current_task()
lot.break_lot()
# breaking an already broken lot with a different task is allowed
# The nursery is only to create a task we can pass to lot.break_lot
async with trio.open_nursery() as nursery:
child_task = await nursery.start(dummy_task)
lot.break_lot(child_task)
nursery.cancel_scope.cancel()
assert lot.broken_by == [task, child_task]
async def test_parking_lot_multiple_breakers_exit() -> None:
# register multiple tasks as lot breakers, then have them all exit
lot = ParkingLot()
async with trio.open_nursery() as nursery:
child_task1 = await nursery.start(dummy_task)
child_task2 = await nursery.start(dummy_task)
child_task3 = await nursery.start(dummy_task)
add_parking_lot_breaker(child_task1, lot)
add_parking_lot_breaker(child_task2, lot)
add_parking_lot_breaker(child_task3, lot)
nursery.cancel_scope.cancel()
# I think the order is guaranteed currently, but doesn't hurt to be safe.
assert set(lot.broken_by) == {child_task1, child_task2, child_task3}
async def test_parking_lot_breaker_register_exited_task() -> None:
lot = ParkingLot()
child_task = None
async with trio.open_nursery() as nursery:
child_task = await nursery.start(dummy_task)
nursery.cancel_scope.cancel()
# trying to register an exited task as lot breaker errors
with pytest.raises(
trio.BrokenResourceError,
match="^Attempted to add already exited task as lot breaker.$",
):
add_parking_lot_breaker(child_task, lot)
async def test_parking_lot_break_itself() -> None:
"""Break a parking lot, where the breakee is parked.
Doing this is weird, but should probably be supported.
"""
async def return_me_and_park(
lot: ParkingLot,
*,
task_status: _core.TaskStatus[_core.Task] = trio.TASK_STATUS_IGNORED,
) -> None:
task_status.started(_core.current_task())
await lot.park()
lot = ParkingLot()
with RaisesGroup(
Matcher(_core.BrokenResourceError, match="^Parking lot broken by"),
):
async with _core.open_nursery() as nursery:
child_task = await nursery.start(return_me_and_park, lot)
lot.break_lot(child_task)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,195 @@
from __future__ import annotations
import threading
import time
from contextlib import contextmanager
from queue import Queue
from typing import TYPE_CHECKING, Iterator, NoReturn
import pytest
from .. import _thread_cache
from .._thread_cache import ThreadCache, start_thread_soon
from .tutil import gc_collect_harder, slow
if TYPE_CHECKING:
from outcome import Outcome
def test_thread_cache_basics() -> None:
q: Queue[Outcome[object]] = Queue()
def fn() -> NoReturn:
raise RuntimeError("hi")
def deliver(outcome: Outcome[object]) -> None:
q.put(outcome)
start_thread_soon(fn, deliver)
outcome = q.get()
with pytest.raises(RuntimeError, match="hi"):
outcome.unwrap()
def test_thread_cache_deref() -> None:
res = [False]
class del_me:
def __call__(self) -> int:
return 42
def __del__(self) -> None:
res[0] = True
q: Queue[Outcome[int]] = Queue()
def deliver(outcome: Outcome[int]) -> None:
q.put(outcome)
start_thread_soon(del_me(), deliver)
outcome = q.get()
assert outcome.unwrap() == 42
gc_collect_harder()
assert res[0]
@slow
def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None:
# We know that no-one else is using the thread cache, so if we keep
# submitting new jobs the instant the previous one is finished, we should
# keep getting the same thread over and over. This tests both that the
# thread cache is LIFO, and that threads can be assigned new work *before*
# deliver exits.
# Make sure there are a few threads running, so if we weren't LIFO then we
# could grab the wrong one.
q: Queue[Outcome[object]] = Queue()
COUNT = 5
for _ in range(COUNT):
start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result))
for _ in range(COUNT):
q.get().unwrap()
seen_threads = set()
done = threading.Event()
def deliver(n: int, _: object) -> None:
print(n)
seen_threads.add(threading.current_thread())
if n == 0:
done.set()
else:
start_thread_soon(lambda: None, lambda _: deliver(n - 1, _))
start_thread_soon(lambda: None, lambda _: deliver(5, _))
done.wait()
assert len(seen_threads) == 1
@slow
def test_idle_threads_exit(monkeypatch: pytest.MonkeyPatch) -> None:
# Temporarily set the idle timeout to something tiny, to speed up the
# test. (But non-zero, so that the worker loop will at least yield the
# CPU.)
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
q: Queue[threading.Thread] = Queue()
start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread()))
seen_thread = q.get()
# Since the idle timeout is 0, after sleeping for 1 second, the thread
# should have exited
time.sleep(1)
assert not seen_thread.is_alive()
@contextmanager
def _join_started_threads() -> Iterator[None]:
before = frozenset(threading.enumerate())
try:
yield
finally:
for thread in threading.enumerate():
if thread not in before:
thread.join(timeout=1.0)
assert not thread.is_alive()
def test_race_between_idle_exit_and_job_assignment(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# This is a lock where the first few times you try to acquire it with a
# timeout, it waits until the lock is available and then pretends to time
# out. Using this in our thread cache implementation causes the following
# sequence:
#
# 1. start_thread_soon grabs the worker thread, assigns it a job, and
# releases its lock.
# 2. The worker thread wakes up (because the lock has been released), but
# the JankyLock lies to it and tells it that the lock timed out. So the
# worker thread tries to exit.
# 3. The worker thread checks for the race between exiting and being
# assigned a job, and discovers that it *is* in the process of being
# assigned a job, so it loops around and tries to acquire the lock
# again.
# 4. Eventually the JankyLock admits that the lock is available, and
# everything proceeds as normal.
class JankyLock:
def __init__(self) -> None:
self._lock = threading.Lock()
self._counter = 3
def acquire(self, timeout: int = -1) -> bool:
got_it = self._lock.acquire(timeout=timeout)
if timeout == -1:
return True
elif got_it:
if self._counter > 0:
self._counter -= 1
self._lock.release()
return False
return True
else:
return False
def release(self) -> None:
self._lock.release()
monkeypatch.setattr(_thread_cache, "Lock", JankyLock)
with _join_started_threads():
tc = ThreadCache()
done = threading.Event()
tc.start_thread_soon(lambda: None, lambda _: done.set())
done.wait()
# Let's kill the thread we started, so it doesn't hang around until the
# test suite finishes. Doesn't really do any harm, but it can be confusing
# to see it in debug output.
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
tc.start_thread_soon(lambda: None, lambda _: None)
def test_raise_in_deliver(capfd: pytest.CaptureFixture[str]) -> None:
seen_threads = set()
def track_threads() -> None:
seen_threads.add(threading.current_thread())
def deliver(_: object) -> NoReturn:
done.set()
raise RuntimeError("don't do this")
done = threading.Event()
start_thread_soon(track_threads, deliver)
done.wait()
done = threading.Event()
start_thread_soon(track_threads, lambda _: done.set())
done.wait()
assert len(seen_threads) == 1
err = capfd.readouterr().err
assert "don't do this" in err
assert "delivering result" in err
@@ -0,0 +1,13 @@
import pytest
from .tutil import check_sequence_matches
def test_check_sequence_matches() -> None:
check_sequence_matches([1, 2, 3], [1, 2, 3])
with pytest.raises(AssertionError):
check_sequence_matches([1, 3, 2], [1, 2, 3])
check_sequence_matches([1, 2, 3, 4], [1, {2, 3}, 4])
check_sequence_matches([1, 3, 2, 4], [1, {2, 3}, 4])
with pytest.raises(AssertionError):
check_sequence_matches([1, 2, 4, 3], [1, {2, 3}, 4])
@@ -0,0 +1,154 @@
from __future__ import annotations
import itertools
import pytest
from ... import _core
from ...testing import assert_checkpoints, wait_all_tasks_blocked
pytestmark = pytest.mark.filterwarnings(
"ignore:.*UnboundedQueue:trio.TrioDeprecationWarning",
)
async def test_UnboundedQueue_basic() -> None:
q: _core.UnboundedQueue[str | int | None] = _core.UnboundedQueue()
q.put_nowait("hi")
assert await q.get_batch() == ["hi"]
with pytest.raises(_core.WouldBlock):
q.get_batch_nowait()
q.put_nowait(1)
q.put_nowait(2)
q.put_nowait(3)
assert q.get_batch_nowait() == [1, 2, 3]
assert q.empty()
assert q.qsize() == 0
q.put_nowait(None)
assert not q.empty()
assert q.qsize() == 1
stats = q.statistics()
assert stats.qsize == 1
assert stats.tasks_waiting == 0
# smoke test
repr(q)
async def test_UnboundedQueue_blocking() -> None:
record = []
q = _core.UnboundedQueue[int]()
async def get_batch_consumer() -> None:
while True:
batch = await q.get_batch()
assert batch
record.append(batch)
async def aiter_consumer() -> None:
async for batch in q:
assert batch
record.append(batch)
for consumer in (get_batch_consumer, aiter_consumer):
record.clear()
async with _core.open_nursery() as nursery:
nursery.start_soon(consumer)
await _core.wait_all_tasks_blocked()
stats = q.statistics()
assert stats.qsize == 0
assert stats.tasks_waiting == 1
q.put_nowait(10)
q.put_nowait(11)
await _core.wait_all_tasks_blocked()
q.put_nowait(12)
await _core.wait_all_tasks_blocked()
assert record == [[10, 11], [12]]
nursery.cancel_scope.cancel()
async def test_UnboundedQueue_fairness() -> None:
q = _core.UnboundedQueue[int]()
# If there's no-one else around, we can put stuff in and take it out
# again, no problem
q.put_nowait(1)
q.put_nowait(2)
assert q.get_batch_nowait() == [1, 2]
result = None
async def get_batch(q: _core.UnboundedQueue[int]) -> None:
nonlocal result
result = await q.get_batch()
# But if someone else is waiting to read, then they get dibs
async with _core.open_nursery() as nursery:
nursery.start_soon(get_batch, q)
await _core.wait_all_tasks_blocked()
q.put_nowait(3)
q.put_nowait(4)
with pytest.raises(_core.WouldBlock):
q.get_batch_nowait()
assert result == [3, 4]
# If two tasks are trying to read, they alternate
record = []
async def reader(name: str) -> None:
while True:
record.append((name, await q.get_batch()))
async with _core.open_nursery() as nursery:
nursery.start_soon(reader, "a")
await _core.wait_all_tasks_blocked()
nursery.start_soon(reader, "b")
await _core.wait_all_tasks_blocked()
for i in range(20):
q.put_nowait(i)
await _core.wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == list(zip(itertools.cycle("ab"), [[i] for i in range(20)]))
async def test_UnboundedQueue_trivial_yields() -> None:
q = _core.UnboundedQueue[None]()
q.put_nowait(None)
with assert_checkpoints():
await q.get_batch()
q.put_nowait(None)
with assert_checkpoints():
async for _ in q: # pragma: no branch
break
async def test_UnboundedQueue_no_spurious_wakeups() -> None:
# If we have two tasks waiting, and put two items into the queue... then
# only one task wakes up
record = []
async def getter(q: _core.UnboundedQueue[int], i: int) -> None:
got = await q.get_batch()
record.append((i, got))
async with _core.open_nursery() as nursery:
q = _core.UnboundedQueue[int]()
nursery.start_soon(getter, q, 1)
await wait_all_tasks_blocked()
nursery.start_soon(getter, q, 2)
await wait_all_tasks_blocked()
for i in range(10):
q.put_nowait(i)
await wait_all_tasks_blocked()
assert record == [(1, list(range(10)))]
nursery.cancel_scope.cancel()
@@ -0,0 +1,299 @@
from __future__ import annotations
import os
import sys
import tempfile
from contextlib import contextmanager
from typing import TYPE_CHECKING
from unittest.mock import create_autospec
import pytest
on_windows = os.name == "nt"
# Mark all the tests in this file as being windows-only
pytestmark = pytest.mark.skipif(not on_windows, reason="windows only")
assert (
sys.platform == "win32" or not TYPE_CHECKING
) # Skip type checking when not on Windows
from ... import _core, sleep
from ...testing import wait_all_tasks_blocked
from .tutil import gc_collect_harder, restore_unraisablehook, slow
if TYPE_CHECKING:
from collections.abc import Generator
from io import BufferedWriter
if on_windows:
from .._windows_cffi import (
INVALID_HANDLE_VALUE,
FileFlags,
Handle,
ffi,
kernel32,
raise_winerror,
)
def test_winerror(monkeypatch: pytest.MonkeyPatch) -> None:
mock = create_autospec(ffi.getwinerror)
monkeypatch.setattr(ffi, "getwinerror", mock)
# Returning none = no error, should not happen.
mock.return_value = None
with pytest.raises(RuntimeError, match=r"^No error set\?$"):
raise_winerror()
mock.assert_called_once_with()
mock.reset_mock()
with pytest.raises(RuntimeError, match=r"^No error set\?$"):
raise_winerror(38)
mock.assert_called_once_with(38)
mock.reset_mock()
mock.return_value = (12, "test error")
with pytest.raises(
OSError,
match=r"^\[WinError 12\] test error: 'file_1' -> 'file_2'$",
) as exc:
raise_winerror(filename="file_1", filename2="file_2")
mock.assert_called_once_with()
mock.reset_mock()
assert exc.value.winerror == 12
assert exc.value.strerror == "test error"
assert exc.value.filename == "file_1"
assert exc.value.filename2 == "file_2"
# With an explicit number passed in, it overrides what getwinerror() returns.
with pytest.raises(
OSError,
match=r"^\[WinError 18\] test error: 'a/file' -> 'b/file'$",
) as exc:
raise_winerror(18, filename="a/file", filename2="b/file")
mock.assert_called_once_with(18)
mock.reset_mock()
assert exc.value.winerror == 18
assert exc.value.strerror == "test error"
assert exc.value.filename == "a/file"
assert exc.value.filename2 == "b/file"
# The undocumented API that this is testing should be changed to stop using
# UnboundedQueue (or just removed until we have time to redo it), but until
# then we filter out the warning.
@pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning")
async def test_completion_key_listen() -> None:
from .. import _io_windows
async def post(key: int) -> None:
iocp = Handle(ffi.cast("HANDLE", _core.current_iocp()))
for i in range(10):
print("post", i)
if i % 3 == 0:
await _core.checkpoint()
success = kernel32.PostQueuedCompletionStatus(iocp, i, key, ffi.NULL)
assert success
with _core.monitor_completion_key() as (key, queue):
async with _core.open_nursery() as nursery:
nursery.start_soon(post, key)
i = 0
print("loop")
async for batch in queue: # pragma: no branch
print("got some", batch)
for info in batch:
assert isinstance(info, _io_windows.CompletionKeyEventInfo)
assert info.lpOverlapped == 0
assert info.dwNumberOfBytesTransferred == i
i += 1
if i == 10:
break
print("end loop")
async def test_readinto_overlapped() -> None:
data = b"1" * 1024 + b"2" * 1024 + b"3" * 1024 + b"4" * 1024
buffer = bytearray(len(data))
with tempfile.TemporaryDirectory() as tdir:
tfile = os.path.join(tdir, "numbers.txt")
with open( # noqa: ASYNC230 # This is a test, synchronous is ok
tfile,
"wb",
) as fp:
fp.write(data)
fp.flush()
rawname = tfile.encode("utf-16le") + b"\0\0"
rawname_buf = ffi.from_buffer(rawname)
handle = kernel32.CreateFileW(
ffi.cast("LPCWSTR", rawname_buf),
FileFlags.GENERIC_READ,
FileFlags.FILE_SHARE_READ,
ffi.NULL, # no security attributes
FileFlags.OPEN_EXISTING,
FileFlags.FILE_FLAG_OVERLAPPED,
ffi.NULL, # no template file
)
if handle == INVALID_HANDLE_VALUE: # pragma: no cover
raise_winerror()
try:
with memoryview(buffer) as buffer_view:
async def read_region(start: int, end: int) -> None:
await _core.readinto_overlapped(
handle,
buffer_view[start:end],
start,
)
_core.register_with_iocp(handle)
async with _core.open_nursery() as nursery:
for start in range(0, 4096, 512):
nursery.start_soon(read_region, start, start + 512)
assert buffer == data
with pytest.raises((BufferError, TypeError)):
await _core.readinto_overlapped(handle, b"immutable")
finally:
kernel32.CloseHandle(handle)
@contextmanager
def pipe_with_overlapped_read() -> Generator[tuple[BufferedWriter, int], None, None]:
import msvcrt
from asyncio.windows_utils import pipe
read_handle, write_handle = pipe(overlapped=(True, False))
try:
write_fd = msvcrt.open_osfhandle(write_handle, 0)
yield os.fdopen(write_fd, "wb", closefd=False), read_handle
finally:
kernel32.CloseHandle(Handle(ffi.cast("HANDLE", read_handle)))
kernel32.CloseHandle(Handle(ffi.cast("HANDLE", write_handle)))
@restore_unraisablehook()
def test_forgot_to_register_with_iocp() -> None:
with pipe_with_overlapped_read() as (write_fp, read_handle):
with write_fp:
write_fp.write(b"test\n")
left_run_yet = False
async def main() -> None:
target = bytearray(1)
try:
async with _core.open_nursery() as nursery:
nursery.start_soon(
_core.readinto_overlapped,
read_handle,
target,
name="xyz",
)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
finally:
# Run loop is exited without unwinding running tasks, so
# we don't get here until the main() coroutine is GC'ed
assert left_run_yet
with pytest.raises(_core.TrioInternalError) as exc_info:
_core.run(main)
left_run_yet = True
assert "Failed to cancel overlapped I/O in xyz " in str(exc_info.value)
assert "forget to call register_with_iocp()?" in str(exc_info.value)
# Make sure the Nursery.__del__ assertion about dangling children
# gets put with the correct test
del exc_info
gc_collect_harder()
@slow
async def test_too_late_to_cancel() -> None:
import time
with pipe_with_overlapped_read() as (write_fp, read_handle):
_core.register_with_iocp(read_handle)
target = bytearray(6)
async with _core.open_nursery() as nursery:
# Start an async read in the background
nursery.start_soon(_core.readinto_overlapped, read_handle, target)
await wait_all_tasks_blocked()
# Synchronous write to the other end of the pipe
with write_fp:
write_fp.write(b"test1\ntest2\n")
# Note: not trio.sleep! We're making sure the OS level
# ReadFile completes, before Trio has a chance to execute
# another checkpoint and notice it completed.
time.sleep(1) # noqa: ASYNC251
nursery.cancel_scope.cancel()
assert target[:6] == b"test1\n"
# Do another I/O to make sure we've actually processed the
# fallback completion that was posted when CancelIoEx failed.
assert await _core.readinto_overlapped(read_handle, target) == 6
assert target[:6] == b"test2\n"
def test_lsp_that_hooks_select_gives_good_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from .. import _io_windows
from .._windows_cffi import CData, WSAIoctls, _handle
def patched_get_underlying(
sock: int | CData,
*,
which: int = WSAIoctls.SIO_BASE_HANDLE,
) -> CData:
if hasattr(sock, "fileno"): # pragma: no branch
sock = sock.fileno()
if which == WSAIoctls.SIO_BSP_HANDLE_SELECT:
return _handle(sock + 1)
else:
return _handle(sock)
monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying)
with pytest.raises(
RuntimeError,
match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ",
):
_core.run(sleep, 0)
def test_lsp_that_completely_hides_base_socket_gives_good_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns
# self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns
# self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to
# make sure we get an error rather than an infinite loop.
from .. import _io_windows
from .._windows_cffi import CData, WSAIoctls, _handle
def patched_get_underlying(
sock: int | CData,
*,
which: int = WSAIoctls.SIO_BASE_HANDLE,
) -> CData:
if hasattr(sock, "fileno"): # pragma: no branch
sock = sock.fileno()
if which == WSAIoctls.SIO_BASE_HANDLE:
raise OSError("nope")
else:
return _handle(sock)
monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying)
with pytest.raises(
RuntimeError,
match="SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a diff",
):
_core.run(sleep, 0)
@@ -0,0 +1,117 @@
# Utilities for testing
from __future__ import annotations
import asyncio
import gc
import os
import socket as stdlib_socket
import sys
import warnings
from contextlib import closing, contextmanager
from typing import TYPE_CHECKING, TypeVar
import pytest
# See trio/_tests/conftest.py for the other half of this
from trio._tests.pytest_plugin import RUN_SLOW
if TYPE_CHECKING:
from collections.abc import Generator, Iterable, Sequence
slow = pytest.mark.skipif(not RUN_SLOW, reason="use --run-slow to run slow tests")
T = TypeVar("T")
try:
s = stdlib_socket.socket(stdlib_socket.AF_INET6, stdlib_socket.SOCK_STREAM, 0)
except OSError: # pragma: no cover
# Some systems don't even support creating an IPv6 socket, let alone
# binding it. (ex: Linux with 'ipv6.disable=1' in the kernel command line)
# We don't have any of those in our CI, and there's nothing that gets
# tested _only_ if can_create_ipv6 = False, so we'll just no-cover this.
can_create_ipv6 = False
can_bind_ipv6 = False
else:
can_create_ipv6 = True
with s:
try:
s.bind(("::1", 0))
except OSError: # pragma: no cover # since support for 3.7 was removed
can_bind_ipv6 = False
else:
can_bind_ipv6 = True
creates_ipv6 = pytest.mark.skipif(not can_create_ipv6, reason="need IPv6")
binds_ipv6 = pytest.mark.skipif(not can_bind_ipv6, reason="need IPv6")
def gc_collect_harder() -> None:
# In the test suite we sometimes want to call gc.collect() to make sure
# that any objects with noisy __del__ methods (e.g. unawaited coroutines)
# get collected before we continue, so their noise doesn't leak into
# unrelated tests.
#
# On PyPy, coroutine objects (for example) can survive at least 1 round of
# garbage collection, because executing their __del__ method to print the
# warning can cause them to be resurrected. So we call collect a few times
# to make sure.
for _ in range(5):
gc.collect()
# Some of our tests need to leak coroutines, and thus trigger the
# "RuntimeWarning: coroutine '...' was never awaited" message. This context
# manager should be used anywhere this happens to hide those messages, because
# when expected they're clutter.
@contextmanager
def ignore_coroutine_never_awaited_warnings() -> Generator[None, None, None]:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="coroutine '.*' was never awaited")
try:
yield
finally:
# Make sure to trigger any coroutine __del__ methods now, before
# we leave the context manager.
gc_collect_harder()
def _noop(*args: object, **kwargs: object) -> None:
pass # pragma: no cover
@contextmanager
def restore_unraisablehook() -> Generator[None, None, None]:
sys.unraisablehook, prev = sys.__unraisablehook__, sys.unraisablehook
try:
yield
finally:
sys.unraisablehook = prev
# Used to check sequences that might have some elements out of order.
# Example usage:
# The sequences [1, 2.1, 2.2, 3] and [1, 2.2, 2.1, 3] are both
# matched by the template [1, {2.1, 2.2}, 3]
def check_sequence_matches(seq: Sequence[T], template: Iterable[T | set[T]]) -> None:
i = 0
for pattern in template:
if not isinstance(pattern, set):
pattern = {pattern}
got = set(seq[i : i + len(pattern)])
assert got == pattern
i += len(got)
# https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=246350
skip_if_fbsd_pipes_broken = pytest.mark.skipif(
sys.platform != "win32" # prevent mypy from complaining about missing uname
and hasattr(os, "uname")
and os.uname().sysname == "FreeBSD"
and os.uname().release[:4] < "12.2",
reason="hangs on FreeBSD 12.1 and earlier, due to FreeBSD bug #246350",
)
def create_asyncio_future_in_new_loop() -> asyncio.Future[object]:
with closing(asyncio.new_event_loop()) as loop:
return loop.create_future()
@@ -0,0 +1,76 @@
"""Test variadic generic typing for Nursery.start[_soon]()."""
from typing import Awaitable, Callable
from trio import TASK_STATUS_IGNORED, Nursery, TaskStatus
async def task_0() -> None: ...
async def task_1a(value: int) -> None: ...
async def task_1b(value: str) -> None: ...
async def task_2a(a: int, b: str) -> None: ...
async def task_2b(a: str, b: int) -> None: ...
async def task_2c(a: str, b: int, optional: bool = False) -> None: ...
async def task_requires_kw(a: int, *, b: bool) -> None: ...
async def task_startable_1(
a: str,
*,
task_status: TaskStatus[bool] = TASK_STATUS_IGNORED,
) -> None: ...
async def task_startable_2(
a: str,
b: float,
*,
task_status: TaskStatus[bool] = TASK_STATUS_IGNORED,
) -> None: ...
async def task_requires_start(*, task_status: TaskStatus[str]) -> None:
"""Check a function requiring start() to be used."""
async def task_pos_or_kw(value: str, task_status: TaskStatus[int]) -> None:
"""Check a function which doesn't use the *-syntax works."""
def check_start_soon(nursery: Nursery) -> None:
"""start_soon() functionality."""
nursery.start_soon(task_0)
nursery.start_soon(task_1a) # type: ignore
nursery.start_soon(task_2b) # type: ignore
nursery.start_soon(task_0, 45) # type: ignore
nursery.start_soon(task_1a, 32)
nursery.start_soon(task_1b, 32) # type: ignore
nursery.start_soon(task_1a, "abc") # type: ignore
nursery.start_soon(task_1b, "abc")
nursery.start_soon(task_2b, "abc") # type: ignore
nursery.start_soon(task_2a, 38, "46")
nursery.start_soon(task_2c, "abc", 12, True)
nursery.start_soon(task_2c, "abc", 12)
task_2c_cast: Callable[[str, int], Awaitable[object]] = (
task_2c # The assignment makes it work.
)
nursery.start_soon(task_2c_cast, "abc", 12)
nursery.start_soon(task_requires_kw, 12, True) # type: ignore
# Tasks following the start() API can be made to work.
nursery.start_soon(task_startable_1, "cdf")
@@ -0,0 +1,48 @@
from __future__ import annotations
from typing import Sequence, overload
import trio
from typing_extensions import assert_type
async def sleep_sort(values: Sequence[float]) -> list[float]:
return [1]
async def has_optional(arg: int | None = None) -> int:
return 5
@overload
async def foo_overloaded(arg: int) -> str: ...
@overload
async def foo_overloaded(arg: str) -> int: ...
async def foo_overloaded(arg: int | str) -> int | str:
if isinstance(arg, str):
return 5
return "hello"
v = trio.run(
sleep_sort,
(1, 3, 5, 2, 4),
clock=trio.testing.MockClock(autojump_threshold=0),
)
assert_type(v, "list[float]")
trio.run(sleep_sort, ["hi", "there"]) # type: ignore[arg-type]
trio.run(sleep_sort) # type: ignore[arg-type]
r = trio.run(has_optional)
assert_type(r, int)
r = trio.run(has_optional, 5)
trio.run(has_optional, 7, 8) # type: ignore[arg-type]
trio.run(has_optional, "hello") # type: ignore[arg-type]
assert_type(trio.run(foo_overloaded, 5), str)
assert_type(trio.run(foo_overloaded, ""), int)
@@ -0,0 +1,295 @@
from __future__ import annotations
import ctypes
import ctypes.util
import sys
import traceback
from functools import partial
from itertools import count
from threading import Lock, Thread
from typing import Any, Callable, Generic, TypeVar
import outcome
RetT = TypeVar("RetT")
def _to_os_thread_name(name: str) -> bytes:
# ctypes handles the trailing \00
return name.encode("ascii", errors="replace")[:15]
# used to construct the method used to set os thread name, or None, depending on platform.
# called once on import
def get_os_thread_name_func() -> Callable[[int | None, str], None] | None:
def namefunc(
setname: Callable[[int, bytes], int],
ident: int | None,
name: str,
) -> None:
# Thread.ident is None "if it has not been started". Unclear if that can happen
# with current usage.
if ident is not None: # pragma: no cover
setname(ident, _to_os_thread_name(name))
# namefunc on Mac also takes an ident, even if pthread_setname_np doesn't/can't use it
# so the caller don't need to care about platform.
def darwin_namefunc(
setname: Callable[[bytes], int],
ident: int | None,
name: str,
) -> None:
# I don't know if Mac can rename threads that hasn't been started, but default
# to no to be on the safe side.
if ident is not None: # pragma: no cover
setname(_to_os_thread_name(name))
# find the pthread library
# this will fail on windows and musl
libpthread_path = ctypes.util.find_library("pthread")
if not libpthread_path:
# musl includes pthread functions directly in libc.so
# (but note that find_library("c") does not work on musl,
# see: https://github.com/python/cpython/issues/65821)
# so try that library instead
# if it doesn't exist, CDLL() will fail below
libpthread_path = "libc.so"
# Sometimes windows can find the path, but gives a permission error when
# accessing it. Catching a wider exception in case of more esoteric errors.
# https://github.com/python-trio/trio/issues/2688
try:
libpthread = ctypes.CDLL(libpthread_path)
except Exception: # pragma: no cover
return None
# get the setname method from it
# afaik this should never fail
pthread_setname_np = getattr(libpthread, "pthread_setname_np", None)
if pthread_setname_np is None: # pragma: no cover
return None
# specify function prototype
pthread_setname_np.restype = ctypes.c_int
# on mac OSX pthread_setname_np does not take a thread id,
# it only lets threads name themselves, which is not a problem for us.
# Just need to make sure to call it correctly
if sys.platform == "darwin":
pthread_setname_np.argtypes = [ctypes.c_char_p]
return partial(darwin_namefunc, pthread_setname_np)
# otherwise assume linux parameter conventions. Should also work on *BSD
pthread_setname_np.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
return partial(namefunc, pthread_setname_np)
# construct os thread name method
set_os_thread_name = get_os_thread_name_func()
# The "thread cache" is a simple unbounded thread pool, i.e., it automatically
# spawns as many threads as needed to handle all the requests its given. Its
# only purpose is to cache worker threads so that they don't have to be
# started from scratch every time we want to delegate some work to a thread.
# It's expected that some higher-level code will track how many threads are in
# use to avoid overwhelming the system (e.g. the limiter= argument to
# trio.to_thread.run_sync).
#
# To maximize sharing, there's only one thread cache per process, even if you
# have multiple calls to trio.run.
#
# Guarantees:
#
# It's safe to call start_thread_soon simultaneously from
# multiple threads.
#
# Idle threads are chosen in LIFO order, i.e. we *don't* spread work evenly
# over all threads. Instead we try to let some threads do most of the work
# while others sit idle as much as possible. Compared to FIFO, this has better
# memory cache behavior, and it makes it easier to detect when we have too
# many threads, so idle ones can exit.
#
# This code assumes that 'dict' has the following properties:
#
# - __setitem__, __delitem__, and popitem are all thread-safe and atomic with
# respect to each other. This is guaranteed by the GIL.
#
# - popitem returns the most-recently-added item (i.e., __setitem__ + popitem
# give you a LIFO queue). This relies on dicts being insertion-ordered, like
# they are in py36+.
# How long a thread will idle waiting for new work before gives up and exits.
# This value is pretty arbitrary; I don't think it matters too much.
IDLE_TIMEOUT = 10 # seconds
name_counter = count()
class WorkerThread(Generic[RetT]):
def __init__(self, thread_cache: ThreadCache) -> None:
self._job: (
tuple[
Callable[[], RetT],
Callable[[outcome.Outcome[RetT]], object],
str | None,
]
| None
) = None
self._thread_cache = thread_cache
# This Lock is used in an unconventional way.
#
# "Unlocked" means we have a pending job that's been assigned to us;
# "locked" means that we don't.
#
# Initially we have no job, so it starts out in locked state.
self._worker_lock = Lock()
self._worker_lock.acquire()
self._default_name = f"Trio thread {next(name_counter)}"
self._thread = Thread(target=self._work, name=self._default_name, daemon=True)
if set_os_thread_name:
set_os_thread_name(self._thread.ident, self._default_name)
self._thread.start()
def _handle_job(self) -> None:
# Handle job in a separate method to ensure user-created
# objects are cleaned up in a consistent manner.
assert self._job is not None
fn, deliver, name = self._job
self._job = None
# set name
if name is not None:
self._thread.name = name
if set_os_thread_name:
set_os_thread_name(self._thread.ident, name)
result = outcome.capture(fn)
# reset name if it was changed
if name is not None:
self._thread.name = self._default_name
if set_os_thread_name:
set_os_thread_name(self._thread.ident, self._default_name)
# Tell the cache that we're available to be assigned a new
# job. We do this *before* calling 'deliver', so that if
# 'deliver' triggers a new job, it can be assigned to us
# instead of spawning a new thread.
self._thread_cache._idle_workers[self] = None
try:
deliver(result)
except BaseException as e:
print("Exception while delivering result of thread", file=sys.stderr)
traceback.print_exception(type(e), e, e.__traceback__)
def _work(self) -> None:
while True:
if self._worker_lock.acquire(timeout=IDLE_TIMEOUT):
# We got a job
self._handle_job()
else:
# Timeout acquiring lock, so we can probably exit. But,
# there's a race condition: we might be assigned a job *just*
# as we're about to exit. So we have to check.
try:
del self._thread_cache._idle_workers[self]
except KeyError:
# Someone else removed us from the idle worker queue, so
# they must be in the process of assigning us a job - loop
# around and wait for it.
continue
else:
# We successfully removed ourselves from the idle
# worker queue, so no more jobs are incoming; it's safe to
# exit.
return
class ThreadCache:
def __init__(self) -> None:
self._idle_workers: dict[WorkerThread[Any], None] = {}
def start_thread_soon(
self,
fn: Callable[[], RetT],
deliver: Callable[[outcome.Outcome[RetT]], object],
name: str | None = None,
) -> None:
worker: WorkerThread[RetT]
try:
worker, _ = self._idle_workers.popitem()
except KeyError:
worker = WorkerThread(self)
worker._job = (fn, deliver, name)
worker._worker_lock.release()
THREAD_CACHE = ThreadCache()
def start_thread_soon(
fn: Callable[[], RetT],
deliver: Callable[[outcome.Outcome[RetT]], object],
name: str | None = None,
) -> None:
"""Runs ``deliver(outcome.capture(fn))`` in a worker thread.
Generally ``fn`` does some blocking work, and ``deliver`` delivers the
result back to whoever is interested.
This is a low-level, no-frills interface, very similar to using
`threading.Thread` to spawn a thread directly. The main difference is
that this function tries to reuse threads when possible, so it can be
a bit faster than `threading.Thread`.
Worker threads have the `~threading.Thread.daemon` flag set, which means
that if your main thread exits, worker threads will automatically be
killed. If you want to make sure that your ``fn`` runs to completion, then
you should make sure that the main thread remains alive until ``deliver``
is called.
It is safe to call this function simultaneously from multiple threads.
Args:
fn (sync function): Performs arbitrary blocking work.
deliver (sync function): Takes the `outcome.Outcome` of ``fn``, and
delivers it. *Must not block.*
Because worker threads are cached and reused for multiple calls, neither
function should mutate thread-level state, like `threading.local` objects
or if they do, they should be careful to revert their changes before
returning.
Note:
The split between ``fn`` and ``deliver`` serves two purposes. First,
it's convenient, since most callers need something like this anyway.
Second, it avoids a small race condition that could cause too many
threads to be spawned. Consider a program that wants to run several
jobs sequentially on a thread, so the main thread submits a job, waits
for it to finish, submits another job, etc. In theory, this program
should only need one worker thread. But what could happen is:
1. Worker thread: First job finishes, and calls ``deliver``.
2. Main thread: receives notification that the job finished, and calls
``start_thread_soon``.
3. Main thread: sees that no worker threads are marked idle, so spawns
a second worker thread.
4. Original worker thread: marks itself as idle.
To avoid this, threads mark themselves as idle *before* calling
``deliver``.
Is this potential extra thread a major problem? Maybe not, but it's
easy enough to avoid, and we figure that if the user is trying to
limit how many threads they're using then it's polite to respect that.
"""
THREAD_CACHE.start_thread_soon(fn, deliver, name)
@@ -0,0 +1,287 @@
"""These are the only functions that ever yield back to the task runner."""
from __future__ import annotations
import enum
import types
from typing import TYPE_CHECKING, Any, Callable, NoReturn
import attrs
import outcome
from . import _run
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from ._run import Task
# Helper for the bottommost 'yield'. You can't use 'yield' inside an async
# function, but you can inside a generator, and if you decorate your generator
# with @types.coroutine, then it's even awaitable. However, it's still not a
# real async function: in particular, it isn't recognized by
# inspect.iscoroutinefunction, and it doesn't trigger the unawaited coroutine
# tracking machinery. Since our traps are public APIs, we make them real async
# functions, and then this helper takes care of the actual yield:
@types.coroutine
def _async_yield(obj: Any) -> Any: # type: ignore[misc]
return (yield obj)
# This class object is used as a singleton.
# Not exported in the trio._core namespace, but imported directly by _run.
class CancelShieldedCheckpoint:
pass
async def cancel_shielded_checkpoint() -> None:
"""Introduce a schedule point, but not a cancel point.
This is *not* a :ref:`checkpoint <checkpoints>`, but it is half of a
checkpoint, and when combined with :func:`checkpoint_if_cancelled` it can
make a full checkpoint.
Equivalent to (but potentially more efficient than)::
with trio.CancelScope(shield=True):
await trio.lowlevel.checkpoint()
"""
(await _async_yield(CancelShieldedCheckpoint)).unwrap()
# Return values for abort functions
class Abort(enum.Enum):
""":class:`enum.Enum` used as the return value from abort functions.
See :func:`wait_task_rescheduled` for details.
.. data:: SUCCEEDED
FAILED
"""
SUCCEEDED = 1
FAILED = 2
# Not exported in the trio._core namespace, but imported directly by _run.
@attrs.frozen(slots=False)
class WaitTaskRescheduled:
abort_func: Callable[[RaiseCancelT], Abort]
RaiseCancelT: TypeAlias = Callable[[], NoReturn]
# Should always return the type a Task "expects", unless you willfully reschedule it
# with a bad value.
async def wait_task_rescheduled(abort_func: Callable[[RaiseCancelT], Abort]) -> Any:
"""Put the current task to sleep, with cancellation support.
This is the lowest-level API for blocking in Trio. Every time a
:class:`~trio.lowlevel.Task` blocks, it does so by calling this function
(usually indirectly via some higher-level API).
This is a tricky interface with no guard rails. If you can use
:class:`ParkingLot` or the built-in I/O wait functions instead, then you
should.
Generally the way it works is that before calling this function, you make
arrangements for "someone" to call :func:`reschedule` on the current task
at some later point.
Then you call :func:`wait_task_rescheduled`, passing in ``abort_func``, an
"abort callback".
(Terminology: in Trio, "aborting" is the process of attempting to
interrupt a blocked task to deliver a cancellation.)
There are two possibilities for what happens next:
1. "Someone" calls :func:`reschedule` on the current task, and
:func:`wait_task_rescheduled` returns or raises whatever value or error
was passed to :func:`reschedule`.
2. The call's context transitions to a cancelled state (e.g. due to a
timeout expiring). When this happens, the ``abort_func`` is called. Its
interface looks like::
def abort_func(raise_cancel):
...
return trio.lowlevel.Abort.SUCCEEDED # or FAILED
It should attempt to clean up any state associated with this call, and
in particular, arrange that :func:`reschedule` will *not* be called
later. If (and only if!) it is successful, then it should return
:data:`Abort.SUCCEEDED`, in which case the task will automatically be
rescheduled with an appropriate :exc:`~trio.Cancelled` error.
Otherwise, it should return :data:`Abort.FAILED`. This means that the
task can't be cancelled at this time, and still has to make sure that
"someone" eventually calls :func:`reschedule`.
At that point there are again two possibilities. You can simply ignore
the cancellation altogether: wait for the operation to complete and
then reschedule and continue as normal. (For example, this is what
:func:`trio.to_thread.run_sync` does if cancellation is disabled.)
The other possibility is that the ``abort_func`` does succeed in
cancelling the operation, but for some reason isn't able to report that
right away. (Example: on Windows, it's possible to request that an
async ("overlapped") I/O operation be cancelled, but this request is
*also* asynchronous you don't find out until later whether the
operation was actually cancelled or not.) To report a delayed
cancellation, then you should reschedule the task yourself, and call
the ``raise_cancel`` callback passed to ``abort_func`` to raise a
:exc:`~trio.Cancelled` (or possibly :exc:`KeyboardInterrupt`) exception
into this task. Either of the approaches sketched below can work::
# Option 1:
# Catch the exception from raise_cancel and inject it into the task.
# (This is what Trio does automatically for you if you return
# Abort.SUCCEEDED.)
trio.lowlevel.reschedule(task, outcome.capture(raise_cancel))
# Option 2:
# wait to be woken by "someone", and then decide whether to raise
# the error from inside the task.
outer_raise_cancel = None
def abort(inner_raise_cancel):
nonlocal outer_raise_cancel
outer_raise_cancel = inner_raise_cancel
TRY_TO_CANCEL_OPERATION()
return trio.lowlevel.Abort.FAILED
await wait_task_rescheduled(abort)
if OPERATION_WAS_SUCCESSFULLY_CANCELLED:
# raises the error
outer_raise_cancel()
In any case it's guaranteed that we only call the ``abort_func`` at most
once per call to :func:`wait_task_rescheduled`.
Sometimes, it's useful to be able to share some mutable sleep-related data
between the sleeping task, the abort function, and the waking task. You
can use the sleeping task's :data:`~Task.custom_sleep_data` attribute to
store this data, and Trio won't touch it, except to make sure that it gets
cleared when the task is rescheduled.
.. warning::
If your ``abort_func`` raises an error, or returns any value other than
:data:`Abort.SUCCEEDED` or :data:`Abort.FAILED`, then Trio will crash
violently. Be careful! Similarly, it is entirely possible to deadlock a
Trio program by failing to reschedule a blocked task, or cause havoc by
calling :func:`reschedule` too many times. Remember what we said up
above about how you should use a higher-level API if at all possible?
"""
return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap()
# Not exported in the trio._core namespace, but imported directly by _run.
@attrs.frozen(slots=False)
class PermanentlyDetachCoroutineObject:
final_outcome: outcome.Outcome[Any]
async def permanently_detach_coroutine_object(
final_outcome: outcome.Outcome[Any],
) -> Any:
"""Permanently detach the current task from the Trio scheduler.
Normally, a Trio task doesn't exit until its coroutine object exits. When
you call this function, Trio acts like the coroutine object just exited
and the task terminates with the given outcome. This is useful if you want
to permanently switch the coroutine object over to a different coroutine
runner.
When the calling coroutine enters this function it's running under Trio,
and when the function returns it's running under the foreign coroutine
runner.
You should make sure that the coroutine object has released any
Trio-specific resources it has acquired (e.g. nurseries).
Args:
final_outcome (outcome.Outcome): Trio acts as if the current task exited
with the given return value or exception.
Returns or raises whatever value or exception the new coroutine runner
uses to resume the coroutine.
"""
if _run.current_task().child_nurseries:
raise RuntimeError(
"can't permanently detach a coroutine object with open nurseries",
)
return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome))
async def temporarily_detach_coroutine_object(
abort_func: Callable[[RaiseCancelT], Abort],
) -> Any:
"""Temporarily detach the current coroutine object from the Trio
scheduler.
When the calling coroutine enters this function it's running under Trio,
and when the function returns it's running under the foreign coroutine
runner.
The Trio :class:`Task` will continue to exist, but will be suspended until
you use :func:`reattach_detached_coroutine_object` to resume it. In the
mean time, you can use another coroutine runner to schedule the coroutine
object. In fact, you have to the function doesn't return until the
coroutine is advanced from outside.
Note that you'll need to save the current :class:`Task` object to later
resume; you can retrieve it with :func:`current_task`. You can also use
this :class:`Task` object to retrieve the coroutine object see
:data:`Task.coro`.
Args:
abort_func: Same as for :func:`wait_task_rescheduled`, except that it
must return :data:`Abort.FAILED`. (If it returned
:data:`Abort.SUCCEEDED`, then Trio would attempt to reschedule the
detached task directly without going through
:func:`reattach_detached_coroutine_object`, which would be bad.)
Your ``abort_func`` should still arrange for whatever the coroutine
object is doing to be cancelled, and then reattach to Trio and call
the ``raise_cancel`` callback, if possible.
Returns or raises whatever value or exception the new coroutine runner
uses to resume the coroutine.
"""
return await _async_yield(WaitTaskRescheduled(abort_func))
async def reattach_detached_coroutine_object(task: Task, yield_value: object) -> None:
"""Reattach a coroutine object that was detached using
:func:`temporarily_detach_coroutine_object`.
When the calling coroutine enters this function it's running under the
foreign coroutine runner, and when the function returns it's running under
Trio.
This must be called from inside the coroutine being resumed, and yields
whatever value you pass in. (Presumably you'll pass a value that will
cause the current coroutine runner to stop scheduling this task.) Then the
coroutine is resumed by the Trio scheduler at the next opportunity.
Args:
task (Task): The Trio task object that the current coroutine was
detached from.
yield_value (object): The object to yield to the current coroutine
runner.
"""
# This is a kind of crude check in particular, it can fail if the
# passed-in task is where the coroutine *runner* is running. But this is
# an experts-only interface, and there's no easy way to do a more accurate
# check, so I guess that's OK.
if not task.coro.cr_running:
raise RuntimeError("given task does not match calling coroutine")
_run.reschedule(task, outcome.Value("reattaching"))
value = await _async_yield(yield_value)
assert value == outcome.Value("reattaching")
@@ -0,0 +1,163 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Generic, TypeVar
import attrs
from .. import _core
from .._deprecate import deprecated
from .._util import final
T = TypeVar("T")
if TYPE_CHECKING:
from typing_extensions import Self
@attrs.frozen
class UnboundedQueueStatistics:
"""An object containing debugging information.
Currently, the following fields are defined:
* ``qsize``: The number of items currently in the queue.
* ``tasks_waiting``: The number of tasks blocked on this queue's
:meth:`get_batch` method.
"""
qsize: int
tasks_waiting: int
@final
class UnboundedQueue(Generic[T]):
"""An unbounded queue suitable for certain unusual forms of inter-task
communication.
This class is designed for use as a queue in cases where the producer for
some reason cannot be subjected to back-pressure, i.e., :meth:`put_nowait`
has to always succeed. In order to prevent the queue backlog from actually
growing without bound, the consumer API is modified to dequeue items in
"batches". If a consumer task processes each batch without yielding, then
this helps achieve (but does not guarantee) an effective bound on the
queue's memory use, at the cost of potentially increasing system latencies
in general. You should generally prefer to use a memory channel
instead if you can.
Currently each batch completely empties the queue, but `this may change in
the future <https://github.com/python-trio/trio/issues/51>`__.
A :class:`UnboundedQueue` object can be used as an asynchronous iterator,
where each iteration returns a new batch of items. I.e., these two loops
are equivalent::
async for batch in queue:
...
while True:
obj = await queue.get_batch()
...
"""
@deprecated(
"0.9.0",
issue=497,
thing="trio.lowlevel.UnboundedQueue",
instead="trio.open_memory_channel(math.inf)",
use_triodeprecationwarning=True,
)
def __init__(self) -> None:
self._lot = _core.ParkingLot()
self._data: list[T] = []
# used to allow handoff from put to the first task in the lot
self._can_get = False
def __repr__(self) -> str:
return f"<UnboundedQueue holding {len(self._data)} items>"
def qsize(self) -> int:
"""Returns the number of items currently in the queue."""
return len(self._data)
def empty(self) -> bool:
"""Returns True if the queue is empty, False otherwise.
There is some subtlety to interpreting this method's return value: see
`issue #63 <https://github.com/python-trio/trio/issues/63>`__.
"""
return not self._data
@_core.enable_ki_protection
def put_nowait(self, obj: T) -> None:
"""Put an object into the queue, without blocking.
This always succeeds, because the queue is unbounded. We don't provide
a blocking ``put`` method, because it would never need to block.
Args:
obj (object): The object to enqueue.
"""
if not self._data:
assert not self._can_get
if self._lot:
self._lot.unpark(count=1)
else:
self._can_get = True
self._data.append(obj)
def _get_batch_protected(self) -> list[T]:
data = self._data.copy()
self._data.clear()
self._can_get = False
return data
def get_batch_nowait(self) -> list[T]:
"""Attempt to get the next batch from the queue, without blocking.
Returns:
list: A list of dequeued items, in order. On a successful call this
list is always non-empty; if it would be empty we raise
:exc:`~trio.WouldBlock` instead.
Raises:
~trio.WouldBlock: if the queue is empty.
"""
if not self._can_get:
raise _core.WouldBlock
return self._get_batch_protected()
async def get_batch(self) -> list[T]:
"""Get the next batch from the queue, blocking as necessary.
Returns:
list: A list of dequeued items, in order. This list is always
non-empty.
"""
await _core.checkpoint_if_cancelled()
if not self._can_get:
await self._lot.park()
return self._get_batch_protected()
else:
try:
return self._get_batch_protected()
finally:
await _core.cancel_shielded_checkpoint()
def statistics(self) -> UnboundedQueueStatistics:
"""Return an :class:`UnboundedQueueStatistics` object containing debugging information."""
return UnboundedQueueStatistics(
qsize=len(self._data),
tasks_waiting=self._lot.statistics().tasks_waiting,
)
def __aiter__(self) -> Self:
return self
async def __anext__(self) -> list[T]:
return await self.get_batch()
@@ -0,0 +1,75 @@
from __future__ import annotations
import contextlib
import signal
import socket
import warnings
from .. import _core
from .._util import is_main_thread
class WakeupSocketpair:
def __init__(self) -> None:
# explicitly typed to please `pyright --verifytypes` without `--ignoreexternal`
self.wakeup_sock: socket.socket
self.write_sock: socket.socket
self.wakeup_sock, self.write_sock = socket.socketpair()
self.wakeup_sock.setblocking(False)
self.write_sock.setblocking(False)
# This somewhat reduces the amount of memory wasted queueing up data
# for wakeups. With these settings, maximum number of 1-byte sends
# before getting BlockingIOError:
# Linux 4.8: 6
# macOS (darwin 15.5): 1
# Windows 10: 525347
# Windows you're weird. (And on Windows setting SNDBUF to 0 makes send
# blocking, even on non-blocking sockets, so don't do that.)
self.wakeup_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
self.write_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1)
# On Windows this is a TCP socket so this might matter. On other
# platforms this fails b/c AF_UNIX sockets aren't actually TCP.
with contextlib.suppress(OSError):
self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.old_wakeup_fd: int | None = None
def wakeup_thread_and_signal_safe(self) -> None:
with contextlib.suppress(BlockingIOError):
self.write_sock.send(b"\x00")
async def wait_woken(self) -> None:
await _core.wait_readable(self.wakeup_sock)
self.drain()
def drain(self) -> None:
try:
while True:
self.wakeup_sock.recv(2**16)
except BlockingIOError:
pass
def wakeup_on_signals(self) -> None:
assert self.old_wakeup_fd is None
if not is_main_thread():
return
fd = self.write_sock.fileno()
self.old_wakeup_fd = signal.set_wakeup_fd(fd, warn_on_full_buffer=False)
if self.old_wakeup_fd != -1:
warnings.warn(
RuntimeWarning(
"It looks like Trio's signal handling code might have "
"collided with another library you're using. If you're "
"running Trio in guest mode, then this might mean you "
"should set host_uses_signal_set_wakeup_fd=True. "
"Otherwise, file a bug on Trio and we'll help you figure "
"out what's going on.",
),
stacklevel=1,
)
def close(self) -> None:
self.wakeup_sock.close()
self.write_sock.close()
if self.old_wakeup_fd is not None:
signal.set_wakeup_fd(self.old_wakeup_fd)
@@ -0,0 +1,520 @@
from __future__ import annotations
import enum
import re
from typing import TYPE_CHECKING, NewType, NoReturn, Protocol, cast
if TYPE_CHECKING:
from typing_extensions import TypeAlias
import cffi
################################################################
# Functions and types
################################################################
LIB = """
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa383751(v=vs.85).aspx
typedef int BOOL;
typedef unsigned char BYTE;
typedef BYTE BOOLEAN;
typedef void* PVOID;
typedef PVOID HANDLE;
typedef unsigned long DWORD;
typedef unsigned long ULONG;
typedef unsigned int NTSTATUS;
typedef unsigned long u_long;
typedef ULONG *PULONG;
typedef const void *LPCVOID;
typedef void *LPVOID;
typedef const wchar_t *LPCWSTR;
typedef uintptr_t ULONG_PTR;
typedef uintptr_t UINT_PTR;
typedef UINT_PTR SOCKET;
typedef struct _OVERLAPPED {
ULONG_PTR Internal;
ULONG_PTR InternalHigh;
union {
struct {
DWORD Offset;
DWORD OffsetHigh;
} DUMMYSTRUCTNAME;
PVOID Pointer;
} DUMMYUNIONNAME;
HANDLE hEvent;
} OVERLAPPED, *LPOVERLAPPED;
typedef OVERLAPPED WSAOVERLAPPED;
typedef LPOVERLAPPED LPWSAOVERLAPPED;
typedef PVOID LPSECURITY_ATTRIBUTES;
typedef PVOID LPCSTR;
typedef struct _OVERLAPPED_ENTRY {
ULONG_PTR lpCompletionKey;
LPOVERLAPPED lpOverlapped;
ULONG_PTR Internal;
DWORD dwNumberOfBytesTransferred;
} OVERLAPPED_ENTRY, *LPOVERLAPPED_ENTRY;
// kernel32.dll
HANDLE WINAPI CreateIoCompletionPort(
_In_ HANDLE FileHandle,
_In_opt_ HANDLE ExistingCompletionPort,
_In_ ULONG_PTR CompletionKey,
_In_ DWORD NumberOfConcurrentThreads
);
BOOL SetFileCompletionNotificationModes(
HANDLE FileHandle,
UCHAR Flags
);
HANDLE CreateFileW(
LPCWSTR lpFileName,
DWORD dwDesiredAccess,
DWORD dwShareMode,
LPSECURITY_ATTRIBUTES lpSecurityAttributes,
DWORD dwCreationDisposition,
DWORD dwFlagsAndAttributes,
HANDLE hTemplateFile
);
BOOL WINAPI CloseHandle(
_In_ HANDLE hObject
);
BOOL WINAPI PostQueuedCompletionStatus(
_In_ HANDLE CompletionPort,
_In_ DWORD dwNumberOfBytesTransferred,
_In_ ULONG_PTR dwCompletionKey,
_In_opt_ LPOVERLAPPED lpOverlapped
);
BOOL WINAPI GetQueuedCompletionStatusEx(
_In_ HANDLE CompletionPort,
_Out_ LPOVERLAPPED_ENTRY lpCompletionPortEntries,
_In_ ULONG ulCount,
_Out_ PULONG ulNumEntriesRemoved,
_In_ DWORD dwMilliseconds,
_In_ BOOL fAlertable
);
BOOL WINAPI CancelIoEx(
_In_ HANDLE hFile,
_In_opt_ LPOVERLAPPED lpOverlapped
);
BOOL WriteFile(
HANDLE hFile,
LPCVOID lpBuffer,
DWORD nNumberOfBytesToWrite,
LPDWORD lpNumberOfBytesWritten,
LPOVERLAPPED lpOverlapped
);
BOOL ReadFile(
HANDLE hFile,
LPVOID lpBuffer,
DWORD nNumberOfBytesToRead,
LPDWORD lpNumberOfBytesRead,
LPOVERLAPPED lpOverlapped
);
BOOL WINAPI SetConsoleCtrlHandler(
_In_opt_ void* HandlerRoutine,
_In_ BOOL Add
);
HANDLE CreateEventA(
LPSECURITY_ATTRIBUTES lpEventAttributes,
BOOL bManualReset,
BOOL bInitialState,
LPCSTR lpName
);
BOOL SetEvent(
HANDLE hEvent
);
BOOL ResetEvent(
HANDLE hEvent
);
DWORD WaitForSingleObject(
HANDLE hHandle,
DWORD dwMilliseconds
);
DWORD WaitForMultipleObjects(
DWORD nCount,
HANDLE *lpHandles,
BOOL bWaitAll,
DWORD dwMilliseconds
);
ULONG RtlNtStatusToDosError(
NTSTATUS Status
);
int WSAIoctl(
SOCKET s,
DWORD dwIoControlCode,
LPVOID lpvInBuffer,
DWORD cbInBuffer,
LPVOID lpvOutBuffer,
DWORD cbOutBuffer,
LPDWORD lpcbBytesReturned,
LPWSAOVERLAPPED lpOverlapped,
// actually LPWSAOVERLAPPED_COMPLETION_ROUTINE
void* lpCompletionRoutine
);
int WSAGetLastError();
BOOL DeviceIoControl(
HANDLE hDevice,
DWORD dwIoControlCode,
LPVOID lpInBuffer,
DWORD nInBufferSize,
LPVOID lpOutBuffer,
DWORD nOutBufferSize,
LPDWORD lpBytesReturned,
LPOVERLAPPED lpOverlapped
);
// From https://github.com/piscisaureus/wepoll/blob/master/src/afd.h
typedef struct _AFD_POLL_HANDLE_INFO {
HANDLE Handle;
ULONG Events;
NTSTATUS Status;
} AFD_POLL_HANDLE_INFO, *PAFD_POLL_HANDLE_INFO;
// This is really defined as a messy union to allow stuff like
// i.DUMMYSTRUCTNAME.LowPart, but we don't need those complications.
// Under all that it's just an int64.
typedef int64_t LARGE_INTEGER;
typedef struct _AFD_POLL_INFO {
LARGE_INTEGER Timeout;
ULONG NumberOfHandles;
ULONG Exclusive;
AFD_POLL_HANDLE_INFO Handles[1];
} AFD_POLL_INFO, *PAFD_POLL_INFO;
"""
# cribbed from pywincffi
# programmatically strips out those annotations MSDN likes, like _In_
REGEX_SAL_ANNOTATION = re.compile(
r"\b(_In_|_Inout_|_Out_|_Outptr_|_Reserved_)(opt_)?\b",
)
LIB = REGEX_SAL_ANNOTATION.sub(" ", LIB)
# Other fixups:
# - get rid of FAR, cffi doesn't like it
LIB = re.sub(r"\bFAR\b", " ", LIB)
# - PASCAL is apparently an alias for __stdcall (on modern compilers - modern
# being _MSC_VER >= 800)
LIB = re.sub(r"\bPASCAL\b", "__stdcall", LIB)
ffi = cffi.api.FFI()
ffi.cdef(LIB)
CData: TypeAlias = cffi.api.FFI.CData
CType: TypeAlias = cffi.api.FFI.CType
AlwaysNull: TypeAlias = CType # We currently always pass ffi.NULL here.
Handle = NewType("Handle", CData)
HandleArray = NewType("HandleArray", CData)
class _Kernel32(Protocol):
"""Statically typed version of the kernel32.dll functions we use."""
def CreateIoCompletionPort(
self,
FileHandle: Handle,
ExistingCompletionPort: CData | AlwaysNull,
CompletionKey: int,
NumberOfConcurrentThreads: int,
/,
) -> Handle: ...
def CreateEventA(
self,
lpEventAttributes: AlwaysNull,
bManualReset: bool,
bInitialState: bool,
lpName: AlwaysNull,
/,
) -> Handle: ...
def SetFileCompletionNotificationModes(
self,
handle: Handle,
flags: CompletionModes,
/,
) -> int: ...
def PostQueuedCompletionStatus(
self,
CompletionPort: Handle,
dwNumberOfBytesTransferred: int,
dwCompletionKey: int,
lpOverlapped: CData | AlwaysNull,
/,
) -> bool: ...
def CancelIoEx(
self,
hFile: Handle,
lpOverlapped: CData | AlwaysNull,
/,
) -> bool: ...
def WriteFile(
self,
hFile: Handle,
# not sure about this type
lpBuffer: CData,
nNumberOfBytesToWrite: int,
lpNumberOfBytesWritten: AlwaysNull,
lpOverlapped: _Overlapped,
/,
) -> bool: ...
def ReadFile(
self,
hFile: Handle,
# not sure about this type
lpBuffer: CData,
nNumberOfBytesToRead: int,
lpNumberOfBytesRead: AlwaysNull,
lpOverlapped: _Overlapped,
/,
) -> bool: ...
def GetQueuedCompletionStatusEx(
self,
CompletionPort: Handle,
lpCompletionPortEntries: CData,
ulCount: int,
ulNumEntriesRemoved: CData,
dwMilliseconds: int,
fAlertable: bool | int,
/,
) -> CData: ...
def CreateFileW(
self,
lpFileName: CData,
dwDesiredAccess: FileFlags,
dwShareMode: FileFlags,
lpSecurityAttributes: AlwaysNull,
dwCreationDisposition: FileFlags,
dwFlagsAndAttributes: FileFlags,
hTemplateFile: AlwaysNull,
/,
) -> Handle: ...
def WaitForSingleObject(self, hHandle: Handle, dwMilliseconds: int, /) -> CData: ...
def WaitForMultipleObjects(
self,
nCount: int,
lpHandles: HandleArray,
bWaitAll: bool,
dwMilliseconds: int,
/,
) -> ErrorCodes: ...
def SetEvent(self, handle: Handle, /) -> None: ...
def CloseHandle(self, handle: Handle, /) -> bool: ...
def DeviceIoControl(
self,
hDevice: Handle,
dwIoControlCode: int,
# this is wrong (it's not always null)
lpInBuffer: AlwaysNull,
nInBufferSize: int,
# this is also wrong
lpOutBuffer: AlwaysNull,
nOutBufferSize: int,
lpBytesReturned: AlwaysNull,
lpOverlapped: CData,
/,
) -> bool: ...
class _Nt(Protocol):
"""Statically typed version of the dtdll.dll functions we use."""
def RtlNtStatusToDosError(self, status: int, /) -> ErrorCodes: ...
class _Ws2(Protocol):
"""Statically typed version of the ws2_32.dll functions we use."""
def WSAGetLastError(self) -> int: ...
def WSAIoctl(
self,
socket: CData,
dwIoControlCode: WSAIoctls,
lpvInBuffer: AlwaysNull,
cbInBuffer: int,
lpvOutBuffer: CData,
cbOutBuffer: int,
lpcbBytesReturned: CData, # int*
lpOverlapped: AlwaysNull,
# actually LPWSAOVERLAPPED_COMPLETION_ROUTINE
lpCompletionRoutine: AlwaysNull,
/,
) -> int: ...
class _DummyStruct(Protocol):
Offset: int
OffsetHigh: int
class _DummyUnion(Protocol):
DUMMYSTRUCTNAME: _DummyStruct
Pointer: object
class _Overlapped(Protocol):
Internal: int
InternalHigh: int
DUMMYUNIONNAME: _DummyUnion
hEvent: Handle
kernel32 = cast(_Kernel32, ffi.dlopen("kernel32.dll"))
ntdll = cast(_Nt, ffi.dlopen("ntdll.dll"))
ws2_32 = cast(_Ws2, ffi.dlopen("ws2_32.dll"))
################################################################
# Magic numbers
################################################################
# Here's a great resource for looking these up:
# https://www.magnumdb.com
# (Tip: check the box to see "Hex value")
INVALID_HANDLE_VALUE = Handle(ffi.cast("HANDLE", -1))
class ErrorCodes(enum.IntEnum):
STATUS_TIMEOUT = 0x102
WAIT_TIMEOUT = 0x102
WAIT_ABANDONED = 0x80
WAIT_OBJECT_0 = 0x00 # object is signaled
WAIT_FAILED = 0xFFFFFFFF
ERROR_IO_PENDING = 997
ERROR_OPERATION_ABORTED = 995
ERROR_ABANDONED_WAIT_0 = 735
ERROR_INVALID_HANDLE = 6
ERROR_INVALID_PARMETER = 87
ERROR_NOT_FOUND = 1168
ERROR_NOT_SOCKET = 10038
class FileFlags(enum.IntFlag):
GENERIC_READ = 0x80000000
SYNCHRONIZE = 0x00100000
FILE_FLAG_OVERLAPPED = 0x40000000
FILE_SHARE_READ = 1
FILE_SHARE_WRITE = 2
FILE_SHARE_DELETE = 4
CREATE_NEW = 1
CREATE_ALWAYS = 2
OPEN_EXISTING = 3
OPEN_ALWAYS = 4
TRUNCATE_EXISTING = 5
class AFDPollFlags(enum.IntFlag):
# These are drawn from a combination of:
# https://github.com/piscisaureus/wepoll/blob/master/src/afd.h
# https://github.com/reactos/reactos/blob/master/sdk/include/reactos/drivers/afd/shared.h
AFD_POLL_RECEIVE = 0x0001
AFD_POLL_RECEIVE_EXPEDITED = 0x0002 # OOB/urgent data
AFD_POLL_SEND = 0x0004
AFD_POLL_DISCONNECT = 0x0008 # received EOF (FIN)
AFD_POLL_ABORT = 0x0010 # received RST
AFD_POLL_LOCAL_CLOSE = 0x0020 # local socket object closed
AFD_POLL_CONNECT = 0x0040 # socket is successfully connected
AFD_POLL_ACCEPT = 0x0080 # you can call accept on this socket
AFD_POLL_CONNECT_FAIL = 0x0100 # connect() terminated unsuccessfully
# See WSAEventSelect docs for more details on these four:
AFD_POLL_QOS = 0x0200
AFD_POLL_GROUP_QOS = 0x0400
AFD_POLL_ROUTING_INTERFACE_CHANGE = 0x0800
AFD_POLL_EVENT_ADDRESS_LIST_CHANGE = 0x1000
class WSAIoctls(enum.IntEnum):
SIO_BASE_HANDLE = 0x48000022
SIO_BSP_HANDLE_SELECT = 0x4800001C
SIO_BSP_HANDLE_POLL = 0x4800001D
class CompletionModes(enum.IntFlag):
FILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 0x1
FILE_SKIP_SET_EVENT_ON_HANDLE = 0x2
class IoControlCodes(enum.IntEnum):
IOCTL_AFD_POLL = 0x00012024
################################################################
# Generic helpers
################################################################
def _handle(obj: int | CData) -> Handle:
# For now, represent handles as either cffi HANDLEs or as ints. If you
# try to pass in a file descriptor instead, it's not going to work
# out. (For that msvcrt.get_osfhandle does the trick, but I don't know if
# we'll actually need that for anything...) For sockets this doesn't
# matter, Python never allocates an fd. So let's wait until we actually
# encounter the problem before worrying about it.
if isinstance(obj, int):
return Handle(ffi.cast("HANDLE", obj))
return Handle(obj)
def handle_array(count: int) -> HandleArray:
"""Make an array of handles."""
return HandleArray(ffi.new(f"HANDLE[{count}]"))
def raise_winerror(
winerror: int | None = None,
*,
filename: str | None = None,
filename2: str | None = None,
) -> NoReturn:
# assert sys.platform == "win32" # TODO: make this work in MyPy
# ... in the meanwhile, ffi.getwinerror() is undefined on non-Windows, necessitating the type
# ignores.
if winerror is None:
err = ffi.getwinerror() # type: ignore[attr-defined,unused-ignore]
if err is None:
raise RuntimeError("No error set?")
winerror, msg = err
else:
err = ffi.getwinerror(winerror) # type: ignore[attr-defined,unused-ignore]
if err is None:
raise RuntimeError("No error set?")
_, msg = err
# https://docs.python.org/3/library/exceptions.html#OSError
raise OSError(0, msg, filename, winerror, filename2)
@@ -0,0 +1,177 @@
from __future__ import annotations
import sys
import warnings
from functools import wraps
from types import ModuleType
from typing import TYPE_CHECKING, ClassVar, TypeVar
import attrs
if TYPE_CHECKING:
from collections.abc import Callable
from typing_extensions import ParamSpec
ArgsT = ParamSpec("ArgsT")
RetT = TypeVar("RetT")
# We want our warnings to be visible by default (at least for now), but we
# also want it to be possible to override that using the -W switch. AFAICT
# this means we cannot inherit from DeprecationWarning, because the only way
# to make it visible by default then would be to add our own filter at import
# time, but that would override -W switches...
class TrioDeprecationWarning(FutureWarning):
"""Warning emitted if you use deprecated Trio functionality.
As a young project, Trio is currently quite aggressive about deprecating
and/or removing functionality that we realize was a bad idea. If you use
Trio, you should subscribe to `issue #1
<https://github.com/python-trio/trio/issues/1>`__ to get information about
upcoming deprecations and other backwards compatibility breaking changes.
Despite the name, this class currently inherits from
:class:`FutureWarning`, not :class:`DeprecationWarning`, because while
we're in young-and-aggressive mode we want these warnings to be visible by
default. You can hide them by installing a filter or with the ``-W``
switch: see the :mod:`warnings` documentation for details.
"""
def _url_for_issue(issue: int) -> str:
return f"https://github.com/python-trio/trio/issues/{issue}"
def _stringify(thing: object) -> str:
if hasattr(thing, "__module__") and hasattr(thing, "__qualname__"):
return f"{thing.__module__}.{thing.__qualname__}"
return str(thing)
def warn_deprecated(
thing: object,
version: str,
*,
issue: int | None,
instead: object,
stacklevel: int = 2,
use_triodeprecationwarning: bool = False,
) -> None:
stacklevel += 1
msg = f"{_stringify(thing)} is deprecated since Trio {version}"
if instead is None:
msg += " with no replacement"
else:
msg += f"; use {_stringify(instead)} instead"
if issue is not None:
msg += f" ({_url_for_issue(issue)})"
if use_triodeprecationwarning:
warning_class: type[Warning] = TrioDeprecationWarning
else:
warning_class = DeprecationWarning
warnings.warn(warning_class(msg), stacklevel=stacklevel)
# @deprecated("0.2.0", issue=..., instead=...)
# def ...
def deprecated(
version: str,
*,
thing: object = None,
issue: int | None,
instead: object,
use_triodeprecationwarning: bool = False,
) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]:
def do_wrap(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]:
nonlocal thing
@wraps(fn)
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
warn_deprecated(
thing,
version,
instead=instead,
issue=issue,
use_triodeprecationwarning=use_triodeprecationwarning,
)
return fn(*args, **kwargs)
# If our __module__ or __qualname__ get modified, we want to pick up
# on that, so we read them off the wrapper object instead of the (now
# hidden) fn object
if thing is None:
thing = wrapper
if wrapper.__doc__ is not None:
doc = wrapper.__doc__
doc = doc.rstrip()
doc += "\n\n"
doc += f".. deprecated:: {version}\n"
if instead is not None:
doc += f" Use {_stringify(instead)} instead.\n"
if issue is not None:
doc += f" For details, see `issue #{issue} <{_url_for_issue(issue)}>`__.\n"
doc += "\n"
wrapper.__doc__ = doc
return wrapper
return do_wrap
def deprecated_alias(
old_qualname: str,
new_fn: Callable[ArgsT, RetT],
version: str,
*,
issue: int | None,
) -> Callable[ArgsT, RetT]:
@deprecated(version, issue=issue, instead=new_fn)
@wraps(new_fn, assigned=("__module__", "__annotations__"))
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
"""Deprecated alias."""
return new_fn(*args, **kwargs)
wrapper.__qualname__ = old_qualname
wrapper.__name__ = old_qualname.rpartition(".")[-1]
return wrapper
@attrs.frozen(slots=False)
class DeprecatedAttribute:
_not_set: ClassVar[object] = object()
value: object
version: str
issue: int | None
instead: object = _not_set
class _ModuleWithDeprecations(ModuleType):
__deprecated_attributes__: dict[str, DeprecatedAttribute]
def __getattr__(self, name: str) -> object:
if name in self.__deprecated_attributes__:
info = self.__deprecated_attributes__[name]
instead = info.instead
if instead is DeprecatedAttribute._not_set:
instead = info.value
thing = f"{self.__name__}.{name}"
warn_deprecated(thing, info.version, issue=info.issue, instead=instead)
return info.value
msg = "module '{}' has no attribute '{}'"
raise AttributeError(msg.format(self.__name__, name))
def enable_attribute_deprecations(module_name: str) -> None:
module = sys.modules[module_name]
module.__class__ = _ModuleWithDeprecations
assert isinstance(module, _ModuleWithDeprecations)
# Make sure that this is always defined so that
# _ModuleWithDeprecations.__getattr__ can access it without jumping
# through hoops or risking infinite recursion.
module.__deprecated_attributes__ = {}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,509 @@
from __future__ import annotations
import io
from functools import partial
from typing import (
IO,
TYPE_CHECKING,
Any,
AnyStr,
BinaryIO,
Callable,
Generic,
Iterable,
TypeVar,
Union,
overload,
)
import trio
from ._util import async_wraps
from .abc import AsyncResource
if TYPE_CHECKING:
from _typeshed import (
OpenBinaryMode,
OpenBinaryModeReading,
OpenBinaryModeUpdating,
OpenBinaryModeWriting,
OpenTextMode,
StrOrBytesPath,
)
from typing_extensions import Literal
# This list is also in the docs, make sure to keep them in sync
_FILE_SYNC_ATTRS: set[str] = {
"closed",
"encoding",
"errors",
"fileno",
"isatty",
"newlines",
"readable",
"seekable",
"writable",
# not defined in *IOBase:
"buffer",
"raw",
"line_buffering",
"closefd",
"name",
"mode",
"getvalue",
"getbuffer",
}
# This list is also in the docs, make sure to keep them in sync
_FILE_ASYNC_METHODS: set[str] = {
"flush",
"read",
"read1",
"readall",
"readinto",
"readline",
"readlines",
"seek",
"tell",
"truncate",
"write",
"writelines",
# not defined in *IOBase:
"readinto1",
"peek",
}
FileT = TypeVar("FileT")
FileT_co = TypeVar("FileT_co", covariant=True)
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)
AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True)
AnyStr_contra = TypeVar("AnyStr_contra", str, bytes, contravariant=True)
# This is a little complicated. IO objects have a lot of methods, and which are available on
# different types varies wildly. We want to match the interface of whatever file we're wrapping.
# This pile of protocols each has one sync method/property, meaning they're going to be compatible
# with a file class that supports that method/property. The ones parameterized with AnyStr take
# either str or bytes depending.
# The wrapper is then a generic class, where the typevar is set to the type of the sync file we're
# wrapping. For generics, adding a type to self has a special meaning - properties/methods can be
# conditional - it's only valid to call them if the object you're accessing them on is compatible
# with that type hint. By using the protocols, the type checker will be checking to see if the
# wrapped type has that method, and only allow the methods that do to be called. We can then alter
# the signature however it needs to match runtime behaviour.
# More info: https://mypy.readthedocs.io/en/stable/more_types.html#advanced-uses-of-self-types
if TYPE_CHECKING:
from typing_extensions import Buffer, Protocol
# fmt: off
class _HasClosed(Protocol):
@property
def closed(self) -> bool: ...
class _HasEncoding(Protocol):
@property
def encoding(self) -> str: ...
class _HasErrors(Protocol):
@property
def errors(self) -> str | None: ...
class _HasFileNo(Protocol):
def fileno(self) -> int: ...
class _HasIsATTY(Protocol):
def isatty(self) -> bool: ...
class _HasNewlines(Protocol[T_co]):
# Type varies here - documented to be None, tuple of strings, strings. Typeshed uses Any.
@property
def newlines(self) -> T_co: ...
class _HasReadable(Protocol):
def readable(self) -> bool: ...
class _HasSeekable(Protocol):
def seekable(self) -> bool: ...
class _HasWritable(Protocol):
def writable(self) -> bool: ...
class _HasBuffer(Protocol):
@property
def buffer(self) -> BinaryIO: ...
class _HasRaw(Protocol):
@property
def raw(self) -> io.RawIOBase: ...
class _HasLineBuffering(Protocol):
@property
def line_buffering(self) -> bool: ...
class _HasCloseFD(Protocol):
@property
def closefd(self) -> bool: ...
class _HasName(Protocol):
@property
def name(self) -> str: ...
class _HasMode(Protocol):
@property
def mode(self) -> str: ...
class _CanGetValue(Protocol[AnyStr_co]):
def getvalue(self) -> AnyStr_co: ...
class _CanGetBuffer(Protocol):
def getbuffer(self) -> memoryview: ...
class _CanFlush(Protocol):
def flush(self) -> None: ...
class _CanRead(Protocol[AnyStr_co]):
def read(self, size: int | None = ..., /) -> AnyStr_co: ...
class _CanRead1(Protocol):
def read1(self, size: int | None = ..., /) -> bytes: ...
class _CanReadAll(Protocol[AnyStr_co]):
def readall(self) -> AnyStr_co: ...
class _CanReadInto(Protocol):
def readinto(self, buf: Buffer, /) -> int | None: ...
class _CanReadInto1(Protocol):
def readinto1(self, buffer: Buffer, /) -> int: ...
class _CanReadLine(Protocol[AnyStr_co]):
def readline(self, size: int = ..., /) -> AnyStr_co: ...
class _CanReadLines(Protocol[AnyStr]):
def readlines(self, hint: int = ..., /) -> list[AnyStr]: ...
class _CanSeek(Protocol):
def seek(self, target: int, whence: int = 0, /) -> int: ...
class _CanTell(Protocol):
def tell(self) -> int: ...
class _CanTruncate(Protocol):
def truncate(self, size: int | None = ..., /) -> int: ...
class _CanWrite(Protocol[T_contra]):
def write(self, data: T_contra, /) -> int: ...
class _CanWriteLines(Protocol[T_contra]):
# The lines parameter varies for bytes/str, so use a typevar to make the async match.
def writelines(self, lines: Iterable[T_contra], /) -> None: ...
class _CanPeek(Protocol[AnyStr_co]):
def peek(self, size: int = 0, /) -> AnyStr_co: ...
class _CanDetach(Protocol[T_co]):
# The T typevar will be the unbuffered/binary file this file wraps.
def detach(self) -> T_co: ...
class _CanClose(Protocol):
def close(self) -> None: ...
# FileT needs to be covariant for the protocol trick to work - the real IO types are effectively a
# subtype of the protocols.
class AsyncIOWrapper(AsyncResource, Generic[FileT_co]):
"""A generic :class:`~io.IOBase` wrapper that implements the :term:`asynchronous
file object` interface. Wrapped methods that could block are executed in
:meth:`trio.to_thread.run_sync`.
All properties and methods defined in :mod:`~io` are exposed by this
wrapper, if they exist in the wrapped file object.
"""
def __init__(self, file: FileT_co) -> None:
self._wrapped = file
@property
def wrapped(self) -> FileT_co:
"""object: A reference to the wrapped file object"""
return self._wrapped
if not TYPE_CHECKING:
def __getattr__(self, name: str) -> object:
if name in _FILE_SYNC_ATTRS:
return getattr(self._wrapped, name)
if name in _FILE_ASYNC_METHODS:
meth = getattr(self._wrapped, name)
@async_wraps(self.__class__, self._wrapped.__class__, name)
async def wrapper(*args, **kwargs):
func = partial(meth, *args, **kwargs)
return await trio.to_thread.run_sync(func)
# cache the generated method
setattr(self, name, wrapper)
return wrapper
raise AttributeError(name)
def __dir__(self) -> Iterable[str]:
attrs = set(super().__dir__())
attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a))
attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a))
return attrs
def __aiter__(self) -> AsyncIOWrapper[FileT_co]:
return self
async def __anext__(self: AsyncIOWrapper[_CanReadLine[AnyStr]]) -> AnyStr:
line = await self.readline()
if line:
return line
else:
raise StopAsyncIteration
async def detach(self: AsyncIOWrapper[_CanDetach[T]]) -> AsyncIOWrapper[T]:
"""Like :meth:`io.BufferedIOBase.detach`, but async.
This also re-wraps the result in a new :term:`asynchronous file object`
wrapper.
"""
raw = await trio.to_thread.run_sync(self._wrapped.detach)
return wrap_file(raw)
async def aclose(self: AsyncIOWrapper[_CanClose]) -> None:
"""Like :meth:`io.IOBase.close`, but async.
This is also shielded from cancellation; if a cancellation scope is
cancelled, the wrapped file object will still be safely closed.
"""
# ensure the underling file is closed during cancellation
with trio.CancelScope(shield=True):
await trio.to_thread.run_sync(self._wrapped.close)
await trio.lowlevel.checkpoint_if_cancelled()
if TYPE_CHECKING:
# fmt: off
# Based on typing.IO and io stubs.
@property
def closed(self: AsyncIOWrapper[_HasClosed]) -> bool: ...
@property
def encoding(self: AsyncIOWrapper[_HasEncoding]) -> str: ...
@property
def errors(self: AsyncIOWrapper[_HasErrors]) -> str | None: ...
@property
def newlines(self: AsyncIOWrapper[_HasNewlines[T]]) -> T: ...
@property
def buffer(self: AsyncIOWrapper[_HasBuffer]) -> BinaryIO: ...
@property
def raw(self: AsyncIOWrapper[_HasRaw]) -> io.RawIOBase: ...
@property
def line_buffering(self: AsyncIOWrapper[_HasLineBuffering]) -> int: ...
@property
def closefd(self: AsyncIOWrapper[_HasCloseFD]) -> bool: ...
@property
def name(self: AsyncIOWrapper[_HasName]) -> str: ...
@property
def mode(self: AsyncIOWrapper[_HasMode]) -> str: ...
def fileno(self: AsyncIOWrapper[_HasFileNo]) -> int: ...
def isatty(self: AsyncIOWrapper[_HasIsATTY]) -> bool: ...
def readable(self: AsyncIOWrapper[_HasReadable]) -> bool: ...
def seekable(self: AsyncIOWrapper[_HasSeekable]) -> bool: ...
def writable(self: AsyncIOWrapper[_HasWritable]) -> bool: ...
def getvalue(self: AsyncIOWrapper[_CanGetValue[AnyStr]]) -> AnyStr: ...
def getbuffer(self: AsyncIOWrapper[_CanGetBuffer]) -> memoryview: ...
async def flush(self: AsyncIOWrapper[_CanFlush]) -> None: ...
async def read(self: AsyncIOWrapper[_CanRead[AnyStr]], size: int | None = -1, /) -> AnyStr: ...
async def read1(self: AsyncIOWrapper[_CanRead1], size: int | None = -1, /) -> bytes: ...
async def readall(self: AsyncIOWrapper[_CanReadAll[AnyStr]]) -> AnyStr: ...
async def readinto(self: AsyncIOWrapper[_CanReadInto], buf: Buffer, /) -> int | None: ...
async def readline(self: AsyncIOWrapper[_CanReadLine[AnyStr]], size: int = -1, /) -> AnyStr: ...
async def readlines(self: AsyncIOWrapper[_CanReadLines[AnyStr]]) -> list[AnyStr]: ...
async def seek(self: AsyncIOWrapper[_CanSeek], target: int, whence: int = 0, /) -> int: ...
async def tell(self: AsyncIOWrapper[_CanTell]) -> int: ...
async def truncate(self: AsyncIOWrapper[_CanTruncate], size: int | None = None, /) -> int: ...
async def write(self: AsyncIOWrapper[_CanWrite[T]], data: T, /) -> int: ...
async def writelines(self: AsyncIOWrapper[_CanWriteLines[T]], lines: Iterable[T], /) -> None: ...
async def readinto1(self: AsyncIOWrapper[_CanReadInto1], buffer: Buffer, /) -> int: ...
async def peek(self: AsyncIOWrapper[_CanPeek[AnyStr]], size: int = 0, /) -> AnyStr: ...
# Type hints are copied from builtin open.
_OpenFile = Union["StrOrBytesPath", int]
_Opener = Callable[[str, int], int]
@overload
async def open_file(
file: _OpenFile,
mode: OpenTextMode = "r",
buffering: int = -1,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[io.TextIOWrapper]: ...
@overload
async def open_file(
file: _OpenFile,
mode: OpenBinaryMode,
buffering: Literal[0],
encoding: None = None,
errors: None = None,
newline: None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[io.FileIO]: ...
@overload
async def open_file(
file: _OpenFile,
mode: OpenBinaryModeUpdating,
buffering: Literal[-1, 1] = -1,
encoding: None = None,
errors: None = None,
newline: None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[io.BufferedRandom]: ...
@overload
async def open_file(
file: _OpenFile,
mode: OpenBinaryModeWriting,
buffering: Literal[-1, 1] = -1,
encoding: None = None,
errors: None = None,
newline: None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[io.BufferedWriter]: ...
@overload
async def open_file(
file: _OpenFile,
mode: OpenBinaryModeReading,
buffering: Literal[-1, 1] = -1,
encoding: None = None,
errors: None = None,
newline: None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[io.BufferedReader]: ...
@overload
async def open_file(
file: _OpenFile,
mode: OpenBinaryMode,
buffering: int,
encoding: None = None,
errors: None = None,
newline: None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[BinaryIO]: ...
@overload
async def open_file( # type: ignore[misc] # Any usage matches builtins.open().
file: _OpenFile,
mode: str,
buffering: int = -1,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[IO[Any]]: ...
async def open_file(
file: _OpenFile,
mode: str = "r",
buffering: int = -1,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
closefd: bool = True,
opener: _Opener | None = None,
) -> AsyncIOWrapper[Any]:
"""Asynchronous version of :func:`open`.
Returns:
An :term:`asynchronous file object`
Example::
async with await trio.open_file(filename) as f:
async for line in f:
pass
assert f.closed
See also:
:func:`trio.Path.open`
"""
_file = wrap_file(
await trio.to_thread.run_sync(
io.open,
file,
mode,
buffering,
encoding,
errors,
newline,
closefd,
opener,
),
)
return _file
def wrap_file(file: FileT) -> AsyncIOWrapper[FileT]:
"""This wraps any file object in a wrapper that provides an asynchronous
file object interface.
Args:
file: a :term:`file object`
Returns:
An :term:`asynchronous file object` that wraps ``file``
Example::
async_file = trio.wrap_file(StringIO('asdf'))
assert await async_file.read() == 'asdf'
"""
def has(attr: str) -> bool:
return hasattr(file, attr) and callable(getattr(file, attr))
if not (has("close") and (has("read") or has("write"))):
raise TypeError(
f"{file} does not implement required duck-file methods: "
"close and (read or write)",
)
return AsyncIOWrapper(file)
@@ -0,0 +1,129 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Generic, TypeVar
import attrs
import trio
from trio._util import final
from .abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream
if TYPE_CHECKING:
from typing_extensions import TypeGuard
SendStreamT = TypeVar("SendStreamT", bound=SendStream)
ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream)
async def aclose_forcefully(resource: AsyncResource) -> None:
"""Close an async resource or async generator immediately, without
blocking to do any graceful cleanup.
:class:`~trio.abc.AsyncResource` objects guarantee that if their
:meth:`~trio.abc.AsyncResource.aclose` method is cancelled, then they will
still close the resource (albeit in a potentially ungraceful
fashion). :func:`aclose_forcefully` is a convenience function that
exploits this behavior to let you force a resource to be closed without
blocking: it works by calling ``await resource.aclose()`` and then
cancelling it immediately.
Most users won't need this, but it may be useful on cleanup paths where
you can't afford to block, or if you want to close a resource and don't
care about handling it gracefully. For example, if
:class:`~trio.SSLStream` encounters an error and cannot perform its
own graceful close, then there's no point in waiting to gracefully shut
down the underlying transport either, so it calls ``await
aclose_forcefully(self.transport_stream)``.
Note that this function is async, and that it acts as a checkpoint, but
unlike most async functions it cannot block indefinitely (at least,
assuming the underlying resource object is correctly implemented).
"""
with trio.CancelScope() as cs:
cs.cancel()
await resource.aclose()
def _is_halfclosable(stream: SendStream) -> TypeGuard[HalfCloseableStream]:
"""Check if the stream has a send_eof() method."""
return hasattr(stream, "send_eof")
@final
@attrs.define(eq=False, slots=False)
class StapledStream(
HalfCloseableStream,
Generic[SendStreamT, ReceiveStreamT],
):
"""This class `staples <https://en.wikipedia.org/wiki/Staple_(fastener)>`__
together two unidirectional streams to make single bidirectional stream.
Args:
send_stream (~trio.abc.SendStream): The stream to use for sending.
receive_stream (~trio.abc.ReceiveStream): The stream to use for
receiving.
Example:
A silly way to make a stream that echoes back whatever you write to
it::
left, right = trio.testing.memory_stream_pair()
echo_stream = StapledStream(SocketStream(left), SocketStream(right))
await echo_stream.send_all(b"x")
assert await echo_stream.receive_some() == b"x"
:class:`StapledStream` objects implement the methods in the
:class:`~trio.abc.HalfCloseableStream` interface. They also have two
additional public attributes:
.. attribute:: send_stream
The underlying :class:`~trio.abc.SendStream`. :meth:`send_all` and
:meth:`wait_send_all_might_not_block` are delegated to this object.
.. attribute:: receive_stream
The underlying :class:`~trio.abc.ReceiveStream`. :meth:`receive_some`
is delegated to this object.
"""
send_stream: SendStreamT
receive_stream: ReceiveStreamT
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
"""Calls ``self.send_stream.send_all``."""
return await self.send_stream.send_all(data)
async def wait_send_all_might_not_block(self) -> None:
"""Calls ``self.send_stream.wait_send_all_might_not_block``."""
return await self.send_stream.wait_send_all_might_not_block()
async def send_eof(self) -> None:
"""Shuts down the send side of the stream.
If :meth:`self.send_stream.send_eof() <trio.abc.HalfCloseableStream.send_eof>` exists,
then this calls it. Otherwise, this calls
:meth:`self.send_stream.aclose() <trio.abc.AsyncResource.aclose>`.
"""
stream = self.send_stream
if _is_halfclosable(stream):
return await stream.send_eof()
else:
return await stream.aclose()
# we intentionally accept more types from the caller than we support returning
async def receive_some(self, max_bytes: int | None = None) -> bytes:
"""Calls ``self.receive_stream.receive_some``."""
return await self.receive_stream.receive_some(max_bytes)
async def aclose(self) -> None:
"""Calls ``aclose`` on both underlying streams."""
try:
await self.send_stream.aclose()
finally:
await self.receive_stream.aclose()
@@ -0,0 +1,251 @@
from __future__ import annotations
import errno
import sys
from typing import TYPE_CHECKING
import trio
from trio import TaskStatus
from . import socket as tsocket
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
# Default backlog size:
#
# Having the backlog too low can cause practical problems (a perfectly healthy
# service that starts failing to accept connections if they arrive in a
# burst).
#
# Having it too high doesn't really cause any problems. Like any buffer, you
# want backlog queue to be zero usually, and it won't save you if you're
# getting connection attempts faster than you can call accept() on an ongoing
# basis. But unlike other buffers, this one doesn't really provide any
# backpressure. If a connection gets stuck waiting in the backlog queue, then
# from the peer's point of view the connection succeeded but then their
# send/recv will stall until we get to it, possibly for a long time. OTOH if
# there isn't room in the backlog queue, then their connect stalls, possibly
# for a long time, which is pretty much the same thing.
#
# A large backlog can also use a bit more kernel memory, but this seems fairly
# negligible these days.
#
# So this suggests we should make the backlog as large as possible. This also
# matches what Golang does. However, they do it in a weird way, where they
# have a bunch of code to sniff out the configured upper limit for backlog on
# different operating systems. But on every system, passing in a too-large
# backlog just causes it to be silently truncated to the configured maximum,
# so this is unnecessary -- we can just pass in "infinity" and get the maximum
# that way. (Verified on Windows, Linux, macOS using
# notes-to-self/measure-listen-backlog.py)
def _compute_backlog(backlog: int | None) -> int:
# Many systems (Linux, BSDs, ...) store the backlog in a uint16 and are
# missing overflow protection, so we apply our own overflow protection.
# https://github.com/golang/go/issues/5030
if not isinstance(backlog, int) and backlog is not None:
raise TypeError(f"backlog must be an int or None, not {backlog!r}")
if backlog is None:
return 0xFFFF
return min(backlog, 0xFFFF)
async def open_tcp_listeners(
port: int,
*,
host: str | bytes | None = None,
backlog: int | None = None,
) -> list[trio.SocketListener]:
"""Create :class:`SocketListener` objects to listen for TCP connections.
Args:
port (int): The port to listen on.
If you use 0 as your port, then the kernel will automatically pick
an arbitrary open port. But be careful: if you use this feature when
binding to multiple IP addresses, then each IP address will get its
own random port, and the returned listeners will probably be
listening on different ports. In particular, this will happen if you
use ``host=None`` which is the default because in this case
:func:`open_tcp_listeners` will bind to both the IPv4 wildcard
address (``0.0.0.0``) and also the IPv6 wildcard address (``::``).
host (str, bytes, or None): The local interface to bind to. This is
passed to :func:`~socket.getaddrinfo` with the ``AI_PASSIVE`` flag
set.
If you want to bind to the wildcard address on both IPv4 and IPv6,
in order to accept connections on all available interfaces, then
pass ``None``. This is the default.
If you have a specific interface you want to bind to, pass its IP
address or hostname here. If a hostname resolves to multiple IP
addresses, this function will open one listener on each of them.
If you want to use only IPv4, or only IPv6, but want to accept on
all interfaces, pass the family-specific wildcard address:
``"0.0.0.0"`` for IPv4-only and ``"::"`` for IPv6-only.
backlog (int or None): The listen backlog to use. If you leave this as
``None`` then Trio will pick a good default. (Currently: whatever
your system has configured as the maximum backlog.)
Returns:
list of :class:`SocketListener`
Raises:
:class:`TypeError` if invalid arguments.
"""
# getaddrinfo sometimes allows port=None, sometimes not (depending on
# whether host=None). And on some systems it treats "" as 0, others it
# doesn't:
# http://klickverbot.at/blog/2012/01/getaddrinfo-edge-case-behavior-on-windows-linux-and-osx/
if not isinstance(port, int):
raise TypeError(f"port must be an int not {port!r}")
computed_backlog = _compute_backlog(backlog)
addresses = await tsocket.getaddrinfo(
host,
port,
type=tsocket.SOCK_STREAM,
flags=tsocket.AI_PASSIVE,
)
listeners = []
unsupported_address_families = []
try:
for family, type_, proto, _, sockaddr in addresses:
try:
sock = tsocket.socket(family, type_, proto)
except OSError as ex:
if ex.errno == errno.EAFNOSUPPORT:
# If a system only supports IPv4, or only IPv6, it
# is still likely that getaddrinfo will return
# both an IPv4 and an IPv6 address. As long as at
# least one of the returned addresses can be
# turned into a socket, we won't complain about a
# failure to create the other.
unsupported_address_families.append(ex)
continue
else:
raise
try:
# See https://github.com/python-trio/trio/issues/39
if sys.platform != "win32":
sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_REUSEADDR, 1)
if family == tsocket.AF_INET6:
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, 1)
await sock.bind(sockaddr)
sock.listen(computed_backlog)
listeners.append(trio.SocketListener(sock))
except:
sock.close()
raise
except:
for listener in listeners:
listener.socket.close()
raise
if unsupported_address_families and not listeners:
msg = (
"This system doesn't support any of the kinds of "
"socket that that address could use"
)
raise OSError(errno.EAFNOSUPPORT, msg) from ExceptionGroup(
msg,
unsupported_address_families,
)
return listeners
async def serve_tcp(
handler: Callable[[trio.SocketStream], Awaitable[object]],
port: int,
*,
host: str | bytes | None = None,
backlog: int | None = None,
handler_nursery: trio.Nursery | None = None,
task_status: TaskStatus[list[trio.SocketListener]] = trio.TASK_STATUS_IGNORED,
) -> None:
"""Listen for incoming TCP connections, and for each one start a task
running ``handler(stream)``.
This is a thin convenience wrapper around :func:`open_tcp_listeners` and
:func:`serve_listeners` see them for full details.
.. warning::
If ``handler`` raises an exception, then this function doesn't do
anything special to catch it so by default the exception will
propagate out and crash your server. If you don't want this, then catch
exceptions inside your ``handler``, or use a ``handler_nursery`` object
that responds to exceptions in some other way.
When used with ``nursery.start`` you get back the newly opened listeners.
So, for example, if you want to start a server in your test suite and then
connect to it to check that it's working properly, you can use something
like::
from trio import SocketListener, SocketStream
from trio.testing import open_stream_to_socket_listener
async with trio.open_nursery() as nursery:
listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0)
client_stream: SocketStream = await open_stream_to_socket_listener(listeners[0])
# Then send and receive data on 'client_stream', for example:
await client_stream.send_all(b"GET / HTTP/1.0\\r\\n\\r\\n")
This avoids several common pitfalls:
1. It lets the kernel pick a random open port, so your test suite doesn't
depend on any particular port being open.
2. It waits for the server to be accepting connections on that port before
``start`` returns, so there's no race condition where the incoming
connection arrives before the server is ready.
3. It uses the Listener object to find out which port was picked, so it
can connect to the right place.
Args:
handler: The handler to start for each incoming connection. Passed to
:func:`serve_listeners`.
port: The port to listen on. Use 0 to let the kernel pick an open port.
Passed to :func:`open_tcp_listeners`.
host (str, bytes, or None): The host interface to listen on; use
``None`` to bind to the wildcard address. Passed to
:func:`open_tcp_listeners`.
backlog: The listen backlog, or None to have a good default picked.
Passed to :func:`open_tcp_listeners`.
handler_nursery: The nursery to start handlers in, or None to use an
internal nursery. Passed to :func:`serve_listeners`.
task_status: This function can be used with ``nursery.start``.
Returns:
This function only returns when cancelled.
"""
listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog)
await trio.serve_listeners(
handler,
listeners,
handler_nursery=handler_nursery,
task_status=task_status,
)
@@ -0,0 +1,411 @@
from __future__ import annotations
import sys
from contextlib import contextmanager, suppress
from typing import TYPE_CHECKING, Any
import trio
from trio.socket import SOCK_STREAM, SocketType, getaddrinfo, socket
if TYPE_CHECKING:
from collections.abc import Generator
from socket import AddressFamily, SocketKind
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
# Implementation of RFC 6555 "Happy eyeballs"
# https://tools.ietf.org/html/rfc6555
#
# Basically, the problem here is that if we want to connect to some host, and
# DNS returns multiple IP addresses, then we don't know which of them will
# actually work -- it can happen that some of them are reachable, and some of
# them are not. One particularly common situation where this happens is on a
# host that thinks it has ipv6 connectivity, but really doesn't. But in
# principle this could happen for any kind of multi-home situation (e.g. the
# route to one mirror is down but another is up).
#
# The naive algorithm (e.g. the stdlib's socket.create_connection) would be to
# pick one of the IP addresses and try to connect; if that fails, try the
# next; etc. The problem with this is that TCP is stubborn, and if the first
# address is a blackhole then it might take a very long time (tens of seconds)
# before that connection attempt fails.
#
# That's where RFC 6555 comes in. It tells us that what we do is:
# - get the list of IPs from getaddrinfo, trusting the order it gives us (with
# one exception noted in section 5.4)
# - start a connection attempt to the first IP
# - when this fails OR if it's still going after DELAY seconds, then start a
# connection attempt to the second IP
# - when this fails OR if it's still going after another DELAY seconds, then
# start a connection attempt to the third IP
# - ... repeat until we run out of IPs.
#
# Our implementation is similarly straightforward: we spawn a chain of tasks,
# where each one (a) waits until the previous connection has failed or DELAY
# seconds have passed, (b) spawns the next task, (c) attempts to connect. As
# soon as any task crashes or succeeds, we cancel all the tasks and return.
#
# Note: this currently doesn't attempt to cache any results, so if you make
# multiple connections to the same host it'll re-run the happy-eyeballs
# algorithm each time. RFC 6555 is pretty confusing about whether this is
# allowed. Section 4 describes an algorithm that attempts ipv4 and ipv6
# simultaneously, and then says "The client MUST cache information regarding
# the outcome of each connection attempt, and it uses that information to
# avoid thrashing the network with subsequent attempts." Then section 4.2 says
# "implementations MUST prefer the first IP address family returned by the
# host's address preference policy, unless implementing a stateful
# algorithm". Here "stateful" means "one that caches information about
# previous attempts". So my reading of this is that IF you're starting ipv4
# and ipv6 at the same time then you MUST cache the result for ~ten minutes,
# but IF you're "preferring" one protocol by trying it first (like we are),
# then you don't need to cache.
#
# Caching is quite tricky: to get it right you need to do things like detect
# when the network interfaces are reconfigured, and if you get it wrong then
# connection attempts basically just don't work. So we don't even try.
# "Firefox and Chrome use 300 ms"
# https://tools.ietf.org/html/rfc6555#section-6
# Though
# https://www.researchgate.net/profile/Vaibhav_Bajpai3/publication/304568993_Measuring_the_Effects_of_Happy_Eyeballs/links/5773848e08ae6f328f6c284c/Measuring-the-Effects-of-Happy-Eyeballs.pdf
# claims that Firefox actually uses 0 ms, unless an about:config option is
# toggled and then it uses 250 ms.
DEFAULT_DELAY = 0.250
# How should we call getaddrinfo? In particular, should we use AI_ADDRCONFIG?
#
# The idea of AI_ADDRCONFIG is that it only returns addresses that might
# work. E.g., if getaddrinfo knows that you don't have any IPv6 connectivity,
# then it doesn't return any IPv6 addresses. And this is kinda nice, because
# it means maybe you can skip sending AAAA requests entirely. But in practice,
# it doesn't really work right.
#
# - on Linux/glibc, empirically, the default is to return all addresses, and
# with AI_ADDRCONFIG then it only returns IPv6 addresses if there is at least
# one non-loopback IPv6 address configured... but this can be a link-local
# address, so in practice I guess this is basically always configured if IPv6
# is enabled at all. OTOH if you pass in "::1" as the target address with
# AI_ADDRCONFIG and there's no *external* IPv6 address configured, you get an
# error. So AI_ADDRCONFIG mostly doesn't do anything, even when you would want
# it to, and when it does do something it might break things that would have
# worked.
#
# - on Windows 10, empirically, if no IPv6 address is configured then by
# default they are also suppressed from getaddrinfo (flags=0 and
# flags=AI_ADDRCONFIG seem to do the same thing). If you pass AI_ALL, then you
# get the full list.
# ...except for localhost! getaddrinfo("localhost", "80") gives me ::1, even
# though there's no ipv6 and other queries only return ipv4.
# If you pass in and IPv6 IP address as the target address, then that's always
# returned OK, even with AI_ADDRCONFIG set and no IPv6 configured.
#
# But I guess other versions of windows messed this up, judging from these bug
# reports:
# https://bugs.chromium.org/p/chromium/issues/detail?id=5234
# https://bugs.chromium.org/p/chromium/issues/detail?id=32522#c50
#
# So basically the options are either to use AI_ADDRCONFIG and then add some
# complicated special cases to work around its brokenness, or else don't use
# AI_ADDRCONFIG and accept that sometimes on legacy/misconfigured networks
# we'll waste 300 ms trying to connect to a blackholed destination.
#
# Twisted and Tornado always uses default flags. I think we'll do the same.
@contextmanager
def close_all() -> Generator[set[SocketType], None, None]:
sockets_to_close: set[SocketType] = set()
try:
yield sockets_to_close
finally:
errs = []
for sock in sockets_to_close:
try:
sock.close()
except BaseException as exc:
errs.append(exc)
if len(errs) == 1:
raise errs[0]
elif errs:
raise BaseExceptionGroup("", errs)
def reorder_for_rfc_6555_section_5_4(
targets: list[
tuple[
AddressFamily,
SocketKind,
int,
str,
Any,
]
],
) -> None:
# RFC 6555 section 5.4 says that if getaddrinfo returns multiple address
# families (e.g. IPv4 and IPv6), then you should make sure that your first
# and second attempts use different families:
#
# https://tools.ietf.org/html/rfc6555#section-5.4
#
# This function post-processes the results from getaddrinfo, in-place, to
# satisfy this requirement.
for i in range(1, len(targets)):
if targets[i][0] != targets[0][0]:
# Found the first entry with a different address family; move it
# so that it becomes the second item on the list.
if i != 1:
targets.insert(1, targets.pop(i))
break
def format_host_port(host: str | bytes, port: int | str) -> str:
host = host.decode("ascii") if isinstance(host, bytes) else host
if ":" in host:
return f"[{host}]:{port}"
else:
return f"{host}:{port}"
# Twisted's HostnameEndpoint has a good set of configurables:
# https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.HostnameEndpoint.html
#
# - per-connection timeout
# this doesn't seem useful -- we let you set a timeout on the whole thing
# using Trio's normal mechanisms, and that seems like enough
# - delay between attempts
# - bind address (but not port!)
# they *don't* support multiple address bindings, like giving the ipv4 and
# ipv6 addresses of the host.
# I think maybe our semantics should be: we accept a list of bind addresses,
# and we bind to the first one that is compatible with the
# connection attempt we want to make, and if none are compatible then we
# don't try to connect to that target.
#
# XX TODO: implement bind address support
#
# Actually, the best option is probably to be explicit: {AF_INET: "...",
# AF_INET6: "..."}
# this might be simpler after
async def open_tcp_stream(
host: str | bytes,
port: int,
*,
happy_eyeballs_delay: float | None = DEFAULT_DELAY,
local_address: str | None = None,
) -> trio.SocketStream:
"""Connect to the given host and port over TCP.
If the given ``host`` has multiple IP addresses associated with it, then
we have a problem: which one do we use?
One approach would be to attempt to connect to the first one, and then if
that fails, attempt to connect to the second one ... until we've tried all
of them. But the problem with this is that if the first IP address is
unreachable (for example, because it's an IPv6 address and our network
discards IPv6 packets), then we might end up waiting tens of seconds for
the first connection attempt to timeout before we try the second address.
Another approach would be to attempt to connect to all of the addresses at
the same time, in parallel, and then use whichever connection succeeds
first, abandoning the others. This would be fast, but create a lot of
unnecessary load on the network and the remote server.
This function strikes a balance between these two extremes: it works its
way through the available addresses one at a time, like the first
approach; but, if ``happy_eyeballs_delay`` seconds have passed and it's
still waiting for an attempt to succeed or fail, then it gets impatient
and starts the next connection attempt in parallel. As soon as any one
connection attempt succeeds, all the other attempts are cancelled. This
avoids unnecessary load because most connections will succeed after just
one or two attempts, but if one of the addresses is unreachable then it
doesn't slow us down too much.
This is known as a "happy eyeballs" algorithm, and our particular variant
is modelled after how Chrome connects to webservers; see `RFC 6555
<https://tools.ietf.org/html/rfc6555>`__ for more details.
Args:
host (str or bytes): The host to connect to. Can be an IPv4 address,
IPv6 address, or a hostname.
port (int): The port to connect to.
happy_eyeballs_delay (float or None): How many seconds to wait for each
connection attempt to succeed or fail before getting impatient and
starting another one in parallel. Set to `None` if you want
to limit to only one connection attempt at a time (like
:func:`socket.create_connection`). Default: 0.25 (250 ms).
local_address (None or str): The local IP address or hostname to use as
the source for outgoing connections. If ``None``, we let the OS pick
the source IP.
This is useful in some exotic networking configurations where your
host has multiple IP addresses, and you want to force the use of a
specific one.
Note that if you pass an IPv4 ``local_address``, then you won't be
able to connect to IPv6 hosts, and vice-versa. If you want to take
advantage of this to force the use of IPv4 or IPv6 without
specifying an exact source address, you can use the IPv4 wildcard
address ``local_address="0.0.0.0"``, or the IPv6 wildcard address
``local_address="::"``.
Returns:
SocketStream: a :class:`~trio.abc.Stream` connected to the given server.
Raises:
OSError: if the connection fails.
See also:
open_ssl_over_tcp_stream
"""
# To keep our public API surface smaller, rule out some cases that
# getaddrinfo will accept in some circumstances, but that act weird or
# have non-portable behavior or are just plain not useful.
if not isinstance(host, (str, bytes)):
raise ValueError(f"host must be str or bytes, not {host!r}")
if not isinstance(port, int):
raise TypeError(f"port must be int, not {port!r}")
if happy_eyeballs_delay is None:
happy_eyeballs_delay = DEFAULT_DELAY
targets = await getaddrinfo(host, port, type=SOCK_STREAM)
# I don't think this can actually happen -- if there are no results,
# getaddrinfo should have raised OSError instead of returning an empty
# list. But let's be paranoid and handle it anyway:
if not targets:
msg = f"no results found for hostname lookup: {format_host_port(host, port)}"
raise OSError(msg)
reorder_for_rfc_6555_section_5_4(targets)
# This list records all the connection failures that we ignored.
oserrors: list[OSError] = []
# Keeps track of the socket that we're going to complete with,
# need to make sure this isn't automatically closed
winning_socket: SocketType | None = None
# Try connecting to the specified address. Possible outcomes:
# - success: record connected socket in winning_socket and cancel
# concurrent attempts
# - failure: record exception in oserrors, set attempt_failed allowing
# the next connection attempt to start early
# code needs to ensure sockets can be closed appropriately in the
# face of crash or cancellation
async def attempt_connect(
socket_args: tuple[AddressFamily, SocketKind, int],
sockaddr: Any,
attempt_failed: trio.Event,
) -> None:
nonlocal winning_socket
try:
sock = socket(*socket_args)
open_sockets.add(sock)
if local_address is not None:
# TCP connections are identified by a 4-tuple:
#
# (local IP, local port, remote IP, remote port)
#
# So if a single local IP wants to make multiple connections
# to the same (remote IP, remote port) pair, then those
# connections have to use different local ports, or else TCP
# won't be able to tell them apart. OTOH, if you have multiple
# connections to different remote IP/ports, then those
# connections can share a local port.
#
# Normally, when you call bind(), the kernel will immediately
# assign a specific local port to your socket. At this point
# the kernel doesn't know which (remote IP, remote port)
# you're going to use, so it has to pick a local port that
# *no* other connection is using. That's the only way to
# guarantee that this local port will be usable later when we
# call connect(). (Alternatively, you can set SO_REUSEADDR to
# allow multiple nascent connections to share the same port,
# but then connect() might fail with EADDRNOTAVAIL if we get
# unlucky and our TCP 4-tuple ends up colliding with another
# unrelated connection.)
#
# So calling bind() before connect() works, but it disables
# sharing of local ports. This is inefficient: it makes you
# more likely to run out of local ports.
#
# But on some versions of Linux, we can re-enable sharing of
# local ports by setting a special flag. This flag tells
# bind() to only bind the IP, and not the port. That way,
# connect() is allowed to pick the the port, and it can do a
# better job of it because it knows the remote IP/port.
with suppress(OSError, AttributeError):
sock.setsockopt(
trio.socket.IPPROTO_IP,
trio.socket.IP_BIND_ADDRESS_NO_PORT,
1,
)
try:
await sock.bind((local_address, 0))
except OSError:
raise OSError(
f"local_address={local_address!r} is incompatible "
f"with remote address {sockaddr!r}",
) from None
await sock.connect(sockaddr)
# Success! Save the winning socket and cancel all outstanding
# connection attempts.
winning_socket = sock
nursery.cancel_scope.cancel()
except OSError as exc:
# This connection attempt failed, but the next one might
# succeed. Save the error for later so we can report it if
# everything fails, and tell the next attempt that it should go
# ahead (if it hasn't already).
oserrors.append(exc)
attempt_failed.set()
with close_all() as open_sockets:
# nursery spawns a task for each connection attempt, will be
# cancelled by the task that gets a successful connection
async with trio.open_nursery() as nursery:
for address_family, socket_type, proto, _, addr in targets:
# create an event to indicate connection failure,
# allowing the next target to be tried early
attempt_failed = trio.Event()
# workaround to check types until typing of nursery.start_soon improved
if TYPE_CHECKING:
await attempt_connect(
(address_family, socket_type, proto),
addr,
attempt_failed,
)
nursery.start_soon(
attempt_connect,
(address_family, socket_type, proto),
addr,
attempt_failed,
)
# give this attempt at most this time before moving on
with trio.move_on_after(happy_eyeballs_delay):
await attempt_failed.wait()
# nothing succeeded
if winning_socket is None:
assert len(oserrors) == len(targets)
msg = f"all attempts to connect to {format_host_port(host, port)} failed"
raise OSError(msg) from ExceptionGroup(msg, oserrors)
else:
stream = trio.SocketStream(winning_socket)
open_sockets.remove(winning_socket)
return stream
@@ -0,0 +1,65 @@
from __future__ import annotations
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Protocol, TypeVar
import trio
from trio.socket import SOCK_STREAM, socket
if TYPE_CHECKING:
from collections.abc import Generator
class Closable(Protocol):
def close(self) -> None: ...
CloseT = TypeVar("CloseT", bound=Closable)
try:
from trio.socket import AF_UNIX
has_unix = True
except ImportError:
has_unix = False
@contextmanager
def close_on_error(obj: CloseT) -> Generator[CloseT, None, None]:
try:
yield obj
except:
obj.close()
raise
async def open_unix_socket(
filename: str | bytes | os.PathLike[str] | os.PathLike[bytes],
) -> trio.SocketStream:
"""Opens a connection to the specified
`Unix domain socket <https://en.wikipedia.org/wiki/Unix_domain_socket>`__.
You must have read/write permission on the specified file to connect.
Args:
filename (str or bytes): The filename to open the connection to.
Returns:
SocketStream: a :class:`~trio.abc.Stream` connected to the given file.
Raises:
OSError: If the socket file could not be connected to.
RuntimeError: If AF_UNIX sockets are not supported.
"""
if not has_unix:
raise RuntimeError("Unix sockets are not supported on this platform")
# much more simplified logic vs tcp sockets - one socket type and only one
# possible location to connect to
sock = socket(AF_UNIX, SOCK_STREAM)
with close_on_error(sock):
await sock.connect(os.fspath(filename))
return trio.SocketStream(sock)
@@ -0,0 +1,147 @@
from __future__ import annotations
import errno
import logging
import os
from typing import Any, Awaitable, Callable, NoReturn, TypeVar
import trio
# Errors that accept(2) can return, and which indicate that the system is
# overloaded
ACCEPT_CAPACITY_ERRNOS = {
errno.EMFILE,
errno.ENFILE,
errno.ENOMEM,
errno.ENOBUFS,
}
# How long to sleep when we get one of those errors
SLEEP_TIME = 0.100
# The logger we use to complain when this happens
LOGGER = logging.getLogger("trio.serve_listeners")
StreamT = TypeVar("StreamT", bound=trio.abc.AsyncResource)
ListenerT = TypeVar("ListenerT", bound=trio.abc.Listener[Any])
Handler = Callable[[StreamT], Awaitable[object]]
async def _run_handler(stream: StreamT, handler: Handler[StreamT]) -> None:
try:
await handler(stream)
finally:
await trio.aclose_forcefully(stream)
async def _serve_one_listener(
listener: trio.abc.Listener[StreamT],
handler_nursery: trio.Nursery,
handler: Handler[StreamT],
) -> NoReturn:
async with listener:
while True:
try:
stream = await listener.accept()
except OSError as exc:
if exc.errno in ACCEPT_CAPACITY_ERRNOS:
LOGGER.error(
"accept returned %s (%s); retrying in %s seconds",
errno.errorcode[exc.errno],
os.strerror(exc.errno),
SLEEP_TIME,
exc_info=True,
)
await trio.sleep(SLEEP_TIME)
else:
raise
else:
handler_nursery.start_soon(_run_handler, stream, handler)
# This cannot be typed correctly, we need generic typevar bounds / HKT to indicate the
# relationship between StreamT & ListenerT.
# https://github.com/python/typing/issues/1226
# https://github.com/python/typing/issues/548
async def serve_listeners(
handler: Handler[StreamT],
listeners: list[ListenerT],
*,
handler_nursery: trio.Nursery | None = None,
task_status: trio.TaskStatus[list[ListenerT]] = trio.TASK_STATUS_IGNORED,
) -> NoReturn:
r"""Listen for incoming connections on ``listeners``, and for each one
start a task running ``handler(stream)``.
.. warning::
If ``handler`` raises an exception, then this function doesn't do
anything special to catch it so by default the exception will
propagate out and crash your server. If you don't want this, then catch
exceptions inside your ``handler``, or use a ``handler_nursery`` object
that responds to exceptions in some other way.
Args:
handler: An async callable, that will be invoked like
``handler_nursery.start_soon(handler, stream)`` for each incoming
connection.
listeners: A list of :class:`~trio.abc.Listener` objects.
:func:`serve_listeners` takes responsibility for closing them.
handler_nursery: The nursery used to start handlers, or any object with
a ``start_soon`` method. If ``None`` (the default), then
:func:`serve_listeners` will create a new nursery internally and use
that.
task_status: This function can be used with ``nursery.start``, which
will return ``listeners``.
Returns:
This function never returns unless cancelled.
Resource handling:
If ``handler`` neglects to close the ``stream``, then it will be closed
using :func:`trio.aclose_forcefully`.
Error handling:
Most errors coming from :meth:`~trio.abc.Listener.accept` are allowed to
propagate out (crashing the server in the process). However, some errors
those which indicate that the server is temporarily overloaded are
handled specially. These are :class:`OSError`\s with one of the following
errnos:
* ``EMFILE``: process is out of file descriptors
* ``ENFILE``: system is out of file descriptors
* ``ENOBUFS``, ``ENOMEM``: the kernel hit some sort of memory limitation
when trying to create a socket object
When :func:`serve_listeners` gets one of these errors, then it:
* Logs the error to the standard library logger ``trio.serve_listeners``
(level = ERROR, with exception information included). By default this
causes it to be printed to stderr.
* Waits 100 ms before calling ``accept`` again, in hopes that the
system will recover.
"""
async with trio.open_nursery() as nursery:
if handler_nursery is None:
handler_nursery = nursery
for listener in listeners:
nursery.start_soon(_serve_one_listener, listener, handler_nursery, handler)
# The listeners are already queueing connections when we're called,
# but we wait until the end to call started() just in case we get an
# error or whatever.
task_status.started(listeners)
raise AssertionError(
"_serve_one_listener should never complete",
) # pragma: no cover
@@ -0,0 +1,414 @@
# "High-level" networking interface
from __future__ import annotations
import errno
from contextlib import contextmanager, suppress
from typing import TYPE_CHECKING, overload
import trio
from . import socket as tsocket
from ._util import ConflictDetector, final
from .abc import HalfCloseableStream, Listener
if TYPE_CHECKING:
from collections.abc import Generator
from typing_extensions import Buffer
from ._socket import SocketType
# XX TODO: this number was picked arbitrarily. We should do experiments to
# tune it. (Or make it dynamic -- one idea is to start small and increase it
# if we observe single reads filling up the whole buffer, at least within some
# limits.)
DEFAULT_RECEIVE_SIZE = 65536
_closed_stream_errnos = {
# Unix
errno.EBADF,
# Windows
errno.ENOTSOCK,
}
@contextmanager
def _translate_socket_errors_to_stream_errors() -> Generator[None, None, None]:
try:
yield
except OSError as exc:
if exc.errno in _closed_stream_errnos:
raise trio.ClosedResourceError("this socket was already closed") from None
else:
raise trio.BrokenResourceError(f"socket connection broken: {exc}") from exc
@final
class SocketStream(HalfCloseableStream):
"""An implementation of the :class:`trio.abc.HalfCloseableStream`
interface based on a raw network socket.
Args:
socket: The Trio socket object to wrap. Must have type ``SOCK_STREAM``,
and be connected.
By default for TCP sockets, :class:`SocketStream` enables ``TCP_NODELAY``,
and (on platforms where it's supported) enables ``TCP_NOTSENT_LOWAT`` with
a reasonable buffer size (currently 16 KiB) see `issue #72
<https://github.com/python-trio/trio/issues/72>`__ for discussion. You can
of course override these defaults by calling :meth:`setsockopt`.
Once a :class:`SocketStream` object is constructed, it implements the full
:class:`trio.abc.HalfCloseableStream` interface. In addition, it provides
a few extra features:
.. attribute:: socket
The Trio socket object that this stream wraps.
"""
def __init__(self, socket: SocketType):
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketStream requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
raise ValueError("SocketStream requires a SOCK_STREAM socket")
self.socket = socket
self._send_conflict_detector = ConflictDetector(
"another task is currently sending data on this SocketStream",
)
# Socket defaults:
# Not supported on e.g. unix domain sockets
with suppress(OSError):
self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, True)
if hasattr(tsocket, "TCP_NOTSENT_LOWAT"):
# 16 KiB is pretty arbitrary and could probably do with some
# tuning. (Apple is also setting this by default in CFNetwork
# apparently -- I'm curious what value they're using, though I
# couldn't find it online trivially. CFNetwork-129.20 source
# has no mentions of TCP_NOTSENT_LOWAT. This presentation says
# "typically 8 kilobytes":
# http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1
# ). The theory is that you want it to be bandwidth *
# rescheduling interval.
with suppress(OSError):
self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NOTSENT_LOWAT, 2**14)
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
if self.socket.did_shutdown_SHUT_WR:
raise trio.ClosedResourceError("can't send data after sending EOF")
with self._send_conflict_detector:
with _translate_socket_errors_to_stream_errors():
with memoryview(data) as data:
if not data:
if self.socket.fileno() == -1:
raise trio.ClosedResourceError("socket was already closed")
await trio.lowlevel.checkpoint()
return
total_sent = 0
while total_sent < len(data):
with data[total_sent:] as remaining:
sent = await self.socket.send(remaining)
total_sent += sent
async def wait_send_all_might_not_block(self) -> None:
with self._send_conflict_detector:
if self.socket.fileno() == -1:
raise trio.ClosedResourceError
with _translate_socket_errors_to_stream_errors():
await self.socket.wait_writable()
async def send_eof(self) -> None:
with self._send_conflict_detector:
await trio.lowlevel.checkpoint()
# On macOS, calling shutdown a second time raises ENOTCONN, but
# send_eof needs to be idempotent.
if self.socket.did_shutdown_SHUT_WR:
return
with _translate_socket_errors_to_stream_errors():
self.socket.shutdown(tsocket.SHUT_WR)
async def receive_some(self, max_bytes: int | None = None) -> bytes:
if max_bytes is None:
max_bytes = DEFAULT_RECEIVE_SIZE
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
with _translate_socket_errors_to_stream_errors():
return await self.socket.recv(max_bytes)
async def aclose(self) -> None:
self.socket.close()
await trio.lowlevel.checkpoint()
# __aenter__, __aexit__ inherited from HalfCloseableStream are OK
@overload
def setsockopt(self, level: int, option: int, value: int | Buffer) -> None: ...
@overload
def setsockopt(self, level: int, option: int, value: None, length: int) -> None: ...
def setsockopt(
self,
level: int,
option: int,
value: int | Buffer | None,
length: int | None = None,
) -> None:
"""Set an option on the underlying socket.
See :meth:`socket.socket.setsockopt` for details.
"""
if length is None:
if value is None:
raise TypeError(
"invalid value for argument 'value', must not be None when specifying length",
)
return self.socket.setsockopt(level, option, value)
if value is not None:
raise TypeError(
f"invalid value for argument 'value': {value!r}, must be None when specifying optlen",
)
return self.socket.setsockopt(level, option, value, length)
@overload
def getsockopt(self, level: int, option: int) -> int: ...
@overload
def getsockopt(self, level: int, option: int, buffersize: int) -> bytes: ...
def getsockopt(self, level: int, option: int, buffersize: int = 0) -> int | bytes:
"""Check the current value of an option on the underlying socket.
See :meth:`socket.socket.getsockopt` for details.
"""
# This is to work around
# https://bitbucket.org/pypy/pypy/issues/2561
# We should be able to drop it when the next PyPy3 beta is released.
if buffersize == 0:
return self.socket.getsockopt(level, option)
else:
return self.socket.getsockopt(level, option, buffersize)
################################################################
# SocketListener
################################################################
# Accept error handling
# =====================
#
# Literature review
# -----------------
#
# Here's a list of all the possible errors that accept() can return, according
# to the POSIX spec or the Linux, FreeBSD, macOS, and Windows docs:
#
# Can't happen with a Trio socket:
# - EAGAIN/(WSA)EWOULDBLOCK
# - EINTR
# - WSANOTINITIALISED
# - WSAEINPROGRESS: a blocking call is already in progress
# - WSAEINTR: someone called WSACancelBlockingCall, but we don't make blocking
# calls in the first place
#
# Something is wrong with our call:
# - EBADF: not a file descriptor
# - (WSA)EINVAL: socket isn't listening, or (Linux, BSD) bad flags
# - (WSA)ENOTSOCK: not a socket
# - (WSA)EOPNOTSUPP: this kind of socket doesn't support accept
# - (Linux, FreeBSD, Windows) EFAULT: the sockaddr pointer points to readonly
# memory
#
# Something is wrong with the environment:
# - (WSA)EMFILE: this process hit its fd limit
# - ENFILE: the system hit its fd limit
# - (WSA)ENOBUFS, ENOMEM: unspecified memory problems
#
# Something is wrong with the connection we were going to accept. There's a
# ton of variability between systems here:
# - ECONNABORTED: documented everywhere, but apparently only the BSDs do this
# (signals a connection was closed/reset before being accepted)
# - EPROTO: unspecified protocol error
# - (Linux) EPERM: firewall rule prevented connection
# - (Linux) ENETDOWN, EPROTO, ENOPROTOOPT, EHOSTDOWN, ENONET, EHOSTUNREACH,
# EOPNOTSUPP, ENETUNREACH, ENOSR, ESOCKTNOSUPPORT, EPROTONOSUPPORT,
# ETIMEDOUT, ... or any other error that the socket could give, because
# apparently if an error happens on a connection before it's accept()ed,
# Linux will report that error from accept().
# - (Windows) WSAECONNRESET, WSAENETDOWN
#
#
# Code review
# -----------
#
# What do other libraries do?
#
# Twisted on Unix or when using nonblocking I/O on Windows:
# - ignores EPERM, with comment about Linux firewalls
# - logs and ignores EMFILE, ENOBUFS, ENFILE, ENOMEM, ECONNABORTED
# Comment notes that ECONNABORTED is a BSDism and that Linux returns the
# socket before having it fail, and macOS just silently discards it.
# - other errors are raised, which is logged + kills the socket
# ref: src/twisted/internet/tcp.py, Port.doRead
#
# Twisted using IOCP on Windows:
# - logs and ignores all errors
# ref: src/twisted/internet/iocpreactor/tcp.py, Port.handleAccept
#
# Tornado:
# - ignore ECONNABORTED (comments notes that it was observed on FreeBSD)
# - everything else raised, but all this does (by default) is cause it to be
# logged and then ignored
# (ref: tornado/netutil.py, tornado/ioloop.py)
#
# libuv on Unix:
# - ignores ECONNABORTED
# - does a "trick" for EMFILE or ENFILE
# - all other errors passed to the connection_cb to be handled
# (ref: src/unix/stream.c:uv__server_io, uv__emfile_trick)
#
# libuv on Windows:
# src/win/tcp.c:uv_tcp_queue_accept
# this calls AcceptEx, and then arranges to call:
# src/win/tcp.c:uv_process_tcp_accept_req
# this gets the result from AcceptEx. If the original AcceptEx call failed,
# then "we stop accepting connections and report this error to the
# connection callback". I think this is for things like ENOTSOCK. If
# AcceptEx successfully queues an overlapped operation, and then that
# reports an error, it's just discarded.
#
# asyncio, selector mode:
# - ignores EWOULDBLOCK, EINTR, ECONNABORTED
# - on EMFILE, ENFILE, ENOBUFS, ENOMEM, logs an error and then disables the
# listening loop for 1 second
# - everything else raises, but then the event loop just logs and ignores it
# (selector_events.py: BaseSelectorEventLoop._accept_connection)
#
#
# What should we do?
# ------------------
#
# When accept() returns an error, we can either ignore it or raise it.
#
# We have a long list of errors that should be ignored, and a long list of
# errors that should be raised. The big question is what to do with an error
# that isn't on either list. On Linux apparently you can get nearly arbitrary
# errors from accept() and they should be ignored, because it just indicates a
# socket that crashed before it began, and there isn't really anything to be
# done about this, plus on other platforms you may not get any indication at
# all, so programs have to tolerate not getting any indication too. OTOH if we
# get an unexpected error then it could indicate something arbitrarily bad --
# after all, it's unexpected.
#
# Given that we know that other libraries seem to be getting along fine with a
# fairly minimal list of errors to ignore, I think we'll be OK if we write
# down that list and then raise on everything else.
#
# The other question is what to do about the capacity problem errors: EMFILE,
# ENFILE, ENOBUFS, ENOMEM. Just flat out ignoring these is clearly not optimal
# -- at the very least you want to log them, and probably you want to take
# some remedial action. And if we ignore them then it prevents higher levels
# from doing anything clever with them. So we raise them.
_ignorable_accept_errno_names = [
# Linux can do this when the a connection is denied by the firewall
"EPERM",
# BSDs with an early close/reset
"ECONNABORTED",
# All the other miscellany noted above -- may not happen in practice, but
# whatever.
"EPROTO",
"ENETDOWN",
"ENOPROTOOPT",
"EHOSTDOWN",
"ENONET",
"EHOSTUNREACH",
"EOPNOTSUPP",
"ENETUNREACH",
"ENOSR",
"ESOCKTNOSUPPORT",
"EPROTONOSUPPORT",
"ETIMEDOUT",
"ECONNRESET",
]
# Not all errnos are defined on all platforms
_ignorable_accept_errnos: set[int] = set()
for name in _ignorable_accept_errno_names:
with suppress(AttributeError):
_ignorable_accept_errnos.add(getattr(errno, name))
@final
class SocketListener(Listener[SocketStream]):
"""A :class:`~trio.abc.Listener` that uses a listening socket to accept
incoming connections as :class:`SocketStream` objects.
Args:
socket: The Trio socket object to wrap. Must have type ``SOCK_STREAM``,
and be listening.
Note that the :class:`SocketListener` "takes ownership" of the given
socket; closing the :class:`SocketListener` will also close the socket.
.. attribute:: socket
The Trio socket object that this stream wraps.
"""
def __init__(self, socket: SocketType):
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketListener requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
raise ValueError("SocketListener requires a SOCK_STREAM socket")
try:
listening = socket.getsockopt(tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN)
except OSError:
# SO_ACCEPTCONN fails on macOS; we just have to trust the user.
pass
else:
if not listening:
raise ValueError("SocketListener requires a listening socket")
self.socket = socket
async def accept(self) -> SocketStream:
"""Accept an incoming connection.
Returns:
:class:`SocketStream`
Raises:
OSError: if the underlying call to ``accept`` raises an unexpected
error.
ClosedResourceError: if you already closed the socket.
This method handles routine errors like ``ECONNABORTED``, but passes
other errors on to its caller. In particular, it does *not* make any
special effort to handle resource exhaustion errors like ``EMFILE``,
``ENFILE``, ``ENOBUFS``, ``ENOMEM``.
"""
while True:
try:
sock, _ = await self.socket.accept()
except OSError as exc:
if exc.errno in _closed_stream_errnos:
raise trio.ClosedResourceError from None
if exc.errno not in _ignorable_accept_errnos:
raise
else:
return SocketStream(sock)
async def aclose(self) -> None:
"""Close this listener and its underlying socket."""
self.socket.close()
await trio.lowlevel.checkpoint()
@@ -0,0 +1,180 @@
from __future__ import annotations
import ssl
from typing import TYPE_CHECKING, NoReturn, TypeVar
import trio
from ._highlevel_open_tcp_stream import DEFAULT_DELAY
T = TypeVar("T")
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from ._highlevel_socket import SocketStream
# It might have been nice to take a ssl_protocols= argument here to set up
# NPN/ALPN, but to do this we have to mutate the context object, which is OK
# if it's one we created, but not OK if it's one that was passed in... and
# the one major protocol using NPN/ALPN is HTTP/2, which mandates that you use
# a specially configured SSLContext anyway! I also thought maybe we could copy
# the given SSLContext and then mutate the copy, but it's no good as SSLContext
# objects can't be copied: https://bugs.python.org/issue33023.
# So... let's punt on that for now. Hopefully we'll be getting a new Python
# TLS API soon and can revisit this then.
async def open_ssl_over_tcp_stream(
host: str | bytes,
port: int,
*,
https_compatible: bool = False,
ssl_context: ssl.SSLContext | None = None,
happy_eyeballs_delay: float | None = DEFAULT_DELAY,
) -> trio.SSLStream[SocketStream]:
"""Make a TLS-encrypted Connection to the given host and port over TCP.
This is a convenience wrapper that calls :func:`open_tcp_stream` and
wraps the result in an :class:`~trio.SSLStream`.
This function does not perform the TLS handshake; you can do it
manually by calling :meth:`~trio.SSLStream.do_handshake`, or else
it will be performed automatically the first time you send or receive
data.
Args:
host (bytes or str): The host to connect to. We require the server
to have a TLS certificate valid for this hostname.
port (int): The port to connect to.
https_compatible (bool): Set this to True if you're connecting to a web
server. See :class:`~trio.SSLStream` for details. Default:
False.
ssl_context (:class:`~ssl.SSLContext` or None): The SSL context to
use. If None (the default), :func:`ssl.create_default_context`
will be called to create a context.
happy_eyeballs_delay (float): See :func:`open_tcp_stream`.
Returns:
trio.SSLStream: the encrypted connection to the server.
"""
tcp_stream = await trio.open_tcp_stream(
host,
port,
happy_eyeballs_delay=happy_eyeballs_delay,
)
if ssl_context is None:
ssl_context = ssl.create_default_context()
if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
return trio.SSLStream(
tcp_stream,
ssl_context,
server_hostname=host,
https_compatible=https_compatible,
)
async def open_ssl_over_tcp_listeners(
port: int,
ssl_context: ssl.SSLContext,
*,
host: str | bytes | None = None,
https_compatible: bool = False,
backlog: int | None = None,
) -> list[trio.SSLListener[SocketStream]]:
"""Start listening for SSL/TLS-encrypted TCP connections to the given port.
Args:
port (int): The port to listen on. See :func:`open_tcp_listeners`.
ssl_context (~ssl.SSLContext): The SSL context to use for all incoming
connections.
host (str, bytes, or None): The address to bind to; use ``None`` to bind
to the wildcard address. See :func:`open_tcp_listeners`.
https_compatible (bool): See :class:`~trio.SSLStream` for details.
backlog (int or None): See :func:`open_tcp_listeners` for details.
"""
tcp_listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog)
ssl_listeners = [
trio.SSLListener(tcp_listener, ssl_context, https_compatible=https_compatible)
for tcp_listener in tcp_listeners
]
return ssl_listeners
async def serve_ssl_over_tcp(
handler: Callable[[trio.SSLStream[SocketStream]], Awaitable[object]],
port: int,
ssl_context: ssl.SSLContext,
*,
host: str | bytes | None = None,
https_compatible: bool = False,
backlog: int | None = None,
handler_nursery: trio.Nursery | None = None,
task_status: trio.TaskStatus[
list[trio.SSLListener[SocketStream]]
] = trio.TASK_STATUS_IGNORED,
) -> NoReturn:
"""Listen for incoming TCP connections, and for each one start a task
running ``handler(stream)``.
This is a thin convenience wrapper around
:func:`open_ssl_over_tcp_listeners` and :func:`serve_listeners` see them
for full details.
.. warning::
If ``handler`` raises an exception, then this function doesn't do
anything special to catch it so by default the exception will
propagate out and crash your server. If you don't want this, then catch
exceptions inside your ``handler``, or use a ``handler_nursery`` object
that responds to exceptions in some other way.
When used with ``nursery.start`` you get back the newly opened listeners.
See the documentation for :func:`serve_tcp` for an example where this is
useful.
Args:
handler: The handler to start for each incoming connection. Passed to
:func:`serve_listeners`.
port (int): The port to listen on. Use 0 to let the kernel pick
an open port. Ultimately passed to :func:`open_tcp_listeners`.
ssl_context (~ssl.SSLContext): The SSL context to use for all incoming
connections. Passed to :func:`open_ssl_over_tcp_listeners`.
host (str, bytes, or None): The address to bind to; use ``None`` to bind
to the wildcard address. Ultimately passed to
:func:`open_tcp_listeners`.
https_compatible (bool): Set this to True if you want to use
"HTTPS-style" TLS. See :class:`~trio.SSLStream` for details.
backlog (int or None): See :class:`~trio.SSLStream` for details.
handler_nursery: The nursery to start handlers in, or None to use an
internal nursery. Passed to :func:`serve_listeners`.
task_status: This function can be used with ``nursery.start``.
Returns:
This function only returns when cancelled.
"""
listeners = await trio.open_ssl_over_tcp_listeners(
port,
ssl_context,
host=host,
https_compatible=https_compatible,
backlog=backlog,
)
await trio.serve_listeners(
handler,
listeners,
handler_nursery=handler_nursery,
task_status=task_status,
)
+264
View File
@@ -0,0 +1,264 @@
from __future__ import annotations
import os
import pathlib
import sys
from functools import partial, update_wrapper
from inspect import cleandoc
from typing import IO, TYPE_CHECKING, Any, BinaryIO, ClassVar, TypeVar, overload
from trio._file_io import AsyncIOWrapper, wrap_file
from trio._util import final
from trio.to_thread import run_sync
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterable
from io import BufferedRandom, BufferedReader, BufferedWriter, FileIO, TextIOWrapper
from _typeshed import (
OpenBinaryMode,
OpenBinaryModeReading,
OpenBinaryModeUpdating,
OpenBinaryModeWriting,
OpenTextMode,
)
from typing_extensions import Concatenate, Literal, ParamSpec, Self
P = ParamSpec("P")
PathT = TypeVar("PathT", bound="Path")
T = TypeVar("T")
def _wraps_async(
wrapped: Callable[..., Any],
) -> Callable[[Callable[P, T]], Callable[P, Awaitable[T]]]:
def decorator(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return await run_sync(partial(fn, *args, **kwargs))
update_wrapper(wrapper, wrapped)
if wrapped.__doc__:
wrapper.__doc__ = (
f"Like :meth:`~{wrapped.__module__}.{wrapped.__qualname__}`, but async.\n"
f"\n"
f"{cleandoc(wrapped.__doc__)}\n"
)
return wrapper
return decorator
def _wrap_method(
fn: Callable[Concatenate[pathlib.Path, P], T],
) -> Callable[Concatenate[Path, P], Awaitable[T]]:
@_wraps_async(fn)
def wrapper(self: Path, /, *args: P.args, **kwargs: P.kwargs) -> T:
return fn(self._wrapped_cls(self), *args, **kwargs)
return wrapper
def _wrap_method_path(
fn: Callable[Concatenate[pathlib.Path, P], pathlib.Path],
) -> Callable[Concatenate[PathT, P], Awaitable[PathT]]:
@_wraps_async(fn)
def wrapper(self: PathT, /, *args: P.args, **kwargs: P.kwargs) -> PathT:
return self.__class__(fn(self._wrapped_cls(self), *args, **kwargs))
return wrapper
def _wrap_method_path_iterable(
fn: Callable[Concatenate[pathlib.Path, P], Iterable[pathlib.Path]],
) -> Callable[Concatenate[PathT, P], Awaitable[Iterable[PathT]]]:
@_wraps_async(fn)
def wrapper(self: PathT, /, *args: P.args, **kwargs: P.kwargs) -> Iterable[PathT]:
return map(self.__class__, [*fn(self._wrapped_cls(self), *args, **kwargs)])
if wrapper.__doc__:
wrapper.__doc__ += (
f"\n"
f"This is an async method that returns a synchronous iterator, so you\n"
f"use it like:\n"
f"\n"
f".. code:: python\n"
f"\n"
f" for subpath in await mypath.{fn.__name__}():\n"
f" ...\n"
f"\n"
f".. note::\n"
f"\n"
f" The iterator is loaded into memory immediately during the initial\n"
f" call (see `issue #501\n"
f" <https://github.com/python-trio/trio/issues/501>`__ for discussion).\n"
)
return wrapper
class Path(pathlib.PurePath):
"""An async :class:`pathlib.Path` that executes blocking methods in :meth:`trio.to_thread.run_sync`.
Instantiating :class:`Path` returns a concrete platform-specific subclass, one of :class:`PosixPath` or
:class:`WindowsPath`.
"""
__slots__ = ()
_wrapped_cls: ClassVar[type[pathlib.Path]]
def __new__(cls, *args: str | os.PathLike[str]) -> Self:
if cls is Path:
cls = WindowsPath if os.name == "nt" else PosixPath # type: ignore[assignment]
return super().__new__(cls, *args)
@classmethod
@_wraps_async(pathlib.Path.cwd)
def cwd(cls) -> Self:
return cls(pathlib.Path.cwd())
@classmethod
@_wraps_async(pathlib.Path.home)
def home(cls) -> Self:
return cls(pathlib.Path.home())
@overload
async def open(
self,
mode: OpenTextMode = "r",
buffering: int = -1,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
) -> AsyncIOWrapper[TextIOWrapper]: ...
@overload
async def open(
self,
mode: OpenBinaryMode,
buffering: Literal[0],
encoding: None = None,
errors: None = None,
newline: None = None,
) -> AsyncIOWrapper[FileIO]: ...
@overload
async def open(
self,
mode: OpenBinaryModeUpdating,
buffering: Literal[-1, 1] = -1,
encoding: None = None,
errors: None = None,
newline: None = None,
) -> AsyncIOWrapper[BufferedRandom]: ...
@overload
async def open(
self,
mode: OpenBinaryModeWriting,
buffering: Literal[-1, 1] = -1,
encoding: None = None,
errors: None = None,
newline: None = None,
) -> AsyncIOWrapper[BufferedWriter]: ...
@overload
async def open(
self,
mode: OpenBinaryModeReading,
buffering: Literal[-1, 1] = -1,
encoding: None = None,
errors: None = None,
newline: None = None,
) -> AsyncIOWrapper[BufferedReader]: ...
@overload
async def open(
self,
mode: OpenBinaryMode,
buffering: int = -1,
encoding: None = None,
errors: None = None,
newline: None = None,
) -> AsyncIOWrapper[BinaryIO]: ...
@overload
async def open( # type: ignore[misc] # Any usage matches builtins.open().
self,
mode: str,
buffering: int = -1,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
) -> AsyncIOWrapper[IO[Any]]: ...
@_wraps_async(pathlib.Path.open) # type: ignore[misc] # Overload return mismatch.
def open(self, *args: Any, **kwargs: Any) -> AsyncIOWrapper[IO[Any]]:
return wrap_file(self._wrapped_cls(self).open(*args, **kwargs))
def __repr__(self) -> str:
return f"trio.Path({str(self)!r})"
stat = _wrap_method(pathlib.Path.stat)
chmod = _wrap_method(pathlib.Path.chmod)
exists = _wrap_method(pathlib.Path.exists)
glob = _wrap_method_path_iterable(pathlib.Path.glob)
rglob = _wrap_method_path_iterable(pathlib.Path.rglob)
is_dir = _wrap_method(pathlib.Path.is_dir)
is_file = _wrap_method(pathlib.Path.is_file)
is_symlink = _wrap_method(pathlib.Path.is_symlink)
is_socket = _wrap_method(pathlib.Path.is_socket)
is_fifo = _wrap_method(pathlib.Path.is_fifo)
is_block_device = _wrap_method(pathlib.Path.is_block_device)
is_char_device = _wrap_method(pathlib.Path.is_char_device)
if sys.version_info >= (3, 12):
is_junction = _wrap_method(pathlib.Path.is_junction)
iterdir = _wrap_method_path_iterable(pathlib.Path.iterdir)
lchmod = _wrap_method(pathlib.Path.lchmod)
lstat = _wrap_method(pathlib.Path.lstat)
mkdir = _wrap_method(pathlib.Path.mkdir)
if sys.platform != "win32":
owner = _wrap_method(pathlib.Path.owner)
group = _wrap_method(pathlib.Path.group)
if sys.platform != "win32" or sys.version_info >= (3, 12):
is_mount = _wrap_method(pathlib.Path.is_mount)
if sys.version_info >= (3, 9):
readlink = _wrap_method_path(pathlib.Path.readlink)
rename = _wrap_method_path(pathlib.Path.rename)
replace = _wrap_method_path(pathlib.Path.replace)
resolve = _wrap_method_path(pathlib.Path.resolve)
rmdir = _wrap_method(pathlib.Path.rmdir)
symlink_to = _wrap_method(pathlib.Path.symlink_to)
if sys.version_info >= (3, 10):
hardlink_to = _wrap_method(pathlib.Path.hardlink_to)
touch = _wrap_method(pathlib.Path.touch)
unlink = _wrap_method(pathlib.Path.unlink)
absolute = _wrap_method_path(pathlib.Path.absolute)
expanduser = _wrap_method_path(pathlib.Path.expanduser)
read_bytes = _wrap_method(pathlib.Path.read_bytes)
read_text = _wrap_method(pathlib.Path.read_text)
samefile = _wrap_method(pathlib.Path.samefile)
write_bytes = _wrap_method(pathlib.Path.write_bytes)
write_text = _wrap_method(pathlib.Path.write_text)
if sys.version_info < (3, 12):
link_to = _wrap_method(pathlib.Path.link_to)
if sys.version_info >= (3, 13):
full_match = _wrap_method(pathlib.Path.full_match)
@final
class PosixPath(Path, pathlib.PurePosixPath):
"""An async :class:`pathlib.PosixPath` that executes blocking methods in :meth:`trio.to_thread.run_sync`."""
__slots__ = ()
_wrapped_cls: ClassVar[type[pathlib.Path]] = pathlib.PosixPath
@final
class WindowsPath(Path, pathlib.PureWindowsPath):
"""An async :class:`pathlib.WindowsPath` that executes blocking methods in :meth:`trio.to_thread.run_sync`."""
__slots__ = ()
_wrapped_cls: ClassVar[type[pathlib.Path]] = pathlib.WindowsPath
@@ -0,0 +1,92 @@
from __future__ import annotations
import ast
import contextlib
import inspect
import sys
import types
import warnings
from code import InteractiveConsole
import outcome
import trio
import trio.lowlevel
from trio._util import final
@final
class TrioInteractiveConsole(InteractiveConsole):
# code.InteractiveInterpreter defines locals as Mapping[str, Any]
# but when we pass this to FunctionType it expects a dict. So
# we make the type more specific on our subclass
locals: dict[str, object]
def __init__(self, repl_locals: dict[str, object] | None = None):
super().__init__(locals=repl_locals)
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
def runcode(self, code: types.CodeType) -> None:
func = types.FunctionType(code, self.locals)
if inspect.iscoroutinefunction(func):
result = trio.from_thread.run(outcome.acapture, func)
else:
result = trio.from_thread.run_sync(outcome.capture, func)
if isinstance(result, outcome.Error):
# If it is SystemExit, quit the repl. Otherwise, print the traceback.
# If there is a SystemExit inside a BaseExceptionGroup, it probably isn't
# the user trying to quit the repl, but rather an error in the code. So, we
# don't try to inspect groups for SystemExit. Instead, we just print and
# return to the REPL.
if isinstance(result.error, SystemExit):
raise result.error
else:
# Inline our own version of self.showtraceback that can use
# outcome.Error.error directly to print clean tracebacks.
# This also means overriding self.showtraceback does nothing.
sys.last_type, sys.last_value = type(result.error), result.error
sys.last_traceback = result.error.__traceback__
# see https://docs.python.org/3/library/sys.html#sys.last_exc
if sys.version_info >= (3, 12):
sys.last_exc = result.error
# We always use sys.excepthook, unlike other implementations.
# This means that overriding self.write also does nothing to tbs.
sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback)
async def run_repl(console: TrioInteractiveConsole) -> None:
banner = (
f"trio REPL {sys.version} on {sys.platform}\n"
f'Use "await" directly instead of "trio.run()".\n'
f'Type "help", "copyright", "credits" or "license" '
f"for more information.\n"
f'{getattr(sys, "ps1", ">>> ")}import trio'
)
try:
await trio.to_thread.run_sync(console.interact, banner)
finally:
warnings.filterwarnings(
"ignore",
message=r"^coroutine .* was never awaited$",
category=RuntimeWarning,
)
def main(original_locals: dict[str, object]) -> None:
with contextlib.suppress(ImportError):
import readline # noqa: F401
repl_locals: dict[str, object] = {"trio": trio}
for key in {
"__name__",
"__package__",
"__loader__",
"__spec__",
"__builtins__",
"__file__",
}:
repl_locals[key] = original_locals[key]
console = TrioInteractiveConsole(repl_locals)
trio.run(run_repl, console)
@@ -0,0 +1,185 @@
from __future__ import annotations
import signal
from collections import OrderedDict
from contextlib import contextmanager
from typing import TYPE_CHECKING
import trio
from ._util import ConflictDetector, is_main_thread, signal_raise
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Callable, Generator, Iterable
from types import FrameType
from typing_extensions import Self
# Discussion of signal handling strategies:
#
# - On Windows signals barely exist. There are no options; signal handlers are
# the only available API.
#
# - On Linux signalfd is arguably the natural way. Semantics: signalfd acts as
# an *alternative* signal delivery mechanism. The way you use it is to mask
# out the relevant signals process-wide (so that they don't get delivered
# the normal way), and then when you read from signalfd that actually counts
# as delivering it (despite the mask). The problem with this is that we
# don't have any reliable way to mask out signals process-wide -- the only
# way to do that in Python is to call pthread_sigmask from the main thread
# *before starting any other threads*, and as a library we can't really
# impose that, and the failure mode is annoying (signals get delivered via
# signal handlers whether we want them to or not).
#
# - on macOS/*BSD, kqueue is the natural way. Semantics: kqueue acts as an
# *extra* signal delivery mechanism. Signals are delivered the normal
# way, *and* are delivered to kqueue. So you want to set them to SIG_IGN so
# that they don't end up pending forever (I guess?). I can't find any actual
# docs on how masking and EVFILT_SIGNAL interact. I did see someone note
# that if a signal is pending when the kqueue filter is added then you
# *don't* get notified of that, which makes sense. But still, we have to
# manipulate signal state (e.g. setting SIG_IGN) which as far as Python is
# concerned means we have to do this from the main thread.
#
# So in summary, there don't seem to be any compelling advantages to using the
# platform-native signal notification systems; they're kinda nice, but it's
# simpler to implement the naive signal-handler-based system once and be
# done. (The big advantage would be if there were a reliable way to monitor
# for SIGCHLD from outside the main thread and without interfering with other
# libraries that also want to monitor for SIGCHLD. But there isn't. I guess
# kqueue might give us that, but in kqueue we don't need it, because kqueue
# can directly monitor for child process state changes.)
@contextmanager
def _signal_handler(
signals: Iterable[int],
handler: Callable[[int, FrameType | None], object] | int | signal.Handlers | None,
) -> Generator[None, None, None]:
original_handlers = {}
try:
for signum in set(signals):
original_handlers[signum] = signal.signal(signum, handler)
yield
finally:
for signum, original_handler in original_handlers.items():
signal.signal(signum, original_handler)
class SignalReceiver:
def __init__(self) -> None:
# {signal num: None}
self._pending: OrderedDict[int, None] = OrderedDict()
self._lot = trio.lowlevel.ParkingLot()
self._conflict_detector = ConflictDetector(
"only one task can iterate on a signal receiver at a time",
)
self._closed = False
def _add(self, signum: int) -> None:
if self._closed:
signal_raise(signum)
else:
self._pending[signum] = None
self._lot.unpark()
def _redeliver_remaining(self) -> None:
# First make sure that any signals still in the delivery pipeline will
# get redelivered
self._closed = True
# And then redeliver any that are sitting in pending. This is done
# using a weird recursive construct to make sure we process everything
# even if some of the handlers raise exceptions.
def deliver_next() -> None:
if self._pending:
signum, _ = self._pending.popitem(last=False)
try:
signal_raise(signum)
finally:
deliver_next()
deliver_next()
def __aiter__(self) -> Self:
return self
async def __anext__(self) -> int:
if self._closed:
raise RuntimeError("open_signal_receiver block already exited")
# In principle it would be possible to support multiple concurrent
# calls to __anext__, but doing it without race conditions is quite
# tricky, and there doesn't seem to be any point in trying.
with self._conflict_detector:
if not self._pending:
await self._lot.park()
else:
await trio.lowlevel.checkpoint()
signum, _ = self._pending.popitem(last=False)
return signum
def get_pending_signal_count(rec: AsyncIterator[int]) -> int:
"""Helper for tests, not public or otherwise used."""
# open_signal_receiver() always produces SignalReceiver, this should not fail.
assert isinstance(rec, SignalReceiver)
return len(rec._pending)
@contextmanager
def open_signal_receiver(
*signals: signal.Signals | int,
) -> Generator[AsyncIterator[int], None, None]:
"""A context manager for catching signals.
Entering this context manager starts listening for the given signals and
returns an async iterator; exiting the context manager stops listening.
The async iterator blocks until a signal arrives, and then yields it.
Note that if you leave the ``with`` block while the iterator has
unextracted signals still pending inside it, then they will be
re-delivered using Python's regular signal handling logic. This avoids a
race condition when signals arrives just before we exit the ``with``
block.
Args:
signals: the signals to listen for.
Raises:
TypeError: if no signals were provided.
RuntimeError: if you try to use this anywhere except Python's main
thread. (This is a Python limitation.)
Example:
A common convention for Unix daemons is that they should reload their
configuration when they receive a ``SIGHUP``. Here's a sketch of what
that might look like using :func:`open_signal_receiver`::
with trio.open_signal_receiver(signal.SIGHUP) as signal_aiter:
async for signum in signal_aiter:
assert signum == signal.SIGHUP
reload_configuration()
"""
if not signals:
raise TypeError("No signals were provided")
if not is_main_thread():
raise RuntimeError(
"Sorry, open_signal_receiver is only possible when running in "
"Python interpreter's main thread",
)
token = trio.lowlevel.current_trio_token()
queue = SignalReceiver()
def handler(signum: int, frame: FrameType | None) -> None:
token.run_sync_soon(queue._add, signum, idempotent=True)
try:
with _signal_handler(signals, handler):
yield queue
finally:
queue._redeliver_remaining()
File diff suppressed because it is too large Load Diff
+951
View File
@@ -0,0 +1,951 @@
from __future__ import annotations
import contextlib
import operator as _operator
import ssl as _stdlib_ssl
from enum import Enum as _Enum
from typing import TYPE_CHECKING, Any, ClassVar, Final as TFinal, Generic, TypeVar
import trio
from . import _sync
from ._highlevel_generic import aclose_forcefully
from ._util import ConflictDetector, final
from .abc import Listener, Stream
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
# General theory of operation:
#
# We implement an API that closely mirrors the stdlib ssl module's blocking
# API, and we do it using the stdlib ssl module's non-blocking in-memory API.
# The stdlib non-blocking in-memory API is barely documented, and acts as a
# thin wrapper around openssl, whose documentation also leaves something to be
# desired. So here's the main things you need to know to understand the code
# in this file:
#
# We use an ssl.SSLObject, which exposes the four main I/O operations:
#
# - do_handshake: performs the initial handshake. Must be called once at the
# beginning of each connection; is a no-op once it's completed once.
#
# - write: takes some unencrypted data and attempts to send it to the remote
# peer.
# - read: attempts to decrypt and return some data from the remote peer.
#
# - unwrap: this is weirdly named; maybe it helps to realize that the thing it
# wraps is called SSL_shutdown. It sends a cryptographically signed message
# saying "I'm closing this connection now", and then waits to receive the
# same from the remote peer (unless we already received one, in which case
# it returns immediately).
#
# All of these operations read and write from some in-memory buffers called
# "BIOs", which are an opaque OpenSSL-specific object that's basically
# semantically equivalent to a Python bytearray. When they want to send some
# bytes to the remote peer, they append them to the outgoing BIO, and when
# they want to receive some bytes from the remote peer, they try to pull them
# out of the incoming BIO. "Sending" always succeeds, because the outgoing BIO
# can always be extended to hold more data. "Receiving" acts sort of like a
# non-blocking socket: it might manage to get some data immediately, or it
# might fail and need to be tried again later. We can also directly add or
# remove data from the BIOs whenever we want.
#
# Now the problem is that while these I/O operations are opaque atomic
# operations from the point of view of us calling them, under the hood they
# might require some arbitrary sequence of sends and receives from the remote
# peer. This is particularly true for do_handshake, which generally requires a
# few round trips, but it's also true for write and read, due to an evil thing
# called "renegotiation".
#
# Renegotiation is the process by which one of the peers might arbitrarily
# decide to redo the handshake at any time. Did I mention it's evil? It's
# pretty evil, and almost universally hated. The HTTP/2 spec forbids the use
# of TLS renegotiation for HTTP/2 connections. TLS 1.3 removes it from the
# protocol entirely. It's impossible to trigger a renegotiation if using
# Python's ssl module. OpenSSL's renegotiation support is pretty buggy [1].
# Nonetheless, it does get used in real life, mostly in two cases:
#
# 1) Normally in TLS 1.2 and below, when the client side of a connection wants
# to present a certificate to prove their identity, that certificate gets sent
# in plaintext. This is bad, because it means that anyone eavesdropping can
# see who's connecting it's like sending your username in plain text. Not as
# bad as sending your password in plain text, but still, pretty bad. However,
# renegotiations *are* encrypted. So as a workaround, it's not uncommon for
# systems that want to use client certificates to first do an anonymous
# handshake, and then to turn around and do a second handshake (=
# renegotiation) and this time ask for a client cert. Or sometimes this is
# done on a case-by-case basis, e.g. a web server might accept a connection,
# read the request, and then once it sees the page you're asking for it might
# stop and ask you for a certificate.
#
# 2) In principle the same TLS connection can be used for an arbitrarily long
# time, and might transmit arbitrarily large amounts of data. But this creates
# a cryptographic problem: an attacker who has access to arbitrarily large
# amounts of data that's all encrypted using the same key may eventually be
# able to use this to figure out the key. Is this a real practical problem? I
# have no idea, I'm not a cryptographer. In any case, some people worry that
# it's a problem, so their TLS libraries are designed to automatically trigger
# a renegotiation every once in a while on some sort of timer.
#
# The end result is that you might be going along, minding your own business,
# and then *bam*! a wild renegotiation appears! And you just have to cope.
#
# The reason that coping with renegotiations is difficult is that some
# unassuming "read" or "write" call might find itself unable to progress until
# it does a handshake, which remember is a process with multiple round
# trips. So read might have to send data, and write might have to receive
# data, and this might happen multiple times. And some of those attempts might
# fail because there isn't any data yet, and need to be retried. Managing all
# this is pretty complicated.
#
# Here's how openssl (and thus the stdlib ssl module) handle this. All of the
# I/O operations above follow the same rules. When you call one of them:
#
# - it might write some data to the outgoing BIO
# - it might read some data from the incoming BIO
# - it might raise SSLWantReadError if it can't complete without reading more
# data from the incoming BIO. This is important: the "read" in ReadError
# refers to reading from the *underlying* stream.
# - (and in principle it might raise SSLWantWriteError too, but that never
# happens when using memory BIOs, so never mind)
#
# If it doesn't raise an error, then the operation completed successfully
# (though we still need to take any outgoing data out of the memory buffer and
# put it onto the wire). If it *does* raise an error, then we need to retry
# *exactly that method call* later in particular, if a 'write' failed, we
# need to try again later *with the same data*, because openssl might have
# already committed some of the initial parts of our data to its output even
# though it didn't tell us that, and has remembered that the next time we call
# write it needs to skip the first 1024 bytes or whatever it is. (Well,
# technically, we're actually allowed to call 'write' again with a data buffer
# which is the same as our old one PLUS some extra stuff added onto the end,
# but in Trio that never comes up so never mind.)
#
# There are some people online who claim that once you've gotten a Want*Error
# then the *very next call* you make to openssl *must* be the same as the
# previous one. I'm pretty sure those people are wrong. In particular, it's
# okay to call write, get a WantReadError, and then call read a few times;
# it's just that *the next time you call write*, it has to be with the same
# data.
#
# One final wrinkle: we want our SSLStream to support full-duplex operation,
# i.e. it should be possible for one task to be calling send_all while another
# task is calling receive_some. But renegotiation makes this a big hassle, because
# even if SSLStream's restricts themselves to one task calling send_all and one
# task calling receive_some, those two tasks might end up both wanting to call
# send_all, or both to call receive_some at the same time *on the underlying
# stream*. So we have to do some careful locking to hide this problem from our
# users.
#
# (Renegotiation is evil.)
#
# So our basic strategy is to define a single helper method called "_retry",
# which has generic logic for dealing with SSLWantReadError, pushing data from
# the outgoing BIO to the wire, reading data from the wire to the incoming
# BIO, retrying an I/O call until it works, and synchronizing with other tasks
# that might be calling _retry concurrently. Basically it takes an SSLObject
# non-blocking in-memory method and converts it into a Trio async blocking
# method. _retry is only about 30 lines of code, but all these cases
# multiplied by concurrent calls make it extremely tricky, so there are lots
# of comments down below on the details, and a really extensive test suite in
# test_ssl.py. And now you know *why* it's so tricky, and can probably
# understand how it works.
#
# [1] https://rt.openssl.org/Ticket/Display.html?id=3712
# XX how closely should we match the stdlib API?
# - maybe suppress_ragged_eofs=False is a better default?
# - maybe check crypto folks for advice?
# - this is also interesting: https://bugs.python.org/issue8108#msg102867
# Definitely keep an eye on Cory's TLS API ideas on security-sig etc.
# XX document behavior on cancellation/error (i.e.: all is lost abandon
# stream)
# docs will need to make very clear that this is different from all the other
# cancellations in core Trio
T = TypeVar("T")
################################################################
# SSLStream
################################################################
# Ideally, when the user calls SSLStream.receive_some() with no argument, then
# we should do exactly one call to self.transport_stream.receive_some(),
# decrypt everything we got, and return it. Unfortunately, the way openssl's
# API works, we have to pick how much data we want to allow when we call
# read(), and then it (potentially) triggers a call to
# transport_stream.receive_some(). So at the time we pick the amount of data
# to decrypt, we don't know how much data we've read. As a simple heuristic,
# we record the max amount of data returned by previous calls to
# transport_stream.receive_some(), and we use that for future calls to read().
# But what do we use for the very first call? That's what this constant sets.
#
# Note that the value passed to read() is a limit on the amount of
# *decrypted* data, but we can only see the size of the *encrypted* data
# returned by transport_stream.receive_some(). TLS adds a small amount of
# framing overhead, and TLS compression is rarely used these days because it's
# insecure. So the size of the encrypted data should be a slight over-estimate
# of the size of the decrypted data, which is exactly what we want.
#
# The specific value is not really based on anything; it might be worth tuning
# at some point. But, if you have an TCP connection with the typical 1500 byte
# MTU and an initial window of 10 (see RFC 6928), then the initial burst of
# data will be limited to ~15000 bytes (or a bit less due to IP-level framing
# overhead), so this is chosen to be larger than that.
STARTING_RECEIVE_SIZE: TFinal = 16384
def _is_eof(exc: BaseException | None) -> bool:
# There appears to be a bug on Python 3.10, where SSLErrors
# aren't properly translated into SSLEOFErrors.
# This stringly-typed error check is borrowed from the AnyIO
# project.
return isinstance(exc, _stdlib_ssl.SSLEOFError) or (
"UNEXPECTED_EOF_WHILE_READING" in getattr(exc, "strerror", ())
)
class NeedHandshakeError(Exception):
"""Some :class:`SSLStream` methods can't return any meaningful data until
after the handshake. If you call them before the handshake, they raise
this error.
"""
class _Once:
def __init__(self, afn: Callable[..., Awaitable[object]], *args: object) -> None:
self._afn = afn
self._args = args
self.started = False
self._done = _sync.Event()
async def ensure(self, *, checkpoint: bool) -> None:
if not self.started:
self.started = True
await self._afn(*self._args)
self._done.set()
elif not checkpoint and self._done.is_set():
return
else:
await self._done.wait()
@property
def done(self) -> bool:
return bool(self._done.is_set())
_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"])
# invariant
T_Stream = TypeVar("T_Stream", bound=Stream)
@final
class SSLStream(Stream, Generic[T_Stream]):
r"""Encrypted communication using SSL/TLS.
:class:`SSLStream` wraps an arbitrary :class:`~trio.abc.Stream`, and
allows you to perform encrypted communication over it using the usual
:class:`~trio.abc.Stream` interface. You pass regular data to
:meth:`send_all`, then it encrypts it and sends the encrypted data on the
underlying :class:`~trio.abc.Stream`; :meth:`receive_some` takes encrypted
data out of the underlying :class:`~trio.abc.Stream` and decrypts it
before returning it.
You should read the standard library's :mod:`ssl` documentation carefully
before attempting to use this class, and probably other general
documentation on SSL/TLS as well. SSL/TLS is subtle and quick to
anger. Really. I'm not kidding.
Args:
transport_stream (~trio.abc.Stream): The stream used to transport
encrypted data. Required.
ssl_context (~ssl.SSLContext): The :class:`~ssl.SSLContext` used for
this connection. Required. Usually created by calling
:func:`ssl.create_default_context`.
server_hostname (str, bytes, or None): The name of the server being
connected to. Used for `SNI
<https://en.wikipedia.org/wiki/Server_Name_Indication>`__ and for
validating the server's certificate (if hostname checking is
enabled). This is effectively mandatory for clients, and actually
mandatory if ``ssl_context.check_hostname`` is ``True``.
server_side (bool): Whether this stream is acting as a client or
server. Defaults to False, i.e. client mode.
https_compatible (bool): There are two versions of SSL/TLS commonly
encountered in the wild: the standard version, and the version used
for HTTPS (HTTP-over-SSL/TLS).
Standard-compliant SSL/TLS implementations always send a
cryptographically signed ``close_notify`` message before closing the
connection. This is important because if the underlying transport
were simply closed, then there wouldn't be any way for the other
side to know whether the connection was intentionally closed by the
peer that they negotiated a cryptographic connection to, or by some
`man-in-the-middle
<https://en.wikipedia.org/wiki/Man-in-the-middle_attack>`__ attacker
who can't manipulate the cryptographic stream, but can manipulate
the transport layer (a so-called "truncation attack").
However, this part of the standard is widely ignored by real-world
HTTPS implementations, which means that if you want to interoperate
with them, then you NEED to ignore it too.
Fortunately this isn't as bad as it sounds, because the HTTP
protocol already includes its own equivalent of ``close_notify``, so
doing this again at the SSL/TLS level is redundant. But not all
protocols do! Therefore, by default Trio implements the safer
standard-compliant version (``https_compatible=False``). But if
you're speaking HTTPS or some other protocol where
``close_notify``\s are commonly skipped, then you should set
``https_compatible=True``; with this setting, Trio will neither
expect nor send ``close_notify`` messages.
If you have code that was written to use :class:`ssl.SSLSocket` and
now you're porting it to Trio, then it may be useful to know that a
difference between :class:`SSLStream` and :class:`ssl.SSLSocket` is
that :class:`~ssl.SSLSocket` implements the
``https_compatible=True`` behavior by default.
Attributes:
transport_stream (trio.abc.Stream): The underlying transport stream
that was passed to ``__init__``. An example of when this would be
useful is if you're using :class:`SSLStream` over a
:class:`~trio.SocketStream` and want to call the
:class:`~trio.SocketStream`'s :meth:`~trio.SocketStream.setsockopt`
method.
Internally, this class is implemented using an instance of
:class:`ssl.SSLObject`, and all of :class:`~ssl.SSLObject`'s methods and
attributes are re-exported as methods and attributes on this class.
However, there is one difference: :class:`~ssl.SSLObject` has several
methods that return information about the encrypted connection, like
:meth:`~ssl.SSLSocket.cipher` or
:meth:`~ssl.SSLSocket.selected_alpn_protocol`. If you call them before the
handshake, when they can't possibly return useful data, then
:class:`ssl.SSLObject` returns None, but :class:`trio.SSLStream`
raises :exc:`NeedHandshakeError`.
This also means that if you register a SNI callback using
`~ssl.SSLContext.sni_callback`, then the first argument your callback
receives will be a :class:`ssl.SSLObject`.
"""
# Note: any new arguments here should likely also be added to
# SSLListener.__init__, and maybe the open_ssl_over_tcp_* helpers.
def __init__(
self,
transport_stream: T_Stream,
ssl_context: _stdlib_ssl.SSLContext,
*,
server_hostname: str | bytes | None = None,
server_side: bool = False,
https_compatible: bool = False,
) -> None:
self.transport_stream: T_Stream = transport_stream
self._state = _State.OK
self._https_compatible = https_compatible
self._outgoing = _stdlib_ssl.MemoryBIO()
self._delayed_outgoing: bytes | None = None
self._incoming = _stdlib_ssl.MemoryBIO()
self._ssl_object = ssl_context.wrap_bio(
self._incoming,
self._outgoing,
server_side=server_side,
server_hostname=server_hostname,
)
# Tracks whether we've already done the initial handshake
self._handshook = _Once(self._do_handshake)
# These are used to synchronize access to self.transport_stream
self._inner_send_lock = _sync.StrictFIFOLock()
self._inner_recv_count = 0
self._inner_recv_lock = _sync.Lock()
# These are used to make sure that our caller doesn't attempt to make
# multiple concurrent calls to send_all/wait_send_all_might_not_block
# or to receive_some.
self._outer_send_conflict_detector = ConflictDetector(
"another task is currently sending data on this SSLStream",
)
self._outer_recv_conflict_detector = ConflictDetector(
"another task is currently receiving data on this SSLStream",
)
self._estimated_receive_size = STARTING_RECEIVE_SIZE
_forwarded: ClassVar = {
"context",
"server_side",
"server_hostname",
"session",
"session_reused",
"getpeercert",
"selected_npn_protocol",
"cipher",
"shared_ciphers",
"compression",
"pending",
"get_channel_binding",
"selected_alpn_protocol",
"version",
}
_after_handshake: ClassVar = {
"session_reused",
"getpeercert",
"selected_npn_protocol",
"cipher",
"shared_ciphers",
"compression",
"get_channel_binding",
"selected_alpn_protocol",
"version",
}
def __getattr__(self, name: str) -> Any:
if name in self._forwarded:
if name in self._after_handshake and not self._handshook.done:
raise NeedHandshakeError(f"call do_handshake() before calling {name!r}")
return getattr(self._ssl_object, name)
else:
raise AttributeError(name)
def __setattr__(self, name: str, value: object) -> None:
if name in self._forwarded:
setattr(self._ssl_object, name, value)
else:
super().__setattr__(name, value)
def __dir__(self) -> list[str]:
return list(super().__dir__()) + list(self._forwarded)
def _check_status(self) -> None:
if self._state is _State.OK:
return
elif self._state is _State.BROKEN:
raise trio.BrokenResourceError
elif self._state is _State.CLOSED:
raise trio.ClosedResourceError
else: # pragma: no cover
raise AssertionError()
# This is probably the single trickiest function in Trio. It has lots of
# comments, though, just make sure to think carefully if you ever have to
# touch it. The big comment at the top of this file will help explain
# too.
async def _retry(
self,
fn: Callable[..., T],
*args: object,
ignore_want_read: bool = False,
is_handshake: bool = False,
) -> T | None:
await trio.lowlevel.checkpoint_if_cancelled()
yielded = False
finished = False
while not finished:
# WARNING: this code needs to be very careful with when it
# calls 'await'! There might be multiple tasks calling this
# function at the same time trying to do different operations,
# so we need to be careful to:
#
# 1) interact with the SSLObject, then
# 2) await on exactly one thing that lets us make forward
# progress, then
# 3) loop or exit
#
# In particular we don't want to yield while interacting with
# the SSLObject (because it's shared state, so someone else
# might come in and mess with it while we're suspended), and
# we don't want to yield *before* starting the operation that
# will help us make progress, because then someone else might
# come in and leapfrog us.
# Call the SSLObject method, and get its result.
#
# NB: despite what the docs say, SSLWantWriteError can't
# happen "Writes to memory BIOs will always succeed if
# memory is available: that is their size can grow
# indefinitely."
# https://wiki.openssl.org/index.php/Manual:BIO_s_mem(3)
want_read = False
ret = None
try:
ret = fn(*args)
except _stdlib_ssl.SSLWantReadError:
want_read = True
except (_stdlib_ssl.SSLError, _stdlib_ssl.CertificateError) as exc:
self._state = _State.BROKEN
raise trio.BrokenResourceError from exc
else:
finished = True
if ignore_want_read:
want_read = False
finished = True
to_send = self._outgoing.read()
# Some versions of SSL_do_handshake have a bug in how they handle
# the TLS 1.3 handshake on the server side: after the handshake
# finishes, they automatically send session tickets, even though
# the client may not be expecting data to arrive at this point and
# sending it could cause a deadlock or lost data. This applies at
# least to OpenSSL 1.1.1c and earlier, and the OpenSSL devs
# currently have no plans to fix it:
#
# https://github.com/openssl/openssl/issues/7948
# https://github.com/openssl/openssl/issues/7967
#
# The correct behavior is to wait to send session tickets on the
# first call to SSL_write. (This is what BoringSSL does.) So, we
# use a heuristic to detect when OpenSSL has tried to send session
# tickets, and we manually delay sending them until the
# appropriate moment. For more discussion see:
#
# https://github.com/python-trio/trio/issues/819#issuecomment-517529763
if (
is_handshake
and not want_read
and self._ssl_object.server_side
and self._ssl_object.version() == "TLSv1.3"
):
assert self._delayed_outgoing is None
self._delayed_outgoing = to_send
to_send = b""
# Outputs from the above code block are:
#
# - to_send: bytestring; if non-empty then we need to send
# this data to make forward progress
#
# - want_read: True if we need to receive_some some data to make
# forward progress
#
# - finished: False means that we need to retry the call to
# fn(*args) again, after having pushed things forward. True
# means we still need to do whatever was said (in particular
# send any data in to_send), but once we do then we're
# done.
#
# - ret: the operation's return value. (Meaningless unless
# finished is True.)
#
# Invariant: want_read and finished can't both be True at the
# same time.
#
# Now we need to move things forward. There are two things we
# might have to do, and any given operation might require
# either, both, or neither to proceed:
#
# - send the data in to_send
#
# - receive_some some data and put it into the incoming BIO
#
# Our strategy is: if there's data to send, send it;
# *otherwise* if there's data to receive_some, receive_some it.
#
# If both need to happen, then we only send. Why? Well, we
# know that *right now* we have to both send and receive_some
# before the operation can complete. But as soon as we yield,
# that information becomes potentially stale e.g. while
# we're sending, some other task might go and receive_some the
# data we need and put it into the incoming BIO. And if it
# does, then we *definitely don't* want to do a receive_some
# there might not be any more data coming, and we'd deadlock!
# We could do something tricky to keep track of whether a
# receive_some happens while we're sending, but the case where
# we have to do both is very unusual (only during a
# renegotiation), so it's better to keep things simple. So we
# do just one potentially-blocking operation, then check again
# for fresh information.
#
# And we prioritize sending over receiving because, if there
# are multiple tasks that want to receive_some, then it
# doesn't matter what order they go in. But if there are
# multiple tasks that want to send, then they each have
# different data, and the data needs to get put onto the wire
# in the same order that it was retrieved from the outgoing
# BIO. So if we have data to send, that *needs* to be the
# *very* *next* *thing* we do, to make sure no-one else sneaks
# in before us. Or if we can't send immediately because
# someone else is, then we at least need to get in line
# immediately.
if to_send:
# NOTE: This relies on the lock being strict FIFO fair!
async with self._inner_send_lock:
yielded = True
try:
if self._delayed_outgoing is not None:
to_send = self._delayed_outgoing + to_send
self._delayed_outgoing = None
await self.transport_stream.send_all(to_send)
except:
# Some unknown amount of our data got sent, and we
# don't know how much. This stream is doomed.
self._state = _State.BROKEN
raise
elif want_read:
# It's possible that someone else is already blocked in
# transport_stream.receive_some. If so then we want to
# wait for them to finish, but we don't want to call
# transport_stream.receive_some again ourselves; we just
# want to loop around and check if their contribution
# helped anything. So we make a note of how many times
# some task has been through here before taking the lock,
# and if it's changed by the time we get the lock, then we
# skip calling transport_stream.receive_some and loop
# around immediately.
recv_count = self._inner_recv_count
async with self._inner_recv_lock:
yielded = True
if recv_count == self._inner_recv_count:
data = await self.transport_stream.receive_some()
if not data:
self._incoming.write_eof()
else:
self._estimated_receive_size = max(
self._estimated_receive_size,
len(data),
)
self._incoming.write(data)
self._inner_recv_count += 1
if not yielded:
await trio.lowlevel.cancel_shielded_checkpoint()
return ret
async def _do_handshake(self) -> None:
try:
await self._retry(self._ssl_object.do_handshake, is_handshake=True)
except:
self._state = _State.BROKEN
raise
async def do_handshake(self) -> None:
"""Ensure that the initial handshake has completed.
The SSL protocol requires an initial handshake to exchange
certificates, select cryptographic keys, and so forth, before any
actual data can be sent or received. You don't have to call this
method; if you don't, then :class:`SSLStream` will automatically
perform the handshake as needed, the first time you try to send or
receive data. But if you want to trigger it manually for example,
because you want to look at the peer's certificate before you start
talking to them then you can call this method.
If the initial handshake is already in progress in another task, this
waits for it to complete and then returns.
If the initial handshake has already completed, this returns
immediately without doing anything (except executing a checkpoint).
.. warning:: If this method is cancelled, then it may leave the
:class:`SSLStream` in an unusable state. If this happens then any
future attempt to use the object will raise
:exc:`trio.BrokenResourceError`.
"""
self._check_status()
await self._handshook.ensure(checkpoint=True)
# Most things work if we don't explicitly force do_handshake to be called
# before calling receive_some or send_all, because openssl will
# automatically perform the handshake on the first SSL_{read,write}
# call. BUT, allowing openssl to do this will disable Python's hostname
# checking!!! See:
# https://bugs.python.org/issue30141
# So we *definitely* have to make sure that do_handshake is called
# before doing anything else.
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
"""Read some data from the underlying transport, decrypt it, and
return it.
See :meth:`trio.abc.ReceiveStream.receive_some` for details.
.. warning:: If this method is cancelled while the initial handshake
or a renegotiation are in progress, then it may leave the
:class:`SSLStream` in an unusable state. If this happens then any
future attempt to use the object will raise
:exc:`trio.BrokenResourceError`.
"""
with self._outer_recv_conflict_detector:
self._check_status()
try:
await self._handshook.ensure(checkpoint=False)
except trio.BrokenResourceError as exc:
# For some reason, EOF before handshake sometimes raises
# SSLSyscallError instead of SSLEOFError (e.g. on my linux
# laptop, but not on appveyor). Thanks openssl.
if self._https_compatible and (
isinstance(exc.__cause__, _stdlib_ssl.SSLSyscallError)
or _is_eof(exc.__cause__)
):
await trio.lowlevel.checkpoint()
return b""
else:
raise
if max_bytes is None:
# If we somehow have more data already in our pending buffer
# than the estimate receive size, bump up our size a bit for
# this read only.
max_bytes = max(self._estimated_receive_size, self._incoming.pending)
else:
max_bytes = _operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
try:
received = await self._retry(self._ssl_object.read, max_bytes)
assert received is not None
return received
except trio.BrokenResourceError as exc:
# This isn't quite equivalent to just returning b"" in the
# first place, because we still end up with self._state set to
# BROKEN. But that's actually fine, because after getting an
# EOF on TLS then the only thing you can do is close the
# stream, and closing doesn't care about the state.
if self._https_compatible and _is_eof(exc.__cause__):
await trio.lowlevel.checkpoint()
return b""
else:
raise
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
"""Encrypt some data and then send it on the underlying transport.
See :meth:`trio.abc.SendStream.send_all` for details.
.. warning:: If this method is cancelled, then it may leave the
:class:`SSLStream` in an unusable state. If this happens then any
attempt to use the object will raise
:exc:`trio.BrokenResourceError`.
"""
with self._outer_send_conflict_detector:
self._check_status()
await self._handshook.ensure(checkpoint=False)
# SSLObject interprets write(b"") as an EOF for some reason, which
# is not what we want.
if not data:
await trio.lowlevel.checkpoint()
return
await self._retry(self._ssl_object.write, data)
async def unwrap(self) -> tuple[Stream, bytes | bytearray]:
"""Cleanly close down the SSL/TLS encryption layer, allowing the
underlying stream to be used for unencrypted communication.
You almost certainly don't need this.
Returns:
A pair ``(transport_stream, trailing_bytes)``, where
``transport_stream`` is the underlying transport stream, and
``trailing_bytes`` is a byte string. Since :class:`SSLStream`
doesn't necessarily know where the end of the encrypted data will
be, it can happen that it accidentally reads too much from the
underlying stream. ``trailing_bytes`` contains this extra data; you
should process it as if it was returned from a call to
``transport_stream.receive_some(...)``.
"""
with self._outer_recv_conflict_detector, self._outer_send_conflict_detector:
self._check_status()
await self._handshook.ensure(checkpoint=False)
await self._retry(self._ssl_object.unwrap)
transport_stream = self.transport_stream
self._state = _State.CLOSED
self.transport_stream = None # type: ignore[assignment] # State is CLOSED now, nothing should use
return (transport_stream, self._incoming.read())
async def aclose(self) -> None:
"""Gracefully shut down this connection, and close the underlying
transport.
If ``https_compatible`` is False (the default), then this attempts to
first send a ``close_notify`` and then close the underlying stream by
calling its :meth:`~trio.abc.AsyncResource.aclose` method.
If ``https_compatible`` is set to True, then this simply closes the
underlying stream and marks this stream as closed.
"""
if self._state is _State.CLOSED:
await trio.lowlevel.checkpoint()
return
if self._state is _State.BROKEN or self._https_compatible:
self._state = _State.CLOSED
await self.transport_stream.aclose()
return
try:
# https_compatible=False, so we're in spec-compliant mode and have
# to send close_notify so that the other side gets a cryptographic
# assurance that we've called aclose. Of course, we can't do
# anything cryptographic until after we've completed the
# handshake:
await self._handshook.ensure(checkpoint=False)
# Then, we call SSL_shutdown *once*, because we want to send a
# close_notify but *not* wait for the other side to send back a
# response. In principle it would be more polite to wait for the
# other side to reply with their own close_notify. However, if
# they aren't paying attention (e.g., if they're just sending
# data and not receiving) then we will never notice our
# close_notify and we'll be waiting forever. Eventually we'll time
# out (hopefully), but it's still kind of nasty. And we can't
# require the other side to always be receiving, because (a)
# backpressure is kind of important, and (b) I bet there are
# broken TLS implementations out there that don't receive all the
# time. (Like e.g. anyone using Python ssl in synchronous mode.)
#
# The send-then-immediately-close behavior is explicitly allowed
# by the TLS specs, so we're ok on that.
#
# Subtlety: SSLObject.unwrap will immediately call it a second
# time, and the second time will raise SSLWantReadError because
# there hasn't been time for the other side to respond
# yet. (Unless they spontaneously sent a close_notify before we
# called this, and it's either already been processed or gets
# pulled out of the buffer by Python's second call.) So the way to
# do what we want is to ignore SSLWantReadError on this call.
#
# Also, because the other side might have already sent
# close_notify and closed their connection then it's possible that
# our attempt to send close_notify will raise
# BrokenResourceError. This is totally legal, and in fact can happen
# with two well-behaved Trio programs talking to each other, so we
# don't want to raise an error. So we suppress BrokenResourceError
# here. (This is safe, because literally the only thing this call
# to _retry will do is send the close_notify alert, so that's
# surely where the error comes from.)
#
# FYI in some cases this could also raise SSLSyscallError which I
# think is because SSL_shutdown is terrible. (Check out that note
# at the bottom of the man page saying that it sometimes gets
# raised spuriously.) I haven't seen this since we switched to
# immediately closing the socket, and I don't know exactly what
# conditions cause it and how to respond, so for now we're just
# letting that happen. But if you start seeing it, then hopefully
# this will give you a little head start on tracking it down,
# because whoa did this puzzle us at the 2017 PyCon sprints.
#
# Also, if someone else is blocked in send/receive, then we aren't
# going to be able to do a clean shutdown. If that happens, we'll
# just do an unclean shutdown.
with contextlib.suppress(trio.BrokenResourceError, trio.BusyResourceError):
await self._retry(self._ssl_object.unwrap, ignore_want_read=True)
except:
# Failure! Kill the stream and move on.
await aclose_forcefully(self.transport_stream)
raise
else:
# Success! Gracefully close the underlying stream.
await self.transport_stream.aclose()
finally:
self._state = _State.CLOSED
async def wait_send_all_might_not_block(self) -> None:
"""See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`."""
# This method's implementation is deceptively simple.
#
# First, we take the outer send lock, because of Trio's standard
# semantics that wait_send_all_might_not_block and send_all
# conflict.
with self._outer_send_conflict_detector:
self._check_status()
# Then we take the inner send lock. We know that no other tasks
# are calling self.send_all or self.wait_send_all_might_not_block,
# because we have the outer_send_lock. But! There might be another
# task calling self.receive_some -> transport_stream.send_all, in
# which case if we were to call
# transport_stream.wait_send_all_might_not_block directly we'd
# have two tasks doing write-related operations on
# transport_stream simultaneously, which is not allowed. We
# *don't* want to raise this conflict to our caller, because it's
# purely an internal affair all they did was call
# wait_send_all_might_not_block and receive_some at the same time,
# which is totally valid. And waiting for the lock is OK, because
# a call to send_all certainly wouldn't complete while the other
# task holds the lock.
async with self._inner_send_lock:
# Now we have the lock, which creates another potential
# problem: what if a call to self.receive_some attempts to do
# transport_stream.send_all now? It'll have to wait for us to
# finish! But that's OK, because we release the lock as soon
# as the underlying stream becomes writable, and the
# self.receive_some call wasn't going to make any progress
# until then anyway.
#
# Of course, this does mean we might return *before* the
# stream is logically writable, because immediately after we
# return self.receive_some might write some data and make it
# non-writable again. But that's OK too,
# wait_send_all_might_not_block only guarantees that it
# doesn't return late.
await self.transport_stream.wait_send_all_might_not_block()
# this is necessary for Sphinx, see also `_abc.py`
SSLStream.__module__ = SSLStream.__module__.replace("._ssl", "")
@final
class SSLListener(Listener[SSLStream[T_Stream]]):
"""A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers.
:class:`SSLListener` wraps around another Listener, and converts
all incoming connections to encrypted connections by wrapping them
in a :class:`SSLStream`.
Args:
transport_listener (~trio.abc.Listener): The listener whose incoming
connections will be wrapped in :class:`SSLStream`.
ssl_context (~ssl.SSLContext): The :class:`~ssl.SSLContext` that will be
used for incoming connections.
https_compatible (bool): Passed on to :class:`SSLStream`.
Attributes:
transport_listener (trio.abc.Listener): The underlying listener that was
passed to ``__init__``.
"""
def __init__(
self,
transport_listener: Listener[T_Stream],
ssl_context: _stdlib_ssl.SSLContext,
*,
https_compatible: bool = False,
) -> None:
self.transport_listener = transport_listener
self._ssl_context = ssl_context
self._https_compatible = https_compatible
async def accept(self) -> SSLStream[T_Stream]:
"""Accept the next connection and wrap it in an :class:`SSLStream`.
See :meth:`trio.abc.Listener.accept` for details.
"""
transport_stream = await self.transport_listener.accept()
return SSLStream(
transport_stream,
self._ssl_context,
server_side=True,
https_compatible=self._https_compatible,
)
async def aclose(self) -> None:
"""Close the transport listener."""
await self.transport_listener.aclose()
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,123 @@
# Platform-specific subprocess bits'n'pieces.
from __future__ import annotations
import os
import sys
from typing import TYPE_CHECKING
import trio
from .. import _core, _subprocess
from .._abc import ReceiveStream, SendStream # noqa: TCH001
_wait_child_exiting_error: ImportError | None = None
_create_child_pipe_error: ImportError | None = None
if TYPE_CHECKING:
# internal types for the pipe representations used in type checking only
class ClosableSendStream(SendStream):
def close(self) -> None: ...
class ClosableReceiveStream(ReceiveStream):
def close(self) -> None: ...
# Fallback versions of the functions provided -- implementations
# per OS are imported atop these at the bottom of the module.
async def wait_child_exiting(process: _subprocess.Process) -> None:
"""Block until the child process managed by ``process`` is exiting.
It is invalid to call this function if the process has already
been waited on; that is, ``process.returncode`` must be None.
When this function returns, it indicates that a call to
:meth:`subprocess.Popen.wait` will immediately be able to
return the process's exit status. The actual exit status is not
consumed by this call, since :class:`~subprocess.Popen` wants
to be able to do that itself.
"""
raise NotImplementedError from _wait_child_exiting_error # pragma: no cover
def create_pipe_to_child_stdin() -> tuple[ClosableSendStream, int]:
"""Create a new pipe suitable for sending data from this
process to the standard input of a child we're about to spawn.
Returns:
A pair ``(trio_end, subprocess_end)`` where ``trio_end`` is a
:class:`~trio.abc.SendStream` and ``subprocess_end`` is
something suitable for passing as the ``stdin`` argument of
:class:`subprocess.Popen`.
"""
raise NotImplementedError from _create_child_pipe_error # pragma: no cover
def create_pipe_from_child_output() -> tuple[ClosableReceiveStream, int]:
"""Create a new pipe suitable for receiving data into this
process from the standard output or error stream of a child
we're about to spawn.
Returns:
A pair ``(trio_end, subprocess_end)`` where ``trio_end`` is a
:class:`~trio.abc.ReceiveStream` and ``subprocess_end`` is
something suitable for passing as the ``stdin`` argument of
:class:`subprocess.Popen`.
"""
raise NotImplementedError from _create_child_pipe_error # pragma: no cover
try:
if sys.platform == "win32":
from .windows import wait_child_exiting
elif sys.platform != "linux" and (TYPE_CHECKING or hasattr(_core, "wait_kevent")):
from .kqueue import wait_child_exiting
else:
# as it's an exported symbol, noqa'd
from .waitid import wait_child_exiting # noqa: F401
except ImportError as ex: # pragma: no cover
_wait_child_exiting_error = ex
try:
if TYPE_CHECKING:
# Not worth type checking these definitions
pass
elif os.name == "posix":
def create_pipe_to_child_stdin():
rfd, wfd = os.pipe()
return trio.lowlevel.FdStream(wfd), rfd
def create_pipe_from_child_output():
rfd, wfd = os.pipe()
return trio.lowlevel.FdStream(rfd), wfd
elif os.name == "nt":
import msvcrt
# This isn't exported or documented, but it's also not
# underscore-prefixed, and seems kosher to use. The asyncio docs
# for 3.5 included an example that imported socketpair from
# windows_utils (before socket.socketpair existed on Windows), and
# when asyncio.windows_utils.socketpair was removed in 3.7, the
# removal was mentioned in the release notes.
from asyncio.windows_utils import pipe as windows_pipe
from .._windows_pipes import PipeReceiveStream, PipeSendStream
def create_pipe_to_child_stdin():
# for stdin, we want the write end (our end) to use overlapped I/O
rh, wh = windows_pipe(overlapped=(False, True))
return PipeSendStream(wh), msvcrt.open_osfhandle(rh, os.O_RDONLY)
def create_pipe_from_child_output():
# for stdout/err, it's the read end that's overlapped
rh, wh = windows_pipe(overlapped=(True, False))
return PipeReceiveStream(rh), msvcrt.open_osfhandle(wh, 0)
else: # pragma: no cover
raise ImportError("pipes not implemented on this platform")
except ImportError as ex: # pragma: no cover
_create_child_pipe_error = ex
@@ -0,0 +1,48 @@
from __future__ import annotations
import select
import sys
from typing import TYPE_CHECKING
from .. import _core, _subprocess
assert (sys.platform != "win32" and sys.platform != "linux") or not TYPE_CHECKING
async def wait_child_exiting(process: _subprocess.Process) -> None:
kqueue = _core.current_kqueue()
try:
from select import KQ_NOTE_EXIT
except ImportError: # pragma: no cover
# pypy doesn't define KQ_NOTE_EXIT:
# https://bitbucket.org/pypy/pypy/issues/2921/
# I verified this value against both Darwin and FreeBSD
KQ_NOTE_EXIT = 0x80000000
def make_event(flags: int) -> select.kevent:
return select.kevent(
process.pid,
filter=select.KQ_FILTER_PROC,
flags=flags,
fflags=KQ_NOTE_EXIT,
)
try:
kqueue.control([make_event(select.KQ_EV_ADD | select.KQ_EV_ONESHOT)], 0)
except ProcessLookupError: # pragma: no cover
# This can supposedly happen if the process is in the process
# of exiting, and it can even be the case that kqueue says the
# process doesn't exist before waitpid(WNOHANG) says it hasn't
# exited yet. See the discussion in https://chromium.googlesource.com/
# chromium/src/base/+/master/process/kill_mac.cc .
# We haven't actually seen this error occur since we added
# locking to prevent multiple calls to wait_child_exiting()
# for the same process simultaneously, but given the explanation
# in Chromium it seems we should still keep the check.
return
def abort(_: _core.RaiseCancelT) -> _core.Abort:
kqueue.control([make_event(select.KQ_EV_DELETE)], 0)
return _core.Abort.SUCCEEDED
await _core.wait_kevent(process.pid, select.KQ_FILTER_PROC, abort)
@@ -0,0 +1,113 @@
import errno
import math
import os
import sys
from typing import TYPE_CHECKING
from .. import _core, _subprocess
from .._sync import CapacityLimiter, Event
from .._threads import to_thread_run_sync
assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING
try:
from os import waitid
def sync_wait_reapable(pid: int) -> None:
waitid(os.P_PID, pid, os.WEXITED | os.WNOWAIT)
except ImportError:
# pypy doesn't define os.waitid so we need to pull it out ourselves
# using cffi: https://bitbucket.org/pypy/pypy/issues/2922/
import cffi
waitid_ffi = cffi.FFI()
# Believe it or not, siginfo_t starts with fields in the
# same layout on both Linux and Darwin. The Linux structure
# is bigger so that's what we use to size `pad`; while
# there are a few extra fields in there, most of it is
# true padding which would not be written by the syscall.
waitid_ffi.cdef(
"""
typedef struct siginfo_s {
int si_signo;
int si_errno;
int si_code;
int si_pid;
int si_uid;
int si_status;
int pad[26];
} siginfo_t;
int waitid(int idtype, int id, siginfo_t* result, int options);
""",
)
waitid_cffi = waitid_ffi.dlopen(None).waitid # type: ignore[attr-defined]
def sync_wait_reapable(pid: int) -> None:
P_PID = 1
WEXITED = 0x00000004
if sys.platform == "darwin": # pragma: no cover
# waitid() is not exposed on Python on Darwin but does
# work through CFFI; note that we typically won't get
# here since Darwin also defines kqueue
WNOWAIT = 0x00000020
else:
WNOWAIT = 0x01000000
result = waitid_ffi.new("siginfo_t *")
while waitid_cffi(P_PID, pid, result, WEXITED | WNOWAIT) < 0:
got_errno = waitid_ffi.errno
if got_errno == errno.EINTR:
continue
raise OSError(got_errno, os.strerror(got_errno))
# adapted from
# https://github.com/python-trio/trio/issues/4#issuecomment-398967572
waitid_limiter = CapacityLimiter(math.inf)
async def _waitid_system_task(pid: int, event: Event) -> None:
"""Spawn a thread that waits for ``pid`` to exit, then wake any tasks
that were waiting on it.
"""
# abandon_on_cancel=True: if this task is cancelled, then we abandon the
# thread to keep running waitpid in the background. Since this is
# always run as a system task, this will only happen if the whole
# call to trio.run is shutting down.
try:
await to_thread_run_sync(
sync_wait_reapable,
pid,
abandon_on_cancel=True,
limiter=waitid_limiter,
)
except OSError:
# If waitid fails, waitpid will fail too, so it still makes
# sense to wake up the callers of wait_process_exiting(). The
# most likely reason for this error in practice is a child
# exiting when wait() is not possible because SIGCHLD is
# ignored.
pass
finally:
event.set()
async def wait_child_exiting(process: "_subprocess.Process") -> None:
# Logic of this function:
# - The first time we get called, we create an Event and start
# an instance of _waitid_system_task that will set the Event
# when waitid() completes. If that Event is set before
# we get cancelled, we're good.
# - Otherwise, a following call after the cancellation must
# reuse the Event created during the first call, lest we
# create an arbitrary number of threads waiting on the same
# process.
if process._wait_for_exit_data is None:
process._wait_for_exit_data = event = Event()
_core.spawn_system_task(_waitid_system_task, process.pid, event)
assert isinstance(process._wait_for_exit_data, Event)
await process._wait_for_exit_data.wait()
@@ -0,0 +1,11 @@
from typing import TYPE_CHECKING
from .._wait_for_object import WaitForSingleObject
if TYPE_CHECKING:
from .. import _subprocess
async def wait_child_exiting(process: "_subprocess.Process") -> None:
# _handle is not in Popen stubs, though it is present on Windows.
await WaitForSingleObject(int(process._proc._handle)) # type: ignore[attr-defined]
+876
View File
@@ -0,0 +1,876 @@
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Protocol
import attrs
import trio
from . import _core
from ._core import (
Abort,
ParkingLot,
RaiseCancelT,
add_parking_lot_breaker,
enable_ki_protection,
remove_parking_lot_breaker,
)
from ._util import final
if TYPE_CHECKING:
from types import TracebackType
from ._core import Task
from ._core._parking_lot import ParkingLotStatistics
@attrs.frozen
class EventStatistics:
"""An object containing debugging information.
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this event's
:meth:`trio.Event.wait` method.
"""
tasks_waiting: int
@final
@attrs.define(repr=False, eq=False)
class Event:
"""A waitable boolean value useful for inter-task synchronization,
inspired by :class:`threading.Event`.
An event object has an internal boolean flag, representing whether
the event has happened yet. The flag is initially False, and the
:meth:`wait` method waits until the flag is True. If the flag is
already True, then :meth:`wait` returns immediately. (If the event has
already happened, there's nothing to wait for.) The :meth:`set` method
sets the flag to True, and wakes up any waiters.
This behavior is useful because it helps avoid race conditions and
lost wakeups: it doesn't matter whether :meth:`set` gets called just
before or after :meth:`wait`. If you want a lower-level wakeup
primitive that doesn't have this protection, consider :class:`Condition`
or :class:`trio.lowlevel.ParkingLot`.
.. note:: Unlike `threading.Event`, `trio.Event` has no
`~threading.Event.clear` method. In Trio, once an `Event` has happened,
it cannot un-happen. If you need to represent a series of events,
consider creating a new `Event` object for each one (they're cheap!),
or other synchronization methods like :ref:`channels <channels>` or
`trio.lowlevel.ParkingLot`.
"""
_tasks: set[Task] = attrs.field(factory=set, init=False)
_flag: bool = attrs.field(default=False, init=False)
def is_set(self) -> bool:
"""Return the current value of the internal flag."""
return self._flag
@enable_ki_protection
def set(self) -> None:
"""Set the internal flag value to True, and wake any waiting tasks."""
if not self._flag:
self._flag = True
for task in self._tasks:
_core.reschedule(task)
self._tasks.clear()
async def wait(self) -> None:
"""Block until the internal flag value becomes True.
If it's already True, then this method returns immediately.
"""
if self._flag:
await trio.lowlevel.checkpoint()
else:
task = _core.current_task()
self._tasks.add(task)
def abort_fn(_: RaiseCancelT) -> Abort:
self._tasks.remove(task)
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort_fn)
def statistics(self) -> EventStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this event's
:meth:`wait` method.
"""
return EventStatistics(tasks_waiting=len(self._tasks))
class _HasAcquireRelease(Protocol):
"""Only classes with acquire() and release() can use the mixin's implementations."""
async def acquire(self) -> object: ...
def release(self) -> object: ...
class AsyncContextManagerMixin:
@enable_ki_protection
async def __aenter__(self: _HasAcquireRelease) -> None:
await self.acquire()
@enable_ki_protection
async def __aexit__(
self: _HasAcquireRelease,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.release()
@attrs.frozen
class CapacityLimiterStatistics:
"""An object containing debugging information.
Currently the following fields are defined:
* ``borrowed_tokens``: The number of tokens currently borrowed from
the sack.
* ``total_tokens``: The total number of tokens in the sack. Usually
this will be larger than ``borrowed_tokens``, but it's possibly for
it to be smaller if :attr:`trio.CapacityLimiter.total_tokens` was recently decreased.
* ``borrowers``: A list of all tasks or other entities that currently
hold a token.
* ``tasks_waiting``: The number of tasks blocked on this
:class:`CapacityLimiter`\'s :meth:`trio.CapacityLimiter.acquire` or
:meth:`trio.CapacityLimiter.acquire_on_behalf_of` methods.
"""
borrowed_tokens: int
total_tokens: int | float
borrowers: list[Task | object]
tasks_waiting: int
# Can be a generic type with a default of Task if/when PEP 696 is released
# and implemented in type checkers. Making it fully generic would currently
# introduce a lot of unnecessary hassle.
@final
class CapacityLimiter(AsyncContextManagerMixin):
"""An object for controlling access to a resource with limited capacity.
Sometimes you need to put a limit on how many tasks can do something at
the same time. For example, you might want to use some threads to run
multiple blocking I/O operations in parallel... but if you use too many
threads at once, then your system can become overloaded and it'll actually
make things slower. One popular solution is to impose a policy like "run
up to 40 threads at the same time, but no more". But how do you implement
a policy like this?
That's what :class:`CapacityLimiter` is for. You can think of a
:class:`CapacityLimiter` object as a sack that starts out holding some fixed
number of tokens::
limit = trio.CapacityLimiter(40)
Then tasks can come along and borrow a token out of the sack::
# Borrow a token:
async with limit:
# We are holding a token!
await perform_expensive_operation()
# Exiting the 'async with' block puts the token back into the sack
And crucially, if you try to borrow a token but the sack is empty, then
you have to wait for another task to finish what it's doing and put its
token back first before you can take it and continue.
Another way to think of it: a :class:`CapacityLimiter` is like a sofa with a
fixed number of seats, and if they're all taken then you have to wait for
someone to get up before you can sit down.
By default, :func:`trio.to_thread.run_sync` uses a
:class:`CapacityLimiter` to limit the number of threads running at once;
see `trio.to_thread.current_default_thread_limiter` for details.
If you're familiar with semaphores, then you can think of this as a
restricted semaphore that's specialized for one common use case, with
additional error checking. For a more traditional semaphore, see
:class:`Semaphore`.
.. note::
Don't confuse this with the `"leaky bucket"
<https://en.wikipedia.org/wiki/Leaky_bucket>`__ or `"token bucket"
<https://en.wikipedia.org/wiki/Token_bucket>`__ algorithms used to
limit bandwidth usage on networks. The basic idea of using tokens to
track a resource limit is similar, but this is a very simple sack where
tokens aren't automatically created or destroyed over time; they're
just borrowed and then put back.
"""
# total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing
def __init__(self, total_tokens: int | float): # noqa: PYI041
self._lot = ParkingLot()
self._borrowers: set[Task | object] = set()
# Maps tasks attempting to acquire -> borrower, to handle on-behalf-of
self._pending_borrowers: dict[Task, Task | object] = {}
# invoke the property setter for validation
self.total_tokens: int | float = total_tokens
assert self._total_tokens == total_tokens
def __repr__(self) -> str:
return f"<trio.CapacityLimiter at {id(self):#x}, {len(self._borrowers)}/{self._total_tokens} with {len(self._lot)} waiting>"
@property
def total_tokens(self) -> int | float:
"""The total capacity available.
You can change :attr:`total_tokens` by assigning to this attribute. If
you make it larger, then the appropriate number of waiting tasks will
be woken immediately to take the new tokens. If you decrease
total_tokens below the number of tasks that are currently using the
resource, then all current tasks will be allowed to finish as normal,
but no new tasks will be allowed in until the total number of tasks
drops below the new total_tokens.
"""
return self._total_tokens
@total_tokens.setter
def total_tokens(self, new_total_tokens: int | float) -> None: # noqa: PYI041
if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf:
raise TypeError("total_tokens must be an int or math.inf")
if new_total_tokens < 1:
raise ValueError("total_tokens must be >= 1")
self._total_tokens = new_total_tokens
self._wake_waiters()
def _wake_waiters(self) -> None:
available = self._total_tokens - len(self._borrowers)
for woken in self._lot.unpark(count=available):
self._borrowers.add(self._pending_borrowers.pop(woken))
@property
def borrowed_tokens(self) -> int:
"""The amount of capacity that's currently in use."""
return len(self._borrowers)
@property
def available_tokens(self) -> int | float:
"""The amount of capacity that's available to use."""
return self.total_tokens - self.borrowed_tokens
@enable_ki_protection
def acquire_nowait(self) -> None:
"""Borrow a token from the sack, without blocking.
Raises:
WouldBlock: if no tokens are available.
RuntimeError: if the current task already holds one of this sack's
tokens.
"""
self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task())
@enable_ki_protection
def acquire_on_behalf_of_nowait(self, borrower: Task | object) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, without
blocking.
Args:
borrower: A :class:`trio.lowlevel.Task` or arbitrary opaque object
used to record who is borrowing this token. This is used by
:func:`trio.to_thread.run_sync` to allow threads to "hold
tokens", with the intention in the future of using it to `allow
deadlock detection and other useful things
<https://github.com/python-trio/trio/issues/182>`__
Raises:
WouldBlock: if no tokens are available.
RuntimeError: if ``borrower`` already holds one of this sack's
tokens.
"""
if borrower in self._borrowers:
raise RuntimeError(
"this borrower is already holding one of this CapacityLimiter's tokens",
)
if len(self._borrowers) < self._total_tokens and not self._lot:
self._borrowers.add(borrower)
else:
raise trio.WouldBlock
@enable_ki_protection
async def acquire(self) -> None:
"""Borrow a token from the sack, blocking if necessary.
Raises:
RuntimeError: if the current task already holds one of this sack's
tokens.
"""
await self.acquire_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
async def acquire_on_behalf_of(self, borrower: Task | object) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, blocking if
necessary.
Args:
borrower: A :class:`trio.lowlevel.Task` or arbitrary opaque object
used to record who is borrowing this token; see
:meth:`acquire_on_behalf_of_nowait` for details.
Raises:
RuntimeError: if ``borrower`` task already holds one of this sack's
tokens.
"""
await trio.lowlevel.checkpoint_if_cancelled()
try:
self.acquire_on_behalf_of_nowait(borrower)
except trio.WouldBlock:
task = trio.lowlevel.current_task()
self._pending_borrowers[task] = borrower
try:
await self._lot.park()
except trio.Cancelled:
self._pending_borrowers.pop(task)
raise
else:
await trio.lowlevel.cancel_shielded_checkpoint()
@enable_ki_protection
def release(self) -> None:
"""Put a token back into the sack.
Raises:
RuntimeError: if the current task has not acquired one of this
sack's tokens.
"""
self.release_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
def release_on_behalf_of(self, borrower: Task | object) -> None:
"""Put a token back into the sack on behalf of ``borrower``.
Raises:
RuntimeError: if the given borrower has not acquired one of this
sack's tokens.
"""
if borrower not in self._borrowers:
raise RuntimeError(
"this borrower isn't holding any of this CapacityLimiter's tokens",
)
self._borrowers.remove(borrower)
self._wake_waiters()
def statistics(self) -> CapacityLimiterStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``borrowed_tokens``: The number of tokens currently borrowed from
the sack.
* ``total_tokens``: The total number of tokens in the sack. Usually
this will be larger than ``borrowed_tokens``, but it's possibly for
it to be smaller if :attr:`total_tokens` was recently decreased.
* ``borrowers``: A list of all tasks or other entities that currently
hold a token.
* ``tasks_waiting``: The number of tasks blocked on this
:class:`CapacityLimiter`\'s :meth:`acquire` or
:meth:`acquire_on_behalf_of` methods.
"""
return CapacityLimiterStatistics(
borrowed_tokens=len(self._borrowers),
total_tokens=self._total_tokens,
# Use a list instead of a frozenset just in case we start to allow
# one borrower to hold multiple tokens in the future
borrowers=list(self._borrowers),
tasks_waiting=len(self._lot),
)
@final
class Semaphore(AsyncContextManagerMixin):
"""A `semaphore <https://en.wikipedia.org/wiki/Semaphore_(programming)>`__.
A semaphore holds an integer value, which can be incremented by
calling :meth:`release` and decremented by calling :meth:`acquire` but
the value is never allowed to drop below zero. If the value is zero, then
:meth:`acquire` will block until someone calls :meth:`release`.
If you're looking for a :class:`Semaphore` to limit the number of tasks
that can access some resource simultaneously, then consider using a
:class:`CapacityLimiter` instead.
This object's interface is similar to, but different from, that of
:class:`threading.Semaphore`.
A :class:`Semaphore` object can be used as an async context manager; it
blocks on entry but not on exit.
Args:
initial_value (int): A non-negative integer giving semaphore's initial
value.
max_value (int or None): If given, makes this a "bounded" semaphore that
raises an error if the value is about to exceed the given
``max_value``.
"""
def __init__(self, initial_value: int, *, max_value: int | None = None):
if not isinstance(initial_value, int):
raise TypeError("initial_value must be an int")
if initial_value < 0:
raise ValueError("initial value must be >= 0")
if max_value is not None:
if not isinstance(max_value, int):
raise TypeError("max_value must be None or an int")
if max_value < initial_value:
raise ValueError("max_values must be >= initial_value")
# Invariants:
# bool(self._lot) implies self._value == 0
# (or equivalently: self._value > 0 implies not self._lot)
self._lot = trio.lowlevel.ParkingLot()
self._value = initial_value
self._max_value = max_value
def __repr__(self) -> str:
if self._max_value is None:
max_value_str = ""
else:
max_value_str = f", max_value={self._max_value}"
return f"<trio.Semaphore({self._value}{max_value_str}) at {id(self):#x}>"
@property
def value(self) -> int:
"""The current value of the semaphore."""
return self._value
@property
def max_value(self) -> int | None:
"""The maximum allowed value. May be None to indicate no limit."""
return self._max_value
@enable_ki_protection
def acquire_nowait(self) -> None:
"""Attempt to decrement the semaphore value, without blocking.
Raises:
WouldBlock: if the value is zero.
"""
if self._value > 0:
assert not self._lot
self._value -= 1
else:
raise trio.WouldBlock
@enable_ki_protection
async def acquire(self) -> None:
"""Decrement the semaphore value, blocking if necessary to avoid
letting it drop below zero.
"""
await trio.lowlevel.checkpoint_if_cancelled()
try:
self.acquire_nowait()
except trio.WouldBlock:
await self._lot.park()
else:
await trio.lowlevel.cancel_shielded_checkpoint()
@enable_ki_protection
def release(self) -> None:
"""Increment the semaphore value, possibly waking a task blocked in
:meth:`acquire`.
Raises:
ValueError: if incrementing the value would cause it to exceed
:attr:`max_value`.
"""
if self._lot:
assert self._value == 0
self._lot.unpark(count=1)
else:
if self._max_value is not None and self._value == self._max_value:
raise ValueError("semaphore released too many times")
self._value += 1
def statistics(self) -> ParkingLotStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this semaphore's
:meth:`acquire` method.
"""
return self._lot.statistics()
@attrs.frozen
class LockStatistics:
"""An object containing debugging information for a Lock.
Currently the following fields are defined:
* ``locked`` (boolean): indicating whether the lock is held.
* ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock,
or None if the lock is not held.
* ``tasks_waiting`` (int): The number of tasks blocked on this lock's
:meth:`trio.Lock.acquire` method.
"""
locked: bool
owner: Task | None
tasks_waiting: int
@attrs.define(eq=False, repr=False, slots=False)
class _LockImpl(AsyncContextManagerMixin):
_lot: ParkingLot = attrs.field(factory=ParkingLot, init=False)
_owner: Task | None = attrs.field(default=None, init=False)
def __repr__(self) -> str:
if self.locked():
s1 = "locked"
s2 = f" with {len(self._lot)} waiters"
else:
s1 = "unlocked"
s2 = ""
return f"<{s1} {self.__class__.__name__} object at {id(self):#x}{s2}>"
def locked(self) -> bool:
"""Check whether the lock is currently held.
Returns:
bool: True if the lock is held, False otherwise.
"""
return self._owner is not None
@enable_ki_protection
def acquire_nowait(self) -> None:
"""Attempt to acquire the lock, without blocking.
Raises:
WouldBlock: if the lock is held.
"""
task = trio.lowlevel.current_task()
if self._owner is task:
raise RuntimeError("attempt to re-acquire an already held Lock")
elif self._owner is None and not self._lot:
# No-one owns it
self._owner = task
add_parking_lot_breaker(task, self._lot)
else:
raise trio.WouldBlock
@enable_ki_protection
async def acquire(self) -> None:
"""Acquire the lock, blocking if necessary.
Raises:
BrokenResourceError: if the owner of the lock exits without releasing.
"""
await trio.lowlevel.checkpoint_if_cancelled()
try:
self.acquire_nowait()
except trio.WouldBlock:
try:
# NOTE: it's important that the contended acquire path is just
# "_lot.park()", because that's how Condition.wait() acquires the
# lock as well.
await self._lot.park()
except trio.BrokenResourceError:
raise trio.BrokenResourceError(
f"Owner of this lock exited without releasing: {self._owner}",
) from None
else:
await trio.lowlevel.cancel_shielded_checkpoint()
@enable_ki_protection
def release(self) -> None:
"""Release the lock.
Raises:
RuntimeError: if the calling task does not hold the lock.
"""
task = trio.lowlevel.current_task()
if task is not self._owner:
raise RuntimeError("can't release a Lock you don't own")
remove_parking_lot_breaker(self._owner, self._lot)
if self._lot:
(self._owner,) = self._lot.unpark(count=1)
add_parking_lot_breaker(self._owner, self._lot)
else:
self._owner = None
def statistics(self) -> LockStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``locked``: boolean indicating whether the lock is held.
* ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock,
or None if the lock is not held.
* ``tasks_waiting``: The number of tasks blocked on this lock's
:meth:`acquire` method.
"""
return LockStatistics(
locked=self.locked(),
owner=self._owner,
tasks_waiting=len(self._lot),
)
@final
class Lock(_LockImpl):
"""A classic `mutex
<https://en.wikipedia.org/wiki/Lock_(computer_science)>`__.
This is a non-reentrant, single-owner lock. Unlike
:class:`threading.Lock`, only the owner of the lock is allowed to release
it.
A :class:`Lock` object can be used as an async context manager; it
blocks on entry but not on exit.
"""
@final
class StrictFIFOLock(_LockImpl):
r"""A variant of :class:`Lock` where tasks are guaranteed to acquire the
lock in strict first-come-first-served order.
An example of when this is useful is if you're implementing something like
:class:`trio.SSLStream` or an HTTP/2 server using `h2
<https://hyper-h2.readthedocs.io/>`__, where you have multiple concurrent
tasks that are interacting with a shared state machine, and at
unpredictable moments the state machine requests that a chunk of data be
sent over the network. (For example, when using h2 simply reading incoming
data can occasionally `create outgoing data to send
<https://http2.github.io/http2-spec/#PING>`__.) The challenge is to make
sure that these chunks are sent in the correct order, without being
garbled.
One option would be to use a regular :class:`Lock`, and wrap it around
every interaction with the state machine::
# This approach is sometimes workable but often sub-optimal; see below
async with lock:
state_machine.do_something()
if state_machine.has_data_to_send():
await conn.sendall(state_machine.get_data_to_send())
But this can be problematic. If you're using h2 then *usually* reading
incoming data doesn't create the need to send any data, so we don't want
to force every task that tries to read from the network to sit and wait
a potentially long time for ``sendall`` to finish. And in some situations
this could even potentially cause a deadlock, if the remote peer is
waiting for you to read some data before it accepts the data you're
sending.
:class:`StrictFIFOLock` provides an alternative. We can rewrite our
example like::
# Note: no awaits between when we start using the state machine and
# when we block to take the lock!
state_machine.do_something()
if state_machine.has_data_to_send():
# Notice that we fetch the data to send out of the state machine
# *before* sleeping, so that other tasks won't see it.
chunk = state_machine.get_data_to_send()
async with strict_fifo_lock:
await conn.sendall(chunk)
First we do all our interaction with the state machine in a single
scheduling quantum (notice there are no ``await``\s in there), so it's
automatically atomic with respect to other tasks. And then if and only if
we have data to send, we get in line to send it and
:class:`StrictFIFOLock` guarantees that each task will send its data in
the same order that the state machine generated it.
Currently, :class:`StrictFIFOLock` is identical to :class:`Lock`,
but (a) this may not always be true in the future, especially if Trio ever
implements `more sophisticated scheduling policies
<https://github.com/python-trio/trio/issues/32>`__, and (b) the above code
is relying on a pretty subtle property of its lock. Using a
:class:`StrictFIFOLock` acts as an executable reminder that you're relying
on this property.
"""
@attrs.frozen
class ConditionStatistics:
r"""An object containing debugging information for a Condition.
Currently the following fields are defined:
* ``tasks_waiting`` (int): The number of tasks blocked on this condition's
:meth:`trio.Condition.wait` method.
* ``lock_statistics``: The result of calling the underlying
:class:`Lock`\s :meth:`~Lock.statistics` method.
"""
tasks_waiting: int
lock_statistics: LockStatistics
@final
class Condition(AsyncContextManagerMixin):
"""A classic `condition variable
<https://en.wikipedia.org/wiki/Monitor_(synchronization)>`__, similar to
:class:`threading.Condition`.
A :class:`Condition` object can be used as an async context manager to
acquire the underlying lock; it blocks on entry but not on exit.
Args:
lock (Lock): the lock object to use. If given, must be a
:class:`trio.Lock`. If None, a new :class:`Lock` will be allocated
and used.
"""
def __init__(self, lock: Lock | None = None):
if lock is None:
lock = Lock()
if type(lock) is not Lock:
raise TypeError("lock must be a trio.Lock")
self._lock = lock
self._lot = trio.lowlevel.ParkingLot()
def locked(self) -> bool:
"""Check whether the underlying lock is currently held.
Returns:
bool: True if the lock is held, False otherwise.
"""
return self._lock.locked()
def acquire_nowait(self) -> None:
"""Attempt to acquire the underlying lock, without blocking.
Raises:
WouldBlock: if the lock is currently held.
"""
return self._lock.acquire_nowait()
async def acquire(self) -> None:
"""Acquire the underlying lock, blocking if necessary.
Raises:
BrokenResourceError: if the owner of the underlying lock exits without releasing.
"""
await self._lock.acquire()
def release(self) -> None:
"""Release the underlying lock."""
self._lock.release()
@enable_ki_protection
async def wait(self) -> None:
"""Wait for another task to call :meth:`notify` or
:meth:`notify_all`.
When calling this method, you must hold the lock. It releases the lock
while waiting, and then re-acquires it before waking up.
There is a subtlety with how this method interacts with cancellation:
when cancelled it will block to re-acquire the lock before raising
:exc:`Cancelled`. This may cause cancellation to be less prompt than
expected. The advantage is that it makes code like this work::
async with condition:
await condition.wait()
If we didn't re-acquire the lock before waking up, and :meth:`wait`
were cancelled here, then we'd crash in ``condition.__aexit__`` when
we tried to release the lock we no longer held.
Raises:
RuntimeError: if the calling task does not hold the lock.
BrokenResourceError: if the owner of the lock exits without releasing, when attempting to re-acquire.
"""
if trio.lowlevel.current_task() is not self._lock._owner:
raise RuntimeError("must hold the lock to wait")
self.release()
# NOTE: we go to sleep on self._lot, but we'll wake up on
# self._lock._lot. That's all that's required to acquire a Lock.
try:
await self._lot.park()
except:
with trio.CancelScope(shield=True):
await self.acquire()
raise
def notify(self, n: int = 1) -> None:
"""Wake one or more tasks that are blocked in :meth:`wait`.
Args:
n (int): The number of tasks to wake.
Raises:
RuntimeError: if the calling task does not hold the lock.
"""
if trio.lowlevel.current_task() is not self._lock._owner:
raise RuntimeError("must hold the lock to notify")
self._lot.repark(self._lock._lot, count=n)
def notify_all(self) -> None:
"""Wake all tasks that are currently blocked in :meth:`wait`.
Raises:
RuntimeError: if the calling task does not hold the lock.
"""
if trio.lowlevel.current_task() is not self._lock._owner:
raise RuntimeError("must hold the lock to notify")
self._lot.repark_all(self._lock._lot)
def statistics(self) -> ConditionStatistics:
r"""Return an object containing debugging information.
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this condition's
:meth:`wait` method.
* ``lock_statistics``: The result of calling the underlying
:class:`Lock`\s :meth:`~Lock.statistics` method.
"""
return ConditionStatistics(
tasks_waiting=len(self._lot),
lock_statistics=self._lock.statistics(),
)
@@ -0,0 +1,246 @@
#!/usr/bin/env python3
"""This is a file that wraps calls to `pyright --verifytypes`, achieving two things:
1. give an error if docstrings are missing.
pyright will give a number of missing docstrings, and error messages, but not exit with a non-zero value.
2. filter out specific errors we don't care about.
this is largely due to 1, but also because Trio does some very complex stuff and --verifytypes has few to no ways of ignoring specific errors.
If this check is giving you false alarms, you can ignore them by adding logic to `has_docstring_at_runtime`, in the main loop in `check_type`, or by updating the json file.
"""
from __future__ import annotations
# this file is not run as part of the tests, instead it's run standalone from check.sh
import argparse
import json
import subprocess
import sys
from pathlib import Path
import trio
import trio.testing
# not needed if everything is working, but if somebody does something to generate
# tons of errors, we can be nice and stop them from getting 3*tons of output
printed_diagnostics: set[str] = set()
# TODO: consider checking manually without `--ignoreexternal`, and/or
# removing it from the below call later on.
def run_pyright(platform: str) -> subprocess.CompletedProcess[bytes]:
return subprocess.run(
[
"pyright",
# Specify a platform and version to keep imported modules consistent.
f"--pythonplatform={platform}",
"--pythonversion=3.8",
"--verifytypes=trio",
"--outputjson",
"--ignoreexternal",
],
capture_output=True,
)
def has_docstring_at_runtime(name: str) -> bool:
"""Pyright gives us an object identifier of xx.yy.zz
This function tries to decompose that into its constituent parts, such that we
can resolve it, in order to check whether it has a `__doc__` at runtime and
verifytypes misses it because we're doing overly fancy stuff.
"""
# This assert is solely for stopping isort from removing our imports of trio & trio.testing
# It could also be done with isort:skip, but that'd also disable import sorting and the like.
assert trio.testing
# figure out what part of the name is the module, so we can "import" it
name_parts = name.split(".")
assert name_parts[0] == "trio"
if name_parts[1] == "tests":
return True
# traverse down the remaining identifiers with getattr
obj = trio
try:
for obj_name in name_parts[1:]:
obj = getattr(obj, obj_name)
except AttributeError as exc:
# asynciowrapper does funky getattr stuff
if "AsyncIOWrapper" in str(exc) or name in (
# Symbols not existing on all platforms, so we can't dynamically inspect them.
# Manually confirmed to have docstrings but pyright doesn't see them due to
# export shenanigans. TODO: actually manually confirm that.
# In theory we could verify these at runtime, probably by running the script separately
# on separate platforms. It might also be a decent idea to work the other way around,
# a la test_static_tool_sees_class_members
# darwin
"trio.lowlevel.current_kqueue",
"trio.lowlevel.monitor_kevent",
"trio.lowlevel.wait_kevent",
"trio._core._io_kqueue._KqueueStatistics",
# windows
"trio._socket.SocketType.share",
"trio._core._io_windows._WindowsStatistics",
"trio._core._windows_cffi.Handle",
"trio.lowlevel.current_iocp",
"trio.lowlevel.monitor_completion_key",
"trio.lowlevel.readinto_overlapped",
"trio.lowlevel.register_with_iocp",
"trio.lowlevel.wait_overlapped",
"trio.lowlevel.write_overlapped",
"trio.lowlevel.WaitForSingleObject",
"trio.socket.fromshare",
# linux
# this test will fail on linux, but I don't develop on linux. So the next
# person to do so is very welcome to open a pull request and populate with
# objects
# TODO: these are erroring on all platforms, why?
"trio._highlevel_generic.StapledStream.send_stream",
"trio._highlevel_generic.StapledStream.receive_stream",
"trio._ssl.SSLStream.transport_stream",
"trio._file_io._HasFileNo",
"trio._file_io._HasFileNo.fileno",
):
return True
else:
print(
f"Pyright sees {name} at runtime, but unable to getattr({obj.__name__}, {obj_name}).",
file=sys.stderr,
)
return False
return bool(obj.__doc__)
def check_type(
platform: str,
full_diagnostics_file: Path | None,
expected_errors: list[object],
) -> list[object]:
# convince isort we use the trio import
assert trio
# run pyright, load output into json
res = run_pyright(platform)
current_result = json.loads(res.stdout)
if res.stderr:
print(res.stderr, file=sys.stderr)
if full_diagnostics_file:
with open(full_diagnostics_file, "a") as f:
json.dump(current_result, f, sort_keys=True, indent=4)
errors = []
for symbol in current_result["typeCompleteness"]["symbols"]:
diagnostics = symbol["diagnostics"]
name = symbol["name"]
for diagnostic in diagnostics:
message = diagnostic["message"]
if name in (
"trio._path.PosixPath",
"trio._path.WindowsPath",
) and message.startswith("Type of base class "):
continue
if name.startswith("trio._path.Path"):
if message.startswith("No docstring found for"):
continue
if message.startswith(
"Type is missing type annotation and could be inferred differently by type checkers",
):
continue
# ignore errors about missing docstrings if they're available at runtime
if message.startswith("No docstring found for"):
if has_docstring_at_runtime(symbol["name"]):
continue
else:
# Missing docstring messages include the name of the object.
# Other errors don't, so we add it.
message = f"{name}: {message}"
if message not in expected_errors and message not in printed_diagnostics:
print(f"new error: {message}", file=sys.stderr)
errors.append(message)
printed_diagnostics.add(message)
continue
return errors
def main(args: argparse.Namespace) -> int:
if args.full_diagnostics_file:
full_diagnostics_file = Path(args.full_diagnostics_file)
full_diagnostics_file.write_text("")
else:
full_diagnostics_file = None
errors_by_platform_file = Path(__file__).parent / "_check_type_completeness.json"
if errors_by_platform_file.exists():
with open(errors_by_platform_file) as f:
errors_by_platform = json.load(f)
else:
errors_by_platform = {"Linux": [], "Windows": [], "Darwin": [], "all": []}
changed = False
for platform in "Linux", "Windows", "Darwin":
platform_errors = errors_by_platform[platform] + errors_by_platform["all"]
print("*" * 20, f"\nChecking {platform}...")
errors = check_type(platform, full_diagnostics_file, platform_errors)
new_errors = [e for e in errors if e not in platform_errors]
missing_errors = [e for e in platform_errors if e not in errors]
if new_errors:
print(
f"New errors introduced in `pyright --verifytypes`. Fix them, or ignore them by modifying {errors_by_platform_file}, either manually or with '--overwrite-file'.",
file=sys.stderr,
)
changed = True
if missing_errors:
print(
f"Congratulations, you have resolved existing errors! Please remove them from {errors_by_platform_file}, either manually or with '--overwrite-file'.",
file=sys.stderr,
)
changed = True
print(missing_errors, file=sys.stderr)
errors_by_platform[platform] = errors
print("*" * 20)
# cut down the size of the json file by a lot, and make it easier to parse for
# humans, by moving errors that appear on all platforms to a separate category
errors_by_platform["all"] = []
for e in errors_by_platform["Linux"].copy():
if e in errors_by_platform["Darwin"] and e in errors_by_platform["Windows"]:
for platform in "Linux", "Windows", "Darwin":
errors_by_platform[platform].remove(e)
errors_by_platform["all"].append(e)
if changed and args.overwrite_file:
with open(errors_by_platform_file, "w") as f:
json.dump(errors_by_platform, f, indent=4, sort_keys=True)
# newline at end of file
f.write("\n")
# True -> 1 -> non-zero exit value -> error
return changed
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite-file",
action="store_true",
default=False,
help="Use this flag to overwrite the current stored results. Either in CI together with a diff check, or to avoid having to manually correct it.",
)
parser.add_argument(
"--full-diagnostics-file",
type=Path,
default=None,
help="Use this for debugging, it will dump the output of all three pyright runs by platform into this file.",
)
args = parser.parse_args()
assert __name__ == "__main__", "This script should be run standalone"
sys.exit(main(args))
@@ -0,0 +1,24 @@
regular = "hi"
from .. import _deprecate
_deprecate.enable_attribute_deprecations(__name__)
# Make sure that we don't trigger infinite recursion when accessing module
# attributes in between calling enable_attribute_deprecations and defining
# __deprecated_attributes__:
import sys
this_mod = sys.modules[__name__]
assert this_mod.regular == "hi"
assert not hasattr(this_mod, "dep1")
__deprecated_attributes__ = {
"dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1),
"dep2": _deprecate.DeprecatedAttribute(
"value2",
"1.2",
issue=1,
instead="instead-string",
),
}
@@ -0,0 +1,54 @@
from __future__ import annotations
import inspect
from typing import NoReturn
import pytest
from ..testing import MockClock, trio_test
RUN_SLOW = True
SKIP_OPTIONAL_IMPORTS = False
def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption("--run-slow", action="store_true", help="run slow tests")
parser.addoption(
"--skip-optional-imports",
action="store_true",
help="skip tests that rely on libraries not required by trio itself",
)
def pytest_configure(config: pytest.Config) -> None:
global RUN_SLOW
RUN_SLOW = config.getoption("--run-slow", default=True)
global SKIP_OPTIONAL_IMPORTS
SKIP_OPTIONAL_IMPORTS = config.getoption("--skip-optional-imports", default=False)
@pytest.fixture
def mock_clock() -> MockClock:
return MockClock()
@pytest.fixture
def autojump_clock() -> MockClock:
return MockClock(autojump_threshold=0)
# FIXME: split off into a package (or just make part of Trio's public
# interface?), with config file to enable? and I guess a mark option too; I
# guess it's useful with the class- and file-level marking machinery (where
# the raw @trio_test decorator isn't enough).
@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> None:
if inspect.iscoroutinefunction(pyfuncitem.obj):
pyfuncitem.obj = trio_test(pyfuncitem.obj)
def skip_if_optional_else_raise(error: ImportError) -> NoReturn:
if SKIP_OPTIONAL_IMPORTS:
pytest.skip(error.msg, allow_module_level=True)
else: # pragma: no cover
raise error
@@ -0,0 +1,72 @@
from __future__ import annotations
import attrs
import pytest
from .. import abc as tabc
from ..lowlevel import Task
def test_instrument_implements_hook_methods() -> None:
attrs = {
"before_run": (),
"after_run": (),
"task_spawned": (Task,),
"task_scheduled": (Task,),
"before_task_step": (Task,),
"after_task_step": (Task,),
"task_exited": (Task,),
"before_io_wait": (3.3,),
"after_io_wait": (3.3,),
}
mayonnaise = tabc.Instrument()
for method_name, args in attrs.items():
assert hasattr(mayonnaise, method_name)
method = getattr(mayonnaise, method_name)
assert callable(method)
method(*args)
async def test_AsyncResource_defaults() -> None:
@attrs.define(slots=False)
class MyAR(tabc.AsyncResource):
record: list[str] = attrs.Factory(list)
async def aclose(self) -> None:
self.record.append("ac")
async with MyAR() as myar:
assert isinstance(myar, MyAR)
assert myar.record == []
assert myar.record == ["ac"]
def test_abc_generics() -> None:
# Pythons below 3.5.2 had a typing.Generic that would throw
# errors when instantiating or subclassing a parameterized
# version of a class with any __slots__. This is why RunVar
# (which has slots) is not generic. This tests that
# the generic ABCs are fine, because while they are slotted
# they don't actually define any slots.
class SlottedChannel(tabc.SendChannel[tabc.Stream]):
__slots__ = ("x",)
def send_nowait(self, value: object) -> None:
raise RuntimeError
async def send(self, value: object) -> None:
raise RuntimeError # pragma: no cover
def clone(self) -> None:
raise RuntimeError # pragma: no cover
async def aclose(self) -> None:
pass # pragma: no cover
channel = SlottedChannel()
with pytest.raises(RuntimeError):
channel.send_nowait(None)
@@ -0,0 +1,413 @@
from __future__ import annotations
from typing import Union
import pytest
import trio
from trio import EndOfChannel, open_memory_channel
from ..testing import assert_checkpoints, wait_all_tasks_blocked
async def test_channel() -> None:
with pytest.raises(TypeError):
open_memory_channel(1.0)
with pytest.raises(ValueError, match="^max_buffer_size must be >= 0$"):
open_memory_channel(-1)
s, r = open_memory_channel[Union[int, str, None]](2)
repr(s) # smoke test
repr(r) # smoke test
s.send_nowait(1)
with assert_checkpoints():
await s.send(2)
with pytest.raises(trio.WouldBlock):
s.send_nowait(None)
with assert_checkpoints():
assert await r.receive() == 1
assert r.receive_nowait() == 2
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
s.send_nowait("last")
await s.aclose()
with pytest.raises(trio.ClosedResourceError):
await s.send("too late")
with pytest.raises(trio.ClosedResourceError):
s.send_nowait("too late")
with pytest.raises(trio.ClosedResourceError):
s.clone()
await s.aclose()
assert r.receive_nowait() == "last"
with pytest.raises(EndOfChannel):
await r.receive()
await r.aclose()
with pytest.raises(trio.ClosedResourceError):
await r.receive()
with pytest.raises(trio.ClosedResourceError):
r.receive_nowait()
await r.aclose()
async def test_553(autojump_clock: trio.abc.Clock) -> None:
s, r = open_memory_channel[str](1)
with trio.move_on_after(10) as timeout_scope:
await r.receive()
assert timeout_scope.cancelled_caught
await s.send("Test for PR #553")
async def test_channel_multiple_producers() -> None:
async def producer(send_channel: trio.MemorySendChannel[int], i: int) -> None:
# We close our handle when we're done with it
async with send_channel:
for j in range(3 * i, 3 * (i + 1)):
await send_channel.send(j)
send_channel, receive_channel = open_memory_channel[int](0)
async with trio.open_nursery() as nursery:
# We hand out clones to all the new producers, and then close the
# original.
async with send_channel:
for i in range(10):
nursery.start_soon(producer, send_channel.clone(), i)
got = [value async for value in receive_channel]
got.sort()
assert got == list(range(30))
async def test_channel_multiple_consumers() -> None:
successful_receivers = set()
received = []
async def consumer(receive_channel: trio.MemoryReceiveChannel[int], i: int) -> None:
async for value in receive_channel:
successful_receivers.add(i)
received.append(value)
async with trio.open_nursery() as nursery:
send_channel, receive_channel = trio.open_memory_channel[int](1)
async with send_channel:
for i in range(5):
nursery.start_soon(consumer, receive_channel, i)
await wait_all_tasks_blocked()
for i in range(10):
await send_channel.send(i)
assert successful_receivers == set(range(5))
assert len(received) == 10
assert set(received) == set(range(10))
async def test_close_basics() -> None:
async def send_block(
s: trio.MemorySendChannel[None],
expect: type[BaseException],
) -> None:
with pytest.raises(expect):
await s.send(None)
# closing send -> other send gets ClosedResourceError
s, r = open_memory_channel[None](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.ClosedResourceError)
await wait_all_tasks_blocked()
await s.aclose()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
s.send_nowait(None)
with pytest.raises(trio.ClosedResourceError):
await s.send(None)
# and receive gets EndOfChannel
with pytest.raises(EndOfChannel):
r.receive_nowait()
with pytest.raises(EndOfChannel):
await r.receive()
# closing receive -> send gets BrokenResourceError
s, r = open_memory_channel[None](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.BrokenResourceError)
await wait_all_tasks_blocked()
await r.aclose()
# and it's persistent
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
with pytest.raises(trio.BrokenResourceError):
await s.send(None)
# closing receive -> other receive gets ClosedResourceError
async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None:
with pytest.raises(trio.ClosedResourceError):
await r.receive()
s2, r2 = open_memory_channel[int](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_block, r2)
await wait_all_tasks_blocked()
await r2.aclose()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
r2.receive_nowait()
with pytest.raises(trio.ClosedResourceError):
await r2.receive()
async def test_close_sync() -> None:
async def send_block(
s: trio.MemorySendChannel[None],
expect: type[BaseException],
) -> None:
with pytest.raises(expect):
await s.send(None)
# closing send -> other send gets ClosedResourceError
s, r = open_memory_channel[None](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.ClosedResourceError)
await wait_all_tasks_blocked()
s.close()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
s.send_nowait(None)
with pytest.raises(trio.ClosedResourceError):
await s.send(None)
# and receive gets EndOfChannel
with pytest.raises(EndOfChannel):
r.receive_nowait()
with pytest.raises(EndOfChannel):
await r.receive()
# closing receive -> send gets BrokenResourceError
s, r = open_memory_channel[None](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.BrokenResourceError)
await wait_all_tasks_blocked()
r.close()
# and it's persistent
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
with pytest.raises(trio.BrokenResourceError):
await s.send(None)
# closing receive -> other receive gets ClosedResourceError
async def receive_block(r: trio.MemoryReceiveChannel[None]) -> None:
with pytest.raises(trio.ClosedResourceError):
await r.receive()
s, r = open_memory_channel[None](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_block, r)
await wait_all_tasks_blocked()
r.close()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
r.receive_nowait()
with pytest.raises(trio.ClosedResourceError):
await r.receive()
async def test_receive_channel_clone_and_close() -> None:
s, r = open_memory_channel[None](10)
r2 = r.clone()
r3 = r.clone()
s.send_nowait(None)
await r.aclose()
with r2:
pass
with pytest.raises(trio.ClosedResourceError):
r.clone()
with pytest.raises(trio.ClosedResourceError):
r2.clone()
# Can still send, r3 is still open
s.send_nowait(None)
await r3.aclose()
# But now the receiver is really closed
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
async def test_close_multiple_send_handles() -> None:
# With multiple send handles, closing one handle only wakes senders on
# that handle, but others can continue just fine
s1, r = open_memory_channel[str](0)
s2 = s1.clone()
async def send_will_close() -> None:
with pytest.raises(trio.ClosedResourceError):
await s1.send("nope")
async def send_will_succeed() -> None:
await s2.send("ok")
async with trio.open_nursery() as nursery:
nursery.start_soon(send_will_close)
nursery.start_soon(send_will_succeed)
await wait_all_tasks_blocked()
await s1.aclose()
assert await r.receive() == "ok"
async def test_close_multiple_receive_handles() -> None:
# With multiple receive handles, closing one handle only wakes receivers on
# that handle, but others can continue just fine
s, r1 = open_memory_channel[str](0)
r2 = r1.clone()
async def receive_will_close() -> None:
with pytest.raises(trio.ClosedResourceError):
await r1.receive()
async def receive_will_succeed() -> None:
assert await r2.receive() == "ok"
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_will_close)
nursery.start_soon(receive_will_succeed)
await wait_all_tasks_blocked()
await r1.aclose()
await s.send("ok")
async def test_inf_capacity() -> None:
send, receive = open_memory_channel[int](float("inf"))
# It's accepted, and we can send all day without blocking
with send:
for i in range(10):
send.send_nowait(i)
got = [i async for i in receive]
assert got == list(range(10))
async def test_statistics() -> None:
s, r = open_memory_channel[None](2)
assert s.statistics() == r.statistics()
stats = s.statistics()
assert stats.current_buffer_used == 0
assert stats.max_buffer_size == 2
assert stats.open_send_channels == 1
assert stats.open_receive_channels == 1
assert stats.tasks_waiting_send == 0
assert stats.tasks_waiting_receive == 0
s.send_nowait(None)
assert s.statistics().current_buffer_used == 1
s2 = s.clone()
assert s.statistics().open_send_channels == 2
await s.aclose()
assert s2.statistics().open_send_channels == 1
r2 = r.clone()
assert s2.statistics().open_receive_channels == 2
await r2.aclose()
assert s2.statistics().open_receive_channels == 1
async with trio.open_nursery() as nursery:
s2.send_nowait(None) # fill up the buffer
assert s.statistics().current_buffer_used == 2
nursery.start_soon(s2.send, None)
nursery.start_soon(s2.send, None)
await wait_all_tasks_blocked()
assert s.statistics().tasks_waiting_send == 2
nursery.cancel_scope.cancel()
assert s.statistics().tasks_waiting_send == 0
# empty out the buffer again
try:
while True:
r.receive_nowait()
except trio.WouldBlock:
pass
async with trio.open_nursery() as nursery:
nursery.start_soon(r.receive)
await wait_all_tasks_blocked()
assert s.statistics().tasks_waiting_receive == 1
nursery.cancel_scope.cancel()
assert s.statistics().tasks_waiting_receive == 0
async def test_channel_fairness() -> None:
# We can remove an item we just sent, and send an item back in after, if
# no-one else is waiting.
s, r = open_memory_channel[Union[int, None]](1)
s.send_nowait(1)
assert r.receive_nowait() == 1
s.send_nowait(2)
assert r.receive_nowait() == 2
# But if someone else is waiting to receive, then they "own" the item we
# send, so we can't receive it (even though we run first):
result: int | None = None
async def do_receive(r: trio.MemoryReceiveChannel[int | None]) -> None:
nonlocal result
result = await r.receive()
async with trio.open_nursery() as nursery:
nursery.start_soon(do_receive, r)
await wait_all_tasks_blocked()
s.send_nowait(2)
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
assert result == 2
# And the analogous situation for send: if we free up a space, we can't
# immediately send something in it if someone is already waiting to do
# that
s, r = open_memory_channel[Union[int, None]](1)
s.send_nowait(1)
with pytest.raises(trio.WouldBlock):
s.send_nowait(None)
async with trio.open_nursery() as nursery:
nursery.start_soon(s.send, 2)
await wait_all_tasks_blocked()
assert r.receive_nowait() == 1
with pytest.raises(trio.WouldBlock):
s.send_nowait(3)
assert (await r.receive()) == 2
async def test_unbuffered() -> None:
s, r = open_memory_channel[int](0)
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
with pytest.raises(trio.WouldBlock):
s.send_nowait(1)
async def do_send(s: trio.MemorySendChannel[int], v: int) -> None:
with assert_checkpoints():
await s.send(v)
async with trio.open_nursery() as nursery:
nursery.start_soon(do_send, s, 1)
with assert_checkpoints():
assert await r.receive() == 1
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
@@ -0,0 +1,56 @@
from __future__ import annotations
import contextvars
from .. import _core
trio_testing_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar(
"trio_testing_contextvar",
)
async def test_contextvars_default() -> None:
trio_testing_contextvar.set("main")
record: list[str] = []
async def child() -> None:
value = trio_testing_contextvar.get()
record.append(value)
async with _core.open_nursery() as nursery:
nursery.start_soon(child)
assert record == ["main"]
async def test_contextvars_set() -> None:
trio_testing_contextvar.set("main")
record: list[str] = []
async def child() -> None:
trio_testing_contextvar.set("child")
value = trio_testing_contextvar.get()
record.append(value)
async with _core.open_nursery() as nursery:
nursery.start_soon(child)
value = trio_testing_contextvar.get()
assert record == ["child"]
assert value == "main"
async def test_contextvars_copy() -> None:
trio_testing_contextvar.set("main")
context = contextvars.copy_context()
trio_testing_contextvar.set("second_main")
record: list[str] = []
async def child() -> None:
value = trio_testing_contextvar.get()
record.append(value)
async with _core.open_nursery() as nursery:
context.run(nursery.start_soon, child)
nursery.start_soon(child)
value = trio_testing_contextvar.get()
assert set(record) == {"main", "second_main"}
assert value == "second_main"
@@ -0,0 +1,283 @@
from __future__ import annotations
import inspect
import warnings
import pytest
from .._deprecate import (
TrioDeprecationWarning,
deprecated,
deprecated_alias,
warn_deprecated,
)
from . import module_with_deprecations
@pytest.fixture
def recwarn_always(recwarn: pytest.WarningsRecorder) -> pytest.WarningsRecorder:
warnings.simplefilter("always")
# ResourceWarnings about unclosed sockets can occur nondeterministically
# (during GC) which throws off the tests in this file
warnings.simplefilter("ignore", ResourceWarning)
return recwarn
def _here() -> tuple[str, int]:
frame = inspect.currentframe()
assert frame is not None
assert frame.f_back is not None
info = inspect.getframeinfo(frame.f_back)
return (info.filename, info.lineno)
def test_warn_deprecated(recwarn_always: pytest.WarningsRecorder) -> None:
def deprecated_thing() -> None:
warn_deprecated("ice", "1.2", issue=1, instead="water")
deprecated_thing()
filename, lineno = _here()
assert len(recwarn_always) == 1
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "ice is deprecated" in got.message.args[0]
assert "Trio 1.2" in got.message.args[0]
assert "water instead" in got.message.args[0]
assert "/issues/1" in got.message.args[0]
assert got.filename == filename
assert got.lineno == lineno - 1
def test_warn_deprecated_no_instead_or_issue(
recwarn_always: pytest.WarningsRecorder,
) -> None:
# Explicitly no instead or issue
warn_deprecated("water", "1.3", issue=None, instead=None)
assert len(recwarn_always) == 1
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "water is deprecated" in got.message.args[0]
assert "no replacement" in got.message.args[0]
assert "Trio 1.3" in got.message.args[0]
def test_warn_deprecated_stacklevel(recwarn_always: pytest.WarningsRecorder) -> None:
def nested1() -> None:
nested2()
def nested2() -> None:
warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3)
filename, lineno = _here()
nested1()
got = recwarn_always.pop(DeprecationWarning)
assert got.filename == filename
assert got.lineno == lineno + 1
def old() -> None: # pragma: no cover
pass
def new() -> None: # pragma: no cover
pass
def test_warn_deprecated_formatting(recwarn_always: pytest.WarningsRecorder) -> None:
warn_deprecated(old, "1.0", issue=1, instead=new)
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "test_deprecate.old is deprecated" in got.message.args[0]
assert "test_deprecate.new instead" in got.message.args[0]
@deprecated("1.5", issue=123, instead=new)
def deprecated_old() -> int:
return 3
def test_deprecated_decorator(recwarn_always: pytest.WarningsRecorder) -> None:
assert deprecated_old() == 3
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0]
assert "1.5" in got.message.args[0]
assert "test_deprecate.new" in got.message.args[0]
assert "issues/123" in got.message.args[0]
class Foo:
@deprecated("1.0", issue=123, instead="crying")
def method(self) -> int:
return 7
def test_deprecated_decorator_method(recwarn_always: pytest.WarningsRecorder) -> None:
f = Foo()
assert f.method() == 7
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "test_deprecate.Foo.method is deprecated" in got.message.args[0]
@deprecated("1.2", thing="the thing", issue=None, instead=None)
def deprecated_with_thing() -> int:
return 72
def test_deprecated_decorator_with_explicit_thing(
recwarn_always: pytest.WarningsRecorder,
) -> None:
assert deprecated_with_thing() == 72
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "the thing is deprecated" in got.message.args[0]
def new_hotness() -> str:
return "new hotness"
old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1)
def test_deprecated_alias(recwarn_always: pytest.WarningsRecorder) -> None:
assert old_hotness() == "new hotness"
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "test_deprecate.old_hotness is deprecated" in got.message.args[0]
assert "1.23" in got.message.args[0]
assert "test_deprecate.new_hotness instead" in got.message.args[0]
assert "issues/1" in got.message.args[0]
assert isinstance(old_hotness.__doc__, str)
assert ".. deprecated:: 1.23" in old_hotness.__doc__
assert "test_deprecate.new_hotness instead" in old_hotness.__doc__
assert "issues/1>`__" in old_hotness.__doc__
class Alias:
def new_hotness_method(self) -> str:
return "new hotness method"
old_hotness_method = deprecated_alias(
"Alias.old_hotness_method",
new_hotness_method,
"3.21",
issue=1,
)
def test_deprecated_alias_method(recwarn_always: pytest.WarningsRecorder) -> None:
obj = Alias()
assert obj.old_hotness_method() == "new hotness method"
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
msg = got.message.args[0]
assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg
assert "test_deprecate.Alias.new_hotness_method instead" in msg
@deprecated("2.1", issue=1, instead="hi")
def docstring_test1() -> None: # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=None, instead="hi")
def docstring_test2() -> None: # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=1, instead=None)
def docstring_test3() -> None: # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=None, instead=None)
def docstring_test4() -> None: # pragma: no cover
"""Hello!"""
def test_deprecated_docstring_munging() -> None:
assert (
docstring_test1.__doc__
== """Hello!
.. deprecated:: 2.1
Use hi instead.
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
"""
)
assert (
docstring_test2.__doc__
== """Hello!
.. deprecated:: 2.1
Use hi instead.
"""
)
assert (
docstring_test3.__doc__
== """Hello!
.. deprecated:: 2.1
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
"""
)
assert (
docstring_test4.__doc__
== """Hello!
.. deprecated:: 2.1
"""
)
def test_module_with_deprecations(recwarn_always: pytest.WarningsRecorder) -> None:
assert module_with_deprecations.regular == "hi"
assert len(recwarn_always) == 0
filename, lineno = _here()
assert module_with_deprecations.dep1 == "value1" # type: ignore[attr-defined]
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert got.filename == filename
assert got.lineno == lineno + 1
assert "module_with_deprecations.dep1" in got.message.args[0]
assert "Trio 1.1" in got.message.args[0]
assert "/issues/1" in got.message.args[0]
assert "value1 instead" in got.message.args[0]
assert module_with_deprecations.dep2 == "value2" # type: ignore[attr-defined]
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "instead-string instead" in got.message.args[0]
with pytest.raises(AttributeError):
module_with_deprecations.asdf # type: ignore[attr-defined] # noqa: B018 # "useless expression"
def test_warning_class() -> None:
with pytest.deprecated_call():
warn_deprecated("foo", "bar", issue=None, instead=None)
# essentially the same as the above check
with pytest.warns(DeprecationWarning):
warn_deprecated("foo", "bar", issue=None, instead=None)
with pytest.warns(TrioDeprecationWarning):
warn_deprecated(
"foo",
"bar",
issue=None,
instead=None,
use_triodeprecationwarning=True,
)
@@ -0,0 +1,64 @@
from typing import Awaitable, Callable
import pytest
import trio
async def test_deprecation_warning_open_nursery() -> None:
with pytest.warns(
trio.TrioDeprecationWarning,
match="strict_exception_groups=False",
) as record:
async with trio.open_nursery(strict_exception_groups=False):
...
assert len(record) == 1
async with trio.open_nursery(strict_exception_groups=True):
...
async with trio.open_nursery():
...
def test_deprecation_warning_run() -> None:
async def foo() -> None: ...
async def foo_nursery() -> None:
# this should not raise a warning, even if it's implied loose
async with trio.open_nursery():
...
async def foo_loose_nursery() -> None:
# this should raise a warning, even if specifying the parameter is redundant
async with trio.open_nursery(strict_exception_groups=False):
...
def helper(fun: Callable[..., Awaitable[None]], num: int) -> None:
with pytest.warns(
trio.TrioDeprecationWarning,
match="strict_exception_groups=False",
) as record:
trio.run(fun, strict_exception_groups=False)
assert len(record) == num
helper(foo, 1)
helper(foo_nursery, 1)
helper(foo_loose_nursery, 2)
def test_deprecation_warning_start_guest_run() -> None:
# "The simplest possible "host" loop."
from .._core._tests.test_guest_mode import trivial_guest_run
async def trio_return(in_host: object) -> str:
await trio.lowlevel.checkpoint()
return "ok"
with pytest.warns(
trio.TrioDeprecationWarning,
match="strict_exception_groups=False",
) as record:
trivial_guest_run(
trio_return,
strict_exception_groups=False,
)
assert len(record) == 1
@@ -0,0 +1,900 @@
from __future__ import annotations
import random
from contextlib import asynccontextmanager
from itertools import count
from typing import TYPE_CHECKING, NoReturn
import attrs
import pytest
from trio._tests.pytest_plugin import skip_if_optional_else_raise
try:
import trustme
from OpenSSL import SSL
except ImportError as error:
skip_if_optional_else_raise(error)
import trio
import trio.testing
from trio import DTLSChannel, DTLSEndpoint
from trio.testing._fake_net import FakeNet, UDPPacket
from .._core._tests.tutil import binds_ipv6, gc_collect_harder, slow
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
ca = trustme.CA()
server_cert = ca.issue_cert("example.com")
server_ctx = SSL.Context(SSL.DTLS_METHOD)
server_cert.configure_cert(server_ctx)
client_ctx = SSL.Context(SSL.DTLS_METHOD)
ca.configure_trust(client_ctx)
parametrize_ipv6 = pytest.mark.parametrize(
"ipv6",
[False, pytest.param(True, marks=binds_ipv6)],
ids=["ipv4", "ipv6"],
)
def endpoint(**kwargs: int | bool) -> DTLSEndpoint:
ipv6 = kwargs.pop("ipv6", False)
family = trio.socket.AF_INET6 if ipv6 else trio.socket.AF_INET
sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family)
return DTLSEndpoint(sock, **kwargs)
@asynccontextmanager
async def dtls_echo_server(
*,
autocancel: bool = True,
mtu: int | None = None,
ipv6: bool = False,
) -> AsyncGenerator[tuple[DTLSEndpoint, tuple[str, int]], None]:
with endpoint(ipv6=ipv6) as server:
localhost = "::1" if ipv6 else "127.0.0.1"
await server.socket.bind((localhost, 0))
async with trio.open_nursery() as nursery:
async def echo_handler(dtls_channel: DTLSChannel) -> None:
print(
"echo handler started: "
f"server {dtls_channel.endpoint.socket.getsockname()!r} "
f"client {dtls_channel.peer_address!r}",
)
if mtu is not None:
dtls_channel.set_ciphertext_mtu(mtu)
try:
print("server starting do_handshake")
await dtls_channel.do_handshake()
print("server finished do_handshake")
async for packet in dtls_channel:
print(f"echoing {packet!r} -> {dtls_channel.peer_address!r}")
await dtls_channel.send(packet)
except trio.BrokenResourceError: # pragma: no cover
print("echo handler channel broken")
await nursery.start(server.serve, server_ctx, echo_handler)
yield server, server.socket.getsockname()
if autocancel:
nursery.cancel_scope.cancel()
@parametrize_ipv6
async def test_smoke(ipv6: bool) -> None:
async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address):
with endpoint(ipv6=ipv6) as client_endpoint:
client_channel = client_endpoint.connect(address, client_ctx)
with pytest.raises(trio.NeedHandshakeError):
client_channel.get_cleartext_mtu()
await client_channel.do_handshake()
await client_channel.send(b"hello")
assert await client_channel.receive() == b"hello"
await client_channel.send(b"goodbye")
assert await client_channel.receive() == b"goodbye"
with pytest.raises(
ValueError,
match="^openssl doesn't support sending empty DTLS packets$",
):
await client_channel.send(b"")
client_channel.set_ciphertext_mtu(1234)
cleartext_mtu_1234 = client_channel.get_cleartext_mtu()
client_channel.set_ciphertext_mtu(4321)
assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234
client_channel.set_ciphertext_mtu(1234)
assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234
@slow
async def test_handshake_over_terrible_network(
autojump_clock: trio.testing.MockClock,
) -> None:
HANDSHAKES = 100
r = random.Random(0)
fn = FakeNet()
fn.enable()
# avoid spurious timeouts on slow machines
autojump_clock.autojump_threshold = 0.001
async with dtls_echo_server() as (_, address):
async with trio.open_nursery() as nursery:
async def route_packet(packet: UDPPacket) -> None:
while True:
op = r.choices(
["deliver", "drop", "dupe", "delay"],
weights=[0.7, 0.1, 0.1, 0.1],
)[0]
print(f"{packet.source} -> {packet.destination}: {op}")
if op == "drop":
return
elif op == "dupe":
fn.send_packet(packet)
elif op == "delay":
await trio.sleep(r.random() * 3)
# I wanted to test random packet corruption too, but it turns out
# openssl has a bug in the following scenario:
#
# - client sends ClientHello
# - server sends HelloVerifyRequest with cookie -- but cookie is
# invalid b/c either the ClientHello or HelloVerifyRequest was
# corrupted
# - client re-sends ClientHello with invalid cookie
# - server replies with new HelloVerifyRequest and correct cookie
#
# At this point, the client *should* switch to the new, valid
# cookie. But OpenSSL doesn't; it stubbornly insists on re-sending
# the original, invalid cookie over and over. In theory we could
# work around this by detecting cookie changes and starting over
# with a whole new SSL object, but (a) it doesn't seem worth it, (b)
# when I tried then I ran into another issue where OpenSSL got stuck
# in an infinite loop sending alerts over and over, which I didn't
# dig into because see (a).
#
# elif op == "distort":
# payload = bytearray(packet.payload)
# payload[r.randrange(len(payload))] ^= 1 << r.randrange(8)
# packet = attrs.evolve(packet, payload=payload)
else:
assert op == "deliver"
print(
f"{packet.source} -> {packet.destination}: delivered"
f" {packet.payload.hex()}",
)
fn.deliver_packet(packet)
break
def route_packet_wrapper(packet: UDPPacket) -> None:
try: # noqa: SIM105 # suppressible-exception
nursery.start_soon(route_packet, packet)
except RuntimeError: # pragma: no cover
# We're exiting the nursery, so any remaining packets can just get
# dropped
pass
fn.route_packet = route_packet_wrapper # type: ignore[assignment] # TODO: Fix FakeNet typing
for i in range(HANDSHAKES):
print("#" * 80)
print("#" * 80)
print("#" * 80)
with endpoint() as client_endpoint:
client = client_endpoint.connect(address, client_ctx)
print("client starting do_handshake")
await client.do_handshake()
print("client finished do_handshake")
msg = str(i).encode()
# Make multiple attempts to send data, because the network might
# drop it
while True:
with trio.move_on_after(10) as cscope:
await client.send(msg)
assert await client.receive() == msg
if not cscope.cancelled_caught:
break
async def test_implicit_handshake() -> None:
async with dtls_echo_server() as (_, address):
with endpoint() as client_endpoint:
client = client_endpoint.connect(address, client_ctx)
# Implicit handshake
await client.send(b"xyz")
assert await client.receive() == b"xyz"
async def test_full_duplex() -> None:
# Tests simultaneous send/receive, and also multiple methods implicitly invoking
# do_handshake simultaneously.
with endpoint() as server_endpoint, endpoint() as client_endpoint:
await server_endpoint.socket.bind(("127.0.0.1", 0))
async with trio.open_nursery() as server_nursery:
async def handler(channel: DTLSChannel) -> None:
async with trio.open_nursery() as nursery:
nursery.start_soon(channel.send, b"from server")
nursery.start_soon(channel.receive)
await server_nursery.start(server_endpoint.serve, server_ctx, handler)
client = client_endpoint.connect(
server_endpoint.socket.getsockname(),
client_ctx,
)
async with trio.open_nursery() as nursery:
nursery.start_soon(client.send, b"from client")
nursery.start_soon(client.receive)
server_nursery.cancel_scope.cancel()
async def test_channel_closing() -> None:
async with dtls_echo_server() as (_, address):
with endpoint() as client_endpoint:
client = client_endpoint.connect(address, client_ctx)
await client.do_handshake()
client.close()
with pytest.raises(trio.ClosedResourceError):
await client.send(b"abc")
with pytest.raises(trio.ClosedResourceError):
await client.receive()
# close is idempotent
client.close()
# can also aclose
await client.aclose()
async def test_serve_exits_cleanly_on_close() -> None:
async with dtls_echo_server(autocancel=False) as (server_endpoint, address):
server_endpoint.close()
# Testing that the nursery exits even without being cancelled
# close is idempotent
server_endpoint.close()
async def test_client_multiplex() -> None:
async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2):
with endpoint() as client_endpoint:
client1 = client_endpoint.connect(address1, client_ctx)
client2 = client_endpoint.connect(address2, client_ctx)
await client1.send(b"abc")
await client2.send(b"xyz")
assert await client2.receive() == b"xyz"
assert await client1.receive() == b"abc"
client_endpoint.close()
with pytest.raises(trio.ClosedResourceError):
await client1.send(b"xxx")
with pytest.raises(trio.ClosedResourceError):
await client2.receive()
with pytest.raises(trio.ClosedResourceError):
client_endpoint.connect(address1, client_ctx)
async def null_handler(_: object) -> None: # pragma: no cover
pass
async with trio.open_nursery() as nursery:
with pytest.raises(trio.ClosedResourceError):
await nursery.start(client_endpoint.serve, server_ctx, null_handler)
async def test_dtls_over_dgram_only() -> None:
with trio.socket.socket() as s:
with pytest.raises(ValueError, match="^DTLS requires a SOCK_DGRAM socket$"):
DTLSEndpoint(s)
async def test_double_serve() -> None:
async def null_handler(_: object) -> None: # pragma: no cover
pass
with endpoint() as server_endpoint:
await server_endpoint.socket.bind(("127.0.0.1", 0))
async with trio.open_nursery() as nursery:
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
with pytest.raises(trio.BusyResourceError):
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
nursery.cancel_scope.cancel()
async with trio.open_nursery() as nursery:
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
nursery.cancel_scope.cancel()
async def test_connect_to_non_server(autojump_clock: trio.abc.Clock) -> None:
fn = FakeNet()
fn.enable()
with endpoint() as client1, endpoint() as client2:
await client1.socket.bind(("127.0.0.1", 0))
# This should just time out
with trio.move_on_after(100) as cscope:
channel = client2.connect(client1.socket.getsockname(), client_ctx)
await channel.do_handshake()
assert cscope.cancelled_caught
async def test_incoming_buffer_overflow(autojump_clock: trio.abc.Clock) -> None:
fn = FakeNet()
fn.enable()
for buffer_size in [10, 20]:
async with dtls_echo_server() as (_, address):
with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint:
assert client_endpoint.incoming_packets_buffer == buffer_size
client = client_endpoint.connect(address, client_ctx)
for i in range(buffer_size + 15):
await client.send(str(i).encode())
await trio.sleep(1)
stats = client.statistics()
assert stats.incoming_packets_dropped_in_trio == 15
for i in range(buffer_size):
assert await client.receive() == str(i).encode()
await client.send(b"buffer clear now")
assert await client.receive() == b"buffer clear now"
async def test_server_socket_doesnt_crash_on_garbage(
autojump_clock: trio.abc.Clock,
) -> None:
fn = FakeNet()
fn.enable()
from trio._dtls import (
ContentType,
HandshakeFragment,
HandshakeType,
ProtocolVersion,
Record,
encode_handshake_fragment,
encode_record,
)
client_hello = encode_record(
Record(
content_type=ContentType.handshake,
version=ProtocolVersion.DTLS10,
epoch_seqno=0,
payload=encode_handshake_fragment(
HandshakeFragment(
msg_type=HandshakeType.client_hello,
msg_len=10,
msg_seq=0,
frag_offset=0,
frag_len=10,
frag=bytes(10),
),
),
),
)
client_hello_extended = client_hello + b"\x00"
client_hello_short = client_hello[:-1]
# cuts off in middle of handshake message header
client_hello_really_short = client_hello[:14]
client_hello_corrupt_record_len = bytearray(client_hello)
client_hello_corrupt_record_len[11] = 0xFF
client_hello_fragmented = encode_record(
Record(
content_type=ContentType.handshake,
version=ProtocolVersion.DTLS10,
epoch_seqno=0,
payload=encode_handshake_fragment(
HandshakeFragment(
msg_type=HandshakeType.client_hello,
msg_len=20,
msg_seq=0,
frag_offset=0,
frag_len=10,
frag=bytes(10),
),
),
),
)
client_hello_trailing_data_in_record = encode_record(
Record(
content_type=ContentType.handshake,
version=ProtocolVersion.DTLS10,
epoch_seqno=0,
payload=encode_handshake_fragment(
HandshakeFragment(
msg_type=HandshakeType.client_hello,
msg_len=20,
msg_seq=0,
frag_offset=0,
frag_len=10,
frag=bytes(10),
),
)
+ b"\x00",
),
)
handshake_empty = encode_record(
Record(
content_type=ContentType.handshake,
version=ProtocolVersion.DTLS10,
epoch_seqno=0,
payload=b"",
),
)
client_hello_truncated_in_cookie = encode_record(
Record(
content_type=ContentType.handshake,
version=ProtocolVersion.DTLS10,
epoch_seqno=0,
payload=bytes(2 + 32 + 1) + b"\xff",
),
)
async with dtls_echo_server() as (_, address):
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock:
for bad_packet in [
b"",
b"xyz",
client_hello_extended,
client_hello_short,
client_hello_really_short,
client_hello_corrupt_record_len,
client_hello_fragmented,
client_hello_trailing_data_in_record,
handshake_empty,
client_hello_truncated_in_cookie,
]:
await sock.sendto(bad_packet, address)
await trio.sleep(1)
async def test_invalid_cookie_rejected(autojump_clock: trio.abc.Clock) -> None:
fn = FakeNet()
fn.enable()
from trio._dtls import BadPacket, decode_client_hello_untrusted
with trio.CancelScope() as cscope:
# the first 11 bytes of ClientHello aren't protected by the cookie, so only test
# corrupting bytes after that.
offset_to_corrupt = count(11)
def route_packet(packet: UDPPacket) -> None:
try:
_, cookie, _ = decode_client_hello_untrusted(packet.payload)
except BadPacket:
pass
else:
if len(cookie) != 0:
# this is a challenge response packet
# let's corrupt the next offset so the handshake should fail
payload = bytearray(packet.payload)
offset = next(offset_to_corrupt)
if offset >= len(payload):
# We've tried all offsets. Clamp offset to the end of the
# payload, and terminate the test.
offset = len(payload) - 1
cscope.cancel()
payload[offset] ^= 0x01
packet = attrs.evolve(packet, payload=payload)
fn.deliver_packet(packet)
fn.route_packet = route_packet # type: ignore[assignment] # TODO: Fix FakeNet typing
async with dtls_echo_server() as (_, address):
while True:
with endpoint() as client:
channel = client.connect(address, client_ctx)
await channel.do_handshake()
assert cscope.cancelled_caught
async def test_client_cancels_handshake_and_starts_new_one(
autojump_clock: trio.abc.Clock,
) -> None:
# if a client disappears during the handshake, and then starts a new handshake from
# scratch, then the first handler's channel should fail, and a new handler get
# started
fn = FakeNet()
fn.enable()
with endpoint() as server, endpoint() as client:
await server.socket.bind(("127.0.0.1", 0))
async with trio.open_nursery() as nursery:
first_time = True
async def handler(channel: DTLSChannel) -> None:
nonlocal first_time
if first_time:
first_time = False
print("handler: first time, cancelling connect")
connect_cscope.cancel()
await trio.sleep(0.5)
print("handler: handshake should fail now")
with pytest.raises(trio.BrokenResourceError):
await channel.do_handshake()
else:
print("handler: not first time, sending hello")
await channel.send(b"hello")
await nursery.start(server.serve, server_ctx, handler)
print("client: starting first connect")
with trio.CancelScope() as connect_cscope:
channel = client.connect(server.socket.getsockname(), client_ctx)
await channel.do_handshake()
assert connect_cscope.cancelled_caught
print("client: starting second connect")
channel = client.connect(server.socket.getsockname(), client_ctx)
assert await channel.receive() == b"hello"
# Give handlers a chance to finish
await trio.sleep(10)
nursery.cancel_scope.cancel()
async def test_swap_client_server() -> None:
with endpoint() as a, endpoint() as b:
await a.socket.bind(("127.0.0.1", 0))
await b.socket.bind(("127.0.0.1", 0))
async def echo_handler(channel: DTLSChannel) -> None:
async for packet in channel:
await channel.send(packet)
async def crashing_echo_handler(channel: DTLSChannel) -> None:
with pytest.raises(trio.BrokenResourceError):
await echo_handler(channel)
async with trio.open_nursery() as nursery:
await nursery.start(a.serve, server_ctx, crashing_echo_handler)
await nursery.start(b.serve, server_ctx, echo_handler)
b_to_a = b.connect(a.socket.getsockname(), client_ctx)
await b_to_a.send(b"b as client")
assert await b_to_a.receive() == b"b as client"
a_to_b = a.connect(b.socket.getsockname(), client_ctx)
await a_to_b.do_handshake()
with pytest.raises(trio.BrokenResourceError):
await b_to_a.send(b"association broken")
await a_to_b.send(b"a as client")
assert await a_to_b.receive() == b"a as client"
nursery.cancel_scope.cancel()
@slow
async def test_openssl_retransmit_doesnt_break_stuff() -> None:
# can't use autojump_clock here, because the point of the test is to wait for
# openssl's built-in retransmit timer to expire, which is hard-coded to use
# wall-clock time.
fn = FakeNet()
fn.enable()
blackholed = True
def route_packet(packet: UDPPacket) -> None:
if blackholed:
print("dropped packet", packet)
return
print("delivered packet", packet)
# packets.append(
# scapy.all.IP(
# src=packet.source.ip.compressed, dst=packet.destination.ip.compressed
# )
# / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port)
# / packet.payload
# )
fn.deliver_packet(packet)
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
async with dtls_echo_server() as (server_endpoint, address):
with endpoint() as client_endpoint:
async with trio.open_nursery() as nursery:
async def connecter() -> None:
client = client_endpoint.connect(address, client_ctx)
await client.do_handshake(initial_retransmit_timeout=1.5)
await client.send(b"hi")
assert await client.receive() == b"hi"
nursery.start_soon(connecter)
# openssl's default timeout is 1 second, so this ensures that it thinks
# the timeout has expired
await trio.sleep(1.1)
# disable blackholing and send a garbage packet to wake up openssl so it
# notices the timeout has expired
blackholed = False
await server_endpoint.socket.sendto(
b"xxx",
client_endpoint.socket.getsockname(),
)
# now the client task should finish connecting and exit cleanly
# scapy.all.wrpcap("/tmp/trace.pcap", packets)
async def test_initial_retransmit_timeout_configuration(
autojump_clock: trio.abc.Clock,
) -> None:
fn = FakeNet()
fn.enable()
blackholed = True
def route_packet(packet: UDPPacket) -> None:
nonlocal blackholed
if blackholed:
blackholed = False
else:
fn.deliver_packet(packet)
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
async with dtls_echo_server() as (_, address):
for t in [1, 2, 4]:
with endpoint() as client:
before = trio.current_time()
blackholed = True
channel = client.connect(address, client_ctx)
await channel.do_handshake(initial_retransmit_timeout=t)
after = trio.current_time()
assert after - before == t
async def test_explicit_tiny_mtu_is_respected() -> None:
# ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to
# be larger than that. (300 is still smaller than any real network though.)
MTU = 300
fn = FakeNet()
fn.enable()
def route_packet(packet: UDPPacket) -> None:
print(f"delivering {packet}")
print(f"payload size: {len(packet.payload)}")
assert len(packet.payload) <= MTU
fn.deliver_packet(packet)
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
async with dtls_echo_server(mtu=MTU) as (server, address):
with endpoint() as client:
channel = client.connect(address, client_ctx)
channel.set_ciphertext_mtu(MTU)
await channel.do_handshake()
await channel.send(b"hi")
assert await channel.receive() == b"hi"
@parametrize_ipv6
async def test_handshake_handles_minimum_network_mtu(
ipv6: bool,
autojump_clock: trio.abc.Clock,
) -> None:
# Fake network that has the minimum allowable MTU for whatever protocol we're using.
fn = FakeNet()
fn.enable()
mtu = 1280 - 48 if ipv6 else 576 - 28
def route_packet(packet: UDPPacket) -> None:
if len(packet.payload) > mtu:
print(f"dropping {packet}")
else:
print(f"delivering {packet}")
fn.deliver_packet(packet)
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
# See if we can successfully do a handshake -- some of the volleys will get dropped,
# and the retransmit logic should detect this and back off the MTU to something
# smaller until it succeeds.
async with dtls_echo_server(ipv6=ipv6) as (_, address):
with endpoint(ipv6=ipv6) as client_endpoint:
client = client_endpoint.connect(address, client_ctx)
# the handshake mtu backoff shouldn't affect the return value from
# get_cleartext_mtu, b/c that's under the user's control via
# set_ciphertext_mtu
client.set_ciphertext_mtu(9999)
await client.send(b"xyz")
assert await client.receive() == b"xyz"
assert client.get_cleartext_mtu() > 9000 # as vegeta said
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
async def test_system_task_cleaned_up_on_gc() -> None:
before_tasks = trio.lowlevel.current_statistics().tasks_living
# We put this into a sub-function so that everything automatically becomes garbage
# when the frame exits. For some reason just doing 'del e' wasn't enough on pypy
# with coverage enabled -- I think we were hitting this bug:
# https://foss.heptapod.net/pypy/pypy/-/issues/3656
async def start_and_forget_endpoint() -> int:
e = endpoint()
# This connection/handshake attempt can't succeed. The only purpose is to force
# the endpoint to set up a receive loop.
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
await s.bind(("127.0.0.1", 0))
c = e.connect(s.getsockname(), client_ctx)
async with trio.open_nursery() as nursery:
nursery.start_soon(c.do_handshake)
await trio.testing.wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
during_tasks = trio.lowlevel.current_statistics().tasks_living
return during_tasks
with pytest.warns(ResourceWarning):
during_tasks = await start_and_forget_endpoint()
await trio.testing.wait_all_tasks_blocked()
gc_collect_harder()
await trio.testing.wait_all_tasks_blocked()
after_tasks = trio.lowlevel.current_statistics().tasks_living
assert before_tasks < during_tasks
assert before_tasks == after_tasks
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
async def test_gc_before_system_task_starts() -> None:
e = endpoint()
with pytest.warns(ResourceWarning):
del e
gc_collect_harder()
await trio.testing.wait_all_tasks_blocked()
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
async def test_gc_as_packet_received() -> None:
fn = FakeNet()
fn.enable()
e = endpoint()
await e.socket.bind(("127.0.0.1", 0))
e._ensure_receive_loop()
await trio.testing.wait_all_tasks_blocked()
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
await s.sendto(b"xxx", e.socket.getsockname())
# At this point, the endpoint's receive loop has been marked runnable because it
# just received a packet; closing the endpoint socket won't interrupt that. But by
# the time it wakes up to process the packet, the endpoint will be gone.
with pytest.warns(ResourceWarning):
del e
gc_collect_harder()
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
def test_gc_after_trio_exits() -> None:
async def main() -> DTLSEndpoint:
# We use fakenet just to make sure no real sockets can leak out of the test
# case - on pypy somehow the socket was outliving the gc_collect_harder call
# below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode
# when called after trio exits, it doesn't need a real socket.
fn = FakeNet()
fn.enable()
return endpoint()
e = trio.run(main)
with pytest.warns(ResourceWarning):
del e
gc_collect_harder()
async def test_already_closed_socket_doesnt_crash() -> None:
with endpoint() as e:
# We close the socket before checkpointing, so the socket will already be closed
# when the system task starts up
e.socket.close()
# Now give it a chance to start up, and hopefully not crash
await trio.testing.wait_all_tasks_blocked()
async def test_socket_closed_while_processing_clienthello(
autojump_clock: trio.abc.Clock,
) -> None:
fn = FakeNet()
fn.enable()
# Check what happens if the socket is discovered to be closed when sending a
# HelloVerifyRequest, since that has its own sending logic
async with dtls_echo_server() as (server, address):
def route_packet(packet: UDPPacket) -> None:
fn.deliver_packet(packet)
server.socket.close()
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
with endpoint() as client_endpoint:
with trio.move_on_after(10):
client = client_endpoint.connect(address, client_ctx)
await client.do_handshake()
async def test_association_replaced_while_handshake_running(
autojump_clock: trio.abc.Clock,
) -> None:
fn = FakeNet()
fn.enable()
def route_packet(packet: UDPPacket) -> None:
pass
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
async with dtls_echo_server() as (_, address):
with endpoint() as client_endpoint:
c1 = client_endpoint.connect(address, client_ctx)
async with trio.open_nursery() as nursery:
async def doomed_handshake() -> None:
with pytest.raises(trio.BrokenResourceError):
await c1.do_handshake()
nursery.start_soon(doomed_handshake)
await trio.sleep(10)
client_endpoint.connect(address, client_ctx)
async def test_association_replaced_before_handshake_starts() -> None:
fn = FakeNet()
fn.enable()
# This test shouldn't send any packets
def route_packet(packet: UDPPacket) -> NoReturn: # pragma: no cover
raise AssertionError()
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
async with dtls_echo_server() as (_, address):
with endpoint() as client_endpoint:
c1 = client_endpoint.connect(address, client_ctx)
client_endpoint.connect(address, client_ctx)
with pytest.raises(trio.BrokenResourceError):
await c1.do_handshake()
async def test_send_to_closed_local_port() -> None:
# On Windows, sending a UDP packet to a closed local port can cause a weird
# ECONNRESET error later, inside the receive task. Make sure we're handling it
# properly.
async with dtls_echo_server() as (_, address):
with endpoint() as client_endpoint:
async with trio.open_nursery() as nursery:
for i in range(1, 10):
channel = client_endpoint.connect(("127.0.0.1", i), client_ctx)
nursery.start_soon(channel.do_handshake)
channel = client_endpoint.connect(address, client_ctx)
await channel.send(b"xxx")
assert await channel.receive() == b"xxx"
nursery.cancel_scope.cancel()
@@ -0,0 +1,574 @@
from __future__ import annotations # isort: split
import __future__ # Regular import, not special!
import enum
import functools
import importlib
import inspect
import json
import socket as stdlib_socket
import sys
import types
from pathlib import Path, PurePath
from types import ModuleType
from typing import TYPE_CHECKING, Protocol
import attrs
import pytest
import trio
import trio.testing
from trio._tests.pytest_plugin import skip_if_optional_else_raise
from .. import _core, _util
from .._core._tests.tutil import slow
from .pytest_plugin import RUN_SLOW
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
mypy_cache_updated = False
try: # If installed, check both versions of this class.
from typing_extensions import Protocol as Protocol_ext
except ImportError: # pragma: no cover
Protocol_ext = Protocol # type: ignore[assignment]
def _ensure_mypy_cache_updated() -> None:
# This pollutes the `empty` dir. Should this be changed?
try:
from mypy.api import run
except ImportError as error:
skip_if_optional_else_raise(error)
global mypy_cache_updated
if not mypy_cache_updated:
# mypy cache was *probably* already updated by the other tests,
# but `pytest -k ...` might run just this test on its own
result = run(
[
"--config-file=",
"--cache-dir=./.mypy_cache",
"--no-error-summary",
"-c",
"import trio",
],
)
assert not result[1] # stderr
assert not result[0] # stdout
mypy_cache_updated = True
def test_core_is_properly_reexported() -> None:
# Each export from _core should be re-exported by exactly one of these
# three modules:
sources = [trio, trio.lowlevel, trio.testing]
for symbol in dir(_core):
if symbol.startswith("_"):
continue
found = 0
for source in sources:
if symbol in dir(source) and getattr(source, symbol) is getattr(
_core,
symbol,
):
found += 1
print(symbol, found)
assert found == 1
def class_is_final(cls: type) -> bool:
"""Check if a class cannot be subclassed."""
try:
# new_class() handles metaclasses properly, type(...) does not.
types.new_class("SubclassTester", (cls,))
except TypeError:
return True
else:
return False
def iter_modules(
module: types.ModuleType,
only_public: bool,
) -> Iterator[types.ModuleType]:
yield module
for name, class_ in module.__dict__.items():
if name.startswith("_") and only_public:
continue
if not isinstance(class_, ModuleType):
continue
if not class_.__name__.startswith(module.__name__): # pragma: no cover
continue
if class_ is module: # pragma: no cover
continue
yield from iter_modules(class_, only_public)
PUBLIC_MODULES = list(iter_modules(trio, only_public=True))
ALL_MODULES = list(iter_modules(trio, only_public=False))
PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES]
# It doesn't make sense for downstream redistributors to run this test, since
# they might be using a newer version of Python with additional symbols which
# won't be reflected in trio.socket, and this shouldn't cause downstream test
# runs to start failing.
@pytest.mark.redistributors_should_skip()
# Static analysis tools often have trouble with alpha releases, where Python's
# internals are in flux, grammar may not have settled down, etc.
@pytest.mark.skipif(
sys.version_info.releaselevel == "alpha",
reason="skip static introspection tools on Python dev/alpha releases",
)
@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES)
@pytest.mark.parametrize("tool", ["pylint", "jedi", "mypy", "pyright_verifytypes"])
@pytest.mark.filterwarnings(
# https://github.com/pypa/setuptools/issues/3274
"ignore:module 'sre_constants' is deprecated:DeprecationWarning",
)
def test_static_tool_sees_all_symbols(tool: str, modname: str, tmp_path: Path) -> None:
module = importlib.import_module(modname)
def no_underscores(symbols: Iterable[str]) -> set[str]:
return {symbol for symbol in symbols if not symbol.startswith("_")}
runtime_names = no_underscores(dir(module))
# ignore deprecated module `tests` being invisible
if modname == "trio":
runtime_names.discard("tests")
# Ignore any __future__ feature objects, if imported under that name.
for name in __future__.all_feature_names:
if getattr(module, name, None) is getattr(__future__, name):
runtime_names.remove(name)
if tool == "pylint":
try:
from pylint.lint import PyLinter
except ImportError as error:
skip_if_optional_else_raise(error)
linter = PyLinter()
assert module.__file__ is not None
ast = linter.get_ast(module.__file__, modname)
static_names = no_underscores(ast) # type: ignore[arg-type]
elif tool == "jedi":
if sys.implementation.name != "cpython":
pytest.skip("jedi does not support pypy")
try:
import jedi
except ImportError as error:
skip_if_optional_else_raise(error)
# Simulate typing "import trio; trio.<TAB>"
script = jedi.Script(f"import {modname}; {modname}.")
completions = script.complete()
static_names = no_underscores(c.name for c in completions)
elif tool == "mypy":
if not RUN_SLOW: # pragma: no cover
pytest.skip("use --run-slow to check against mypy")
cache = Path.cwd() / ".mypy_cache"
_ensure_mypy_cache_updated()
trio_cache = next(cache.glob("*/trio"))
_, modname = (modname + ".").split(".", 1)
modname = modname[:-1]
mod_cache = trio_cache / modname if modname else trio_cache
if mod_cache.is_dir(): # pragma: no coverage
mod_cache = mod_cache / "__init__.data.json"
else:
mod_cache = trio_cache / (modname + ".data.json")
assert mod_cache.exists()
assert mod_cache.is_file()
with mod_cache.open() as cache_file:
cache_json = json.loads(cache_file.read())
static_names = no_underscores(
key
for key, value in cache_json["names"].items()
if not key.startswith(".") and value["kind"] == "Gdef"
)
elif tool == "pyright_verifytypes":
if not RUN_SLOW: # pragma: no cover
pytest.skip("use --run-slow to check against pyright")
try:
import pyright # noqa: F401
except ImportError as error:
skip_if_optional_else_raise(error)
import subprocess
res = subprocess.run(
["pyright", f"--verifytypes={modname}", "--outputjson"],
capture_output=True,
)
current_result = json.loads(res.stdout)
static_names = {
x["name"][len(modname) + 1 :]
for x in current_result["typeCompleteness"]["symbols"]
if x["name"].startswith(modname)
}
else: # pragma: no cover
raise AssertionError()
# It's expected that the static set will contain more names than the
# runtime set:
# - static tools are sometimes sloppy and include deleted names
# - some symbols are platform-specific at runtime, but always show up in
# static analysis (e.g. in trio.socket or trio.lowlevel)
# So we check that the runtime names are a subset of the static names.
missing_names = runtime_names - static_names
# ignore warnings about deprecated module tests
missing_names -= {"tests"}
if missing_names: # pragma: no cover
print(f"{tool} can't see the following names in {modname}:")
print()
for name in sorted(missing_names):
print(f" {name}")
raise AssertionError()
# this could be sped up by only invoking mypy once per module, or even once for all
# modules, instead of once per class.
@slow
# see comment on test_static_tool_sees_all_symbols
@pytest.mark.redistributors_should_skip()
# Static analysis tools often have trouble with alpha releases, where Python's
# internals are in flux, grammar may not have settled down, etc.
@pytest.mark.skipif(
sys.version_info.releaselevel == "alpha",
reason="skip static introspection tools on Python dev/alpha releases",
)
@pytest.mark.parametrize("module_name", PUBLIC_MODULE_NAMES)
@pytest.mark.parametrize("tool", ["jedi", "mypy"])
def test_static_tool_sees_class_members(
tool: str,
module_name: str,
tmp_path: Path,
) -> None:
module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)]
# ignore hidden, but not dunder, symbols
def no_hidden(symbols: Iterable[str]) -> set[str]:
return {
symbol
for symbol in symbols
if (not symbol.startswith("_")) or symbol.startswith("__")
}
if tool == "jedi" and sys.implementation.name != "cpython":
pytest.skip("jedi does not support pypy")
if tool == "mypy":
cache = Path.cwd() / ".mypy_cache"
_ensure_mypy_cache_updated()
trio_cache = next(cache.glob("*/trio"))
modname = module_name
_, modname = (modname + ".").split(".", 1)
modname = modname[:-1]
mod_cache = trio_cache / modname if modname else trio_cache
if mod_cache.is_dir():
mod_cache = mod_cache / "__init__.data.json"
else:
mod_cache = trio_cache / (modname + ".data.json")
assert mod_cache.exists()
assert mod_cache.is_file()
with mod_cache.open() as cache_file:
cache_json = json.loads(cache_file.read())
# skip a bunch of file-system activity (probably can un-memoize?)
@functools.lru_cache
def lookup_symbol(symbol: str) -> dict[str, str]:
topname, *modname, name = symbol.split(".")
version = next(cache.glob("3.*/"))
mod_cache = version / topname
if not mod_cache.is_dir():
mod_cache = version / (topname + ".data.json")
if modname:
for piece in modname[:-1]:
mod_cache /= piece
next_cache = mod_cache / modname[-1]
if next_cache.is_dir(): # pragma: no coverage
mod_cache = next_cache / "__init__.data.json"
else:
mod_cache = mod_cache / (modname[-1] + ".data.json")
elif mod_cache.is_dir():
mod_cache /= "__init__.data.json"
with mod_cache.open() as f:
return json.loads(f.read())["names"][name] # type: ignore[no-any-return]
errors: dict[str, object] = {}
for class_name, class_ in module.__dict__.items():
if not isinstance(class_, type):
continue
if module_name == "trio.socket" and class_name in dir(stdlib_socket):
continue
# ignore class that does dirty tricks
if class_ is trio.testing.RaisesGroup:
continue
# dir() and inspect.getmembers doesn't display properties from the metaclass
# also ignore some dunder methods that tend to differ but are of no consequence
ignore_names = set(dir(type(class_))) | {
"__annotations__",
"__attrs_attrs__",
"__attrs_own_setattr__",
"__callable_proto_members_only__",
"__class_getitem__",
"__final__",
"__getstate__",
"__match_args__",
"__order__",
"__orig_bases__",
"__parameters__",
"__protocol_attrs__",
"__setstate__",
"__slots__",
"__weakref__",
# ignore errors about dunders inherited from stdlib that tools might
# not see
"__copy__",
"__deepcopy__",
}
if type(class_) is type:
# C extension classes don't have these dunders, but Python classes do
ignore_names.add("__firstlineno__")
ignore_names.add("__static_attributes__")
# pypy seems to have some additional dunders that differ
if sys.implementation.name == "pypy":
ignore_names |= {
"__basicsize__",
"__dictoffset__",
"__itemsize__",
"__sizeof__",
"__weakrefoffset__",
"__unicode__",
}
# inspect.getmembers sees `name` and `value` in Enums, otherwise
# it behaves the same way as `dir`
# runtime_names = no_underscores(dir(class_))
runtime_names = (
no_hidden(x[0] for x in inspect.getmembers(class_)) - ignore_names
)
if tool == "jedi":
try:
import jedi
except ImportError as error:
skip_if_optional_else_raise(error)
script = jedi.Script(
f"from {module_name} import {class_name}; {class_name}.",
)
completions = script.complete()
static_names = no_hidden(c.name for c in completions) - ignore_names
elif tool == "mypy":
# load the cached type information
cached_type_info = cache_json["names"][class_name]
if "node" not in cached_type_info:
cached_type_info = lookup_symbol(cached_type_info["cross_ref"])
assert "node" in cached_type_info
node = cached_type_info["node"]
static_names = no_hidden(k for k in node["names"] if not k.startswith("."))
for symbol in node["mro"][1:]:
node = lookup_symbol(symbol)["node"]
static_names |= no_hidden(
k for k in node["names"] if not k.startswith(".")
)
static_names -= ignore_names
else: # pragma: no cover
raise AssertionError("unknown tool")
missing = runtime_names - static_names
extra = static_names - runtime_names
# using .remove() instead of .delete() to get an error in case they start not
# being missing
if (
tool == "jedi"
and BaseException in class_.__mro__
and sys.version_info >= (3, 11)
):
missing.remove("add_note")
if (
tool == "mypy"
and BaseException in class_.__mro__
and sys.version_info >= (3, 11)
):
extra.remove("__notes__")
if tool == "mypy" and attrs.has(class_):
# e.g. __trio__core__run_CancelScope_AttrsAttributes__
before = len(extra)
extra = {e for e in extra if not e.endswith("AttrsAttributes__")}
assert len(extra) == before - 1
# mypy does not see these attributes in Enum subclasses
if (
tool == "mypy"
and enum.Enum in class_.__mro__
and sys.version_info >= (3, 12)
):
# Another attribute, in 3.12+ only.
extra.remove("__signature__")
# TODO: this *should* be visible via `dir`!!
if tool == "mypy" and class_ == trio.Nursery:
extra.remove("cancel_scope")
# These are (mostly? solely?) *runtime* attributes, often set in
# __init__, which doesn't show up with dir() or inspect.getmembers,
# but we get them in the way we query mypy & jedi
EXTRAS = {
trio.DTLSChannel: {"peer_address", "endpoint"},
trio.DTLSEndpoint: {"socket", "incoming_packets_buffer"},
trio.Process: {"args", "pid", "stderr", "stdin", "stdio", "stdout"},
trio.SSLListener: {"transport_listener"},
trio.SSLStream: {"transport_stream"},
trio.SocketListener: {"socket"},
trio.SocketStream: {"socket"},
trio.testing.MemoryReceiveStream: {"close_hook", "receive_some_hook"},
trio.testing.MemorySendStream: {
"close_hook",
"send_all_hook",
"wait_send_all_might_not_block_hook",
},
trio.testing.Matcher: {
"exception_type",
"match",
"check",
},
}
if tool == "mypy" and class_ in EXTRAS:
before = len(extra)
extra -= EXTRAS[class_]
assert len(extra) == before - len(EXTRAS[class_])
# TODO: why is this? Is it a problem?
# see https://github.com/python-trio/trio/pull/2631#discussion_r1185615916
if class_ == trio.StapledStream:
extra.remove("receive_stream")
extra.remove("send_stream")
# I have not researched why these are missing, should maybe create an issue
# upstream with jedi
if tool == "jedi" and sys.version_info >= (3, 11):
if class_ in (
trio.DTLSChannel,
trio.MemoryReceiveChannel,
trio.MemorySendChannel,
trio.SSLListener,
trio.SocketListener,
):
missing.remove("__aenter__")
missing.remove("__aexit__")
if class_ in (trio.DTLSChannel, trio.MemoryReceiveChannel):
missing.remove("__aiter__")
missing.remove("__anext__")
if class_ in (trio.Path, trio.WindowsPath, trio.PosixPath):
# These are from inherited subclasses.
missing -= PurePath.__dict__.keys()
# These are unix-only.
if tool == "mypy" and sys.platform == "win32":
missing -= {"owner", "is_mount", "group"}
if tool == "jedi" and sys.platform == "win32":
extra -= {"owner", "is_mount", "group"}
# not sure why jedi in particular ignores this (static?) method in 3.13
# (especially given the method is from 3.12....)
if (
tool == "jedi"
and sys.version_info >= (3, 13)
and class_ in (trio.Path, trio.WindowsPath, trio.PosixPath)
):
missing.remove("with_segments")
if missing or extra: # pragma: no cover
errors[f"{module_name}.{class_name}"] = {
"missing": missing,
"extra": extra,
}
# `assert not errors` will not print the full content of errors, even with
# `--verbose`, so we manually print it
if errors: # pragma: no cover
from pprint import pprint
print(f"\n{tool} can't see the following symbols in {module_name}:")
pprint(errors)
assert not errors
def test_nopublic_is_final() -> None:
"""Check all NoPublicConstructor classes are also @final."""
assert class_is_final(_util.NoPublicConstructor) # This is itself final.
for module in ALL_MODULES:
for class_ in module.__dict__.values():
if isinstance(class_, _util.NoPublicConstructor):
assert class_is_final(class_)
def test_classes_are_final() -> None:
# Sanity checks.
assert not class_is_final(object)
assert class_is_final(bool)
for module in PUBLIC_MODULES:
for name, class_ in module.__dict__.items():
if not isinstance(class_, type):
continue
# Deprecated classes are exported with a leading underscore
if name.startswith("_"): # pragma: no cover
continue
# Abstract classes can be subclassed, because that's the whole
# point of ABCs
if inspect.isabstract(class_):
continue
# Same with protocols, but only direct children.
if Protocol in class_.__bases__ or Protocol_ext in class_.__bases__:
continue
# Exceptions are allowed to be subclassed, because exception
# subclassing isn't used to inherit behavior.
if issubclass(class_, BaseException):
continue
# These are classes that are conceptually abstract, but
# inspect.isabstract returns False for boring reasons.
if class_ is trio.abc.Instrument or class_ is trio.socket.SocketType:
continue
# ... insert other special cases here ...
# The `Path` class needs to support inheritance to allow `WindowsPath` and `PosixPath`.
if class_ is trio.Path:
continue
# don't care about the *Statistics classes
if name.endswith("Statistics"):
continue
assert class_is_final(class_)
@@ -0,0 +1,313 @@
import errno
import re
import socket
import sys
import pytest
import trio
from trio.testing._fake_net import FakeNet
# ENOTCONN gives different messages on different platforms
if sys.platform == "linux":
ENOTCONN_MSG = r"^\[Errno 107\] (Transport endpoint is|Socket) not connected$"
elif sys.platform == "darwin":
ENOTCONN_MSG = r"^\[Errno 57\] Socket is not connected$"
else:
ENOTCONN_MSG = r"^\[Errno 10057\] Unknown error$"
def fn() -> FakeNet:
fn = FakeNet()
fn.enable()
return fn
async def test_basic_udp() -> None:
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
await s1.bind(("127.0.0.1", 0))
ip, port = s1.getsockname()
assert ip == "127.0.0.1"
assert port != 0
with pytest.raises(
OSError,
match=r"^\[\w+ \d+\] Invalid argument$",
) as exc: # Cannot rebind.
await s1.bind(("192.0.2.1", 0))
assert exc.value.errno == errno.EINVAL
# Cannot bind multiple sockets to the same address
with pytest.raises(
OSError,
match=r"^\[\w+ \d+\] (Address (already )?in use|Unknown error)$",
) as exc:
await s2.bind(("127.0.0.1", port))
assert exc.value.errno == errno.EADDRINUSE
await s2.sendto(b"xyz", s1.getsockname())
data, addr = await s1.recvfrom(10)
assert data == b"xyz"
assert addr == s2.getsockname()
await s1.sendto(b"abc", s2.getsockname())
data, addr = await s2.recvfrom(10)
assert data == b"abc"
assert addr == s1.getsockname()
async def test_msg_trunc() -> None:
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
await s1.bind(("127.0.0.1", 0))
await s2.sendto(b"xyz", s1.getsockname())
data, addr = await s1.recvfrom(10)
async def test_recv_methods() -> None:
"""Test all recv methods for codecov"""
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
# receiving on an unbound socket is a bad idea (I think?)
with pytest.raises(NotImplementedError, match="code will most likely hang"):
await s2.recv(10)
await s1.bind(("127.0.0.1", 0))
ip, port = s1.getsockname()
assert ip == "127.0.0.1"
assert port != 0
# recvfrom
await s2.sendto(b"abc", s1.getsockname())
data, addr = await s1.recvfrom(10)
assert data == b"abc"
assert addr == s2.getsockname()
# recv
await s1.sendto(b"def", s2.getsockname())
data = await s2.recv(10)
assert data == b"def"
# recvfrom_into
assert await s1.sendto(b"ghi", s2.getsockname()) == 3
buf = bytearray(10)
with pytest.raises(NotImplementedError, match="^partial recvfrom_into$"):
(nbytes, addr) = await s2.recvfrom_into(buf, nbytes=2)
(nbytes, addr) = await s2.recvfrom_into(buf)
assert nbytes == 3
assert buf == b"ghi" + b"\x00" * 7
assert addr == s1.getsockname()
# recv_into
assert await s1.sendto(b"jkl", s2.getsockname()) == 3
buf2 = bytearray(10)
nbytes = await s2.recv_into(buf2)
assert nbytes == 3
assert buf2 == b"jkl" + b"\x00" * 7
if sys.platform == "linux" and sys.implementation.name == "cpython":
flags: int = socket.MSG_MORE
else:
flags = 1
# Send seems explicitly non-functional
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
await s2.send(b"mno")
assert exc.value.errno == errno.ENOTCONN
with pytest.raises(NotImplementedError, match="^FakeNet send flags must be 0, not"):
await s2.send(b"mno", flags)
# sendto errors
# it's successfully used earlier
with pytest.raises(NotImplementedError, match="^FakeNet send flags must be 0, not"):
await s2.sendto(b"mno", flags, s1.getsockname())
with pytest.raises(TypeError, match="wrong number of arguments$"):
await s2.sendto(b"mno", flags, s1.getsockname(), "extra arg") # type: ignore[call-overload]
@pytest.mark.skipif(
sys.platform == "win32",
reason="functions not in socket on windows",
)
async def test_nonwindows_functionality() -> None:
# mypy doesn't support a good way of aborting typechecking on different platforms
if sys.platform != "win32": # pragma: no branch
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
await s2.bind(("127.0.0.1", 0))
# sendmsg
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
await s2.sendmsg([b"mno"])
assert exc.value.errno == errno.ENOTCONN
assert await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) == 3
(data, ancdata, msg_flags, addr) = await s2.recvmsg(10)
assert data == b"jkl"
assert ancdata == []
assert msg_flags == 0
assert addr == s1.getsockname()
# TODO: recvmsg
# recvmsg_into
assert await s1.sendto(b"xyzw", s2.getsockname()) == 4
buf1 = bytearray(2)
buf2 = bytearray(3)
ret = await s2.recvmsg_into([buf1, buf2])
(nbytes, ancdata, msg_flags, addr) = ret
assert nbytes == 4
assert buf1 == b"xy"
assert buf2 == b"zw" + b"\x00"
assert ancdata == []
assert msg_flags == 0
assert addr == s1.getsockname()
# recvmsg_into with MSG_TRUNC set
assert await s1.sendto(b"xyzwv", s2.getsockname()) == 5
buf1 = bytearray(2)
ret = await s2.recvmsg_into([buf1])
(nbytes, ancdata, msg_flags, addr) = ret
assert nbytes == 2
assert buf1 == b"xy"
assert ancdata == []
assert msg_flags == socket.MSG_TRUNC
assert addr == s1.getsockname()
with pytest.raises(
AttributeError,
match="^'FakeSocket' object has no attribute 'share'$",
):
await s1.share(0) # type: ignore[attr-defined]
@pytest.mark.skipif(
sys.platform != "win32",
reason="windows-specific fakesocket testing",
)
async def test_windows_functionality() -> None:
# mypy doesn't support a good way of aborting typechecking on different platforms
if sys.platform == "win32": # pragma: no branch
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
await s1.bind(("127.0.0.1", 0))
with pytest.raises(
AttributeError,
match="^'FakeSocket' object has no attribute 'sendmsg'$",
):
await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) # type: ignore[attr-defined]
with pytest.raises(
AttributeError,
match="^'FakeSocket' object has no attribute 'recvmsg'$",
):
s2.recvmsg(0) # type: ignore[attr-defined]
with pytest.raises(
AttributeError,
match="^'FakeSocket' object has no attribute 'recvmsg_into'$",
):
s2.recvmsg_into([]) # type: ignore[attr-defined]
with pytest.raises(NotImplementedError):
s1.share(0)
async def test_basic_tcp() -> None:
fn()
with pytest.raises(NotImplementedError):
trio.socket.socket()
async def test_not_implemented_functions() -> None:
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
# getsockopt
with pytest.raises(
OSError,
match=r"^FakeNet doesn't implement getsockopt\(\d, \d\)$",
):
s1.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
# setsockopt
with pytest.raises(
NotImplementedError,
match="^FakeNet always has IPV6_V6ONLY=True$",
):
s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
with pytest.raises(
OSError,
match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$",
):
s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True)
with pytest.raises(
OSError,
match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$",
):
s1.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# set_inheritable
s1.set_inheritable(False)
with pytest.raises(
NotImplementedError,
match="^FakeNet can't make inheritable sockets$",
):
s1.set_inheritable(True)
# get_inheritable
assert not s1.get_inheritable()
async def test_getpeername() -> None:
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
s1.getpeername()
assert exc.value.errno == errno.ENOTCONN
await s1.bind(("127.0.0.1", 0))
with pytest.raises(
AssertionError,
match="^This method seems to assume that self._binding has a remote UDPEndpoint$",
):
s1.getpeername()
async def test_init() -> None:
fn()
with pytest.raises(
NotImplementedError,
match=re.escape(
f"FakeNet doesn't (yet) support type={trio.socket.SOCK_STREAM}",
),
):
s1 = trio.socket.socket()
# getsockname on unbound ipv4 socket
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
assert s1.getsockname() == ("0.0.0.0", 0)
# getsockname on bound ipv4 socket
await s1.bind(("0.0.0.0", 0))
ip, port = s1.getsockname()
assert ip == "127.0.0.1"
assert port != 0
# getsockname on unbound ipv6 socket
s2 = trio.socket.socket(family=socket.AF_INET6, type=socket.SOCK_DGRAM)
assert s2.getsockname() == ("::", 0)
# getsockname on bound ipv6 socket
await s2.bind(("::", 0))
ip, port, *_ = s2.getsockname()
assert ip == "::1"
assert port != 0
assert _ == [0, 0]
@@ -0,0 +1,269 @@
from __future__ import annotations
import importlib
import io
import os
import re
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import sentinel
import pytest
import trio
from trio import _core, _file_io
from trio._file_io import _FILE_ASYNC_METHODS, _FILE_SYNC_ATTRS, AsyncIOWrapper
if TYPE_CHECKING:
import pathlib
@pytest.fixture
def path(tmp_path: pathlib.Path) -> str:
return os.fspath(tmp_path / "test")
@pytest.fixture
def wrapped() -> mock.Mock:
return mock.Mock(spec_set=io.StringIO)
@pytest.fixture
def async_file(wrapped: mock.Mock) -> AsyncIOWrapper[mock.Mock]:
return trio.wrap_file(wrapped)
def test_wrap_invalid() -> None:
with pytest.raises(TypeError):
trio.wrap_file("")
def test_wrap_non_iobase() -> None:
class FakeFile:
def close(self) -> None: # pragma: no cover
pass
def write(self) -> None: # pragma: no cover
pass
wrapped = FakeFile()
assert not isinstance(wrapped, io.IOBase)
async_file = trio.wrap_file(wrapped)
assert isinstance(async_file, AsyncIOWrapper)
del FakeFile.write
with pytest.raises(TypeError):
trio.wrap_file(FakeFile())
def test_wrapped_property(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
assert async_file.wrapped is wrapped
def test_dir_matches_wrapped(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS)
# all supported attrs in wrapped should be available in async_file
assert all(attr in dir(async_file) for attr in attrs if attr in dir(wrapped))
# all supported attrs not in wrapped should not be available in async_file
assert not any(
attr in dir(async_file) for attr in attrs if attr not in dir(wrapped)
)
def test_unsupported_not_forwarded() -> None:
class FakeFile(io.RawIOBase):
def unsupported_attr(self) -> None: # pragma: no cover
pass
async_file = trio.wrap_file(FakeFile())
assert hasattr(async_file.wrapped, "unsupported_attr")
with pytest.raises(AttributeError):
# B018 "useless expression"
async_file.unsupported_attr # type: ignore[attr-defined] # noqa: B018
def test_type_stubs_match_lists() -> None:
"""Check the manual stubs match the list of wrapped methods."""
# Fetch the module's source code.
assert _file_io.__spec__ is not None
loader = _file_io.__spec__.loader
assert isinstance(loader, importlib.abc.SourceLoader)
source = io.StringIO(loader.get_source("trio._file_io"))
# Find the class, then find the TYPE_CHECKING block.
for line in source:
if "class AsyncIOWrapper" in line:
break
else: # pragma: no cover - should always find this
pytest.fail("No class definition line?")
for line in source:
if "if TYPE_CHECKING" in line:
break
else: # pragma: no cover - should always find this
pytest.fail("No TYPE CHECKING line?")
# Now we should be at the type checking block.
found: list[tuple[str, str]] = []
for line in source: # pragma: no branch - expected to break early
if line.strip() and not line.startswith(" " * 8):
break # Dedented out of the if TYPE_CHECKING block.
match = re.match(r"\s*(async )?def ([a-zA-Z0-9_]+)\(", line)
if match is not None:
kind = "async" if match.group(1) is not None else "sync"
found.append((match.group(2), kind))
# Compare two lists so that we can easily see duplicates, and see what is different overall.
expected = [(fname, "async") for fname in _FILE_ASYNC_METHODS]
expected += [(fname, "sync") for fname in _FILE_SYNC_ATTRS]
# Ignore order, error if duplicates are present.
found.sort()
expected.sort()
assert found == expected
def test_sync_attrs_forwarded(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
for attr_name in _FILE_SYNC_ATTRS:
if attr_name not in dir(async_file):
continue
assert getattr(async_file, attr_name) is getattr(wrapped, attr_name)
def test_sync_attrs_match_wrapper(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
for attr_name in _FILE_SYNC_ATTRS:
if attr_name in dir(async_file):
continue
with pytest.raises(AttributeError):
getattr(async_file, attr_name)
with pytest.raises(AttributeError):
getattr(wrapped, attr_name)
def test_async_methods_generated_once(async_file: AsyncIOWrapper[mock.Mock]) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name not in dir(async_file):
continue
assert getattr(async_file, meth_name) is getattr(async_file, meth_name)
# I gave up on typing this one
def test_async_methods_signature(async_file: AsyncIOWrapper[mock.Mock]) -> None:
# use read as a representative of all async methods
assert async_file.read.__name__ == "read"
assert async_file.read.__qualname__ == "AsyncIOWrapper.read"
assert async_file.read.__doc__ is not None
assert "io.StringIO.read" in async_file.read.__doc__
async def test_async_methods_wrap(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name not in dir(async_file):
continue
meth = getattr(async_file, meth_name)
wrapped_meth = getattr(wrapped, meth_name)
value = await meth(sentinel.argument, keyword=sentinel.keyword)
wrapped_meth.assert_called_once_with(
sentinel.argument,
keyword=sentinel.keyword,
)
assert value == wrapped_meth()
wrapped.reset_mock()
async def test_async_methods_match_wrapper(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name in dir(async_file):
continue
with pytest.raises(AttributeError):
getattr(async_file, meth_name)
with pytest.raises(AttributeError):
getattr(wrapped, meth_name)
async def test_open(path: pathlib.Path) -> None:
f = await trio.open_file(path, "w")
assert isinstance(f, AsyncIOWrapper)
await f.aclose()
async def test_open_context_manager(path: pathlib.Path) -> None:
async with await trio.open_file(path, "w") as f:
assert isinstance(f, AsyncIOWrapper)
assert not f.closed
assert f.closed
async def test_async_iter() -> None:
async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar"))
expected = list(async_file.wrapped)
async_file.wrapped.seek(0)
result = [line async for line in async_file]
assert result == expected
async def test_aclose_cancelled(path: pathlib.Path) -> None:
with _core.CancelScope() as cscope:
f = await trio.open_file(path, "w")
cscope.cancel()
with pytest.raises(_core.Cancelled):
await f.write("a")
with pytest.raises(_core.Cancelled):
await f.aclose()
assert f.closed
async def test_detach_rewraps_asynciobase(tmp_path: pathlib.Path) -> None:
tmp_file = tmp_path / "filename"
tmp_file.touch()
# flake8-async does not like opening files in async mode
with open(tmp_file, mode="rb", buffering=0) as raw: # noqa: ASYNC230
buffered = io.BufferedReader(raw)
async_file = trio.wrap_file(buffered)
detached = await async_file.detach()
assert isinstance(detached, AsyncIOWrapper)
assert detached.wrapped is raw
@@ -0,0 +1,98 @@
from __future__ import annotations
from typing import NoReturn
import attrs
import pytest
from .._highlevel_generic import StapledStream
from ..abc import ReceiveStream, SendStream
@attrs.define(slots=False)
class RecordSendStream(SendStream):
record: list[str | tuple[str, object]] = attrs.Factory(list)
async def send_all(self, data: object) -> None:
self.record.append(("send_all", data))
async def wait_send_all_might_not_block(self) -> None:
self.record.append("wait_send_all_might_not_block")
async def aclose(self) -> None:
self.record.append("aclose")
@attrs.define(slots=False)
class RecordReceiveStream(ReceiveStream):
record: list[str | tuple[str, int | None]] = attrs.Factory(list)
async def receive_some(self, max_bytes: int | None = None) -> bytes:
self.record.append(("receive_some", max_bytes))
return b""
async def aclose(self) -> None:
self.record.append("aclose")
async def test_StapledStream() -> None:
send_stream = RecordSendStream()
receive_stream = RecordReceiveStream()
stapled = StapledStream(send_stream, receive_stream)
assert stapled.send_stream is send_stream
assert stapled.receive_stream is receive_stream
await stapled.send_all(b"foo")
await stapled.wait_send_all_might_not_block()
assert send_stream.record == [
("send_all", b"foo"),
"wait_send_all_might_not_block",
]
send_stream.record.clear()
await stapled.send_eof()
assert send_stream.record == ["aclose"]
send_stream.record.clear()
async def fake_send_eof() -> None:
send_stream.record.append("send_eof")
send_stream.send_eof = fake_send_eof # type: ignore[attr-defined]
await stapled.send_eof()
assert send_stream.record == ["send_eof"]
send_stream.record.clear()
assert receive_stream.record == []
await stapled.receive_some(1234)
assert receive_stream.record == [("receive_some", 1234)]
assert send_stream.record == []
receive_stream.record.clear()
await stapled.aclose()
assert receive_stream.record == ["aclose"]
assert send_stream.record == ["aclose"]
async def test_StapledStream_with_erroring_close() -> None:
# Make sure that if one of the aclose methods errors out, then the other
# one still gets called.
class BrokenSendStream(RecordSendStream):
async def aclose(self) -> NoReturn:
await super().aclose()
raise ValueError("send error")
class BrokenReceiveStream(RecordReceiveStream):
async def aclose(self) -> NoReturn:
await super().aclose()
raise ValueError("recv error")
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
with pytest.raises(ValueError, match="^(send|recv) error$") as excinfo:
await stapled.aclose()
assert isinstance(excinfo.value.__context__, ValueError)
assert stapled.send_stream.record == ["aclose"]
assert stapled.receive_stream.record == ["aclose"]
@@ -0,0 +1,410 @@
from __future__ import annotations
import errno
import socket as stdlib_socket
import sys
from socket import AddressFamily, SocketKind
from typing import TYPE_CHECKING, Any, Sequence, overload
import attrs
import pytest
import trio
from trio import (
SocketListener,
open_tcp_listeners,
open_tcp_stream,
serve_tcp,
)
from trio.abc import HostnameResolver, SendStream, SocketFactory
from trio.testing import open_stream_to_socket_listener
from .. import socket as tsocket
from .._core._tests.tutil import binds_ipv6
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
if TYPE_CHECKING:
from typing_extensions import Buffer
async def test_open_tcp_listeners_basic() -> None:
listeners = await open_tcp_listeners(0)
assert isinstance(listeners, list)
for obj in listeners:
assert isinstance(obj, SocketListener)
# Binds to wildcard address by default
assert obj.socket.family in [tsocket.AF_INET, tsocket.AF_INET6]
assert obj.socket.getsockname()[0] in ["0.0.0.0", "::"]
listener = listeners[0]
# Make sure the backlog is at least 2
c1 = await open_stream_to_socket_listener(listener)
c2 = await open_stream_to_socket_listener(listener)
s1 = await listener.accept()
s2 = await listener.accept()
# Note that we don't know which client stream is connected to which server
# stream
await s1.send_all(b"x")
await s2.send_all(b"x")
assert await c1.receive_some(1) == b"x"
assert await c2.receive_some(1) == b"x"
for resource in [c1, c2, s1, s2, *listeners]:
await resource.aclose()
async def test_open_tcp_listeners_specific_port_specific_host() -> None:
# Pick a port
sock = tsocket.socket()
await sock.bind(("127.0.0.1", 0))
host, port = sock.getsockname()
sock.close()
(listener,) = await open_tcp_listeners(port, host=host)
async with listener:
assert listener.socket.getsockname() == (host, port)
@binds_ipv6
async def test_open_tcp_listeners_ipv6_v6only() -> None:
# Check IPV6_V6ONLY is working properly
(ipv6_listener,) = await open_tcp_listeners(0, host="::1")
async with ipv6_listener:
_, port, *_ = ipv6_listener.socket.getsockname()
with pytest.raises(
OSError,
match=r"(Error|all attempts to) connect(ing)* to (\(')*127\.0\.0\.1(', |:)\d+(\): Connection refused| failed)$",
):
await open_tcp_stream("127.0.0.1", port)
async def test_open_tcp_listeners_rebind() -> None:
(l1,) = await open_tcp_listeners(0, host="127.0.0.1")
sockaddr1 = l1.socket.getsockname()
# Plain old rebinding while it's still there should fail, even if we have
# SO_REUSEADDR set
with stdlib_socket.socket() as probe:
probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
with pytest.raises(
OSError,
match="(Address (already )?in use|An attempt was made to access a socket in a way forbidden by its access permissions)$",
):
probe.bind(sockaddr1)
# Now use the first listener to set up some connections in various states,
# and make sure that they don't create any obstacle to rebinding a second
# listener after the first one is closed.
c_established = await open_stream_to_socket_listener(l1)
s_established = await l1.accept()
c_time_wait = await open_stream_to_socket_listener(l1)
s_time_wait = await l1.accept()
# Server-initiated close leaves socket in TIME_WAIT
await s_time_wait.aclose()
await l1.aclose()
(l2,) = await open_tcp_listeners(sockaddr1[1], host="127.0.0.1")
sockaddr2 = l2.socket.getsockname()
assert sockaddr1 == sockaddr2
assert s_established.socket.getsockname() == sockaddr2
assert c_time_wait.socket.getpeername() == sockaddr2
for resource in [
l1,
l2,
c_established,
s_established,
c_time_wait,
s_time_wait,
]:
await resource.aclose()
class FakeOSError(OSError):
pass
@attrs.define(slots=False)
class FakeSocket(tsocket.SocketType):
_family: AddressFamily = attrs.field(converter=AddressFamily)
_type: SocketKind = attrs.field(converter=SocketKind)
_proto: int
closed: bool = False
poison_listen: bool = False
backlog: int | None = None
@property
def type(self) -> SocketKind:
return self._type
@property
def family(self) -> AddressFamily:
return self._family
@property
def proto(self) -> int: # pragma: no cover
return self._proto
@overload
def getsockopt(self, /, level: int, optname: int) -> int: ...
@overload
def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ...
def getsockopt(
self,
/,
level: int,
optname: int,
buflen: int | None = None,
) -> int | bytes:
if (level, optname) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN):
return True
raise AssertionError() # pragma: no cover
@overload
def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ...
@overload
def setsockopt(
self,
/,
level: int,
optname: int,
value: None,
optlen: int,
) -> None: ...
def setsockopt(
self,
/,
level: int,
optname: int,
value: int | Buffer | None,
optlen: int | None = None,
) -> None:
pass
async def bind(self, address: Any) -> None:
pass
def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None:
assert self.backlog is None
assert backlog is not None
self.backlog = backlog
if self.poison_listen:
raise FakeOSError("whoops")
def close(self) -> None:
self.closed = True
@attrs.define(slots=False)
class FakeSocketFactory(SocketFactory):
poison_after: int
sockets: list[tsocket.SocketType] = attrs.Factory(list)
raise_on_family: dict[AddressFamily, int] = attrs.Factory(dict) # family => errno
def socket(
self,
family: AddressFamily | int | None = None,
type_: SocketKind | int | None = None,
proto: int = 0,
) -> tsocket.SocketType:
assert family is not None
assert type_ is not None
if isinstance(family, int) and not isinstance(family, AddressFamily):
family = AddressFamily(family) # pragma: no cover
if family in self.raise_on_family:
raise OSError(self.raise_on_family[family], "nope")
sock = FakeSocket(family, type_, proto)
self.poison_after -= 1
if self.poison_after == 0:
sock.poison_listen = True
self.sockets.append(sock)
return sock
@attrs.define(slots=False)
class FakeHostnameResolver(HostnameResolver):
family_addr_pairs: Sequence[tuple[AddressFamily, str]]
async def getaddrinfo(
self,
host: bytes | None,
port: bytes | str | int | None,
family: int = 0,
type: int = 0,
proto: int = 0,
flags: int = 0,
) -> list[
tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
assert isinstance(port, int)
return [
(family, tsocket.SOCK_STREAM, 0, "", (addr, port))
for family, addr in self.family_addr_pairs
]
async def getnameinfo(
self,
sockaddr: tuple[str, int] | tuple[str, int, int, int],
flags: int,
) -> tuple[str, str]:
raise NotImplementedError()
async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None:
# If we were trying to bind to multiple hosts and one of them failed, they
# call get cleaned up before returning
fsf = FakeSocketFactory(3)
tsocket.set_custom_socket_factory(fsf)
tsocket.set_custom_hostname_resolver(
FakeHostnameResolver(
[
(tsocket.AF_INET, "1.1.1.1"),
(tsocket.AF_INET, "2.2.2.2"),
(tsocket.AF_INET, "3.3.3.3"),
],
),
)
with pytest.raises(FakeOSError):
await open_tcp_listeners(80, host="example.org")
assert len(fsf.sockets) == 3
for sock in fsf.sockets:
# property only exists on FakeSocket
assert sock.closed # type: ignore[attr-defined]
async def test_open_tcp_listeners_port_checking() -> None:
for host in ["127.0.0.1", None]:
with pytest.raises(TypeError):
await open_tcp_listeners(None, host=host) # type: ignore[arg-type]
with pytest.raises(TypeError):
await open_tcp_listeners(b"80", host=host) # type: ignore[arg-type]
with pytest.raises(TypeError):
await open_tcp_listeners("http", host=host) # type: ignore[arg-type]
async def test_serve_tcp() -> None:
async def handler(stream: SendStream) -> None:
await stream.send_all(b"x")
async with trio.open_nursery() as nursery:
# nursery.start is incorrectly typed, awaiting #2773
listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0)
stream = await open_stream_to_socket_listener(listeners[0])
async with stream:
assert await stream.receive_some(1) == b"x"
nursery.cancel_scope.cancel()
@pytest.mark.parametrize(
"try_families",
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
)
@pytest.mark.parametrize(
"fail_families",
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
)
async def test_open_tcp_listeners_some_address_families_unavailable(
try_families: set[AddressFamily],
fail_families: set[AddressFamily],
) -> None:
fsf = FakeSocketFactory(
10,
raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families},
)
tsocket.set_custom_socket_factory(fsf)
tsocket.set_custom_hostname_resolver(
FakeHostnameResolver([(family, "foo") for family in try_families]),
)
should_succeed = try_families - fail_families
if not should_succeed:
with pytest.raises(OSError, match="This system doesn't support") as exc_info:
await open_tcp_listeners(80, host="example.org")
# open_listeners always creates an exceptiongroup with the
# unsupported address families, regardless of the value of
# strict_exception_groups or number of unsupported families.
assert isinstance(exc_info.value.__cause__, BaseExceptionGroup)
for subexc in exc_info.value.__cause__.exceptions:
assert "nope" in str(subexc)
else:
listeners = await open_tcp_listeners(80)
for listener in listeners:
should_succeed.remove(listener.socket.family)
assert not should_succeed
async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None:
fsf = FakeSocketFactory(
10,
raise_on_family={
tsocket.AF_INET: errno.EAFNOSUPPORT,
tsocket.AF_INET6: errno.EINVAL,
},
)
tsocket.set_custom_socket_factory(fsf)
tsocket.set_custom_hostname_resolver(
FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")]),
)
with pytest.raises(OSError, match="nope") as exc_info:
await open_tcp_listeners(80, host="example.org")
assert exc_info.value.errno == errno.EINVAL
assert exc_info.value.__cause__ is None
assert "nope" in str(exc_info.value)
# We used to have an elaborate test that opened a real TCP listening socket
# and then tried to measure its backlog by making connections to it. And most
# of the time, it worked. But no matter what we tried, it was always fragile,
# because it had to do things like use timeouts to guess when the listening
# queue was full, sometimes the CI hosts go into SYN-cookie mode (where there
# effectively is no backlog), sometimes the host might not be enough resources
# to give us the full requested backlog... it was a mess. So now we just check
# that the backlog argument is passed through correctly.
async def test_open_tcp_listeners_backlog() -> None:
fsf = FakeSocketFactory(99)
tsocket.set_custom_socket_factory(fsf)
for given, expected in [
(None, 0xFFFF),
(99999999, 0xFFFF),
(10, 10),
(1, 1),
]:
listeners = await open_tcp_listeners(0, backlog=given)
assert listeners
for listener in listeners:
# `backlog` only exists on FakeSocket
assert listener.socket.backlog == expected # type: ignore[attr-defined]
async def test_open_tcp_listeners_backlog_float_error() -> None:
fsf = FakeSocketFactory(99)
tsocket.set_custom_socket_factory(fsf)
for should_fail in (0.0, 2.18, 3.14, 9.75):
with pytest.raises(
TypeError,
match=f"backlog must be an int or None, not {should_fail!r}",
):
await open_tcp_listeners(0, backlog=should_fail) # type: ignore[arg-type]
@@ -0,0 +1,686 @@
from __future__ import annotations
import socket
import sys
from socket import AddressFamily, SocketKind
from typing import TYPE_CHECKING, Any, Sequence
import attrs
import pytest
import trio
from trio._highlevel_open_tcp_stream import (
close_all,
format_host_port,
open_tcp_stream,
reorder_for_rfc_6555_section_5_4,
)
from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType
from trio.testing import Matcher, RaisesGroup
if TYPE_CHECKING:
from trio.testing import MockClock
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
def test_close_all() -> None:
class CloseMe(SocketType):
closed = False
def close(self) -> None:
self.closed = True
class CloseKiller(SocketType):
def close(self) -> None:
raise OSError("os error text")
c: CloseMe = CloseMe()
with close_all() as to_close:
to_close.add(c)
assert c.closed
c = CloseMe()
with pytest.raises(RuntimeError):
with close_all() as to_close:
to_close.add(c)
raise RuntimeError
assert c.closed
c = CloseMe()
with pytest.raises(OSError, match="os error text"):
with close_all() as to_close:
to_close.add(CloseKiller())
to_close.add(c)
assert c.closed
def test_reorder_for_rfc_6555_section_5_4() -> None:
def fake4(
i: int,
) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
return (
AF_INET,
SOCK_STREAM,
IPPROTO_TCP,
"",
(f"10.0.0.{i}", 80),
)
def fake6(
i: int,
) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", (f"::{i}", 80))
for fake in fake4, fake6:
# No effect on homogeneous lists
targets = [fake(0), fake(1), fake(2)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake(0), fake(1), fake(2)]
# Single item lists also OK
targets = [fake(0)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake(0)]
# If the list starts out with different families in positions 0 and 1,
# then it's left alone
orig = [fake4(0), fake6(0), fake4(1), fake6(1)]
targets = list(orig)
reorder_for_rfc_6555_section_5_4(targets)
assert targets == orig
# If not, it's reordered
targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)]
def test_format_host_port() -> None:
assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80"
assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80"
assert format_host_port("example.com", 443) == "example.com:443"
assert format_host_port(b"example.com", 443) == "example.com:443"
assert format_host_port("::1", "http") == "[::1]:http"
assert format_host_port(b"::1", "http") == "[::1]:http"
# Make sure we can connect to localhost using real kernel sockets
async def test_open_tcp_stream_real_socket_smoketest() -> None:
listen_sock = trio.socket.socket()
await listen_sock.bind(("127.0.0.1", 0))
_, listen_port = listen_sock.getsockname()
listen_sock.listen(1)
client_stream = await open_tcp_stream("127.0.0.1", listen_port)
server_sock, _ = await listen_sock.accept()
await client_stream.send_all(b"x")
assert await server_sock.recv(1) == b"x"
await client_stream.aclose()
server_sock.close()
listen_sock.close()
async def test_open_tcp_stream_input_validation() -> None:
with pytest.raises(ValueError, match="^host must be str or bytes, not None$"):
await open_tcp_stream(None, 80) # type: ignore[arg-type]
with pytest.raises(TypeError):
await open_tcp_stream("127.0.0.1", b"80") # type: ignore[arg-type]
def can_bind_127_0_0_2() -> bool:
with socket.socket() as s:
try:
s.bind(("127.0.0.2", 0))
except OSError:
return False
# s.getsockname() is typed as returning Any
return s.getsockname()[0] == "127.0.0.2" # type: ignore[no-any-return]
async def test_local_address_real() -> None:
with trio.socket.socket() as listener:
await listener.bind(("127.0.0.1", 0))
listener.listen()
# It's hard to test local_address properly, because you need multiple
# local addresses that you can bind to. Fortunately, on most Linux
# systems, you can bind to any 127.*.*.* address, and they all go
# through the loopback interface. So we can use a non-standard
# loopback address. On other systems, the only address we know for
# certain we have is 127.0.0.1, so we can't really test local_address=
# properly -- passing local_address=127.0.0.1 is indistinguishable
# from not passing local_address= at all. But, we can still do a smoke
# test to make sure the local_address= code doesn't crash.
local_address = "127.0.0.2" if can_bind_127_0_0_2() else "127.0.0.1"
async with await open_tcp_stream(
*listener.getsockname(),
local_address=local_address,
) as client_stream:
assert client_stream.socket.getsockname()[0] == local_address
if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"):
assert client_stream.socket.getsockopt(
trio.socket.IPPROTO_IP,
trio.socket.IP_BIND_ADDRESS_NO_PORT,
)
server_sock, remote_addr = await listener.accept()
await client_stream.aclose()
server_sock.close()
# accept returns tuple[SocketType, object], due to typeshed returning `Any`
assert remote_addr[0] == local_address
# Trying to connect to an ipv4 address with the ipv6 wildcard
# local_address should fail
with pytest.raises(
OSError,
match=r"^all attempts to connect* to *127\.0\.0\.\d:\d+ failed$",
):
await open_tcp_stream(*listener.getsockname(), local_address="::")
# But the ipv4 wildcard address should work
async with await open_tcp_stream(
*listener.getsockname(),
local_address="0.0.0.0",
) as client_stream:
server_sock, remote_addr = await listener.accept()
server_sock.close()
assert remote_addr == client_stream.socket.getsockname()
# Now, thorough tests using fake sockets
@attrs.define(eq=False, slots=False)
class FakeSocket(trio.socket.SocketType):
scenario: Scenario
_family: AddressFamily
_type: SocketKind
_proto: int
ip: str | int | None = None
port: str | int | None = None
succeeded: bool = False
closed: bool = False
failing: bool = False
@property
def type(self) -> SocketKind:
return self._type
@property
def family(self) -> AddressFamily: # pragma: no cover
return self._family
@property
def proto(self) -> int: # pragma: no cover
return self._proto
async def connect(self, sockaddr: tuple[str | int, str | int | None]) -> None:
self.ip = sockaddr[0]
self.port = sockaddr[1]
assert self.ip not in self.scenario.sockets
self.scenario.sockets[self.ip] = self
self.scenario.connect_times[self.ip] = trio.current_time()
delay, result = self.scenario.ip_dict[self.ip]
await trio.sleep(delay)
if result == "error":
raise OSError("sorry")
if result == "postconnect_fail":
self.failing = True
self.succeeded = True
def close(self) -> None:
self.closed = True
# called when SocketStream is constructed
def setsockopt(self, *args: object, **kwargs: object) -> None:
if self.failing:
# raise something that isn't OSError as SocketStream
# ignores those
raise KeyboardInterrupt
class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver):
def __init__(
self,
port: int,
ip_list: Sequence[tuple[str, float, str]],
supported_families: set[AddressFamily],
) -> None:
# ip_list have to be unique
ip_order = [ip for (ip, _, _) in ip_list]
assert len(set(ip_order)) == len(ip_list)
ip_dict: dict[str | int, tuple[float, str]] = {}
for ip, delay, result in ip_list:
assert delay >= 0
assert result in ["error", "success", "postconnect_fail"]
ip_dict[ip] = (delay, result)
self.port = port
self.ip_order = ip_order
self.ip_dict = ip_dict
self.supported_families = supported_families
self.socket_count = 0
self.sockets: dict[str | int, FakeSocket] = {}
self.connect_times: dict[str | int, float] = {}
def socket(
self,
family: AddressFamily | int | None = None,
type_: SocketKind | int | None = None,
proto: int | None = None,
) -> SocketType:
assert isinstance(family, AddressFamily)
assert isinstance(type_, SocketKind)
assert proto is not None
if family not in self.supported_families:
raise OSError("pretending not to support this family")
self.socket_count += 1
return FakeSocket(self, family, type_, proto)
def _ip_to_gai_entry(self, ip: str) -> tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int, int, int] | tuple[str, int],
]:
sockaddr: tuple[str, int] | tuple[str, int, int, int]
if ":" in ip:
family = trio.socket.AF_INET6
sockaddr = (ip, self.port, 0, 0)
else:
family = trio.socket.AF_INET
sockaddr = (ip, self.port)
return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
async def getaddrinfo(
self,
host: bytes | None,
port: bytes | str | int | None,
family: int = -1,
type: int = -1,
proto: int = -1,
flags: int = -1,
) -> list[
tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int, int, int] | tuple[str, int],
]
]:
assert host == b"test.example.com"
assert port == self.port
assert family == trio.socket.AF_UNSPEC
assert type == trio.socket.SOCK_STREAM
assert proto == 0
assert flags == 0
return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
async def getnameinfo(
self,
sockaddr: tuple[str, int] | tuple[str, int, int, int],
flags: int,
) -> tuple[str, str]:
raise NotImplementedError
def check(self, succeeded: SocketType | None) -> None:
# sockets only go into self.sockets when connect is called; make sure
# all the sockets that were created did in fact go in there.
assert self.socket_count == len(self.sockets)
for ip, socket_ in self.sockets.items():
assert ip in self.ip_dict
if socket_ is not succeeded:
assert socket_.closed
assert socket_.port == self.port
async def run_scenario(
# The port to connect to
port: int,
# A list of
# (ip, delay, result)
# tuples, where delay is in seconds and result is "success" or "error"
# The ip's will be returned from getaddrinfo in this order, and then
# connect() calls to them will have the given result.
ip_list: Sequence[tuple[str, float, str]],
*,
# If False, AF_INET4/6 sockets error out on creation, before connect is
# even called.
ipv4_supported: bool = True,
ipv6_supported: bool = True,
# Normally, we return (winning_sock, scenario object)
# If this is True, we require there to be an exception, and return
# (exception, scenario object)
expect_error: tuple[type[BaseException], ...] | type[BaseException] = (),
**kwargs: Any,
) -> tuple[SocketType, Scenario] | tuple[BaseException, Scenario]:
supported_families = set()
if ipv4_supported:
supported_families.add(trio.socket.AF_INET)
if ipv6_supported:
supported_families.add(trio.socket.AF_INET6)
scenario = Scenario(port, ip_list, supported_families)
trio.socket.set_custom_hostname_resolver(scenario)
trio.socket.set_custom_socket_factory(scenario)
try:
stream = await open_tcp_stream("test.example.com", port, **kwargs)
assert expect_error == ()
scenario.check(stream.socket)
return (stream.socket, scenario)
except AssertionError: # pragma: no cover
raise
except expect_error as exc:
scenario.check(None)
return (exc, scenario)
async def test_one_host_quick_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
assert isinstance(sock, FakeSocket)
assert sock.ip == "1.2.3.4"
assert trio.current_time() == 0.123
async def test_one_host_slow_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")])
assert isinstance(sock, FakeSocket)
assert sock.ip == "1.2.3.4"
assert trio.current_time() == 100
async def test_one_host_quick_fail(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
82,
[("1.2.3.4", 0.123, "error")],
expect_error=OSError,
)
assert isinstance(exc, OSError)
assert trio.current_time() == 0.123
async def test_one_host_slow_fail(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
83,
[("1.2.3.4", 100, "error")],
expect_error=OSError,
)
assert isinstance(exc, OSError)
assert trio.current_time() == 100
async def test_one_host_failed_after_connect(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
83,
[("1.2.3.4", 1, "postconnect_fail")],
expect_error=KeyboardInterrupt,
)
assert isinstance(exc, KeyboardInterrupt)
# With the default 0.250 second delay, the third attempt will win
async def test_basic_fallthrough(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 1, "success"),
("3.3.3.3", 0.2, "success"),
],
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "3.3.3.3"
# current time is default time + default time + connection time
assert trio.current_time() == (0.250 + 0.250 + 0.2)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.250,
"3.3.3.3": 0.500,
}
async def test_early_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 0.1, "success"),
("3.3.3.3", 0.2, "success"),
],
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "2.2.2.2"
assert trio.current_time() == (0.250 + 0.1)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.250,
# 3.3.3.3 was never even started
}
# With a 0.450 second delay, the first attempt will win
async def test_custom_delay(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 1, "success"),
("3.3.3.3", 0.2, "success"),
],
happy_eyeballs_delay=0.450,
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "1.1.1.1"
assert trio.current_time() == 1
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.450,
"3.3.3.3": 0.900,
}
async def test_none_default(autojump_clock: MockClock) -> None:
"""Copy of test_basic_fallthrough, but specifying the delay =None"""
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 1, "success"),
("3.3.3.3", 0.2, "success"),
],
happy_eyeballs_delay=None,
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "3.3.3.3"
# current time is default time + default time + connection time
assert trio.current_time() == (0.250 + 0.250 + 0.2)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.250,
"3.3.3.3": 0.500,
}
async def test_custom_errors_expedite(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.1, "error"),
("2.2.2.2", 0.2, "error"),
("3.3.3.3", 10, "success"),
# .25 is the default timeout
("4.4.4.4", 0.25, "success"),
],
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "4.4.4.4"
assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.1,
"3.3.3.3": 0.1 + 0.2,
"4.4.4.4": 0.1 + 0.2 + 0.25,
}
async def test_all_fail(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.1, "error"),
("2.2.2.2", 0.2, "error"),
("3.3.3.3", 10, "error"),
("4.4.4.4", 0.250, "error"),
],
expect_error=OSError,
)
assert isinstance(exc, OSError)
subexceptions = (Matcher(OSError, match="^sorry$"),) * 4
assert RaisesGroup(
*subexceptions,
match="all attempts to connect to test.example.com:80 failed",
).matches(exc.__cause__)
assert trio.current_time() == (0.1 + 0.2 + 10)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.1,
"3.3.3.3": 0.1 + 0.2,
"4.4.4.4": 0.1 + 0.2 + 0.25,
}
async def test_multi_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.5, "error"),
("2.2.2.2", 10, "success"),
("3.3.3.3", 10 - 1, "success"),
("4.4.4.4", 10 - 2, "success"),
("5.5.5.5", 0.5, "error"),
],
happy_eyeballs_delay=1,
)
assert not scenario.sockets["1.1.1.1"].succeeded
assert (
scenario.sockets["2.2.2.2"].succeeded
or scenario.sockets["3.3.3.3"].succeeded
or scenario.sockets["4.4.4.4"].succeeded
)
assert not scenario.sockets["5.5.5.5"].succeeded
assert isinstance(sock, FakeSocket)
assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"]
assert trio.current_time() == (0.5 + 10)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.5,
"3.3.3.3": 1.5,
"4.4.4.4": 2.5,
"5.5.5.5": 3.5,
}
async def test_does_reorder(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 10, "error"),
# This would win if we tried it first...
("2.2.2.2", 1, "success"),
# But in fact we try this first, because of section 5.4
("::3", 0.5, "success"),
],
happy_eyeballs_delay=1,
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "::3"
assert trio.current_time() == 1 + 0.5
assert scenario.connect_times == {
"1.1.1.1": 0,
"::3": 1,
}
async def test_handles_no_ipv4(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
# Here the ipv6 addresses fail at socket creation time, so the connect
# configuration doesn't matter
[
("::1", 10, "success"),
("2.2.2.2", 0, "success"),
("::3", 0.1, "success"),
("4.4.4.4", 0, "success"),
],
happy_eyeballs_delay=1,
ipv4_supported=False,
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "::3"
assert trio.current_time() == 1 + 0.1
assert scenario.connect_times == {
"::1": 0,
"::3": 1.0,
}
async def test_handles_no_ipv6(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
# Here the ipv6 addresses fail at socket creation time, so the connect
# configuration doesn't matter
[
("::1", 0, "success"),
("2.2.2.2", 10, "success"),
("::3", 0, "success"),
("4.4.4.4", 0.1, "success"),
],
happy_eyeballs_delay=1,
ipv6_supported=False,
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "4.4.4.4"
assert trio.current_time() == 1 + 0.1
assert scenario.connect_times == {
"2.2.2.2": 0,
"4.4.4.4": 1.0,
}
async def test_no_hosts(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(80, [], expect_error=OSError)
assert "no results found" in str(exc)
async def test_cancel(autojump_clock: MockClock) -> None:
with trio.move_on_after(5) as cancel_scope:
exc, scenario = await run_scenario(
80,
[
("1.1.1.1", 10, "success"),
("2.2.2.2", 10, "success"),
("3.3.3.3", 10, "success"),
("4.4.4.4", 10, "success"),
],
expect_error=BaseExceptionGroup,
)
assert isinstance(exc, BaseException)
# What comes out should be 1 or more Cancelled errors that all belong
# to this cancel_scope; this is the easiest way to check that
raise exc
assert cancel_scope.cancelled_caught
assert trio.current_time() == 5
# This should have been called already, but just to make sure, since the
# exception-handling logic in run_scenario is a bit complicated and the
# main thing we care about here is that all the sockets were cleaned up.
scenario.check(succeeded=None)
@@ -0,0 +1,86 @@
import os
import socket
import sys
import tempfile
from typing import TYPE_CHECKING
import pytest
from trio import Path, open_unix_socket
from trio._highlevel_open_unix_stream import close_on_error
assert not TYPE_CHECKING or sys.platform != "win32"
skip_if_not_unix = pytest.mark.skipif(
not hasattr(socket, "AF_UNIX"),
reason="Needs unix socket support",
)
@skip_if_not_unix
def test_close_on_error() -> None:
class CloseMe:
closed = False
def close(self) -> None:
self.closed = True
with close_on_error(CloseMe()) as c:
pass
assert not c.closed
with pytest.raises(RuntimeError):
with close_on_error(CloseMe()) as c:
raise RuntimeError
assert c.closed
@skip_if_not_unix
@pytest.mark.parametrize("filename", [4, 4.5])
async def test_open_with_bad_filename_type(filename: float) -> None:
with pytest.raises(TypeError):
await open_unix_socket(filename) # type: ignore[arg-type]
@skip_if_not_unix
async def test_open_bad_socket() -> None:
# mktemp is marked as insecure, but that's okay, we don't want the file to
# exist
name = tempfile.mktemp()
with pytest.raises(FileNotFoundError):
await open_unix_socket(name)
@skip_if_not_unix
async def test_open_unix_socket() -> None:
for name_type in [Path, str]:
name = tempfile.mktemp()
serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
with serv_sock:
serv_sock.bind(name)
try:
serv_sock.listen(1)
# The actual function we're testing
unix_socket = await open_unix_socket(name_type(name))
async with unix_socket:
client, _ = serv_sock.accept()
with client:
await unix_socket.send_all(b"test")
assert client.recv(2048) == b"test"
client.sendall(b"response")
received = await unix_socket.receive_some(2048)
assert received == b"response"
finally:
os.unlink(name)
@pytest.mark.skipif(hasattr(socket, "AF_UNIX"), reason="Test for non-unix platforms")
async def test_error_on_no_unix() -> None:
with pytest.raises(
RuntimeError,
match="^Unix sockets are not supported on this platform$",
):
await open_unix_socket("")
@@ -0,0 +1,183 @@
from __future__ import annotations
import errno
from functools import partial
from typing import TYPE_CHECKING, Awaitable, Callable, NoReturn
import attrs
import trio
from trio import Nursery, StapledStream, TaskStatus
from trio.testing import (
Matcher,
MemoryReceiveStream,
MemorySendStream,
MockClock,
RaisesGroup,
memory_stream_pair,
wait_all_tasks_blocked,
)
if TYPE_CHECKING:
import pytest
from trio._channel import MemoryReceiveChannel, MemorySendChannel
from trio.abc import Stream
# types are somewhat tentative - I just bruteforced them until I got something that didn't
# give errors
StapledMemoryStream = StapledStream[MemorySendStream, MemoryReceiveStream]
@attrs.define(eq=False, slots=False)
class MemoryListener(trio.abc.Listener[StapledMemoryStream]):
closed: bool = False
accepted_streams: list[trio.abc.Stream] = attrs.Factory(list)
queued_streams: tuple[
MemorySendChannel[StapledMemoryStream],
MemoryReceiveChannel[StapledMemoryStream],
] = attrs.Factory(lambda: trio.open_memory_channel[StapledMemoryStream](1))
accept_hook: Callable[[], Awaitable[object]] | None = None
async def connect(self) -> StapledMemoryStream:
assert not self.closed
client, server = memory_stream_pair()
await self.queued_streams[0].send(server)
return client
async def accept(self) -> StapledMemoryStream:
await trio.lowlevel.checkpoint()
assert not self.closed
if self.accept_hook is not None:
await self.accept_hook()
stream = await self.queued_streams[1].receive()
self.accepted_streams.append(stream)
return stream
async def aclose(self) -> None:
self.closed = True
await trio.lowlevel.checkpoint()
async def test_serve_listeners_basic() -> None:
listeners = [MemoryListener(), MemoryListener()]
record = []
def close_hook() -> None:
# Make sure this is a forceful close
assert trio.current_effective_deadline() == float("-inf")
record.append("closed")
async def handler(stream: StapledMemoryStream) -> None:
await stream.send_all(b"123")
assert await stream.receive_some(10) == b"456"
stream.send_stream.close_hook = close_hook
stream.receive_stream.close_hook = close_hook
async def client(listener: MemoryListener) -> None:
s = await listener.connect()
assert await s.receive_some(10) == b"123"
await s.send_all(b"456")
async def do_tests(parent_nursery: Nursery) -> None:
async with trio.open_nursery() as nursery:
for listener in listeners:
for _ in range(3):
nursery.start_soon(client, listener)
await wait_all_tasks_blocked()
# verifies that all 6 streams x 2 directions each were closed ok
assert len(record) == 12
parent_nursery.cancel_scope.cancel()
async with trio.open_nursery() as nursery:
l2: list[MemoryListener] = await nursery.start(
trio.serve_listeners,
handler,
listeners,
)
assert l2 == listeners
# This is just split into another function because gh-136 isn't
# implemented yet
nursery.start_soon(do_tests, nursery)
for listener in listeners:
assert listener.closed
async def test_serve_listeners_accept_unrecognized_error() -> None:
for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]:
listener = MemoryListener()
async def raise_error() -> NoReturn:
raise error # noqa: B023 # Set from loop
def check_error(e: BaseException) -> bool:
return e is error # noqa: B023
listener.accept_hook = raise_error
with RaisesGroup(Matcher(check=check_error)):
await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
async def test_serve_listeners_accept_capacity_error(
autojump_clock: MockClock,
caplog: pytest.LogCaptureFixture,
) -> None:
listener = MemoryListener()
async def raise_EMFILE() -> NoReturn:
raise OSError(errno.EMFILE, "out of file descriptors")
listener.accept_hook = raise_EMFILE
# It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900
# = 10 times total
with trio.move_on_after(0.950):
await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
assert len(caplog.records) == 10
for record in caplog.records:
assert "retrying" in record.msg
assert record.exc_info is not None
assert isinstance(record.exc_info[1], OSError)
assert record.exc_info[1].errno == errno.EMFILE
async def test_serve_listeners_connection_nursery(autojump_clock: MockClock) -> None:
listener = MemoryListener()
async def handler(stream: Stream) -> None:
await trio.sleep(1)
class Done(Exception):
pass
async def connection_watcher(
*,
task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED,
) -> NoReturn:
async with trio.open_nursery() as nursery:
task_status.started(nursery)
await wait_all_tasks_blocked()
assert len(nursery.child_tasks) == 10
raise Done
# the exception is wrapped twice because we open two nested nurseries
with RaisesGroup(RaisesGroup(Done)):
async with trio.open_nursery() as nursery:
handler_nursery: trio.Nursery = await nursery.start(connection_watcher)
await nursery.start(
partial(
trio.serve_listeners,
handler,
[listener],
handler_nursery=handler_nursery,
),
)
for _ in range(10):
nursery.start_soon(listener.connect)
@@ -0,0 +1,330 @@
from __future__ import annotations
import errno
import socket as stdlib_socket
import sys
from typing import Sequence
import pytest
from .. import _core, socket as tsocket
from .._highlevel_socket import *
from ..testing import (
assert_checkpoints,
check_half_closeable_stream,
wait_all_tasks_blocked,
)
from .test_socket import setsockopt_tests
async def test_SocketStream_basics() -> None:
# stdlib socket bad (even if connected)
stdlib_a, stdlib_b = stdlib_socket.socketpair()
with stdlib_a, stdlib_b:
with pytest.raises(TypeError):
SocketStream(stdlib_a) # type: ignore[arg-type]
# DGRAM socket bad
with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
with pytest.raises(
ValueError,
match="^SocketStream requires a SOCK_STREAM socket$",
):
# TODO: does not raise an error?
SocketStream(sock)
a, b = tsocket.socketpair()
with a, b:
s = SocketStream(a)
assert s.socket is a
# Use a real, connected socket to test socket options, because
# socketpair() might give us a unix socket that doesn't support any of
# these options
with tsocket.socket() as listen_sock:
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(1)
with tsocket.socket() as client_sock:
await client_sock.connect(listen_sock.getsockname())
s = SocketStream(client_sock)
# TCP_NODELAY enabled by default
assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
# We can disable it though
s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
res = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
assert isinstance(res, bytes)
setsockopt_tests(s)
async def test_SocketStream_send_all() -> None:
BIG = 10000000
a_sock, b_sock = tsocket.socketpair()
with a_sock, b_sock:
a = SocketStream(a_sock)
b = SocketStream(b_sock)
# Check a send_all that has to be split into multiple parts (on most
# platforms... on Windows every send() either succeeds or fails as a
# whole)
async def sender() -> None:
data = bytearray(BIG)
await a.send_all(data)
# send_all uses memoryviews internally, which temporarily "lock"
# the object they view. If it doesn't clean them up properly, then
# some bytearray operations might raise an error afterwards, which
# would be a pretty weird and annoying side-effect to spring on
# users. So test that this doesn't happen, by forcing the
# bytearray's underlying buffer to be realloc'ed:
data += bytes(BIG)
# (Note: the above line of code doesn't do a very good job at
# testing anything, because:
# - on CPython, the refcount GC generally cleans up memoryviews
# for us even if we're sloppy.
# - on PyPy3, at least as of 5.7.0, the memoryview code and the
# bytearray code conspire so that resizing never fails if
# resizing forces the bytearray's internal buffer to move, then
# all memoryview references are automagically updated (!!).
# See:
# https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
# But I'm leaving the test here in hopes that if this ever changes
# and we break our implementation of send_all, then we'll get some
# early warning...)
async def receiver() -> None:
# Make sure the sender fills up the kernel buffers and blocks
await wait_all_tasks_blocked()
nbytes = 0
while nbytes < BIG:
nbytes += len(await b.receive_some(BIG))
assert nbytes == BIG
async with _core.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(receiver)
# We know that we received BIG bytes of NULs so far. Make sure that
# was all the data in there.
await a.send_all(b"e")
assert await b.receive_some(10) == b"e"
await a.send_eof()
assert await b.receive_some(10) == b""
async def fill_stream(s: SocketStream) -> None:
async def sender() -> None:
while True:
await s.send_all(b"x" * 10000)
async def waiter(nursery: _core.Nursery) -> None:
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
async with _core.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(waiter, nursery)
async def test_SocketStream_generic() -> None:
async def stream_maker() -> tuple[SocketStream, SocketStream]:
left, right = tsocket.socketpair()
return SocketStream(left), SocketStream(right)
async def clogged_stream_maker() -> tuple[SocketStream, SocketStream]:
left, right = await stream_maker()
await fill_stream(left)
await fill_stream(right)
return left, right
await check_half_closeable_stream(stream_maker, clogged_stream_maker)
async def test_SocketListener() -> None:
# Not a Trio socket
with stdlib_socket.socket() as s:
s.bind(("127.0.0.1", 0))
s.listen(10)
with pytest.raises(TypeError):
SocketListener(s) # type: ignore[arg-type]
# Not a SOCK_STREAM
with tsocket.socket(type=tsocket.SOCK_DGRAM) as s:
await s.bind(("127.0.0.1", 0))
with pytest.raises(
ValueError,
match="^SocketListener requires a SOCK_STREAM socket$",
) as excinfo:
SocketListener(s)
excinfo.match(r".*SOCK_STREAM")
# Didn't call .listen()
# macOS has no way to check for this, so skip testing it there.
if sys.platform != "darwin":
with tsocket.socket() as s:
await s.bind(("127.0.0.1", 0))
with pytest.raises(
ValueError,
match="^SocketListener requires a listening socket$",
) as excinfo:
SocketListener(s)
excinfo.match(r".*listen")
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(10)
listener = SocketListener(listen_sock)
assert listener.socket is listen_sock
client_sock = tsocket.socket()
await client_sock.connect(listen_sock.getsockname())
with assert_checkpoints():
server_stream = await listener.accept()
assert isinstance(server_stream, SocketStream)
assert server_stream.socket.getsockname() == listen_sock.getsockname()
assert server_stream.socket.getpeername() == client_sock.getsockname()
with assert_checkpoints():
await listener.aclose()
with assert_checkpoints():
await listener.aclose()
with assert_checkpoints():
with pytest.raises(_core.ClosedResourceError):
await listener.accept()
client_sock.close()
await server_stream.aclose()
async def test_SocketListener_socket_closed_underfoot() -> None:
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(10)
listener = SocketListener(listen_sock)
# Close the socket, not the listener
listen_sock.close()
# SocketListener gives correct error
with assert_checkpoints():
with pytest.raises(_core.ClosedResourceError):
await listener.accept()
async def test_SocketListener_accept_errors() -> None:
class FakeSocket(tsocket.SocketType):
def __init__(self, events: Sequence[SocketType | BaseException]) -> None:
self._events = iter(events)
type = tsocket.SOCK_STREAM
# Fool the check for SO_ACCEPTCONN in SocketListener.__init__
@overload
def getsockopt(self, /, level: int, optname: int) -> int: ...
@overload
def getsockopt( # noqa: F811
self,
/,
level: int,
optname: int,
buflen: int,
) -> bytes: ...
def getsockopt( # noqa: F811
self,
/,
level: int,
optname: int,
buflen: int | None = None,
) -> int | bytes:
return True
@overload
def setsockopt(
self,
/,
level: int,
optname: int,
value: int | Buffer,
) -> None: ...
@overload
def setsockopt( # noqa: F811
self,
/,
level: int,
optname: int,
value: None,
optlen: int,
) -> None: ...
def setsockopt( # noqa: F811
self,
/,
level: int,
optname: int,
value: int | Buffer | None,
optlen: int | None = None,
) -> None:
pass
async def accept(self) -> tuple[SocketType, object]:
await _core.checkpoint()
event = next(self._events)
if isinstance(event, BaseException):
raise event
else:
return event, None
fake_server_sock = FakeSocket([])
fake_listen_sock = FakeSocket(
[
OSError(errno.ECONNABORTED, "Connection aborted"),
OSError(errno.EPERM, "Permission denied"),
OSError(errno.EPROTO, "Bad protocol"),
fake_server_sock,
OSError(errno.EMFILE, "Out of file descriptors"),
OSError(errno.EFAULT, "attempt to write to read-only memory"),
OSError(errno.ENOBUFS, "out of buffers"),
fake_server_sock,
],
)
listener = SocketListener(fake_listen_sock)
with assert_checkpoints():
stream = await listener.accept()
assert stream.socket is fake_server_sock
for code, match in {
errno.EMFILE: r"\[\w+ \d+\] Out of file descriptors$",
errno.EFAULT: r"\[\w+ \d+\] attempt to write to read-only memory$",
errno.ENOBUFS: r"\[\w+ \d+\] out of buffers$",
}.items():
with assert_checkpoints():
with pytest.raises(OSError, match=match) as excinfo:
await listener.accept()
assert excinfo.value.errno == code
with assert_checkpoints():
stream = await listener.accept()
assert stream.socket is fake_server_sock
async def test_socket_stream_works_when_peer_has_already_closed() -> None:
sock_a, sock_b = tsocket.socketpair()
with sock_a, sock_b:
await sock_b.send(b"x")
sock_b.close()
stream = SocketStream(sock_a)
assert await stream.receive_some(1) == b"x"
assert await stream.receive_some(1) == b""
@@ -0,0 +1,166 @@
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, NoReturn
import attrs
import pytest
import trio
import trio.testing
from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM
from .._highlevel_ssl_helpers import (
open_ssl_over_tcp_listeners,
open_ssl_over_tcp_stream,
serve_ssl_over_tcp,
)
# using noqa because linters don't understand how pytest fixtures work.
from .test_ssl import SERVER_CTX, client_ctx # noqa: F401
if TYPE_CHECKING:
from socket import AddressFamily, SocketKind
from ssl import SSLContext
from trio.abc import Stream
from .._highlevel_socket import SocketListener
from .._ssl import SSLListener
async def echo_handler(stream: Stream) -> None:
async with stream:
try:
while True:
data = await stream.receive_some(10000)
if not data:
break
await stream.send_all(data)
except trio.BrokenResourceError:
pass
# Resolver that always returns the given sockaddr, no matter what host/port
# you ask for.
@attrs.define(slots=False)
class FakeHostnameResolver(trio.abc.HostnameResolver):
sockaddr: tuple[str, int] | tuple[str, int, int, int]
async def getaddrinfo(
self,
host: bytes | None,
port: bytes | str | int | None,
family: int = 0,
type: int = 0,
proto: int = 0,
flags: int = 0,
) -> list[
tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]
async def getnameinfo(self, *args: Any) -> NoReturn: # pragma: no cover
raise NotImplementedError
# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
# using noqa because linters don't understand how pytest fixtures work.
async def test_open_ssl_over_tcp_stream_and_everything_else(
client_ctx: SSLContext, # noqa: F811 # linters doesn't understand fixture
) -> None:
async with trio.open_nursery() as nursery:
# TODO: this function wraps an SSLListener around a SocketListener, this is illegal
# according to current type hints, and probably for good reason. But there should
# maybe be a different wrapper class/function that could be used instead?
res: list[SSLListener[SocketListener]] = ( # type: ignore[type-var]
await nursery.start(
partial(
serve_ssl_over_tcp,
echo_handler,
0,
SERVER_CTX,
host="127.0.0.1",
),
)
)
(listener,) = res
async with listener:
# listener.transport_listener is of type Listener[Stream]
tp_listener: SocketListener = listener.transport_listener # type: ignore[assignment]
sockaddr = tp_listener.socket.getsockname()
hostname_resolver = FakeHostnameResolver(sockaddr)
trio.socket.set_custom_hostname_resolver(hostname_resolver)
# We don't have the right trust set up
# (checks that ssl_context=None is doing some validation)
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
async with stream:
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()
# We have the trust but not the hostname
# (checks custom ssl_context + hostname checking)
stream = await open_ssl_over_tcp_stream(
"xyzzy.example.org",
80,
ssl_context=client_ctx,
)
async with stream:
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()
# This one should work!
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org",
80,
ssl_context=client_ctx,
)
async with stream:
assert isinstance(stream, trio.SSLStream)
assert stream.server_hostname == "trio-test-1.example.org"
await stream.send_all(b"x")
assert await stream.receive_some(1) == b"x"
# Check https_compatible settings are being passed through
assert not stream._https_compatible
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org",
80,
ssl_context=client_ctx,
https_compatible=True,
# also, smoke test happy_eyeballs_delay
happy_eyeballs_delay=1,
)
async with stream:
assert stream._https_compatible
# Stop the echo server
nursery.cancel_scope.cancel()
async def test_open_ssl_over_tcp_listeners() -> None:
(listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
async with listener:
assert isinstance(listener, trio.SSLListener)
tl = listener.transport_listener
assert isinstance(tl, trio.SocketListener)
assert tl.socket.getsockname()[0] == "127.0.0.1"
assert not listener._https_compatible
(listener,) = await open_ssl_over_tcp_listeners(
0,
SERVER_CTX,
host="127.0.0.1",
https_compatible=True,
)
async with listener:
assert listener._https_compatible
@@ -0,0 +1,275 @@
from __future__ import annotations
import os
import pathlib
from typing import TYPE_CHECKING, Type, Union
import pytest
import trio
from trio._file_io import AsyncIOWrapper
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
@pytest.fixture
def path(tmp_path: pathlib.Path) -> trio.Path:
return trio.Path(tmp_path / "test")
def method_pair(
path: str,
method_name: str,
) -> tuple[Callable[[], object], Callable[[], Awaitable[object]]]:
sync_path = pathlib.Path(path)
async_path = trio.Path(path)
return getattr(sync_path, method_name), getattr(async_path, method_name)
@pytest.mark.skipif(os.name == "nt", reason="OS is not posix")
async def test_instantiate_posix() -> None:
assert isinstance(trio.Path(), trio.PosixPath)
@pytest.mark.skipif(os.name != "nt", reason="OS is not Windows")
async def test_instantiate_windows() -> None:
assert isinstance(trio.Path(), trio.WindowsPath)
async def test_open_is_async_context_manager(path: trio.Path) -> None:
async with await path.open("w") as f:
assert isinstance(f, AsyncIOWrapper)
assert f.closed
async def test_magic() -> None:
path = trio.Path("test")
assert str(path) == "test"
assert bytes(path) == b"test"
EitherPathType = Union[Type[trio.Path], Type[pathlib.Path]]
PathOrStrType = Union[EitherPathType, Type[str]]
cls_pairs: list[tuple[EitherPathType, EitherPathType]] = [
(trio.Path, pathlib.Path),
(pathlib.Path, trio.Path),
(trio.Path, trio.Path),
]
@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs)
async def test_cmp_magic(cls_a: EitherPathType, cls_b: EitherPathType) -> None:
a, b = cls_a(""), cls_b("")
assert a == b
assert not a != b # noqa: SIM202 # negate-not-equal-op
a, b = cls_a("a"), cls_b("b")
assert a < b
assert b > a
# this is intentionally testing equivalence with none, due to the
# other=sentinel logic in _forward_magic
assert not a == None # noqa
assert not b == None # noqa
# upstream python3.8 bug: we should also test (pathlib.Path, trio.Path), but
# __*div__ does not properly raise NotImplementedError like the other comparison
# magic, so trio.Path's implementation does not get dispatched
cls_pairs_str: list[tuple[PathOrStrType, PathOrStrType]] = [
(trio.Path, pathlib.Path),
(trio.Path, trio.Path),
(trio.Path, str),
(str, trio.Path),
]
@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs_str)
async def test_div_magic(cls_a: PathOrStrType, cls_b: PathOrStrType) -> None:
a, b = cls_a("a"), cls_b("b")
result = a / b # type: ignore[operator]
# Type checkers think str / str could happen. Check each combo manually in type_tests/.
assert isinstance(result, trio.Path)
assert str(result) == os.path.join("a", "b")
@pytest.mark.parametrize(
("cls_a", "cls_b"),
[(trio.Path, pathlib.Path), (trio.Path, trio.Path)],
)
@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"])
async def test_hash_magic(
cls_a: EitherPathType,
cls_b: EitherPathType,
path: str,
) -> None:
a, b = cls_a(path), cls_b(path)
assert hash(a) == hash(b)
async def test_forwarded_properties(path: trio.Path) -> None:
# use `name` as a representative of forwarded properties
assert "name" in dir(path)
assert path.name == "test"
async def test_async_method_signature(path: trio.Path) -> None:
# use `resolve` as a representative of wrapped methods
assert path.resolve.__name__ == "resolve"
assert path.resolve.__qualname__ == "Path.resolve"
assert path.resolve.__doc__ is not None
assert path.resolve.__qualname__ in path.resolve.__doc__
@pytest.mark.parametrize("method_name", ["is_dir", "is_file"])
async def test_compare_async_stat_methods(method_name: str) -> None:
method, async_method = method_pair(".", method_name)
result = method()
async_result = await async_method()
assert result == async_result
async def test_invalid_name_not_wrapped(path: trio.Path) -> None:
with pytest.raises(AttributeError):
getattr(path, "invalid_fake_attr") # noqa: B009 # "get-attr-with-constant"
@pytest.mark.parametrize("method_name", ["absolute", "resolve"])
async def test_async_methods_rewrap(method_name: str) -> None:
method, async_method = method_pair(".", method_name)
result = method()
async_result = await async_method()
assert isinstance(async_result, trio.Path)
assert str(result) == str(async_result)
async def test_forward_methods_rewrap(path: trio.Path, tmp_path: pathlib.Path) -> None:
with_name = path.with_name("foo")
with_suffix = path.with_suffix(".py")
assert isinstance(with_name, trio.Path)
assert with_name == tmp_path / "foo"
assert isinstance(with_suffix, trio.Path)
assert with_suffix == tmp_path / "test.py"
async def test_forward_properties_rewrap(path: trio.Path) -> None:
assert isinstance(path.parent, trio.Path)
async def test_forward_methods_without_rewrap(path: trio.Path) -> None:
path = await path.parent.resolve()
assert path.as_uri().startswith("file:///")
async def test_repr() -> None:
path = trio.Path(".")
assert repr(path) == "trio.Path('.')"
@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath])
async def test_path_wraps_path(
path: trio.Path,
meth: Callable[[trio.Path, trio.Path], object],
) -> None:
wrapped = await path.absolute()
result = meth(path, wrapped)
if result is None:
result = path
assert wrapped == result
async def test_path_nonpath() -> None:
with pytest.raises(TypeError):
trio.Path(1) # type: ignore
async def test_open_file_can_open_path(path: trio.Path) -> None:
async with await trio.open_file(path, "w") as f:
assert f.name == os.fspath(path)
async def test_globmethods(path: trio.Path) -> None:
# Populate a directory tree
await path.mkdir()
await (path / "foo").mkdir()
await (path / "foo" / "_bar.txt").write_bytes(b"")
await (path / "bar.txt").write_bytes(b"")
await (path / "bar.dat").write_bytes(b"")
# Path.glob
for _pattern, _results in {
"*.txt": {"bar.txt"},
"**/*.txt": {"_bar.txt", "bar.txt"},
}.items():
entries = set()
for entry in await path.glob(_pattern):
assert isinstance(entry, trio.Path)
entries.add(entry.name)
assert entries == _results
# Path.rglob
entries = set()
for entry in await path.rglob("*.txt"):
assert isinstance(entry, trio.Path)
entries.add(entry.name)
assert entries == {"_bar.txt", "bar.txt"}
async def test_iterdir(path: trio.Path) -> None:
# Populate a directory
await path.mkdir()
await (path / "foo").mkdir()
await (path / "bar.txt").write_bytes(b"")
entries = set()
for entry in await path.iterdir():
assert isinstance(entry, trio.Path)
entries.add(entry.name)
assert entries == {"bar.txt", "foo"}
async def test_classmethods() -> None:
assert isinstance(await trio.Path.home(), trio.Path)
# pathlib.Path has only two classmethods
assert str(await trio.Path.home()) == os.path.expanduser("~")
assert str(await trio.Path.cwd()) == os.getcwd()
# Wrapped method has docstring
assert trio.Path.home.__doc__
@pytest.mark.parametrize(
"wrapper",
[
trio._path._wraps_async,
trio._path._wrap_method,
trio._path._wrap_method_path,
trio._path._wrap_method_path_iterable,
],
)
def test_wrapping_without_docstrings(
wrapper: Callable[[Callable[[], None]], Callable[[], None]],
) -> None:
@wrapper
def func_without_docstring() -> None: ... # pragma: no cover
assert func_without_docstring.__doc__ is None
@@ -0,0 +1,242 @@
from __future__ import annotations
import subprocess
import sys
from typing import Protocol
import pytest
import trio._repl
class RawInput(Protocol):
def __call__(self, prompt: str = "") -> str: ...
def build_raw_input(cmds: list[str]) -> RawInput:
"""
Pass in a list of strings.
Returns a callable that returns each string, each time its called
When there are not more strings to return, raise EOFError
"""
cmds_iter = iter(cmds)
prompts = []
def _raw_helper(prompt: str = "") -> str:
prompts.append(prompt)
try:
return next(cmds_iter)
except StopIteration:
raise EOFError from None
return _raw_helper
def test_build_raw_input() -> None:
"""Quick test of our helper function."""
raw_input = build_raw_input(["cmd1"])
assert raw_input() == "cmd1"
with pytest.raises(EOFError):
raw_input()
# In 3.10 or later, types.FunctionType (used internally) will automatically
# attach __builtins__ to the function objects. However we need to explicitly
# include it for 3.8 & 3.9
def build_locals() -> dict[str, object]:
return {"__builtins__": __builtins__}
async def test_basic_interaction(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Run some basic commands through the interpreter while capturing stdout.
Ensure that the interpreted prints the expected results.
"""
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
# evaluate simple expression and recall the value
"x = 1",
"print(f'{x=}')",
# Literal gets printed
"'hello'",
# define and call sync function
"def func():",
" print(x + 1)",
"",
"func()",
# define and call async function
"async def afunc():",
" return 4",
"",
"await afunc()",
# import works
"import sys",
"sys.stdout.write('hello stdout\\n')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert out.splitlines() == ["x=1", "'hello'", "2", "4", "hello stdout", "13"]
async def test_system_exits_quit_interpreter(monkeypatch: pytest.MonkeyPatch) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"raise SystemExit",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
with pytest.raises(SystemExit):
await trio._repl.run_repl(console)
async def test_KI_interrupts(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"from trio._util import signal_raise",
"import signal, trio, trio.lowlevel",
"async def f():",
" trio.lowlevel.spawn_system_task("
" trio.to_thread.run_sync,"
" signal_raise,signal.SIGINT,"
" )", # just awaiting this kills the test runner?!
" await trio.sleep_forever()",
" print('should not see this')",
"",
"await f()",
"print('AFTER KeyboardInterrupt')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert "KeyboardInterrupt" in err
assert "should" not in out
assert "AFTER KeyboardInterrupt" in out
async def test_system_exits_in_exc_group(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"import sys",
"if sys.version_info < (3, 11):",
" from exceptiongroup import BaseExceptionGroup",
"",
"raise BaseExceptionGroup('', [RuntimeError(), SystemExit()])",
"print('AFTER BaseExceptionGroup')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
# assert that raise SystemExit in an exception group
# doesn't quit
assert "AFTER BaseExceptionGroup" in out
async def test_system_exits_in_nested_exc_group(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"import sys",
"if sys.version_info < (3, 11):",
" from exceptiongroup import BaseExceptionGroup",
"",
"raise BaseExceptionGroup(",
" '', [BaseExceptionGroup('', [RuntimeError(), SystemExit()])])",
"print('AFTER BaseExceptionGroup')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
# assert that raise SystemExit in an exception group
# doesn't quit
assert "AFTER BaseExceptionGroup" in out
async def test_base_exception_captured(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
# The statement after raise should still get executed
"raise BaseException",
"print('AFTER BaseException')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert "_threads.py" not in err
assert "_repl.py" not in err
assert "AFTER BaseException" in out
async def test_exc_group_captured(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
# The statement after raise should still get executed
"raise ExceptionGroup('', [KeyError()])",
"print('AFTER ExceptionGroup')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert "AFTER ExceptionGroup" in out
async def test_base_exception_capture_from_coroutine(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"async def async_func_raises_base_exception():",
" raise BaseException",
"",
# This will raise, but the statement after should still
# be executed
"await async_func_raises_base_exception()",
"print('AFTER BaseException')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert "_threads.py" not in err
assert "_repl.py" not in err
assert "AFTER BaseException" in out
def test_main_entrypoint() -> None:
"""
Basic smoke test when running via the package __main__ entrypoint.
"""
repl = subprocess.run([sys.executable, "-m", "trio"], input=b"exit()")
assert repl.returncode == 0
@@ -0,0 +1,47 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import trio
if TYPE_CHECKING:
import pytest
async def scheduler_trace() -> tuple[tuple[str, int], ...]:
"""Returns a scheduler-dependent value we can use to check determinism."""
trace = []
async def tracer(name: str) -> None:
for i in range(50):
trace.append((name, i))
await trio.lowlevel.checkpoint()
async with trio.open_nursery() as nursery:
for i in range(5):
nursery.start_soon(tracer, str(i))
return tuple(trace)
def test_the_trio_scheduler_is_not_deterministic() -> None:
# At least, not yet. See https://github.com/python-trio/trio/issues/32
traces = [trio.run(scheduler_trace) for _ in range(10)]
assert len(set(traces)) == len(traces)
def test_the_trio_scheduler_is_deterministic_if_seeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True)
traces = []
for _ in range(10):
state = trio._core._run._r.getstate()
try:
trio._core._run._r.seed(0)
traces.append(trio.run(scheduler_trace))
finally:
trio._core._run._r.setstate(state)
assert len(traces) == 10
assert len(set(traces)) == 1
@@ -0,0 +1,188 @@
from __future__ import annotations
import signal
from typing import TYPE_CHECKING, NoReturn
import pytest
import trio
from trio.testing import RaisesGroup
from .. import _core
from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver
from .._util import signal_raise
if TYPE_CHECKING:
from types import FrameType
async def test_open_signal_receiver() -> None:
orig = signal.getsignal(signal.SIGILL)
with open_signal_receiver(signal.SIGILL) as receiver:
# Raise it a few times, to exercise signal coalescing, both at the
# call_soon level and at the SignalQueue level
signal_raise(signal.SIGILL)
signal_raise(signal.SIGILL)
await _core.wait_all_tasks_blocked()
signal_raise(signal.SIGILL)
await _core.wait_all_tasks_blocked()
async for signum in receiver: # pragma: no branch
assert signum == signal.SIGILL
break
assert get_pending_signal_count(receiver) == 0
signal_raise(signal.SIGILL)
async for signum in receiver: # pragma: no branch
assert signum == signal.SIGILL
break
assert get_pending_signal_count(receiver) == 0
with pytest.raises(RuntimeError):
await receiver.__anext__()
assert signal.getsignal(signal.SIGILL) is orig
async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None:
orig = signal.getsignal(signal.SIGILL)
with pytest.raises(
ValueError,
match="(signal number out of range|invalid signal value)$",
):
with open_signal_receiver(signal.SIGILL, 1234567):
pass # pragma: no cover
# Still restored even if we errored out
assert signal.getsignal(signal.SIGILL) is orig
async def test_open_signal_receiver_empty_fail() -> None:
with pytest.raises(TypeError, match="No signals were provided"):
with open_signal_receiver():
pass
async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None:
orig = signal.getsignal(signal.SIGILL)
with open_signal_receiver(signal.SIGILL, signal.SIGILL):
pass
# Still restored correctly
assert signal.getsignal(signal.SIGILL) is orig
async def test_catch_signals_wrong_thread() -> None:
async def naughty() -> None:
with open_signal_receiver(signal.SIGINT):
pass # pragma: no cover
with pytest.raises(RuntimeError):
await trio.to_thread.run_sync(trio.run, naughty)
async def test_open_signal_receiver_conflict() -> None:
with RaisesGroup(trio.BusyResourceError):
with open_signal_receiver(signal.SIGILL) as receiver:
async with trio.open_nursery() as nursery:
nursery.start_soon(receiver.__anext__)
nursery.start_soon(receiver.__anext__)
# Blocks until all previous calls to run_sync_soon(idempotent=True) have been
# processed.
async def wait_run_sync_soon_idempotent_queue_barrier() -> None:
ev = trio.Event()
token = _core.current_trio_token()
token.run_sync_soon(ev.set, idempotent=True)
await ev.wait()
async def test_open_signal_receiver_no_starvation() -> None:
# Set up a situation where there are always 2 pending signals available to
# report, and make sure that instead of getting the same signal reported
# over and over, it alternates between reporting both of them.
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
try:
print(signal.getsignal(signal.SIGILL))
previous = None
for _ in range(10):
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
if previous is None:
previous = await receiver.__anext__()
else:
got = await receiver.__anext__()
assert got in [signal.SIGILL, signal.SIGFPE]
assert got != previous
previous = got
# Clear out the last signal so that it doesn't get redelivered
while get_pending_signal_count(receiver) != 0:
await receiver.__anext__()
except BaseException: # pragma: no cover
# If there's an unhandled exception above, then exiting the
# open_signal_receiver block might cause the signal to be
# redelivered and give us a core dump instead of a traceback...
import traceback
traceback.print_exc()
async def test_catch_signals_race_condition_on_exit() -> None:
delivered_directly: set[int] = set()
def direct_handler(signo: int, frame: FrameType | None) -> None:
delivered_directly.add(signo)
print(1)
# Test the version where the call_soon *doesn't* have a chance to run
# before we exit the with block:
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
delivered_directly.clear()
print(2)
# Test the version where the call_soon *does* have a chance to run before
# we exit the with block:
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert get_pending_signal_count(receiver) == 2
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
delivered_directly.clear()
# Again, but with a SIG_IGN signal:
print(3)
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
with open_signal_receiver(signal.SIGILL) as receiver:
signal_raise(signal.SIGILL)
await wait_run_sync_soon_idempotent_queue_barrier()
# test passes if the process reaches this point without dying
print(4)
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
with open_signal_receiver(signal.SIGILL) as receiver:
signal_raise(signal.SIGILL)
await wait_run_sync_soon_idempotent_queue_barrier()
assert get_pending_signal_count(receiver) == 1
# test passes if the process reaches this point without dying
# Check exception chaining if there are multiple exception-raising
# handlers
def raise_handler(signum: int, frame: FrameType | None) -> NoReturn:
raise RuntimeError(signum)
with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler):
with pytest.raises(RuntimeError) as excinfo:
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert get_pending_signal_count(receiver) == 2
exc = excinfo.value
signums = {exc.args[0]}
assert isinstance(exc.__context__, RuntimeError)
signums.add(exc.__context__.args[0])
assert signums == {signal.SIGILL, signal.SIGFPE}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,696 @@
from __future__ import annotations
import gc
import os
import random
import signal
import subprocess
import sys
from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path as SyncPath
from signal import Signals
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
AsyncIterator,
Callable,
NoReturn,
)
import pytest
import trio
from trio.testing import Matcher, RaisesGroup
from .. import (
Event,
Process,
_core,
fail_after,
move_on_after,
run_process,
sleep,
sleep_forever,
)
from .._core._tests.tutil import skip_if_fbsd_pipes_broken, slow
from ..lowlevel import open_process
from ..testing import MockClock, assert_no_checkpoints, wait_all_tasks_blocked
if TYPE_CHECKING:
from types import FrameType
from typing_extensions import TypeAlias
from .._abc import ReceiveStream
if sys.platform == "win32":
SignalType: TypeAlias = None
else:
SignalType: TypeAlias = Signals
SIGKILL: SignalType
SIGTERM: SignalType
SIGUSR1: SignalType
posix = os.name == "posix"
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
from signal import SIGKILL, SIGTERM, SIGUSR1
else:
SIGKILL, SIGTERM, SIGUSR1 = None, None, None
# Since Windows has very few command-line utilities generally available,
# all of our subprocesses are Python processes running short bits of
# (mostly) cross-platform code.
def python(code: str) -> list[str]:
return [sys.executable, "-u", "-c", "import sys; " + code]
EXIT_TRUE = python("sys.exit(0)")
EXIT_FALSE = python("sys.exit(1)")
CAT = python("sys.stdout.buffer.write(sys.stdin.buffer.read())")
if posix:
def SLEEP(seconds: int) -> list[str]:
return ["sleep", str(seconds)]
else:
def SLEEP(seconds: int) -> list[str]:
return python(f"import time; time.sleep({seconds})")
def got_signal(proc: Process, sig: SignalType) -> bool:
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
return proc.returncode == -sig
else:
return proc.returncode != 0
@asynccontextmanager # type: ignore[misc] # Any in decorator
async def open_process_then_kill(*args: Any, **kwargs: Any) -> AsyncIterator[Process]:
proc = await open_process(*args, **kwargs)
try:
yield proc
finally:
proc.kill()
await proc.wait()
@asynccontextmanager # type: ignore[misc] # Any in decorator
async def run_process_in_nursery(*args: Any, **kwargs: Any) -> AsyncIterator[Process]:
async with _core.open_nursery() as nursery:
kwargs.setdefault("check", False)
proc: Process = await nursery.start(partial(run_process, *args, **kwargs))
yield proc
nursery.cancel_scope.cancel()
background_process_param = pytest.mark.parametrize(
"background_process",
[open_process_then_kill, run_process_in_nursery],
ids=["open_process", "run_process in nursery"],
)
BackgroundProcessType: TypeAlias = Callable[..., AsyncContextManager[Process]]
@background_process_param
async def test_basic(background_process: BackgroundProcessType) -> None:
async with background_process(EXIT_TRUE) as proc:
await proc.wait()
assert isinstance(proc, Process)
assert proc._pidfd is None
assert proc.returncode == 0
assert repr(proc) == f"<trio.Process {EXIT_TRUE}: exited with status 0>"
async with background_process(EXIT_FALSE) as proc:
await proc.wait()
assert proc.returncode == 1
assert repr(proc) == "<trio.Process {!r}: {}>".format(
EXIT_FALSE,
"exited with status 1",
)
@background_process_param
async def test_auto_update_returncode(
background_process: BackgroundProcessType,
) -> None:
async with background_process(SLEEP(9999)) as p:
assert p.returncode is None
assert "running" in repr(p)
p.kill()
p._proc.wait()
assert p.returncode is not None
assert "exited" in repr(p)
assert p._pidfd is None
assert p.returncode is not None
@background_process_param
async def test_multi_wait(background_process: BackgroundProcessType) -> None:
async with background_process(SLEEP(10)) as proc:
# Check that wait (including multi-wait) tolerates being cancelled
async with _core.open_nursery() as nursery:
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# Now try waiting for real
async with _core.open_nursery() as nursery:
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
await wait_all_tasks_blocked()
proc.kill()
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR = python(
"data = sys.stdin.buffer.read(); "
"sys.stdout.buffer.write(data); "
"sys.stderr.buffer.write(data[::-1])",
)
@background_process_param
async def test_pipes(background_process: BackgroundProcessType) -> None:
async with background_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
msg = b"the quick brown fox jumps over the lazy dog"
async def feed_input() -> None:
assert proc.stdin is not None
await proc.stdin.send_all(msg)
await proc.stdin.aclose()
async def check_output(stream: ReceiveStream, expected: bytes) -> None:
seen = bytearray()
async for chunk in stream:
seen += chunk
assert seen == expected
assert proc.stdout is not None
assert proc.stderr is not None
async with _core.open_nursery() as nursery:
# fail eventually if something is broken
nursery.cancel_scope.deadline = _core.current_time() + 30.0
nursery.start_soon(feed_input)
nursery.start_soon(check_output, proc.stdout, msg)
nursery.start_soon(check_output, proc.stderr, msg[::-1])
assert not nursery.cancel_scope.cancelled_caught
assert await proc.wait() == 0
@background_process_param
async def test_interactive(background_process: BackgroundProcessType) -> None:
# Test some back-and-forth with a subprocess. This one works like so:
# in: 32\n
# out: 0000...0000\n (32 zeroes)
# err: 1111...1111\n (64 ones)
# in: 10\n
# out: 2222222222\n (10 twos)
# err: 3333....3333\n (20 threes)
# in: EOF
# out: EOF
# err: EOF
async with background_process(
python(
"idx = 0\n"
"while True:\n"
" line = sys.stdin.readline()\n"
" if line == '': break\n"
" request = int(line.strip())\n"
" print(str(idx * 2) * request)\n"
" print(str(idx * 2 + 1) * request * 2, file=sys.stderr)\n"
" idx += 1\n",
),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
newline = b"\n" if posix else b"\r\n"
async def expect(idx: int, request: int) -> None:
async with _core.open_nursery() as nursery:
async def drain_one(
stream: ReceiveStream,
count: int,
digit: int,
) -> None:
while count > 0:
result = await stream.receive_some(count)
assert result == (f"{digit}".encode() * len(result))
count -= len(result)
assert count == 0
assert await stream.receive_some(len(newline)) == newline
assert proc.stdout is not None
assert proc.stderr is not None
nursery.start_soon(drain_one, proc.stdout, request, idx * 2)
nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1)
assert proc.stdin is not None
assert proc.stdout is not None
assert proc.stderr is not None
with fail_after(5):
await proc.stdin.send_all(b"12")
await sleep(0.1)
await proc.stdin.send_all(b"345" + newline)
await expect(0, 12345)
await proc.stdin.send_all(b"100" + newline + b"200" + newline)
await expect(1, 100)
await expect(2, 200)
await proc.stdin.send_all(b"0" + newline)
await expect(3, 0)
await proc.stdin.send_all(b"999999")
with move_on_after(0.1) as scope:
await expect(4, 0)
assert scope.cancelled_caught
await proc.stdin.send_all(newline)
await expect(4, 999999)
await proc.stdin.aclose()
assert await proc.stdout.receive_some(1) == b""
assert await proc.stderr.receive_some(1) == b""
await proc.wait()
assert proc.returncode == 0
async def test_run() -> None:
data = bytes(random.randint(0, 255) for _ in range(2**18))
result = await run_process(
CAT,
stdin=data,
capture_stdout=True,
capture_stderr=True,
)
assert result.args == CAT
assert result.returncode == 0
assert result.stdout == data
assert result.stderr == b""
result = await run_process(CAT, capture_stdout=True)
assert result.args == CAT
assert result.returncode == 0
assert result.stdout == b""
assert result.stderr is None
result = await run_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=data,
capture_stdout=True,
capture_stderr=True,
)
assert result.args == COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR
assert result.returncode == 0
assert result.stdout == data
assert result.stderr == data[::-1]
# invalid combinations
with pytest.raises(UnicodeError):
await run_process(CAT, stdin="oh no, it's text")
pipe_stdout_error = r"^stdout=subprocess\.PIPE is only valid with nursery\.start, since that's the only way to access the pipe(; use nursery\.start or pass the data you want to write directly)*$"
with pytest.raises(ValueError, match=pipe_stdout_error):
await run_process(CAT, stdin=subprocess.PIPE)
with pytest.raises(ValueError, match=pipe_stdout_error):
await run_process(CAT, stdout=subprocess.PIPE)
with pytest.raises(
ValueError,
match=pipe_stdout_error.replace("stdout", "stderr", 1),
):
await run_process(CAT, stderr=subprocess.PIPE)
with pytest.raises(
ValueError,
match="^can't specify both stdout and capture_stdout$",
):
await run_process(CAT, capture_stdout=True, stdout=subprocess.DEVNULL)
with pytest.raises(
ValueError,
match="^can't specify both stderr and capture_stderr$",
):
await run_process(CAT, capture_stderr=True, stderr=None)
async def test_run_check() -> None:
cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)")
with pytest.raises(subprocess.CalledProcessError) as excinfo:
await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True)
assert excinfo.value.cmd == cmd
assert excinfo.value.returncode == 1
assert excinfo.value.stderr == b"test\n"
assert excinfo.value.stdout is None
result = await run_process(
cmd,
capture_stdout=True,
capture_stderr=True,
check=False,
)
assert result.args == cmd
assert result.stdout == b""
assert result.stderr == b"test\n"
assert result.returncode == 1
@skip_if_fbsd_pipes_broken
async def test_run_with_broken_pipe() -> None:
result = await run_process(
[sys.executable, "-c", "import sys; sys.stdin.close()"],
stdin=b"x" * 131072,
)
assert result.returncode == 0
assert result.stdout is result.stderr is None
@background_process_param
async def test_stderr_stdout(background_process: BackgroundProcessType) -> None:
async with background_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as proc:
assert proc.stdio is not None
assert proc.stdout is not None
assert proc.stderr is None
await proc.stdio.send_all(b"1234")
await proc.stdio.send_eof()
output = []
while True:
chunk = await proc.stdio.receive_some(16)
if chunk == b"":
break
output.append(chunk)
assert b"".join(output) == b"12344321"
assert proc.returncode == 0
# equivalent test with run_process()
result = await run_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=b"1234",
capture_stdout=True,
stderr=subprocess.STDOUT,
)
assert result.returncode == 0
assert result.stdout == b"12344321"
assert result.stderr is None
# this one hits the branch where stderr=STDOUT but stdout
# is not redirected
async with background_process(
CAT,
stdin=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as proc:
assert proc.stdout is None
assert proc.stderr is None
await proc.stdin.aclose()
await proc.wait()
assert proc.returncode == 0
if posix:
try:
r, w = os.pipe()
async with background_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
stdout=w,
stderr=subprocess.STDOUT,
) as proc:
os.close(w)
assert proc.stdio is None
assert proc.stdout is None
assert proc.stderr is None
await proc.stdin.send_all(b"1234")
await proc.stdin.aclose()
assert await proc.wait() == 0
assert os.read(r, 4096) == b"12344321"
assert os.read(r, 4096) == b""
finally:
os.close(r)
async def test_errors() -> None:
with pytest.raises(TypeError) as excinfo:
# call-overload on unix, call-arg on windows
await open_process(["ls"], encoding="utf-8") # type: ignore
assert "unbuffered byte streams" in str(excinfo.value)
assert "the 'encoding' option is not supported" in str(excinfo.value)
if posix:
with pytest.raises(TypeError) as excinfo:
await open_process(["ls"], shell=True)
with pytest.raises(TypeError) as excinfo:
await open_process("ls", shell=False)
@background_process_param
async def test_signals(background_process: BackgroundProcessType) -> None:
async def test_one_signal(
send_it: Callable[[Process], None],
signum: signal.Signals | None,
) -> None:
with move_on_after(1.0) as scope:
async with background_process(SLEEP(3600)) as proc:
send_it(proc)
await proc.wait()
assert not scope.cancelled_caught
if posix:
assert signum is not None
assert proc.returncode == -signum
else:
assert proc.returncode != 0
await test_one_signal(Process.kill, SIGKILL)
await test_one_signal(Process.terminate, SIGTERM)
# Test that we can send arbitrary signals.
#
# We used to use SIGINT here, but it turns out that the Python interpreter
# has race conditions that can cause it to explode in weird ways if it
# tries to handle SIGINT during startup. SIGUSR1's default disposition is
# to terminate the target process, and Python doesn't try to do anything
# clever to handle it.
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
await test_one_signal(lambda proc: proc.send_signal(SIGUSR1), SIGUSR1)
@pytest.mark.skipif(not posix, reason="POSIX specific")
@background_process_param
async def test_wait_reapable_fails(background_process: BackgroundProcessType) -> None:
if TYPE_CHECKING and sys.platform == "win32":
return
old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
try:
# With SIGCHLD disabled, the wait() syscall will wait for the
# process to exit but then fail with ECHILD. Make sure we
# support this case as the stdlib subprocess module does.
async with background_process(SLEEP(3600)) as proc:
async with _core.open_nursery() as nursery:
nursery.start_soon(proc.wait)
await wait_all_tasks_blocked()
proc.kill()
nursery.cancel_scope.deadline = _core.current_time() + 1.0
assert not nursery.cancel_scope.cancelled_caught
assert proc.returncode == 0 # exit status unknowable, so...
finally:
signal.signal(signal.SIGCHLD, old_sigchld)
@slow
def test_waitid_eintr() -> None:
# This only matters on PyPy (where we're coding EINTR handling
# ourselves) but the test works on all waitid platforms.
from .._subprocess_platform import wait_child_exiting
if TYPE_CHECKING and (sys.platform == "win32" or sys.platform == "darwin"):
return
if not wait_child_exiting.__module__.endswith("waitid"):
pytest.skip("waitid only")
# despite the TYPE_CHECKING early return silencing warnings about signal.SIGALRM etc
# this import is still checked on win32&darwin and raises [attr-defined].
# Linux doesn't raise [attr-defined] though, so we need [unused-ignore]
from .._subprocess_platform.waitid import ( # type: ignore[attr-defined, unused-ignore]
sync_wait_reapable,
)
got_alarm = False
sleeper = subprocess.Popen(["sleep", "3600"])
def on_alarm(sig: int, frame: FrameType | None) -> None:
nonlocal got_alarm
got_alarm = True
sleeper.kill()
old_sigalrm = signal.signal(signal.SIGALRM, on_alarm)
try:
signal.alarm(1)
sync_wait_reapable(sleeper.pid)
assert sleeper.wait(timeout=1) == -9
finally:
if sleeper.returncode is None: # pragma: no cover
# We only get here if something fails in the above;
# if the test passes, wait() will reap the process
sleeper.kill()
sleeper.wait()
signal.signal(signal.SIGALRM, old_sigalrm)
async def test_custom_deliver_cancel() -> None:
custom_deliver_cancel_called = False
async def custom_deliver_cancel(proc: Process) -> None:
nonlocal custom_deliver_cancel_called
custom_deliver_cancel_called = True
proc.terminate()
# Make sure this does get cancelled when the process exits, and that
# the process really exited.
try:
await sleep_forever()
finally:
assert proc.returncode is not None
async with _core.open_nursery() as nursery:
nursery.start_soon(
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel),
)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert custom_deliver_cancel_called
def test_bad_deliver_cancel() -> None:
async def custom_deliver_cancel(proc: Process) -> None:
proc.terminate()
raise ValueError("foo")
async def do_stuff() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel),
)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# double wrap from our nursery + the internal nursery
with RaisesGroup(RaisesGroup(Matcher(ValueError, "^foo$"))):
_core.run(do_stuff, strict_exception_groups=True)
async def test_warn_on_failed_cancel_terminate(monkeypatch: pytest.MonkeyPatch) -> None:
original_terminate = Process.terminate
def broken_terminate(self: Process) -> NoReturn:
original_terminate(self)
raise OSError("whoops")
monkeypatch.setattr(Process, "terminate", broken_terminate)
with pytest.warns(RuntimeWarning, match=".*whoops.*"):
async with _core.open_nursery() as nursery:
nursery.start_soon(run_process, SLEEP(9999))
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
@pytest.mark.skipif(not posix, reason="posix only")
async def test_warn_on_cancel_SIGKILL_escalation(
autojump_clock: MockClock,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(Process, "terminate", lambda *args: None)
with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"):
async with _core.open_nursery() as nursery:
nursery.start_soon(run_process, SLEEP(9999))
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# the background_process_param exercises a lot of run_process cases, but it uses
# check=False, so lets have a test that uses check=True as well
async def test_run_process_background_fail() -> None:
with RaisesGroup(subprocess.CalledProcessError):
async with _core.open_nursery() as nursery:
proc: Process = await nursery.start(run_process, EXIT_FALSE)
assert proc.returncode == 1
@pytest.mark.skipif(
not SyncPath("/dev/fd").exists(),
reason="requires a way to iterate through open files",
)
async def test_for_leaking_fds() -> None:
gc.collect() # address possible flakiness on PyPy
starting_fds = set(SyncPath("/dev/fd").iterdir())
await run_process(EXIT_TRUE)
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
with pytest.raises(subprocess.CalledProcessError):
await run_process(EXIT_FALSE)
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
with pytest.raises(PermissionError):
await run_process(["/dev/fd/0"])
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
async def test_run_process_internal_error(monkeypatch: pytest.MonkeyPatch) -> None:
# There's probably less extreme ways of triggering errors inside the nursery
# in run_process.
async def very_broken_open(*args: object, **kwargs: object) -> str:
return "oops"
monkeypatch.setattr(trio._subprocess, "open_process", very_broken_open)
with RaisesGroup(AttributeError, AttributeError):
await run_process(EXIT_TRUE, capture_stdout=True)
# regression test for #2209
async def test_subprocess_pidfd_unnotified() -> None:
noticed_exit = None
async def wait_and_tell(proc: Process) -> None:
nonlocal noticed_exit
noticed_exit = Event()
await proc.wait()
noticed_exit.set()
proc = await open_process(SLEEP(9999))
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_and_tell, proc)
await wait_all_tasks_blocked()
assert isinstance(noticed_exit, Event)
proc.terminate()
# without giving trio a chance to do so,
with assert_no_checkpoints():
# wait until the process has actually exited;
proc._proc.wait()
# force a call to poll (that closes the pidfd on linux)
proc.poll()
with move_on_after(5):
# Some platforms use threads to wait for exit, so it might take a bit
# for everything to notice
await noticed_exit.wait()
assert noticed_exit.is_set(), "child task wasn't woken after poll, DEADLOCK"
@@ -0,0 +1,655 @@
from __future__ import annotations
import re
import weakref
from typing import TYPE_CHECKING, Callable, Union
import pytest
from trio.testing import Matcher, RaisesGroup
from .. import _core
from .._core._parking_lot import GLOBAL_PARKING_LOT_BREAKER
from .._sync import *
from .._timeouts import sleep_forever
from ..testing import assert_checkpoints, wait_all_tasks_blocked
if TYPE_CHECKING:
from typing_extensions import TypeAlias
async def test_Event() -> None:
e = Event()
assert not e.is_set()
assert e.statistics().tasks_waiting == 0
e.set()
assert e.is_set()
with assert_checkpoints():
await e.wait()
e = Event()
record = []
async def child() -> None:
record.append("sleeping")
await e.wait()
record.append("woken")
async with _core.open_nursery() as nursery:
nursery.start_soon(child)
nursery.start_soon(child)
await wait_all_tasks_blocked()
assert record == ["sleeping", "sleeping"]
assert e.statistics().tasks_waiting == 2
e.set()
await wait_all_tasks_blocked()
assert record == ["sleeping", "sleeping", "woken", "woken"]
async def test_CapacityLimiter() -> None:
with pytest.raises(TypeError):
CapacityLimiter(1.0)
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
CapacityLimiter(-1)
c = CapacityLimiter(2)
repr(c) # smoke test
assert c.total_tokens == 2
assert c.borrowed_tokens == 0
assert c.available_tokens == 2
with pytest.raises(RuntimeError):
c.release()
assert c.borrowed_tokens == 0
c.acquire_nowait()
assert c.borrowed_tokens == 1
assert c.available_tokens == 1
stats = c.statistics()
assert stats.borrowed_tokens == 1
assert stats.total_tokens == 2
assert stats.borrowers == [_core.current_task()]
assert stats.tasks_waiting == 0
# Can't re-acquire when we already have it
with pytest.raises(RuntimeError):
c.acquire_nowait()
assert c.borrowed_tokens == 1
with pytest.raises(RuntimeError):
await c.acquire()
assert c.borrowed_tokens == 1
# We can acquire on behalf of someone else though
with assert_checkpoints():
await c.acquire_on_behalf_of("someone")
# But then we've run out of capacity
assert c.borrowed_tokens == 2
with pytest.raises(_core.WouldBlock):
c.acquire_on_behalf_of_nowait("third party")
assert set(c.statistics().borrowers) == {_core.current_task(), "someone"}
# Until we release one
c.release_on_behalf_of(_core.current_task())
assert c.statistics().borrowers == ["someone"]
c.release_on_behalf_of("someone")
assert c.borrowed_tokens == 0
with assert_checkpoints():
async with c:
assert c.borrowed_tokens == 1
async with _core.open_nursery() as nursery:
await c.acquire_on_behalf_of("value 1")
await c.acquire_on_behalf_of("value 2")
nursery.start_soon(c.acquire_on_behalf_of, "value 3")
await wait_all_tasks_blocked()
assert c.borrowed_tokens == 2
assert c.statistics().tasks_waiting == 1
c.release_on_behalf_of("value 2")
# Fairness:
assert c.borrowed_tokens == 2
with pytest.raises(_core.WouldBlock):
c.acquire_nowait()
c.release_on_behalf_of("value 3")
c.release_on_behalf_of("value 1")
async def test_CapacityLimiter_inf() -> None:
from math import inf
c = CapacityLimiter(inf)
repr(c) # smoke test
assert c.total_tokens == inf
assert c.borrowed_tokens == 0
assert c.available_tokens == inf
with pytest.raises(RuntimeError):
c.release()
assert c.borrowed_tokens == 0
c.acquire_nowait()
assert c.borrowed_tokens == 1
assert c.available_tokens == inf
async def test_CapacityLimiter_change_total_tokens() -> None:
c = CapacityLimiter(2)
with pytest.raises(TypeError):
c.total_tokens = 1.0
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
c.total_tokens = 0
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
c.total_tokens = -10
assert c.total_tokens == 2
async with _core.open_nursery() as nursery:
for i in range(5):
nursery.start_soon(c.acquire_on_behalf_of, i)
await wait_all_tasks_blocked()
assert set(c.statistics().borrowers) == {0, 1}
assert c.statistics().tasks_waiting == 3
c.total_tokens += 2
assert set(c.statistics().borrowers) == {0, 1, 2, 3}
assert c.statistics().tasks_waiting == 1
c.total_tokens -= 3
assert c.borrowed_tokens == 4
assert c.total_tokens == 1
c.release_on_behalf_of(0)
c.release_on_behalf_of(1)
c.release_on_behalf_of(2)
assert set(c.statistics().borrowers) == {3}
assert c.statistics().tasks_waiting == 1
c.release_on_behalf_of(3)
assert set(c.statistics().borrowers) == {4}
assert c.statistics().tasks_waiting == 0
# regression test for issue #548
async def test_CapacityLimiter_memleak_548() -> None:
limiter = CapacityLimiter(total_tokens=1)
await limiter.acquire()
async with _core.open_nursery() as n:
n.start_soon(limiter.acquire)
await wait_all_tasks_blocked() # give it a chance to run the task
n.cancel_scope.cancel()
# if this is 1, the acquire call (despite being killed) is still there in the task, and will
# leak memory all the while the limiter is active
assert len(limiter._pending_borrowers) == 0
async def test_Semaphore() -> None:
with pytest.raises(TypeError):
Semaphore(1.0) # type: ignore[arg-type]
with pytest.raises(ValueError, match="^initial value must be >= 0$"):
Semaphore(-1)
s = Semaphore(1)
repr(s) # smoke test
assert s.value == 1
assert s.max_value is None
s.release()
assert s.value == 2
assert s.statistics().tasks_waiting == 0
s.acquire_nowait()
assert s.value == 1
with assert_checkpoints():
await s.acquire()
assert s.value == 0
with pytest.raises(_core.WouldBlock):
s.acquire_nowait()
s.release()
assert s.value == 1
with assert_checkpoints():
async with s:
assert s.value == 0
assert s.value == 1
s.acquire_nowait()
record = []
async def do_acquire(s: Semaphore) -> None:
record.append("started")
await s.acquire()
record.append("finished")
async with _core.open_nursery() as nursery:
nursery.start_soon(do_acquire, s)
await wait_all_tasks_blocked()
assert record == ["started"]
assert s.value == 0
s.release()
# Fairness:
assert s.value == 0
with pytest.raises(_core.WouldBlock):
s.acquire_nowait()
assert record == ["started", "finished"]
async def test_Semaphore_bounded() -> None:
with pytest.raises(TypeError):
Semaphore(1, max_value=1.0) # type: ignore[arg-type]
with pytest.raises(ValueError, match="^max_values must be >= initial_value$"):
Semaphore(2, max_value=1)
bs = Semaphore(1, max_value=1)
assert bs.max_value == 1
repr(bs) # smoke test
with pytest.raises(ValueError, match="^semaphore released too many times$"):
bs.release()
assert bs.value == 1
bs.acquire_nowait()
assert bs.value == 0
bs.release()
assert bs.value == 1
@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
async def test_Lock_and_StrictFIFOLock(
lockcls: type[Lock | StrictFIFOLock],
) -> None:
l = lockcls() # noqa
assert not l.locked()
# make sure locks can be weakref'ed (gh-331)
r = weakref.ref(l)
assert r() is l
repr(l) # smoke test
# make sure repr uses the right name for subclasses
assert lockcls.__name__ in repr(l)
with assert_checkpoints():
async with l:
assert l.locked()
repr(l) # smoke test (repr branches on locked/unlocked)
assert not l.locked()
l.acquire_nowait()
assert l.locked()
l.release()
assert not l.locked()
with assert_checkpoints():
await l.acquire()
assert l.locked()
l.release()
assert not l.locked()
l.acquire_nowait()
with pytest.raises(RuntimeError):
# Error out if we already own the lock
l.acquire_nowait()
l.release()
with pytest.raises(RuntimeError):
# Error out if we don't own the lock
l.release()
holder_task = None
async def holder() -> None:
nonlocal holder_task
holder_task = _core.current_task()
async with l:
await sleep_forever()
async with _core.open_nursery() as nursery:
assert not l.locked()
nursery.start_soon(holder)
await wait_all_tasks_blocked()
assert l.locked()
# WouldBlock if someone else holds the lock
with pytest.raises(_core.WouldBlock):
l.acquire_nowait()
# Can't release a lock someone else holds
with pytest.raises(RuntimeError):
l.release()
statistics = l.statistics()
print(statistics)
assert statistics.locked
assert statistics.owner is holder_task
assert statistics.tasks_waiting == 0
nursery.start_soon(holder)
await wait_all_tasks_blocked()
statistics = l.statistics()
print(statistics)
assert statistics.tasks_waiting == 1
nursery.cancel_scope.cancel()
statistics = l.statistics()
assert not statistics.locked
assert statistics.owner is None
assert statistics.tasks_waiting == 0
async def test_Condition() -> None:
with pytest.raises(TypeError):
Condition(Semaphore(1)) # type: ignore[arg-type]
with pytest.raises(TypeError):
Condition(StrictFIFOLock) # type: ignore[arg-type]
l = Lock() # noqa
c = Condition(l)
assert not l.locked()
assert not c.locked()
with assert_checkpoints():
await c.acquire()
assert l.locked()
assert c.locked()
c = Condition()
assert not c.locked()
c.acquire_nowait()
assert c.locked()
with pytest.raises(RuntimeError):
c.acquire_nowait()
c.release()
with pytest.raises(RuntimeError):
# Can't wait without holding the lock
await c.wait()
with pytest.raises(RuntimeError):
# Can't notify without holding the lock
c.notify()
with pytest.raises(RuntimeError):
# Can't notify without holding the lock
c.notify_all()
finished_waiters = set()
async def waiter(i: int) -> None:
async with c:
await c.wait()
finished_waiters.add(i)
async with _core.open_nursery() as nursery:
for i in range(3):
nursery.start_soon(waiter, i)
await wait_all_tasks_blocked()
async with c:
c.notify()
assert c.locked()
await wait_all_tasks_blocked()
assert finished_waiters == {0}
async with c:
c.notify_all()
await wait_all_tasks_blocked()
assert finished_waiters == {0, 1, 2}
finished_waiters = set()
async with _core.open_nursery() as nursery:
for i in range(3):
nursery.start_soon(waiter, i)
await wait_all_tasks_blocked()
async with c:
c.notify(2)
statistics = c.statistics()
print(statistics)
assert statistics.tasks_waiting == 1
assert statistics.lock_statistics.tasks_waiting == 2
# exiting the context manager hands off the lock to the first task
assert c.statistics().lock_statistics.tasks_waiting == 1
await wait_all_tasks_blocked()
assert finished_waiters == {0, 1}
async with c:
c.notify_all()
# After being cancelled still hold the lock (!)
# (Note that c.__aexit__ checks that we hold the lock as well)
with _core.CancelScope() as scope:
async with c:
scope.cancel()
try:
await c.wait()
finally:
assert c.locked()
from .._channel import open_memory_channel
from .._sync import AsyncContextManagerMixin
# Three ways of implementing a Lock in terms of a channel. Used to let us put
# the channel through the generic lock tests.
class ChannelLock1(AsyncContextManagerMixin):
def __init__(self, capacity: int) -> None:
self.s, self.r = open_memory_channel[None](capacity)
for _ in range(capacity - 1):
self.s.send_nowait(None)
def acquire_nowait(self) -> None:
self.s.send_nowait(None)
async def acquire(self) -> None:
await self.s.send(None)
def release(self) -> None:
self.r.receive_nowait()
class ChannelLock2(AsyncContextManagerMixin):
def __init__(self) -> None:
self.s, self.r = open_memory_channel[None](10)
self.s.send_nowait(None)
def acquire_nowait(self) -> None:
self.r.receive_nowait()
async def acquire(self) -> None:
await self.r.receive()
def release(self) -> None:
self.s.send_nowait(None)
class ChannelLock3(AsyncContextManagerMixin):
def __init__(self) -> None:
self.s, self.r = open_memory_channel[None](0)
# self.acquired is true when one task acquires the lock and
# only becomes false when it's released and no tasks are
# waiting to acquire.
self.acquired = False
def acquire_nowait(self) -> None:
assert not self.acquired
self.acquired = True
async def acquire(self) -> None:
if self.acquired:
await self.s.send(None)
else:
self.acquired = True
await _core.checkpoint()
def release(self) -> None:
try:
self.r.receive_nowait()
except _core.WouldBlock:
assert self.acquired
self.acquired = False
lock_factories = [
lambda: CapacityLimiter(1),
lambda: Semaphore(1),
Lock,
StrictFIFOLock,
lambda: ChannelLock1(10),
lambda: ChannelLock1(1),
ChannelLock2,
ChannelLock3,
]
lock_factory_names = [
"CapacityLimiter(1)",
"Semaphore(1)",
"Lock",
"StrictFIFOLock",
"ChannelLock1(10)",
"ChannelLock1(1)",
"ChannelLock2",
"ChannelLock3",
]
generic_lock_test = pytest.mark.parametrize(
"lock_factory",
lock_factories,
ids=lock_factory_names,
)
LockLike: TypeAlias = Union[
CapacityLimiter,
Semaphore,
Lock,
StrictFIFOLock,
ChannelLock1,
ChannelLock2,
ChannelLock3,
]
LockFactory: TypeAlias = Callable[[], LockLike]
# Spawn a bunch of workers that take a lock and then yield; make sure that
# only one worker is ever in the critical section at a time.
@generic_lock_test
async def test_generic_lock_exclusion(lock_factory: LockFactory) -> None:
LOOPS = 10
WORKERS = 5
in_critical_section = False
acquires = 0
async def worker(lock_like: LockLike) -> None:
nonlocal in_critical_section, acquires
for _ in range(LOOPS):
async with lock_like:
acquires += 1
assert not in_critical_section
in_critical_section = True
await _core.checkpoint()
await _core.checkpoint()
assert in_critical_section
in_critical_section = False
async with _core.open_nursery() as nursery:
lock_like = lock_factory()
for _ in range(WORKERS):
nursery.start_soon(worker, lock_like)
assert not in_critical_section
assert acquires == LOOPS * WORKERS
# Several workers queue on the same lock; make sure they each get it, in
# order.
@generic_lock_test
async def test_generic_lock_fifo_fairness(lock_factory: LockFactory) -> None:
initial_order = []
record = []
LOOPS = 5
async def loopy(name: int, lock_like: LockLike) -> None:
# Record the order each task was initially scheduled in
initial_order.append(name)
for _ in range(LOOPS):
async with lock_like:
record.append(name)
lock_like = lock_factory()
async with _core.open_nursery() as nursery:
nursery.start_soon(loopy, 1, lock_like)
nursery.start_soon(loopy, 2, lock_like)
nursery.start_soon(loopy, 3, lock_like)
# The first three could be in any order due to scheduling randomness,
# but after that they should repeat in the same order
for i in range(LOOPS):
assert record[3 * i : 3 * (i + 1)] == initial_order
@generic_lock_test
async def test_generic_lock_acquire_nowait_blocks_acquire(
lock_factory: LockFactory,
) -> None:
lock_like = lock_factory()
record = []
async def lock_taker() -> None:
record.append("started")
async with lock_like:
pass
record.append("finished")
async with _core.open_nursery() as nursery:
lock_like.acquire_nowait()
nursery.start_soon(lock_taker)
await wait_all_tasks_blocked()
assert record == ["started"]
lock_like.release()
async def test_lock_acquire_unowned_lock() -> None:
"""Test that trying to acquire a lock whose owner has exited raises an error.
see https://github.com/python-trio/trio/issues/3035
"""
assert not GLOBAL_PARKING_LOT_BREAKER
lock = trio.Lock()
async with trio.open_nursery() as nursery:
nursery.start_soon(lock.acquire)
owner_str = re.escape(str(lock._lot.broken_by[0]))
with pytest.raises(
trio.BrokenResourceError,
match=f"^Owner of this lock exited without releasing: {owner_str}$",
):
await lock.acquire()
assert not GLOBAL_PARKING_LOT_BREAKER
async def test_lock_multiple_acquire() -> None:
"""Test for error if awaiting on a lock whose owner exits without releasing.
see https://github.com/python-trio/trio/issues/3035"""
assert not GLOBAL_PARKING_LOT_BREAKER
lock = trio.Lock()
with RaisesGroup(
Matcher(
trio.BrokenResourceError,
match="^Owner of this lock exited without releasing: ",
),
):
async with trio.open_nursery() as nursery:
nursery.start_soon(lock.acquire)
nursery.start_soon(lock.acquire)
assert not GLOBAL_PARKING_LOT_BREAKER
async def test_lock_handover() -> None:
assert not GLOBAL_PARKING_LOT_BREAKER
child_task: Task | None = None
lock = trio.Lock()
# this task acquires the lock
lock.acquire_nowait()
assert GLOBAL_PARKING_LOT_BREAKER == {
_core.current_task(): [
lock._lot,
],
}
async with trio.open_nursery() as nursery:
nursery.start_soon(lock.acquire)
await wait_all_tasks_blocked()
# hand over the lock to the child task
lock.release()
# check values, and get the identifier out of the dict for later check
assert len(GLOBAL_PARKING_LOT_BREAKER) == 1
child_task = next(iter(GLOBAL_PARKING_LOT_BREAKER))
assert GLOBAL_PARKING_LOT_BREAKER[child_task] == [lock._lot]
assert lock._lot.broken_by == [child_task]
assert not GLOBAL_PARKING_LOT_BREAKER
@@ -0,0 +1,684 @@
from __future__ import annotations
# XX this should get broken up, like testing.py did
import tempfile
from typing import TYPE_CHECKING
import pytest
from trio.testing import RaisesGroup
from .. import _core, sleep, socket as tsocket
from .._core._tests.tutil import can_bind_ipv6
from .._highlevel_generic import StapledStream, aclose_forcefully
from .._highlevel_socket import SocketListener
from ..testing import *
from ..testing._check_streams import _assert_raises
from ..testing._memory_streams import _UnboundedByteQueue
if TYPE_CHECKING:
from trio import Nursery
from trio.abc import ReceiveStream, SendStream
async def test_wait_all_tasks_blocked() -> None:
record = []
async def busy_bee() -> None:
for _ in range(10):
await _core.checkpoint()
record.append("busy bee exhausted")
async def waiting_for_bee_to_leave() -> None:
await wait_all_tasks_blocked()
record.append("quiet at last!")
async with _core.open_nursery() as nursery:
nursery.start_soon(busy_bee)
nursery.start_soon(waiting_for_bee_to_leave)
nursery.start_soon(waiting_for_bee_to_leave)
# check cancellation
record = []
async def cancelled_while_waiting() -> None:
try:
await wait_all_tasks_blocked()
except _core.Cancelled:
record.append("ok")
async with _core.open_nursery() as nursery:
nursery.start_soon(cancelled_while_waiting)
nursery.cancel_scope.cancel()
assert record == ["ok"]
async def test_wait_all_tasks_blocked_with_timeouts(mock_clock: MockClock) -> None:
record = []
async def timeout_task() -> None:
record.append("tt start")
await sleep(5)
record.append("tt finished")
async with _core.open_nursery() as nursery:
nursery.start_soon(timeout_task)
await wait_all_tasks_blocked()
assert record == ["tt start"]
mock_clock.jump(10)
await wait_all_tasks_blocked()
assert record == ["tt start", "tt finished"]
async def test_wait_all_tasks_blocked_with_cushion() -> None:
record = []
async def blink() -> None:
record.append("blink start")
await sleep(0.01)
await sleep(0.01)
await sleep(0.01)
record.append("blink end")
async def wait_no_cushion() -> None:
await wait_all_tasks_blocked()
record.append("wait_no_cushion end")
async def wait_small_cushion() -> None:
await wait_all_tasks_blocked(0.02)
record.append("wait_small_cushion end")
async def wait_big_cushion() -> None:
await wait_all_tasks_blocked(0.03)
record.append("wait_big_cushion end")
async with _core.open_nursery() as nursery:
nursery.start_soon(blink)
nursery.start_soon(wait_no_cushion)
nursery.start_soon(wait_small_cushion)
nursery.start_soon(wait_small_cushion)
nursery.start_soon(wait_big_cushion)
assert record == [
"blink start",
"wait_no_cushion end",
"blink end",
"wait_small_cushion end",
"wait_small_cushion end",
"wait_big_cushion end",
]
################################################################
async def test_assert_checkpoints(recwarn: pytest.WarningsRecorder) -> None:
with assert_checkpoints():
await _core.checkpoint()
with pytest.raises(AssertionError):
with assert_checkpoints():
1 + 1 # noqa: B018 # "useless expression"
# partial yield cases
# if you have a schedule point but not a cancel point, or vice-versa, then
# that's not a checkpoint.
for partial_yield in [
_core.checkpoint_if_cancelled,
_core.cancel_shielded_checkpoint,
]:
print(partial_yield)
with pytest.raises(AssertionError):
with assert_checkpoints():
await partial_yield()
# But both together count as a checkpoint
with assert_checkpoints():
await _core.checkpoint_if_cancelled()
await _core.cancel_shielded_checkpoint()
async def test_assert_no_checkpoints(recwarn: pytest.WarningsRecorder) -> None:
with assert_no_checkpoints():
1 + 1 # noqa: B018 # "useless expression"
with pytest.raises(AssertionError):
with assert_no_checkpoints():
await _core.checkpoint()
# partial yield cases
# if you have a schedule point but not a cancel point, or vice-versa, then
# that doesn't make *either* version of assert_{no_,}yields happy.
for partial_yield in [
_core.checkpoint_if_cancelled,
_core.cancel_shielded_checkpoint,
]:
print(partial_yield)
with pytest.raises(AssertionError):
with assert_no_checkpoints():
await partial_yield()
# And both together also count as a checkpoint
with pytest.raises(AssertionError):
with assert_no_checkpoints():
await _core.checkpoint_if_cancelled()
await _core.cancel_shielded_checkpoint()
################################################################
async def test_Sequencer() -> None:
record = []
def t(val: object) -> None:
print(val)
record.append(val)
async def f1(seq: Sequencer) -> None:
async with seq(1):
t(("f1", 1))
async with seq(3):
t(("f1", 3))
async with seq(4):
t(("f1", 4))
async def f2(seq: Sequencer) -> None:
async with seq(0):
t(("f2", 0))
async with seq(2):
t(("f2", 2))
seq = Sequencer()
async with _core.open_nursery() as nursery:
nursery.start_soon(f1, seq)
nursery.start_soon(f2, seq)
async with seq(5):
await wait_all_tasks_blocked()
assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)]
seq = Sequencer()
# Catches us if we try to reuse a sequence point:
async with seq(0):
pass
with pytest.raises(RuntimeError):
async with seq(0):
pass # pragma: no cover
async def test_Sequencer_cancel() -> None:
# Killing a blocked task makes everything blow up
record = []
seq = Sequencer()
async def child(i: int) -> None:
with _core.CancelScope() as scope:
if i == 1:
scope.cancel()
try:
async with seq(i):
pass # pragma: no cover
except RuntimeError:
record.append(f"seq({i}) RuntimeError")
async with _core.open_nursery() as nursery:
nursery.start_soon(child, 1)
nursery.start_soon(child, 2)
async with seq(0):
pass # pragma: no cover
assert record == ["seq(1) RuntimeError", "seq(2) RuntimeError"]
# Late arrivals also get errors
with pytest.raises(RuntimeError):
async with seq(3):
pass # pragma: no cover
################################################################
async def test__assert_raises() -> None:
with pytest.raises(AssertionError):
with _assert_raises(RuntimeError):
1 + 1 # noqa: B018 # "useless expression"
with pytest.raises(TypeError):
with _assert_raises(RuntimeError):
"foo" + 1 # type: ignore[operator] # noqa: B018 # "useless expression"
with _assert_raises(RuntimeError):
raise RuntimeError
# This is a private implementation detail, but it's complex enough to be worth
# testing directly
async def test__UnboundeByteQueue() -> None:
ubq = _UnboundedByteQueue()
ubq.put(b"123")
ubq.put(b"456")
assert ubq.get_nowait(1) == b"1"
assert ubq.get_nowait(10) == b"23456"
ubq.put(b"789")
assert ubq.get_nowait() == b"789"
with pytest.raises(_core.WouldBlock):
ubq.get_nowait(10)
with pytest.raises(_core.WouldBlock):
ubq.get_nowait()
with pytest.raises(TypeError):
ubq.put("string") # type: ignore[arg-type]
ubq.put(b"abc")
with assert_checkpoints():
assert await ubq.get(10) == b"abc"
ubq.put(b"def")
ubq.put(b"ghi")
with assert_checkpoints():
assert await ubq.get(1) == b"d"
with assert_checkpoints():
assert await ubq.get() == b"efghi"
async def putter(data: bytes) -> None:
await wait_all_tasks_blocked()
ubq.put(data)
async def getter(expect: bytes) -> None:
with assert_checkpoints():
assert await ubq.get() == expect
async with _core.open_nursery() as nursery:
nursery.start_soon(getter, b"xyz")
nursery.start_soon(putter, b"xyz")
# Two gets at the same time -> BusyResourceError
with RaisesGroup(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(getter, b"asdf")
nursery.start_soon(getter, b"asdf")
# Closing
ubq.close()
with pytest.raises(_core.ClosedResourceError):
ubq.put(b"---")
assert ubq.get_nowait(10) == b""
assert ubq.get_nowait() == b""
assert await ubq.get(10) == b""
assert await ubq.get() == b""
# close is idempotent
ubq.close()
# close wakes up blocked getters
ubq2 = _UnboundedByteQueue()
async def closer() -> None:
await wait_all_tasks_blocked()
ubq2.close()
async with _core.open_nursery() as nursery:
nursery.start_soon(getter, b"")
nursery.start_soon(closer)
async def test_MemorySendStream() -> None:
mss = MemorySendStream()
async def do_send_all(data: bytes) -> None:
with assert_checkpoints():
await mss.send_all(data)
await do_send_all(b"123")
assert mss.get_data_nowait(1) == b"1"
assert mss.get_data_nowait() == b"23"
with assert_checkpoints():
await mss.wait_send_all_might_not_block()
with pytest.raises(_core.WouldBlock):
mss.get_data_nowait()
with pytest.raises(_core.WouldBlock):
mss.get_data_nowait(10)
await do_send_all(b"456")
with assert_checkpoints():
assert await mss.get_data() == b"456"
# Call send_all twice at once; one should get BusyResourceError and one
# should succeed. But we can't let the error propagate, because it might
# cause the other to be cancelled before it can finish doing its thing,
# and we don't know which one will get the error.
resource_busy_count = 0
async def do_send_all_count_resourcebusy() -> None:
nonlocal resource_busy_count
try:
await do_send_all(b"xxx")
except _core.BusyResourceError:
resource_busy_count += 1
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all_count_resourcebusy)
nursery.start_soon(do_send_all_count_resourcebusy)
assert resource_busy_count == 1
with assert_checkpoints():
await mss.aclose()
assert await mss.get_data() == b"xxx"
assert await mss.get_data() == b""
with pytest.raises(_core.ClosedResourceError):
await do_send_all(b"---")
# hooks
assert mss.send_all_hook is None
assert mss.wait_send_all_might_not_block_hook is None
assert mss.close_hook is None
record = []
async def send_all_hook() -> None:
# hook runs after send_all does its work (can pull data out)
assert mss2.get_data_nowait() == b"abc"
record.append("send_all_hook")
async def wait_send_all_might_not_block_hook() -> None:
record.append("wait_send_all_might_not_block_hook")
def close_hook() -> None:
record.append("close_hook")
mss2 = MemorySendStream(
send_all_hook,
wait_send_all_might_not_block_hook,
close_hook,
)
assert mss2.send_all_hook is send_all_hook
assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook
assert mss2.close_hook is close_hook
await mss2.send_all(b"abc")
await mss2.wait_send_all_might_not_block()
await aclose_forcefully(mss2)
mss2.close()
assert record == [
"send_all_hook",
"wait_send_all_might_not_block_hook",
"close_hook",
"close_hook",
]
async def test_MemoryReceiveStream() -> None:
mrs = MemoryReceiveStream()
async def do_receive_some(max_bytes: int | None) -> bytes:
with assert_checkpoints():
return await mrs.receive_some(max_bytes)
mrs.put_data(b"abc")
assert await do_receive_some(1) == b"a"
assert await do_receive_some(10) == b"bc"
mrs.put_data(b"abc")
assert await do_receive_some(None) == b"abc"
with RaisesGroup(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(do_receive_some, 10)
nursery.start_soon(do_receive_some, 10)
assert mrs.receive_some_hook is None
mrs.put_data(b"def")
mrs.put_eof()
mrs.put_eof()
assert await do_receive_some(10) == b"def"
assert await do_receive_some(10) == b""
assert await do_receive_some(10) == b""
with pytest.raises(_core.ClosedResourceError):
mrs.put_data(b"---")
async def receive_some_hook() -> None:
mrs2.put_data(b"xxx")
record = []
def close_hook() -> None:
record.append("closed")
mrs2 = MemoryReceiveStream(receive_some_hook, close_hook)
assert mrs2.receive_some_hook is receive_some_hook
assert mrs2.close_hook is close_hook
mrs2.put_data(b"yyy")
assert await mrs2.receive_some(10) == b"yyyxxx"
assert await mrs2.receive_some(10) == b"xxx"
assert await mrs2.receive_some(10) == b"xxx"
mrs2.put_data(b"zzz")
mrs2.receive_some_hook = None
assert await mrs2.receive_some(10) == b"zzz"
mrs2.put_data(b"lost on close")
with assert_checkpoints():
await mrs2.aclose()
assert record == ["closed"]
with pytest.raises(_core.ClosedResourceError):
await mrs2.receive_some(10)
async def test_MemoryRecvStream_closing() -> None:
mrs = MemoryReceiveStream()
# close with no pending data
mrs.close()
with pytest.raises(_core.ClosedResourceError):
assert await mrs.receive_some(10) == b""
# repeated closes ok
mrs.close()
# put_data now fails
with pytest.raises(_core.ClosedResourceError):
mrs.put_data(b"123")
mrs2 = MemoryReceiveStream()
# close with pending data
mrs2.put_data(b"xyz")
mrs2.close()
with pytest.raises(_core.ClosedResourceError):
await mrs2.receive_some(10)
async def test_memory_stream_pump() -> None:
mss = MemorySendStream()
mrs = MemoryReceiveStream()
# no-op if no data present
memory_stream_pump(mss, mrs)
await mss.send_all(b"123")
memory_stream_pump(mss, mrs)
assert await mrs.receive_some(10) == b"123"
await mss.send_all(b"456")
assert memory_stream_pump(mss, mrs, max_bytes=1)
assert await mrs.receive_some(10) == b"4"
assert memory_stream_pump(mss, mrs, max_bytes=1)
assert memory_stream_pump(mss, mrs, max_bytes=1)
assert not memory_stream_pump(mss, mrs, max_bytes=1)
assert await mrs.receive_some(10) == b"56"
mss.close()
memory_stream_pump(mss, mrs)
assert await mrs.receive_some(10) == b""
async def test_memory_stream_one_way_pair() -> None:
s, r = memory_stream_one_way_pair()
assert s.send_all_hook is not None
assert s.wait_send_all_might_not_block_hook is None
assert s.close_hook is not None
assert r.receive_some_hook is None
await s.send_all(b"123")
assert await r.receive_some(10) == b"123"
async def receiver(expected: bytes) -> None:
assert await r.receive_some(10) == expected
# This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver, b"abc")
await wait_all_tasks_blocked()
await s.send_all(b"abc")
# And this fails if we don't pump from close_hook
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver, b"")
await wait_all_tasks_blocked()
await s.aclose()
s, r = memory_stream_one_way_pair()
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver, b"")
await wait_all_tasks_blocked()
s.close()
s, r = memory_stream_one_way_pair()
old = s.send_all_hook
s.send_all_hook = None
await s.send_all(b"456")
async def cancel_after_idle(nursery: Nursery) -> None:
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
async def check_for_cancel() -> None:
with pytest.raises(_core.Cancelled):
# This should block forever... or until cancelled. Even though we
# sent some data on the send stream.
await r.receive_some(10)
async with _core.open_nursery() as nursery:
nursery.start_soon(cancel_after_idle, nursery)
nursery.start_soon(check_for_cancel)
s.send_all_hook = old
await s.send_all(b"789")
assert await r.receive_some(10) == b"456789"
async def test_memory_stream_pair() -> None:
a, b = memory_stream_pair()
await a.send_all(b"123")
await b.send_all(b"abc")
assert await b.receive_some(10) == b"123"
assert await a.receive_some(10) == b"abc"
await a.send_eof()
assert await b.receive_some(10) == b""
async def sender() -> None:
await wait_all_tasks_blocked()
await b.send_all(b"xyz")
async def receiver() -> None:
assert await a.receive_some(10) == b"xyz"
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver)
nursery.start_soon(sender)
async def test_memory_streams_with_generic_tests() -> None:
async def one_way_stream_maker() -> tuple[MemorySendStream, MemoryReceiveStream]:
return memory_stream_one_way_pair()
await check_one_way_stream(one_way_stream_maker, None)
async def half_closeable_stream_maker() -> tuple[
StapledStream[MemorySendStream, MemoryReceiveStream],
StapledStream[MemorySendStream, MemoryReceiveStream],
]:
return memory_stream_pair()
await check_half_closeable_stream(half_closeable_stream_maker, None)
async def test_lockstep_streams_with_generic_tests() -> None:
async def one_way_stream_maker() -> tuple[SendStream, ReceiveStream]:
return lockstep_stream_one_way_pair()
await check_one_way_stream(one_way_stream_maker, one_way_stream_maker)
async def two_way_stream_maker() -> tuple[
StapledStream[SendStream, ReceiveStream],
StapledStream[SendStream, ReceiveStream],
]:
return lockstep_stream_pair()
await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
async def test_open_stream_to_socket_listener() -> None:
async def check(listener: SocketListener) -> None:
async with listener:
client_stream = await open_stream_to_socket_listener(listener)
async with client_stream:
server_stream = await listener.accept()
async with server_stream:
await client_stream.send_all(b"x")
assert await server_stream.receive_some(1) == b"x"
# Listener bound to localhost
sock = tsocket.socket()
await sock.bind(("127.0.0.1", 0))
sock.listen(10)
await check(SocketListener(sock))
# Listener bound to IPv4 wildcard (needs special handling)
sock = tsocket.socket()
await sock.bind(("0.0.0.0", 0))
sock.listen(10)
await check(SocketListener(sock))
# true on all CI systems
if can_bind_ipv6: # pragma: no branch
# Listener bound to IPv6 wildcard (needs special handling)
sock = tsocket.socket(family=tsocket.AF_INET6)
await sock.bind(("::", 0))
sock.listen(10)
await check(SocketListener(sock))
if hasattr(tsocket, "AF_UNIX"):
# Listener bound to Unix-domain socket
sock = tsocket.socket(family=tsocket.AF_UNIX)
# can't use pytest's tmpdir; if we try then macOS says "OSError:
# AF_UNIX path too long"
with tempfile.TemporaryDirectory() as tmpdir:
path = f"{tmpdir}/sock"
await sock.bind(path)
sock.listen(10)
await check(SocketListener(sock))
def test_trio_test() -> None:
async def busy_kitchen(
*,
mock_clock: object,
autojump_clock: object,
) -> None: ... # pragma: no cover
with pytest.raises(ValueError, match="^too many clocks spoil the broth!$"):
trio_test(busy_kitchen)(
mock_clock=MockClock(),
autojump_clock=MockClock(autojump_threshold=0),
)
@@ -0,0 +1,374 @@
from __future__ import annotations
import re
import sys
from types import TracebackType
from typing import Any
import pytest
import trio
from trio.testing import Matcher, RaisesGroup
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
def wrap_escape(s: str) -> str:
return "^" + re.escape(s) + "$"
def test_raises_group() -> None:
with pytest.raises(
ValueError,
match=wrap_escape(
f'Invalid argument "{TypeError()!r}" must be exception type, Matcher, or RaisesGroup.',
),
):
RaisesGroup(TypeError())
with RaisesGroup(ValueError):
raise ExceptionGroup("foo", (ValueError(),))
with RaisesGroup(SyntaxError):
with RaisesGroup(ValueError):
raise ExceptionGroup("foo", (SyntaxError(),))
# multiple exceptions
with RaisesGroup(ValueError, SyntaxError):
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
# order doesn't matter
with RaisesGroup(SyntaxError, ValueError):
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
# nested exceptions
with RaisesGroup(RaisesGroup(ValueError)):
raise ExceptionGroup("foo", (ExceptionGroup("bar", (ValueError(),)),))
with RaisesGroup(
SyntaxError,
RaisesGroup(ValueError),
RaisesGroup(RuntimeError),
):
raise ExceptionGroup(
"foo",
(
SyntaxError(),
ExceptionGroup("bar", (ValueError(),)),
ExceptionGroup("", (RuntimeError(),)),
),
)
# will error if there's excess exceptions
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError):
raise ExceptionGroup("", (ValueError(), ValueError()))
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError):
raise ExceptionGroup("", (RuntimeError(), ValueError()))
# will error if there's missing exceptions
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, ValueError):
raise ExceptionGroup("", (ValueError(),))
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, SyntaxError):
raise ExceptionGroup("", (ValueError(),))
def test_flatten_subgroups() -> None:
# loose semantics, as with expect*
with RaisesGroup(ValueError, flatten_subgroups=True):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
with RaisesGroup(ValueError, TypeError, flatten_subgroups=True):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(), TypeError())),))
with RaisesGroup(ValueError, TypeError, flatten_subgroups=True):
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()]), TypeError()])
# mixed loose is possible if you want it to be at least N deep
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
raise ExceptionGroup(
"",
(ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),),
)
with pytest.raises(ExceptionGroup):
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
raise ExceptionGroup("", (ValueError(),))
# but not the other way around
with pytest.raises(
ValueError,
match="^You cannot specify a nested structure inside a RaisesGroup with",
):
RaisesGroup(RaisesGroup(ValueError), flatten_subgroups=True) # type: ignore[call-overload]
def test_catch_unwrapped_exceptions() -> None:
# Catches lone exceptions with strict=False
# just as except* would
with RaisesGroup(ValueError, allow_unwrapped=True):
raise ValueError
# expecting multiple unwrapped exceptions is not possible
with pytest.raises(
ValueError,
match="^You cannot specify multiple exceptions with",
):
RaisesGroup(SyntaxError, ValueError, allow_unwrapped=True) # type: ignore[call-overload]
# if users want one of several exception types they need to use a Matcher
# (which the error message suggests)
with RaisesGroup(
Matcher(check=lambda e: isinstance(e, (SyntaxError, ValueError))),
allow_unwrapped=True,
):
raise ValueError
# Unwrapped nested `RaisesGroup` is likely a user error, so we raise an error.
with pytest.raises(ValueError, match="has no effect when expecting"):
RaisesGroup(RaisesGroup(ValueError), allow_unwrapped=True) # type: ignore[call-overload]
# But it *can* be used to check for nesting level +- 1 if they move it to
# the nested RaisesGroup. Users should probably use `Matcher`s instead though.
with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)):
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)):
raise ExceptionGroup("", [ValueError()])
# with allow_unwrapped=False (default) it will not be caught
with pytest.raises(ValueError, match="^value error text$"):
with RaisesGroup(ValueError):
raise ValueError("value error text")
# allow_unwrapped on it's own won't match against nested groups
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, allow_unwrapped=True):
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
# for that you need both allow_unwrapped and flatten_subgroups
with RaisesGroup(ValueError, allow_unwrapped=True, flatten_subgroups=True):
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
# code coverage
with pytest.raises(TypeError):
with RaisesGroup(ValueError, allow_unwrapped=True):
raise TypeError
def test_match() -> None:
# supports match string
with RaisesGroup(ValueError, match="bar"):
raise ExceptionGroup("bar", (ValueError(),))
# now also works with ^$
with RaisesGroup(ValueError, match="^bar$"):
raise ExceptionGroup("bar", (ValueError(),))
# it also includes notes
with RaisesGroup(ValueError, match="my note"):
e = ExceptionGroup("bar", (ValueError(),))
e.add_note("my note")
raise e
# and technically you can match it all with ^$
# but you're probably better off using a Matcher at that point
with RaisesGroup(ValueError, match="^bar\nmy note$"):
e = ExceptionGroup("bar", (ValueError(),))
e.add_note("my note")
raise e
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, match="foo"):
raise ExceptionGroup("bar", (ValueError(),))
def test_check() -> None:
exc = ExceptionGroup("", (ValueError(),))
with RaisesGroup(ValueError, check=lambda x: x is exc):
raise exc
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, check=lambda x: x is exc):
raise ExceptionGroup("", (ValueError(),))
def test_unwrapped_match_check() -> None:
def my_check(e: object) -> bool: # pragma: no cover
return True
msg = (
"`allow_unwrapped=True` bypasses the `match` and `check` parameters"
" if the exception is unwrapped. If you intended to match/check the"
" exception you should use a `Matcher` object. If you want to match/check"
" the exceptiongroup when the exception *is* wrapped you need to"
" do e.g. `if isinstance(exc.value, ExceptionGroup):"
" assert RaisesGroup(...).matches(exc.value)` afterwards."
)
with pytest.raises(ValueError, match=re.escape(msg)):
RaisesGroup(ValueError, allow_unwrapped=True, match="foo") # type: ignore[call-overload]
with pytest.raises(ValueError, match=re.escape(msg)):
RaisesGroup(ValueError, allow_unwrapped=True, check=my_check) # type: ignore[call-overload]
# Users should instead use a Matcher
rg = RaisesGroup(Matcher(ValueError, match="^foo$"), allow_unwrapped=True)
with rg:
raise ValueError("foo")
with rg:
raise ExceptionGroup("", [ValueError("foo")])
# or if they wanted to match/check the group, do a conditional `.matches()`
with RaisesGroup(ValueError, allow_unwrapped=True) as exc:
raise ExceptionGroup("bar", [ValueError("foo")])
if isinstance(exc.value, ExceptionGroup): # pragma: no branch
assert RaisesGroup(ValueError, match="bar").matches(exc.value)
def test_RaisesGroup_matches() -> None:
rg = RaisesGroup(ValueError)
assert not rg.matches(None)
assert not rg.matches(ValueError())
assert rg.matches(ExceptionGroup("", (ValueError(),)))
def test_message() -> None:
def check_message(message: str, body: RaisesGroup[Any]) -> None:
with pytest.raises(
AssertionError,
match=f"^DID NOT RAISE any exception, expected {re.escape(message)}$",
):
with body:
...
# basic
check_message("ExceptionGroup(ValueError)", RaisesGroup(ValueError))
# multiple exceptions
check_message(
"ExceptionGroup(ValueError, ValueError)",
RaisesGroup(ValueError, ValueError),
)
# nested
check_message(
"ExceptionGroup(ExceptionGroup(ValueError))",
RaisesGroup(RaisesGroup(ValueError)),
)
# Matcher
check_message(
"ExceptionGroup(Matcher(ValueError, match='my_str'))",
RaisesGroup(Matcher(ValueError, "my_str")),
)
check_message(
"ExceptionGroup(Matcher(match='my_str'))",
RaisesGroup(Matcher(match="my_str")),
)
# BaseExceptionGroup
check_message(
"BaseExceptionGroup(KeyboardInterrupt)",
RaisesGroup(KeyboardInterrupt),
)
# BaseExceptionGroup with type inside Matcher
check_message(
"BaseExceptionGroup(Matcher(KeyboardInterrupt))",
RaisesGroup(Matcher(KeyboardInterrupt)),
)
# Base-ness transfers to parent containers
check_message(
"BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt))",
RaisesGroup(RaisesGroup(KeyboardInterrupt)),
)
# but not to child containers
check_message(
"BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt), ExceptionGroup(ValueError))",
RaisesGroup(RaisesGroup(KeyboardInterrupt), RaisesGroup(ValueError)),
)
def test_matcher() -> None:
with pytest.raises(
ValueError,
match="^You must specify at least one parameter to match on.$",
):
Matcher() # type: ignore[call-overload]
with pytest.raises(
ValueError,
match=f"^exception_type {re.escape(repr(object))} must be a subclass of BaseException$",
):
Matcher(object) # type: ignore[type-var]
with RaisesGroup(Matcher(ValueError)):
raise ExceptionGroup("", (ValueError(),))
with pytest.raises(ExceptionGroup):
with RaisesGroup(Matcher(TypeError)):
raise ExceptionGroup("", (ValueError(),))
def test_matcher_match() -> None:
with RaisesGroup(Matcher(ValueError, "foo")):
raise ExceptionGroup("", (ValueError("foo"),))
with pytest.raises(ExceptionGroup):
with RaisesGroup(Matcher(ValueError, "foo")):
raise ExceptionGroup("", (ValueError("bar"),))
# Can be used without specifying the type
with RaisesGroup(Matcher(match="foo")):
raise ExceptionGroup("", (ValueError("foo"),))
with pytest.raises(ExceptionGroup):
with RaisesGroup(Matcher(match="foo")):
raise ExceptionGroup("", (ValueError("bar"),))
# check ^$
with RaisesGroup(Matcher(ValueError, match="^bar$")):
raise ExceptionGroup("", [ValueError("bar")])
with pytest.raises(ExceptionGroup):
with RaisesGroup(Matcher(ValueError, match="^bar$")):
raise ExceptionGroup("", [ValueError("barr")])
def test_Matcher_check() -> None:
def check_oserror_and_errno_is_5(e: BaseException) -> bool:
return isinstance(e, OSError) and e.errno == 5
with RaisesGroup(Matcher(check=check_oserror_and_errno_is_5)):
raise ExceptionGroup("", (OSError(5, ""),))
# specifying exception_type narrows the parameter type to the callable
def check_errno_is_5(e: OSError) -> bool:
return e.errno == 5
with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
raise ExceptionGroup("", (OSError(5, ""),))
with pytest.raises(ExceptionGroup):
with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
raise ExceptionGroup("", (OSError(6, ""),))
def test_matcher_tostring() -> None:
assert str(Matcher(ValueError)) == "Matcher(ValueError)"
assert str(Matcher(match="[a-z]")) == "Matcher(match='[a-z]')"
pattern_no_flags = re.compile("noflag", 0)
assert str(Matcher(match=pattern_no_flags)) == "Matcher(match='noflag')"
pattern_flags = re.compile("noflag", re.IGNORECASE)
assert str(Matcher(match=pattern_flags)) == f"Matcher(match={pattern_flags!r})"
assert (
str(Matcher(ValueError, match="re", check=bool))
== f"Matcher(ValueError, match='re', check={bool!r})"
)
def test__ExceptionInfo(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
trio.testing._raises_group,
"ExceptionInfo",
trio.testing._raises_group._ExceptionInfo,
)
with trio.testing.RaisesGroup(ValueError) as excinfo:
raise ExceptionGroup("", (ValueError("hello"),))
assert excinfo.type is ExceptionGroup
assert excinfo.value.exceptions[0].args == ("hello",)
assert isinstance(excinfo.tb, TracebackType)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,272 @@
import time
from typing import Awaitable, Callable, Protocol, TypeVar
import outcome
import pytest
import trio
from .. import _core
from .._core._tests.tutil import slow
from .._timeouts import (
TooSlowError,
fail_after,
fail_at,
move_on_after,
move_on_at,
sleep,
sleep_forever,
sleep_until,
)
from ..testing import assert_checkpoints
T = TypeVar("T")
async def check_takes_about(f: Callable[[], Awaitable[T]], expected_dur: float) -> T:
start = time.perf_counter()
result = await outcome.acapture(f)
dur = time.perf_counter() - start
print(dur / expected_dur)
# 1.5 is an arbitrary fudge factor because there's always some delay
# between when we become eligible to wake up and when we actually do. We
# used to sleep for 0.05, and regularly observed overruns of 1.6x on
# Appveyor, and then started seeing overruns of 2.3x on Travis's macOS, so
# now we bumped up the sleep to 1 second, marked the tests as slow, and
# hopefully now the proportional error will be less huge.
#
# We also also for durations that are a hair shorter than expected. For
# example, here's a run on Windows where a 1.0 second sleep was measured
# to take 0.9999999999999858 seconds:
# https://ci.appveyor.com/project/njsmith/trio/build/1.0.768/job/3lbdyxl63q3h9s21
# I believe that what happened here is that Windows's low clock resolution
# meant that our calls to time.monotonic() returned exactly the same
# values as the calls inside the actual run loop, but the two subtractions
# returned slightly different values because the run loop's clock adds a
# random floating point offset to both times, which should cancel out, but
# lol floating point we got slightly different rounding errors. (That
# value above is exactly 128 ULPs below 1.0, which would make sense if it
# started as a 1 ULP error at a different dynamic range.)
assert (1 - 1e-8) <= (dur / expected_dur) < 1.5
return result.unwrap()
# How long to (attempt to) sleep for when testing. Smaller numbers make the
# test suite go faster.
TARGET = 1.0
@slow
async def test_sleep() -> None:
async def sleep_1() -> None:
await sleep_until(_core.current_time() + TARGET)
await check_takes_about(sleep_1, TARGET)
async def sleep_2() -> None:
await sleep(TARGET)
await check_takes_about(sleep_2, TARGET)
with assert_checkpoints():
await sleep(0)
# This also serves as a test of the trivial move_on_at
with move_on_at(_core.current_time()):
with pytest.raises(_core.Cancelled):
await sleep(0)
@slow
async def test_move_on_after() -> None:
async def sleep_3() -> None:
with move_on_after(TARGET):
await sleep(100)
await check_takes_about(sleep_3, TARGET)
async def test_cannot_wake_sleep_forever() -> None:
# Test an error occurs if you manually wake sleep_forever().
task = trio.lowlevel.current_task()
async def wake_task() -> None:
await trio.lowlevel.checkpoint()
trio.lowlevel.reschedule(task, outcome.Value(None))
async with trio.open_nursery() as nursery:
nursery.start_soon(wake_task)
with pytest.raises(RuntimeError):
await trio.sleep_forever()
class TimeoutScope(Protocol):
def __call__(self, seconds: float, *, shield: bool) -> trio.CancelScope: ...
@pytest.mark.parametrize("scope", [move_on_after, fail_after])
async def test_context_shields_from_outer(scope: TimeoutScope) -> None:
with _core.CancelScope() as outer, scope(TARGET, shield=True) as inner:
outer.cancel()
try:
await trio.lowlevel.checkpoint()
except trio.Cancelled:
pytest.fail("shield didn't work")
inner.shield = False
with pytest.raises(trio.Cancelled):
await trio.lowlevel.checkpoint()
@slow
async def test_move_on_after_moves_on_even_if_shielded() -> None:
async def task() -> None:
with _core.CancelScope() as outer, move_on_after(TARGET, shield=True):
outer.cancel()
# The outer scope is cancelled, but this task is protected by the
# shield, so it manages to get to sleep until deadline is met
await sleep_forever()
await check_takes_about(task, TARGET)
@slow
async def test_fail_after_fails_even_if_shielded() -> None:
async def task() -> None:
with pytest.raises(TooSlowError), _core.CancelScope() as outer, fail_after(
TARGET,
shield=True,
):
outer.cancel()
# The outer scope is cancelled, but this task is protected by the
# shield, so it manages to get to sleep until deadline is met
await sleep_forever()
await check_takes_about(task, TARGET)
@slow
async def test_fail() -> None:
async def sleep_4() -> None:
with fail_at(_core.current_time() + TARGET):
await sleep(100)
with pytest.raises(TooSlowError):
await check_takes_about(sleep_4, TARGET)
with fail_at(_core.current_time() + 100):
await sleep(0)
async def sleep_5() -> None:
with fail_after(TARGET):
await sleep(100)
with pytest.raises(TooSlowError):
await check_takes_about(sleep_5, TARGET)
with fail_after(100):
await sleep(0)
async def test_timeouts_raise_value_error() -> None:
# deadlines are allowed to be negative, but not delays.
# neither delays nor deadlines are allowed to be NaN
nan = float("nan")
for fun, val in (
(sleep, -1),
(sleep, nan),
(sleep_until, nan),
):
with pytest.raises(
ValueError,
match="^(deadline|`seconds`) must (not )*be (non-negative|NaN)$",
):
await fun(val)
for cm, val in (
(fail_after, -1),
(fail_after, nan),
(fail_at, nan),
(move_on_after, -1),
(move_on_after, nan),
(move_on_at, nan),
):
with pytest.raises(
ValueError,
match="^(deadline|`seconds`) must (not )*be (non-negative|NaN)$",
):
with cm(val):
pass # pragma: no cover
async def test_timeout_deadline_on_entry(mock_clock: _core.MockClock) -> None:
rcs = move_on_after(5)
assert rcs.relative_deadline == 5
mock_clock.jump(3)
start = _core.current_time()
with rcs as cs:
assert cs.is_relative is None
# This would previously be start+2
assert cs.deadline == start + 5
assert cs.relative_deadline == 5
cs.deadline = start + 3
assert cs.deadline == start + 3
assert cs.relative_deadline == 3
cs.relative_deadline = 4
assert cs.deadline == start + 4
assert cs.relative_deadline == 4
rcs = move_on_after(5)
assert rcs.shield is False
rcs.shield = True
assert rcs.shield is True
mock_clock.jump(3)
start = _core.current_time()
with rcs as cs:
assert cs.deadline == start + 5
assert rcs is cs
async def test_invalid_access_unentered(mock_clock: _core.MockClock) -> None:
cs = move_on_after(5)
mock_clock.jump(3)
start = _core.current_time()
match_str = "^unentered relative cancel scope does not have an absolute deadline"
with pytest.warns(DeprecationWarning, match=match_str):
assert cs.deadline == start + 5
mock_clock.jump(1)
# this is hella sketchy, but they *have* been warned
with pytest.warns(DeprecationWarning, match=match_str):
assert cs.deadline == start + 6
with pytest.warns(DeprecationWarning, match=match_str):
cs.deadline = 7
# now transformed into absolute
assert cs.deadline == 7
assert not cs.is_relative
cs = move_on_at(5)
match_str = (
"^unentered non-relative cancel scope does not have a relative deadline$"
)
with pytest.raises(RuntimeError, match=match_str):
assert cs.relative_deadline
with pytest.raises(RuntimeError, match=match_str):
cs.relative_deadline = 7
@pytest.mark.xfail(reason="not implemented")
async def test_fail_access_before_entering() -> None: # pragma: no cover
my_fail_at = fail_at(5)
assert my_fail_at.deadline # type: ignore[attr-defined]
my_fail_after = fail_after(5)
assert my_fail_after.relative_deadline # type: ignore[attr-defined]

Some files were not shown because too many files have changed in this diff Show More