"""
The parsing Context used internally for tracking parsed arguments and configuration overrides.
:author: Doug Skrypa
"""
# pylint: disable=R0801
from __future__ import annotations
import sys
from collections import defaultdict
from contextlib import AbstractContextManager
from contextvars import ContextVar
from enum import Enum
from functools import cached_property
from inspect import Parameter as _Parameter, Signature
from typing import TYPE_CHECKING, Any, Callable, Collection, Iterator, Optional, Sequence, Union, cast
from .config import DEFAULT_CONFIG, CommandConfig
from .error_handling import ErrorHandler, NullErrorHandler, extended_error_handler
from .exceptions import NoActiveContext
from .utils import Terminal, _NotSet
if TYPE_CHECKING:
from .command_parameters import CommandParameters
from .commands import Command
from .parameters import ActionFlag, Option, Parameter
from .typing import AnyConfig, Bool, CommandObj, CommandType, OptStr, ParamOrGroup, PathLike, StrSeq # noqa
__all__ = ['Context', 'ctx', 'get_current_context', 'get_or_create_context', 'get_context', 'get_parsed', 'get_raw_arg']
_context_stack = ContextVar('cli_command_parser.context.stack')
_TERMINAL = Terminal()
Argv = Optional['StrSeq']
[docs]
class Context(AbstractContextManager): # Extending AbstractContextManager to make PyCharm's type checker happy
"""
The parsing context.
Holds user input while parsing, and holds the parsed values. Handles config overrides / hierarchy for settings that
affect parser behavior.
"""
config: CommandConfig
prog: OptStr = None
allow_argv_prog: Bool = True
_command_obj: CommandObj = None
_terminal_width: Optional[int]
_provided: dict[ParamOrGroup, int]
def __init__(
self,
argv: Argv = None,
command_cls: Optional[CommandType] = None,
*,
parent: Optional[Context] = None,
config: AnyConfig = None,
terminal_width: int = None,
allow_argv_prog: Bool = None,
command: Optional[CommandObj] = None,
**kwargs,
):
self.command_cls = command_cls
self.command = command
self.parent = parent
self.actions_taken = 0
self.config = _normalize_config(config, kwargs, parent, command_cls)
if parent:
self._set_argv(parent.prog, argv)
self._parsed = parent._parsed.copy()
self._provided = parent._provided.copy()
self._terminal_width = parent._terminal_width if terminal_width is None else terminal_width
self.allow_argv_prog = parent.allow_argv_prog if allow_argv_prog is None else allow_argv_prog
else:
self._set_argv(None, argv)
self._parsed = {}
self._provided = defaultdict(int)
self._terminal_width = terminal_width
if allow_argv_prog is not None:
self.allow_argv_prog = allow_argv_prog
# region Internal Methods
[docs]
@classmethod
def for_prog(cls, prog: PathLike, *args, **kwargs) -> Context:
self = cls(*args, **kwargs)
self.prog = getattr(prog, 'name', prog)
return self
def _set_argv(self, prog: OptStr, argv: Argv):
if prog:
self.prog = prog
self.argv = sys.argv[1:] if argv is None else argv
elif argv is None:
self.prog, *self.argv = sys.argv
else:
self.argv = argv
self.remaining = list(self.argv)
def _sub_context(
self, command_cls: CommandType, argv: Argv = None, command: CommandObj = None, **kwargs
) -> Context:
return self.__class__(
self.remaining if argv is None else argv,
command_cls,
parent=self,
command=self.command if command is None else command,
**kwargs,
)
def __repr__(self) -> str:
command = getattr(self.command_cls, '__name__', None)
prog, argv, allow_argv_prog = self.prog, self.argv, self.allow_argv_prog
return f'<{self.__class__.__name__}[{command=!s}, {prog=}, {allow_argv_prog=}, {argv=}]>'
def __enter__(self) -> Context:
try:
_context_stack.get().append(self)
except LookupError:
_context_stack.set([self])
return self
def __exit__(self, exc_type, exc_val, exc_tb):
_context_stack.get().pop()
def __contains__(self, param: Union[ParamOrGroup, str, Any]) -> bool:
try:
self._parsed[param]
except KeyError:
if isinstance(param, str):
try:
next((v for p, v in self._parsed.items() if p.name == param))
except StopIteration:
return False
else:
return True
return False
else:
return True
# endregion
[docs]
@property
def terminal_width(self) -> int:
"""Returns the current terminal width as the number of characters that fit on a single line."""
if self._terminal_width is not None:
return self._terminal_width
return _TERMINAL.width
[docs]
def get_parsed(
self,
command: Command = None,
*,
exclude: Collection[Parameter] = (),
recursive: Bool = True,
default: Any = None,
include_defaults: Bool = True,
) -> dict[str, Any]:
"""
Returns all of the parsed arguments as a dictionary.
The :ref:`get_parsed() <advanced:Parsed Args as a Dictionary>` helper function provides an easier way to access
this functionality.
:param command: An initialized Command object for which arguments were already parsed.
:param exclude: Parameter objects that should be excluded from the returned results
:param recursive: Whether parsed arguments should be recursively gathered from parent Commands
:param default: The default value to use for parameters that raise :class:`.MissingArgument` when attempting to
obtain their result values.
:param include_defaults: Whether default values should be included in the returned results. If False, only
user-provided values will be included.
:return: A dictionary containing all of the arguments that were parsed. The keys in the returned dict match
the names assigned to the Parameters in the Command associated with this Context.
"""
if command is None:
command = self.command
with self:
if recursive and self.parent:
parsed = self.parent.get_parsed(
command, exclude=exclude, recursive=recursive, default=default, include_defaults=include_defaults
)
else:
parsed = {}
# TODO: Add way to get a nested dict with ParamGroup names as the keys of the nested sections?
if self.params:
for param in self.params.iter_params(exclude):
if include_defaults or param in self._parsed:
parsed[param.name] = param.result(command, default)
return parsed
[docs]
@cached_property
def params(self) -> Optional[CommandParameters]:
"""
The :class:`.CommandParameters` object that contains the categorized Parameters from the Command associated
with this Context.
"""
if self.command_cls is not None:
return self.command_cls.__class__.params(self.command_cls)
return None
[docs]
def get_error_handler(self) -> Union[ErrorHandler, NullErrorHandler]:
"""Returns the :class:`.ErrorHandler` configured to be used."""
if (error_handler := self.config.error_handler) is _NotSet:
return extended_error_handler
elif error_handler is None:
return NullErrorHandler()
else:
return error_handler
# region Parsing Methods - Generally not intended to be called by users
[docs]
def has_parsed_value(self, param: Parameter) -> bool:
return param in self._parsed
[docs]
def get_parsed_value(self, param: Parameter, default=_NotSet):
"""Not intended to be called by users. Used by Parameters to access their parsed values."""
return self._parsed.get(param, default)
[docs]
def set_parsed_value(self, param: Parameter, value: Any):
"""Not intended to be called by users. Used by Parameters during parsing to store parsed values."""
self._parsed[param] = value
[docs]
def pop_parsed_value(self, param: Parameter):
"""Not intended to be called by users. Used by Parameters during parsing if backtracking is necessary."""
self._provided[param] = 0
return self._parsed.pop(param)
[docs]
def roll_back_parsed_values(self, param: Parameter, count: int):
"""Not intended to be called by users. Used during parsing as part of backtracking."""
values = self._parsed[param]
self._parsed[param] = values[:-count]
self._provided[param] -= count
return values[-count:]
[docs]
def record_action(self, param: ParamOrGroup, val_count: int = 1):
"""
Not intended to be called by users. Used by Parameters during parsing to indicate that they were provided.
"""
self._provided[param] += val_count
[docs]
def num_provided(self, param: ParamOrGroup) -> int:
"""Not intended to be called by users. Used by Parameters during parsing to handle nargs."""
return self._provided[param]
[docs]
def get_missing(self) -> list[Parameter]:
"""Not intended to be called by users. Used during parsing to determine if any Parameters are missing."""
return [p for p in self.params.required_check_params() if not self._provided[p]]
[docs]
def missing_options_with_env_var(self) -> Iterator[Option]:
"""Yields Option parameters that have an environment variable configured, and did not have any CLI values."""
yield from (p for p in self.params.options if p.env_var and not self._provided[p])
# endregion
# region Actions
@cached_property
def _parsed_action_flags(self) -> tuple[int, list[ActionFlag], list[ActionFlag]]:
"""
Not intended to be accessed by users. Returns a tuple containing the total number of action flags provided, the
action flags to run before main, and the action flags to run after main.
"""
try:
before_main, after_main = self.params.split_action_flags # Each part is already sorted
except AttributeError: # self.command_cls is None
return 0, [], []
parsed = self._parsed
before_main = [p for p in before_main if p in parsed] if before_main else []
after_main = [p for p in after_main if p in parsed] if after_main else []
return len(before_main) + len(after_main), before_main, after_main
[docs]
@property
def action_flag_count(self) -> int:
"""Not intended to be accessed by users. Returns the count of parsed action flags."""
return self._parsed_action_flags[0]
[docs]
@cached_property
def all_action_flags(self) -> list[ActionFlag]:
"""Not intended to be accessed by users. Returns all parsed action flags."""
_, before_main, after_main = self._parsed_action_flags
return before_main + after_main
[docs]
@cached_property
def categorized_action_flags(self) -> dict[ActionPhase, Sequence[ActionFlag]]:
"""
Not intended to be accessed by users. Returns a dict of parsed action flags, categorized by the
:class:`ActionPhase` during which they will run.
"""
_, before_main, after_main = self._parsed_action_flags
init_actions, before_actions = [], []
for flag in before_main:
if flag.always_available:
init_actions.append(flag)
else:
before_actions.append(flag)
return {
ActionPhase.PRE_INIT: init_actions,
ActionPhase.BEFORE_MAIN: before_actions,
ActionPhase.AFTER_MAIN: after_main,
}
[docs]
def iter_action_flags(self, phase: ActionPhase) -> Iterator[ActionFlag]:
"""
Not intended to be called by users. Iterator that yields action flags to be executed during the specified
phase while incrementing the counter of actions taken.
:param phase: The current :class:`ActionPhase`
"""
for action_flag in self.categorized_action_flags[phase]:
self.actions_taken += 1
yield action_flag
# endregion
def _normalize_config(
config: AnyConfig, kwargs: dict[str, Any], parent: Context | None, command: CommandType | None
) -> CommandConfig:
if config is not None:
if kwargs:
raise TypeError(f'Cannot combine {config=} with keyword config arguments={kwargs}')
elif isinstance(config, CommandConfig):
return config
kwargs = config
if parent:
for key, val in parent.config._data.items():
kwargs.setdefault(key, val)
return CommandConfig(parent=command.__class__.config(command) if command is not None else None, **kwargs)
[docs]
class ActionPhase(Enum):
PRE_INIT = 0
BEFORE_MAIN = 1
AFTER_MAIN = 2
# def __next__(self) -> ActionPhase:
# try:
# return self._value2member_map_[self._value_ + 1] # noqa
# except KeyError:
# raise StopIteration
[docs]
class ContextProxy:
"""
Proxy for the currently active :class:`Context` object. Allows usage similar to the ``request`` object in Flask.
This class should not be instantiated by users - use the common :data:`ctx` instance.
"""
__slots__ = ()
# region Generic Proxy Methods
def __getattr__(self, attr: str):
return getattr(get_current_context(), attr)
def __setattr__(self, attr: str, value):
return setattr(get_current_context(), attr, value)
def __eq__(self, other) -> bool:
return get_current_context() == other
def __contains__(self, item) -> bool:
return item in get_current_context()
def __enter__(self) -> Context:
# The current context is already active, so there's no need to re-enter it - it can just be returned
return get_current_context()
def __exit__(self, exc_type, exc_val, exc_tb):
pass
# endregion
# region Proxied Parsing Methods
[docs]
def has_parsed_value(self, param: Parameter) -> bool:
return get_current_context().has_parsed_value(param)
[docs]
def get_parsed_value(self, param: Parameter):
return get_current_context().get_parsed_value(param)
[docs]
def set_parsed_value(self, param: Parameter, value: Any):
get_current_context().set_parsed_value(param, value)
[docs]
def record_action(self, param: ParamOrGroup, val_count: int = 1):
get_current_context().record_action(param, val_count)
[docs]
def num_provided(self, param: ParamOrGroup) -> int:
return get_current_context().num_provided(param)
# endregion
# region Properties with Inactive Handlers
[docs]
@property
def terminal_width(self) -> int:
if context := get_current_context(True):
return context.terminal_width
else:
return _TERMINAL.width
[docs]
@property
def config(self) -> CommandConfig:
if context := get_current_context(True):
return context.config
else:
return DEFAULT_CONFIG
# endregion
ctx: Context = cast(Context, ContextProxy())
# region Public / Semi-Public Functions
[docs]
def get_current_context(silent: bool = False) -> Optional[Context]:
"""
Get the currently active parsing context.
:param silent: If True, allow this function to return ``None`` if there is no active :class:`Context`
:return: The active :class:`Context` object
:raises: :class:`~.exceptions.NoActiveContext` if there is no active Context and ``silent=False`` (default)
"""
try:
return _context_stack.get()[-1]
except (LookupError, IndexError):
if silent:
return None
raise NoActiveContext('There is no active context') from None
[docs]
def get_or_create_context(
command_cls: CommandType, argv: Argv = None, *, command: CommandObj = None, **kwargs
) -> Context:
"""
Used internally by Commands to re-use an existing user-activated Context, or to create a new Context if there was
no active Context.
"""
if not (context := get_current_context(True)):
return Context(argv, command_cls, command=command, **kwargs)
elif argv is None and command is None and context.command_cls is command_cls and not kwargs:
return context
else:
return context._sub_context(command_cls, argv=argv, command=command, **kwargs)
[docs]
def get_context(command: Command) -> Context:
"""
:param command: An initialized Command object
:return: The Context associated with the given Command
"""
try:
return command._Command__ctx # noqa
except AttributeError as e:
raise TypeError('get_context only supports Command objects') from e
[docs]
def get_parsed(
command: Command, to_call: Callable = None, default: Any = None, include_defaults: Bool = True
) -> dict[str, Any]:
"""
Provides a way to obtain all of the arguments that were parsed for the given Command as a dictionary.
If the parsed arguments are intended to be used to call a particular function/method, or to initialize a particular
class, then that callable can be provided as the ``to_call`` parameter to filter the parsed arguments to only the
ones that would be accepted by it. It will not be called by this function.
If the callable accepts any :attr:`VAR_KEYWORD <python:inspect.Parameter.kind>` parameters (i.e., ``**kwargs``),
then those param names will not be used for filtering. That is, if the command has a Parameter named ``kwargs``
and the callable accepts ``**kwargs``, the ``kwargs`` key will not be included in the argument dict returned by
this function. If any of the parameters of the given callable cannot be passed as a keyword argument (i.e.,
:attr:`POSITIONAL_ONLY or VAR_POSITIONAL <python:inspect.Parameter.kind>`), then they must be handled after calling
this function. They will be included in the returned dict.
:param command: An initialized Command object for which arguments were already parsed.
:param to_call: A :class:`callable <python:collections.abc.Callable>` (function, method, class, etc.) that should
be used to filter the parsed arguments. If provided, then only the keys that match the callable's signature will
be included in the returned dictionary of parsed arguments.
:param default: The default value to use for parameters that raise :class:`.MissingArgument` when attempting to
obtain their result values.
:param include_defaults: Whether default values should be included in the returned results. If False, only
user-provided values will be included.
:return: A dictionary containing all of the (optionally filtered) arguments that were parsed. The keys in the
returned dict match the names assigned to the Parameters in the given Command.
"""
parsed = get_context(command).get_parsed(command, default=default, include_defaults=include_defaults)
if to_call is not None:
sig = Signature.from_callable(to_call)
keys = {k for k, p in sig.parameters.items() if p.kind != _Parameter.VAR_KEYWORD}
parsed = {k: v for k, v in parsed.items() if k in keys}
return parsed
[docs]
def get_raw_arg(command: Command, parameter: Parameter) -> Any:
"""Retrieve the raw parsed argument value(s) provided for the given Parameter"""
return get_context(command).get_parsed_value(parameter)
# endregion