pytorch_forecasting.utils._dependencies 的源代码

"""Utilities for managing dependencies.

Copied from sktime/skbase.
"""

from functools import lru_cache


[docs] @lru_cache def _get_installed_packages_private(): """Get a dictionary of installed packages and their versions. Same as _get_installed_packages, but internal to avoid mutating the lru_cache by accident. """ from importlib.metadata import distributions, version dists = distributions() package_names = {dist.metadata["Name"] for dist in dists} package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names} # developer note: # we cannot just use distributions naively, # because the same top level package name may appear *twice*, # e.g., in a situation where a virtual env overrides a base env, # such as in deployment environments like databricks. # the "version" contract ensures we always get the version that corresponds # to the importable distribution, i.e., the top one in the sys.path. return package_versions
[docs] def _get_installed_packages(): """Get a dictionary of installed packages and their versions. Returns ------- dict : dictionary of installed packages and their versions keys are PEP 440 compatible package names, values are package versions MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3" """ return _get_installed_packages_private().copy()
[docs] def _check_matplotlib(ref="This feature", raise_error=True): """Check if matplotlib is installed. Parameters ---------- ref : str, optional (default="This feature") reference to the feature that requires matplotlib, used in error message raise_error : bool, optional (default=True) whether to raise an error if matplotlib is not installed Returns ------- bool : whether matplotlib is installed """ pkgs = _get_installed_packages() if raise_error and "matplotlib" not in pkgs: raise ImportError( ( f"{ref} requires matplotlib." " Please install matplotlib with `pip install matplotlib`." ) ) return "matplotlib" in pkgs