Source code for marEx._dependencies

"""Dependency management for marEx."""

import warnings
from typing import Dict, List


class DependencyTracker:
    """Tracks availability of optional dependencies"""

    def __init__(self):
        """Initialise the dependency tracker."""
        self._dependencies: Dict[str, bool] = {}
        self._check_all_dependencies()

    def _check_all_dependencies(self) -> None:
        """Check availability of all optional dependencies."""
        # Performance dependencies
        try:
            pass

            self._dependencies["jax"] = True
        except ImportError:
            self._dependencies["jax"] = False

        try:
            pass

            self._dependencies["jaxlib"] = True
        except ImportError:
            self._dependencies["jaxlib"] = False

        # HPC dependencies
        try:
            pass

            self._dependencies["dask_jobqueue"] = True
        except (ImportError, ValueError):
            # ValueError can occur when importing dask_jobqueue in worker threads
            # due to signal handling limitations
            self._dependencies["dask_jobqueue"] = False

        try:
            pass

            self._dependencies["psutil"] = True
        except ImportError:
            self._dependencies["psutil"] = False

        # Visualisation dependencies
        try:
            pass

            self._dependencies["matplotlib"] = True
        except ImportError:
            self._dependencies["matplotlib"] = False

        try:
            pass

            self._dependencies["cartopy"] = True
        except ImportError:
            self._dependencies["cartopy"] = False

        try:
            pass

            self._dependencies["seaborn"] = True
        except ImportError:
            self._dependencies["seaborn"] = False

        try:
            pass

            self._dependencies["cmocean"] = True
        except ImportError:
            self._dependencies["cmocean"] = False

        # PIL for image processing
        try:
            pass

            self._dependencies["pillow"] = True
        except ImportError:
            self._dependencies["pillow"] = False

    def has_dependency(self, dep_name: str) -> bool:
        """Check if a specific dependency is available."""
        return self._dependencies.get(dep_name, False)

    def require_dependencies(self, dependencies: List[str], feature: str = "This functionality") -> None:
        """
        Require specific dependencies for a feature.

        Parameters
        ----------
        dependencies : List[str]
            List of required dependency names
        feature : str
            Description of the feature requiring the dependencies

        Raises
        ------
        ImportError
            If any required dependencies are missing
        """
        missing = [dep for dep in dependencies if not self.has_dependency(dep)]

        if missing:
            if len(missing) == 1:
                dep_name = missing[0]
                install_cmd = self._get_install_command(dep_name)
                raise ImportError(f"{feature} requires {dep_name}. " f"Install with: {install_cmd}")
            else:
                dep_list = ", ".join(missing)
                install_cmd = "pip install marEx[full]"
                raise ImportError(f"{feature} requires the following dependencies: {dep_list}. " f"Install with: {install_cmd}")

    def warn_missing_dependency(self, dep_name: str, feature: str = "Some functionality") -> None:
        """
        Issue a warning for a missing optional dependency.

        Parameters
        ----------
        dep_name : str
            Name of the missing dependency
        feature : str
            Description of affected functionality
        """
        if not self.has_dependency(dep_name):
            install_cmd = self._get_install_command(dep_name)
            warnings.warn(
                f"{dep_name} not installed. {feature} will be slower or unavailable. "
                f"For best performance, install with: {install_cmd}",
                ImportWarning,
                stacklevel=3,
            )

    def _get_install_command(self, dep_name: str) -> str:
        """Get appropriate install command for a dependency."""
        specific_commands = {
            "jax": "pip install marEx[performance]",
            "jaxlib": "pip install marEx[performance]",
            "dask_jobqueue": "pip install marEx[hpc]",
            "pillow": "pip install pillow",
            "matplotlib": "pip install marEx[plotting]",
            "cartopy": "pip install marEx[plotting]",
            "seaborn": "pip install marEx[plotting]",
            "cmocean": "pip install marEx[plotting]",
        }

        return specific_commands.get(dep_name, f"pip install {dep_name}")

    def get_missing_dependencies(self) -> List[str]:
        """Get list of missing optional dependencies."""
        return [dep for dep, available in self._dependencies.items() if not available]

    def get_installation_profile(self) -> str:
        """Determine the current installation profile based on available dependencies."""
        profiles = {
            "minimal": [],
            "performance": ["jax", "jaxlib"],
            "hpc": ["dask_jobqueue", "psutil"],
            "plotting": ["seaborn", "cmocean"],
            "full": ["jax", "jaxlib", "dask_jobqueue", "seaborn", "cmocean"],
        }

        # Check which profiles are satisfied
        satisfied_profiles = []
        for profile, required_deps in profiles.items():
            if all(self.has_dependency(dep) for dep in required_deps):
                satisfied_profiles.append(profile)

        # Return the most comprehensive profile
        profile_hierarchy = ["full", "plotting", "hpc", "performance", "minimal"]
        for profile in profile_hierarchy:
            if profile in satisfied_profiles:
                return profile

        return "minimal"

    def print_dependency_status(self) -> None:
        """Print status of all tracked dependencies."""
        current_profile = self.get_installation_profile()
        print("marEx Dependency Status:")
        print("-" * 40)
        print(f"Current Profile: {current_profile}")

        # Group dependencies by category
        categories = {
            "Performance": ["jax", "jaxlib"],
            "HPC": ["dask_jobqueue", "psutil"],
            "Visualisation": ["matplotlib", "cartopy", "seaborn", "cmocean"],
            "Core": ["pillow"],
        }

        for category, deps in categories.items():
            print(f"\n{category}:")
            for dep in deps:
                if dep in self._dependencies:
                    status = "✓ Available" if self._dependencies[dep] else "✗ Missing"
                    print(f"  {dep:15} {status}")

        missing = self.get_missing_dependencies()
        if missing:
            print("\nInstallation suggestions:")
            print("  All features:     pip install marEx[full]")
            print("  Performance:      pip install marEx[performance]")
            print("  HPC:              pip install marEx[hpc]")
            print("  Visualisation:    pip install marEx[plotting]")


# Global dependency tracker instance
_dependency_tracker = DependencyTracker()


# Convenience functions
[docs] def has_dependency(dep_name: str) -> bool: """Check if a dependency is available.""" return _dependency_tracker.has_dependency(dep_name)
def require_dependencies(dependencies: List[str], feature: str = "This functionality") -> None: """Require specific dependencies for a feature.""" _dependency_tracker.require_dependencies(dependencies, feature) def warn_missing_dependency(dep_name: str, feature: str = "Some functionality") -> None: """Issue a warning for a missing optional dependency.""" _dependency_tracker.warn_missing_dependency(dep_name, feature)
[docs] def get_installation_profile() -> str: """Get the current installation profile.""" return _dependency_tracker.get_installation_profile()
# Check and warn about missing optional dependencies on import def _check_optional_dependencies_on_import(): """Check optional dependencies and issue warnings on package import.""" if not _dependency_tracker.has_dependency("jax"): _dependency_tracker.warn_missing_dependency("jax", "Some operations") # Run the check when this module is imported _check_optional_dependencies_on_import()