"""
The parsing Context used internally for tracking parsed arguments and configuration overrides.
:author: Doug Skrypa
"""
# pylint: disable=R0801
from __future__ import annotations
import sys
from collections import defaultdict
from contextlib import AbstractContextManager
from contextvars import ContextVar
from enum import Enum
from functools import cached_property
from inspect import Parameter as _Parameter, Signature
from typing import TYPE_CHECKING, Any, Callable, Collection, Iterator, Optional, Sequence, Union, cast
from .config import DEFAULT_CONFIG, CommandConfig
from .error_handling import ErrorHandler, NullErrorHandler, extended_error_handler
from .exceptions import NoActiveContext
from .utils import Terminal, _NotSet
if TYPE_CHECKING:
    from .command_parameters import CommandParameters
    from .commands import Command
    from .parameters import ActionFlag, BaseOption, Option, Parameter
    from .typing import AnyConfig, Bool, CommandObj, CommandType, OptStr, ParamOrGroup, PathLike, StrSeq  # noqa
__all__ = ['Context', 'ctx', 'get_current_context', 'get_or_create_context', 'get_context', 'get_parsed', 'get_raw_arg']
_context_stack = ContextVar('cli_command_parser.context.stack')
_TERMINAL = Terminal()
Argv = Optional['StrSeq']
[docs]
class Context(AbstractContextManager):  # Extending AbstractContextManager to make PyCharm's type checker happy
    """
    The parsing context.
    Holds user input while parsing, and holds the parsed values.  Handles config overrides / hierarchy for settings that
    affect parser behavior.
    """
    config: CommandConfig
    prog: OptStr = None
    allow_argv_prog: Bool = True
    _command_obj: CommandObj = None
    _terminal_width: Optional[int]
    _provided: dict[ParamOrGroup, int]
    def __init__(
        self,
        argv: Argv = None,
        command_cls: Optional[CommandType] = None,
        *,
        parent: Optional[Context] = None,
        config: AnyConfig = None,
        terminal_width: int = None,
        allow_argv_prog: Bool = None,
        command: Optional[CommandObj] = None,
        **kwargs,
    ):
        self.command_cls = command_cls
        self.command = command
        self.parent = parent
        self.actions_taken = 0
        self.config = _normalize_config(config, kwargs, parent, command_cls)
        if parent:
            self._set_argv(parent.prog, argv)
            self._parsed = parent._parsed.copy()
            self._provided = parent._provided.copy()
            self._terminal_width = parent._terminal_width if terminal_width is None else terminal_width
            self.allow_argv_prog = parent.allow_argv_prog if allow_argv_prog is None else allow_argv_prog
        else:
            self._set_argv(None, argv)
            self._parsed = {}
            self._provided = defaultdict(int)
            self._terminal_width = terminal_width
            if allow_argv_prog is not None:
                self.allow_argv_prog = allow_argv_prog
    # region Internal Methods
[docs]
    @classmethod
    def for_prog(cls, prog: PathLike, *args, **kwargs) -> Context:
        self = cls(*args, **kwargs)
        self.prog = getattr(prog, 'name', prog)
        return self 
    def _set_argv(self, prog: OptStr, argv: Argv):
        if prog:
            self.prog = prog
            self.argv = sys.argv[1:] if argv is None else argv
        elif argv is None:
            self.prog, *self.argv = sys.argv
        else:
            self.argv = argv
        self.remaining = list(self.argv)
    def _sub_context(
        self, command_cls: CommandType, argv: Argv = None, command: CommandObj = None, **kwargs
    ) -> Context:
        return self.__class__(
            self.remaining if argv is None else argv,
            command_cls,
            parent=self,
            command=self.command if command is None else command,
            **kwargs,
        )
    def __repr__(self) -> str:
        command = getattr(self.command_cls, '__name__', None)
        prog, argv, allow_argv_prog = self.prog, self.argv, self.allow_argv_prog
        return f'<{self.__class__.__name__}[{command=!s}, {prog=}, {allow_argv_prog=}, {argv=}]>'
    def __enter__(self) -> Context:
        try:
            _context_stack.get().append(self)
        except LookupError:
            _context_stack.set([self])
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        _context_stack.get().pop()
    def __contains__(self, param: Union[ParamOrGroup, str, Any]) -> bool:
        try:
            self._parsed[param]
        except KeyError:
            if isinstance(param, str):
                try:
                    next((v for p, v in self._parsed.items() if p.name == param))
                except StopIteration:
                    return False
                else:
                    return True
            return False
        else:
            return True
    # endregion
[docs]
    @property
    def terminal_width(self) -> int:
        """Returns the current terminal width as the number of characters that fit on a single line."""
        if self._terminal_width is not None:
            return self._terminal_width
        return _TERMINAL.width 
[docs]
    def get_parsed(
        self,
        command: Command = None,
        *,
        exclude: Collection[Parameter] = (),
        recursive: Bool = True,
        default: Any = None,
        include_defaults: Bool = True,
    ) -> dict[str, Any]:
        """
        Returns all of the parsed arguments as a dictionary.
        The :ref:`get_parsed() <advanced:Parsed Args as a Dictionary>` helper function provides an easier way to access
        this functionality.
        :param command: An initialized Command object for which arguments were already parsed.
        :param exclude: Parameter objects that should be excluded from the returned results
        :param recursive: Whether parsed arguments should be recursively gathered from parent Commands
        :param default: The default value to use for parameters that raise :class:`.MissingArgument` when attempting to
          obtain their result values.
        :param include_defaults: Whether default values should be included in the returned results.  If False, only
          user-provided values will be included.
        :return: A dictionary containing all of the arguments that were parsed.  The keys in the returned dict match
          the names assigned to the Parameters in the Command associated with this Context.
        """
        if command is None:
            command = self.command
        with self:
            if recursive and self.parent:
                parsed = self.parent.get_parsed(
                    command, exclude=exclude, recursive=recursive, default=default, include_defaults=include_defaults
                )
            else:
                parsed = {}
            # TODO: Add way to get a nested dict with ParamGroup names as the keys of the nested sections?
            if self.params:
                for param in self.params.iter_params(exclude):
                    if include_defaults or param in self._parsed:
                        parsed[param.name] = param.result(command, default)
        return parsed 
[docs]
    @cached_property
    def params(self) -> Optional[CommandParameters]:
        """
        The :class:`.CommandParameters` object that contains the categorized Parameters from the Command associated
        with this Context.
        """
        if self.command_cls is not None:
            return self.command_cls.__class__.params(self.command_cls)
        return None 
[docs]
    def get_error_handler(self) -> Union[ErrorHandler, NullErrorHandler]:
        """Returns the :class:`.ErrorHandler` configured to be used."""
        if (error_handler := self.config.error_handler) is _NotSet:
            return extended_error_handler
        elif error_handler is None:
            return NullErrorHandler()
        else:
            return error_handler 
    # region Parsing Methods - Generally not intended to be called by users
[docs]
    def has_parsed_value(self, param: Parameter) -> bool:
        return param in self._parsed 
[docs]
    def get_parsed_value(self, param: Parameter, default=_NotSet):
        """Not intended to be called by users.  Used by Parameters to access their parsed values."""
        return self._parsed.get(param, default) 
[docs]
    def set_parsed_value(self, param: Parameter, value: Any):
        """Not intended to be called by users.  Used by Parameters during parsing to store parsed values."""
        self._parsed[param] = value 
[docs]
    def pop_parsed_value(self, param: Parameter):
        """Not intended to be called by users.  Used by Parameters during parsing if backtracking is necessary."""
        self._provided[param] = 0
        return self._parsed.pop(param) 
[docs]
    def roll_back_parsed_values(self, param: Parameter, count: int):
        """Not intended to be called by users.  Used during parsing as part of backtracking."""
        values = self._parsed[param]
        self._parsed[param] = values[:-count]
        self._provided[param] -= count
        return values[-count:] 
[docs]
    def record_action(self, param: ParamOrGroup, val_count: int = 1):
        """
        Not intended to be called by users.  Used by Parameters during parsing to indicate that they were provided.
        """
        self._provided[param] += val_count 
[docs]
    def num_provided(self, param: ParamOrGroup) -> int:
        """Not intended to be called by users.  Used by Parameters during parsing to handle nargs."""
        return self._provided[param] 
[docs]
    def get_missing(self) -> list[Parameter]:
        """Not intended to be called by users.  Used during parsing to determine if any Parameters are missing."""
        return [p for p in self.params.required_check_params() if not self._provided[p]] 
[docs]
    def missing_options_with_env_var(self) -> Iterator[BaseOption]:
        """Yields Option parameters that have an environment variable configured, and did not have any CLI values."""
        for param in self.params.options:
            if param.env_var and not self._provided[param]:
                yield param 
    # endregion
    # region Actions
    @cached_property
    def _parsed_action_flags(self) -> tuple[int, list[ActionFlag], list[ActionFlag]]:
        """
        Not intended to be accessed by users.  Returns a tuple containing the total number of action flags provided, the
        action flags to run before main, and the action flags to run after main.
        """
        try:
            before_main, after_main = self.params.split_action_flags  # Each part is already sorted
        except AttributeError:  # self.command_cls is None
            return 0, [], []
        parsed = self._parsed
        before_main = [p for p in before_main if p in parsed] if before_main else []
        after_main = [p for p in after_main if p in parsed] if after_main else []
        return len(before_main) + len(after_main), before_main, after_main
[docs]
    @property
    def action_flag_count(self) -> int:
        """Not intended to be accessed by users.  Returns the count of parsed action flags."""
        return self._parsed_action_flags[0] 
[docs]
    @cached_property
    def all_action_flags(self) -> list[ActionFlag]:
        """Not intended to be accessed by users.  Returns all parsed action flags."""
        _, before_main, after_main = self._parsed_action_flags
        return before_main + after_main 
[docs]
    @cached_property
    def categorized_action_flags(self) -> dict[ActionPhase, Sequence[ActionFlag]]:
        """
        Not intended to be accessed by users.  Returns a dict of parsed action flags, categorized by the
        :class:`ActionPhase` during which they will run.
        """
        _, before_main, after_main = self._parsed_action_flags
        init_actions, before_actions = [], []
        for flag in before_main:
            if flag.always_available:
                init_actions.append(flag)
            else:
                before_actions.append(flag)
        return {
            ActionPhase.PRE_INIT: init_actions,
            ActionPhase.BEFORE_MAIN: before_actions,
            ActionPhase.AFTER_MAIN: after_main,
        } 
[docs]
    def iter_action_flags(self, phase: ActionPhase) -> Iterator[ActionFlag]:
        """
        Not intended to be called by users.  Iterator that yields action flags to be executed during the specified
        phase while incrementing the counter of actions taken.
        :param phase: The current :class:`ActionPhase`
        """
        for action_flag in self.categorized_action_flags[phase]:
            self.actions_taken += 1
            yield action_flag 
 
    # endregion
def _normalize_config(
    config: AnyConfig, kwargs: dict[str, Any], parent: Context | None, command: CommandType | None
) -> CommandConfig:
    if config is not None:
        if kwargs:
            raise TypeError(f'Cannot combine {config=} with keyword config arguments={kwargs}')
        elif isinstance(config, CommandConfig):
            return config
        kwargs = config
    if parent:
        for key, val in parent.config._data.items():
            kwargs.setdefault(key, val)
    return CommandConfig(parent=command.__class__.config(command) if command is not None else None, **kwargs)
[docs]
class ActionPhase(Enum):
    PRE_INIT = 0
    BEFORE_MAIN = 1
    AFTER_MAIN = 2 
    # def __next__(self) -> ActionPhase:
    #     try:
    #         return self._value2member_map_[self._value_ + 1]  # noqa
    #     except KeyError:
    #         raise StopIteration
[docs]
class ContextProxy:
    """
    Proxy for the currently active :class:`Context` object.  Allows usage similar to the ``request`` object in Flask.
    This class should not be instantiated by users - use the common :data:`ctx` instance.
    """
    __slots__ = ()
    # region Generic Proxy Methods
    def __getattr__(self, attr: str):
        return getattr(get_current_context(), attr)
    def __setattr__(self, attr: str, value):
        return setattr(get_current_context(), attr, value)
    def __eq__(self, other) -> bool:
        return get_current_context() == other
    def __contains__(self, item) -> bool:
        return item in get_current_context()
    def __enter__(self) -> Context:
        # The current context is already active, so there's no need to re-enter it - it can just be returned
        return get_current_context()
    def __exit__(self, exc_type, exc_val, exc_tb):
        pass
    # endregion
    # region Proxied Parsing Methods
[docs]
    def has_parsed_value(self, param: Parameter) -> bool:
        return get_current_context().has_parsed_value(param) 
[docs]
    def get_parsed_value(self, param: Parameter):
        return get_current_context().get_parsed_value(param) 
[docs]
    def set_parsed_value(self, param: Parameter, value: Any):
        get_current_context().set_parsed_value(param, value) 
[docs]
    def record_action(self, param: ParamOrGroup, val_count: int = 1):
        get_current_context().record_action(param, val_count) 
[docs]
    def num_provided(self, param: ParamOrGroup) -> int:
        return get_current_context().num_provided(param) 
    # endregion
    # region Properties with Inactive Handlers
[docs]
    @property
    def terminal_width(self) -> int:
        if context := get_current_context(True):
            return context.terminal_width
        else:
            return _TERMINAL.width 
[docs]
    @property
    def config(self) -> CommandConfig:
        if context := get_current_context(True):
            return context.config
        else:
            return DEFAULT_CONFIG 
 
    # endregion
ctx: Context = cast(Context, ContextProxy())
# region Public / Semi-Public Functions
[docs]
def get_current_context(silent: bool = False) -> Optional[Context]:
    """
    Get the currently active parsing context.
    :param silent: If True, allow this function to return ``None`` if there is no active :class:`Context`
    :return: The active :class:`Context` object
    :raises: :class:`~.exceptions.NoActiveContext` if there is no active Context and ``silent=False`` (default)
    """
    try:
        return _context_stack.get()[-1]
    except (LookupError, IndexError):
        if silent:
            return None
        raise NoActiveContext('There is no active context') from None 
[docs]
def get_or_create_context(
    command_cls: CommandType, argv: Argv = None, *, command: CommandObj = None, **kwargs
) -> Context:
    """
    Used internally by Commands to re-use an existing user-activated Context, or to create a new Context if there was
    no active Context.
    """
    if not (context := get_current_context(True)):
        return Context(argv, command_cls, command=command, **kwargs)
    elif argv is None and command is None and context.command_cls is command_cls and not kwargs:
        return context
    else:
        return context._sub_context(command_cls, argv=argv, command=command, **kwargs) 
[docs]
def get_context(command: Command) -> Context:
    """
    :param command: An initialized Command object
    :return: The Context associated with the given Command
    """
    try:
        return command._Command__ctx  # noqa
    except AttributeError as e:
        raise TypeError('get_context only supports Command objects') from e 
[docs]
def get_parsed(
    command: Command, to_call: Callable = None, default: Any = None, include_defaults: Bool = True
) -> dict[str, Any]:
    """
    Provides a way to obtain all of the arguments that were parsed for the given Command as a dictionary.
    If the parsed arguments are intended to be used to call a particular function/method, or to initialize a particular
    class, then that callable can be provided as the ``to_call`` parameter to filter the parsed arguments to only the
    ones that would be accepted by it.  It will not be called by this function.
    If the callable accepts any :attr:`VAR_KEYWORD <python:inspect.Parameter.kind>` parameters (i.e., ``**kwargs``),
    then those param names will not be used for filtering.  That is, if the command has a Parameter named ``kwargs``
    and the callable accepts ``**kwargs``, the ``kwargs`` key will not be included in the argument dict returned by
    this function.  If any of the parameters of the given callable cannot be passed as a keyword argument (i.e.,
    :attr:`POSITIONAL_ONLY or VAR_POSITIONAL <python:inspect.Parameter.kind>`), then they must be handled after calling
    this function.  They will be included in the returned dict.
    :param command: An initialized Command object for which arguments were already parsed.
    :param to_call: A :class:`callable <python:collections.abc.Callable>` (function, method, class, etc.) that should
      be used to filter the parsed arguments.  If provided, then only the keys that match the callable's signature will
      be included in the returned dictionary of parsed arguments.
    :param default: The default value to use for parameters that raise :class:`.MissingArgument` when attempting to
      obtain their result values.
    :param include_defaults: Whether default values should be included in the returned results.  If False, only
      user-provided values will be included.
    :return: A dictionary containing all of the (optionally filtered) arguments that were parsed.  The keys in the
      returned dict match the names assigned to the Parameters in the given Command.
    """
    parsed = get_context(command).get_parsed(command, default=default, include_defaults=include_defaults)
    if to_call is not None:
        sig = Signature.from_callable(to_call)
        keys = {k for k, p in sig.parameters.items() if p.kind != _Parameter.VAR_KEYWORD}
        parsed = {k: v for k, v in parsed.items() if k in keys}
    return parsed 
[docs]
def get_raw_arg(command: Command, parameter: Parameter) -> Any:
    """Retrieve the raw parsed argument value(s) provided for the given Parameter"""
    return get_context(command).get_parsed_value(parameter) 
# endregion