#! /usr/bin/python3

"""Driver package query/installation tool for Ubuntu"""

# (C) 2012 Canonical Ltd.
# Author: Martin Pitt <martin.pitt@ubuntu.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.

import click
import subprocess
import fnmatch
import sys
import os
import logging
import apt_pkg
from typing import Optional, Any, Dict

from functools import cmp_to_key
import UbuntuDrivers.detect
from UbuntuDrivers import kerneldetection

sys_path = os.environ.get("UBUNTU_DRIVERS_SYS_DIR")

# Make sure that the PATH environment variable is set
# See LP: #1854472
if not os.environ.get("PATH"):
    os.environ["PATH"] = "/sbin:/usr/sbin:/bin:/usr/bin"

logger = logging.getLogger()
LOGLEVEL = os.environ.get("LOGLEVEL", logging.WARNING)
logger.setLevel(level=LOGLEVEL)

CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])


class Config(object):
    def __init__(self) -> None:
        self.gpgpu: bool = False
        self.free_only: bool = False
        self.package_list: str = ""
        self.install_oem_meta: bool = True
        self.driver_string: str = ""
        self.include_dkms: bool = False
        self.recommended: bool = False


pass_config = click.make_pass_decorator(Config, ensure=True)


def command_list(args: Config) -> int:
    """Show all driver packages which apply to the current system."""
    apt_pkg.init_config()
    apt_pkg.init_system()

    try:
        cache = apt_pkg.Cache(None)
    except Exception as ex:
        print(ex)
        return 1

    # First check if kernel needs updating
    kernel_detector = kerneldetection.KernelDetection(cache)

    should_exit = kernel_detector.get_kernel_update_warning(args.include_dkms)

    if should_exit:
        return 1

    packages = UbuntuDrivers.detect.system_driver_packages(
        apt_cache=cache,
        sys_path=sys_path,
        freeonly=args.free_only,
        include_oem=args.install_oem_meta,
    )

    for package in packages:
        try:
            linux_modules = UbuntuDrivers.detect.get_linux_modules_metapackage(
                cache, package
            )
            if not linux_modules and package.find("dkms") != -1:
                linux_modules = package

            if linux_modules:
                print("%s, (kernel modules provided by %s)" % (package, linux_modules))
            else:
                print(package)
        except KeyError:
            print(package)

    return 0


def command_list_oem(args: Config) -> int:
    """Show all OEM enablement packages which apply to this system"""

    if not args.install_oem_meta:
        return 0

    apt_pkg.init_config()
    apt_pkg.init_system()

    try:
        cache = apt_pkg.Cache(None)
    except Exception as ex:
        print(ex)
        return 1

    # First check if kernel needs updating
    kernel_detector = kerneldetection.KernelDetection(cache)

    should_exit = kernel_detector.get_kernel_update_warning(args.include_dkms)

    if should_exit:
        return 1

    packages = UbuntuDrivers.detect.system_device_specific_metapackages(
        apt_cache=cache, sys_path=sys_path, include_oem=args.install_oem_meta
    )

    if packages:
        print("\n".join(packages))

        if args.package_list:
            with open(args.package_list, "a") as f:
                f.write("\n".join(packages))
                f.write("\n")

    return 0


def list_gpgpu(args: Config) -> int:
    """Show all GPGPU driver packages which apply to the current system."""
    found = False
    apt_pkg.init_config()
    apt_pkg.init_system()

    try:
        cache = apt_pkg.Cache(None)
    except Exception as ex:
        print(ex)
        return 1

    # First check if kernel needs updating
    kernel_detector = kerneldetection.KernelDetection(cache)

    should_exit = kernel_detector.get_kernel_update_warning(args.include_dkms)

    if should_exit:
        return 1

    packages = UbuntuDrivers.detect.system_gpgpu_driver_packages(cache, sys_path)
    for package in packages:
        candidate = packages[package]["metapackage"]
        if candidate:
            print(
                "%s, (kernel modules provided by %s)"
                % (
                    candidate,
                    UbuntuDrivers.detect.get_linux_modules_metapackage(
                        cache, candidate
                    ),
                )
            )

    return 0


def command_devices(args: Config) -> Optional[int]:
    """Show all devices which need drivers, and which packages apply to them."""
    apt_pkg.init_config()
    apt_pkg.init_system()

    try:
        cache = apt_pkg.Cache(None)
    except Exception as ex:
        print(ex)
        return 1

    drivers = UbuntuDrivers.detect.system_device_drivers(
        apt_cache=cache, sys_path=sys_path, freeonly=args.free_only
    )
    for device, info in drivers.items():
        print("== %s ==" % device)
        for k, v in info.items():
            if k == "drivers":
                continue
            print("%-9s: %s" % (k, v))

        for pkg, pkginfo in info["drivers"].items():
            info_str = ""
            if pkginfo["from_distro"]:
                info_str += " distro"
            else:
                info_str += " third-party"
            if pkginfo["free"]:
                info_str += " free"
            else:
                info_str += " non-free"
            if pkginfo.get("builtin"):
                info_str += " builtin"
            if pkginfo.get("recommended"):
                info_str += " recommended"
            print("%-9s: %s -%s" % ("driver", pkg, info_str))
        print("")

    return None


def command_install(args: Config) -> Optional[int]:
    """Install drivers that are appropriate for your hardware."""
    apt_pkg.init_config()
    apt_pkg.init_system()

    try:
        cache = apt_pkg.Cache(None)
    except Exception as ex:
        print(ex)
        return 1

    with_nvidia_kms = False
    is_nvidia = False

    # First check if kernel needs updating
    kernel_detector = kerneldetection.KernelDetection(cache)

    should_exit = kernel_detector.get_kernel_update_warning(args.include_dkms)

    if should_exit:
        return 1

    to_install = UbuntuDrivers.detect.get_desktop_package_list(
        cache,
        sys_path,
        free_only=args.free_only,
        include_oem=args.install_oem_meta,
        driver_string=args.driver_string,
        include_dkms=args.include_dkms,
    )

    if not to_install:
        print("All the available drivers are already installed.")
        return None

    for package in to_install:
        if "nvidia" in package:
            is_nvidia = True
            break

    if is_nvidia:
        UbuntuDrivers.detect.nvidia_desktop_pre_installation_hook(to_install)

    ret = subprocess.call(
        ["apt-get", "install", "-o", "DPkg::options::=--force-confnew", "-y"]
        + to_install
    )

    oem_meta_to_install = fnmatch.filter(to_install, "oem-*-meta")

    # create package list
    if ret == 0 and args.package_list:
        with open(args.package_list, "a") as f:
            f.write("\n".join(to_install))
            f.write("\n")
            f.close()
    elif ret != 0:
        return ret

    for package_to_install in oem_meta_to_install:
        sources_list_path = os.path.join(
            os.path.sep, "etc", "apt", "sources.list.d", f"{package_to_install}.list"
        )

        update_ret = subprocess.call(
            [
                "apt",
                "-o",
                f"Dir::Etc::SourceList={sources_list_path}",
                "-o",
                "Dir::Etc::SourceParts=/dev/null",
                "--no-list-cleanup",
                "update",
            ]
        )

        if update_ret != 0:
            return update_ret

    if is_nvidia:
        UbuntuDrivers.detect.nvidia_desktop_post_installation_hook()

    # All updates completed successfully, now let's upgrade the packages
    if oem_meta_to_install:
        ret = subprocess.call(
            ["apt", "install", "-o", "DPkg::Options::=--force-confnew", "-y"]
            + oem_meta_to_install
        )

    return ret


def install_gpgpu(args: Config) -> int:
    """Install GPGPU drivers that are appropriate for your hardware."""
    candidate: str = ""
    if args.driver_string:
        # Just one driver
        # e.g. --gpgpu 390
        #      --gpgpu nvidia:390
        #
        # Or Multiple drivers
        # e.g. --gpgpu nvidia:390,amdgpu
        not_found_exit_status = 1
    else:
        # No args, just --gpgpu
        not_found_exit_status = 0

    apt_pkg.init_config()
    apt_pkg.init_system()

    try:
        cache = apt_pkg.Cache(None)
    except Exception as ex:
        print(ex)
        return 1

    # First check if kernel needs updating
    kernel_detector = kerneldetection.KernelDetection(cache)

    should_exit = kernel_detector.get_kernel_update_warning(args.include_dkms)

    if should_exit:
        return 1

    packages = UbuntuDrivers.detect.system_gpgpu_driver_packages(cache, sys_path)
    to_install = UbuntuDrivers.detect.gpgpu_install_filter(
        cache, args.include_dkms, packages, args.driver_string, get_recommended=False
    )
    if not to_install:
        print("No drivers found for installation.")
        return not_found_exit_status

    if not to_install:
        print("All the available drivers are already installed.")
        return 0

    ret = subprocess.call(
        [
            "apt-get",
            "install",
            "-o",
            "DPkg::options::=--force-confnew",
            "--no-install-recommends",
            "-y",
        ]
        + to_install
    )

    # create package list
    if ret == 0 and args.package_list:
        with open(args.package_list, "a") as f:
            f.write("\n".join(to_install))
            f.write("\n")

    return ret


def command_debug(args: Config) -> int:
    """Print all available information and debug data about drivers."""

    logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)

    print("=== log messages from detection ===")
    aliases = UbuntuDrivers.detect.system_modaliases()

    apt_pkg.init_config()
    apt_pkg.init_system()

    try:
        cache = apt_pkg.Cache(None)
    except Exception as ex:
        print(ex)
        return 1

    depcache = apt_pkg.DepCache(cache)
    packages = UbuntuDrivers.detect.system_driver_packages(
        cache, sys_path, freeonly=args.free_only, include_oem=args.install_oem_meta
    )
    auto_packages = UbuntuDrivers.detect.auto_install_filter(
        cache, args.include_dkms, packages
    )

    print("=== modaliases in the system ===")
    for alias in aliases:
        print(alias)

    print("=== matching driver packages ===")
    for package, info in packages.items():
        p = cache[package]
        try:
            inst = p.current_ver.ver_str
        except AttributeError:
            inst = "<none>"
        try:
            package_candidate = depcache.get_candidate_ver(p)
            cand = package_candidate.ver_str  # type: ignore[union-attr]
        except AttributeError:
            cand = "<none>"
        if package in auto_packages:
            auto = " (auto-install)"
        else:
            auto = ""

        support = info.get("support")

        info_str = ""
        if info["from_distro"]:
            info_str += "  [distro]"
        else:
            info_str += "  [third party]"
        if info["free"]:
            info_str += "  free"
        else:
            info_str += "  non-free"
        if "modalias" in info:
            info_str += "  modalias: " + info["modalias"]
        if "syspath" in info:
            info_str += "  path: " + info["syspath"]
        if "vendor" in info:
            info_str += "  vendor: " + info["vendor"]
        if "model" in info:
            info_str += "  model: " + info["model"]
        if support:
            info_str += "  support level: " + support

        print(
            "%s: installed: %s   available: %s%s%s "
            % (package, inst, cand, auto, info_str)
        )

    return 0


def format_welcome_page(data: Dict[str, Any]) -> str:
    """Format the welcome page output from gathered data.

    Args:
        data: Dictionary containing welcome page data (from gather_welcome_page_data)

    Returns:
        str: Formatted welcome page text
    """
    lines = []
    lines.append("\n=== Welcome to ubuntu-drivers ===")
    lines.append(
        "This tool helps you install and manage hardware drivers for your Ubuntu system."
    )

    # Handle cache error
    if data["cache_error"]:
        lines.append(f"Warning: Could not access package cache: {data['cache_error']}")
        return "\n".join(lines)

    # Display installed drivers
    lines.append("\n--- Installed OEM / NVIDIA Drivers ---")
    if data["nvidia_drivers"] or data["oem_packages"]:
        if data["nvidia_drivers"]:
            lines.append("NVIDIA Drivers:")
            for driver in data["nvidia_drivers"]:
                lines.append(f"  • {driver}")
        if data["oem_packages"]:
            lines.append("OEM Enablement Packages:")
            for package in data["oem_packages"]:
                lines.append(f"  • {package}")
    else:
        lines.append("No OEM or NVIDIA drivers are currently installed.")

    lines.append(
        "\nFor more information, use 'ubuntu-drivers --help' or 'ubuntu-drivers list'"
    )
    lines.append("To install drivers, use 'sudo ubuntu-drivers install'")

    # Check NVIDIA module status and kernel compatibility if NVIDIA drivers were detected
    if not data["nvidia_drivers"]:
        return "\n".join(lines)

    if data["nvidia_status_error"]:
        lines.append("\n--- NVIDIA Module Status ---")
        lines.append(
            f'⚠️  Could not check NVIDIA module status: {data["nvidia_status_error"]}'
        )
        return "\n".join(lines)

    if not data["nvidia_status"]:
        return "\n".join(lines)

    nvidia_status = data["nvidia_status"]

    lines.append("\n--- NVIDIA Module Status ---")
    if nvidia_status["loaded"]:
        lines.append("✓ NVIDIA module is currently loaded")
    else:
        # This case can happen after a new driver install, but before load, for example
        lines.append("ℹ️  NVIDIA module is not currently loaded")

        # Show next boot module path even when not loaded
        if nvidia_status["next_boot_module_path"]:
            lines.append(f"   📍 The NVIDIA module should load on next boot")
            lines.append(
                f'   📍 Next boot module path: {nvidia_status["next_boot_module_path"]}'
            )

    if nvidia_status["current_module_path"]:
        lines.append(
            f'  Module path for current kernel: {nvidia_status["current_module_path"]}'
        )

    if nvidia_status["module_missing"]:
        lines.append("\n❌  ERROR: NVIDIA module is missing for the next boot kernel!")
        lines.append(f'   Next boot kernel: {nvidia_status["next_boot_kernel"]}')
        lines.append("   Please re-run ubuntu-drivers to install the missing module.")
        return "\n".join(
            lines
        )  # We don't want to instruct the user to reboot if they're missing the new kernel's NV driver
    elif nvidia_status["next_boot_module_path"]:
        lines.append(
            f'   ✓ Next boot module path: {nvidia_status["next_boot_module_path"]}'
        )
        # Show both paths when they differ
        if (
            nvidia_status["current_module_path"]
            and nvidia_status["current_module_path"]
            != nvidia_status["next_boot_module_path"]
        ):
            lines.append(
                f"   📍 Module paths differ - you should reboot to use the latest kernel & driver:"
            )
            lines.append(f'      Current:  {nvidia_status["current_module_path"]}')
            lines.append(f'      Next boot: {nvidia_status["next_boot_module_path"]}')

    if nvidia_status["needs_reboot"]:
        lines.append(
            f"\n⚠️  Kernel version difference detected - reboot to use the latest kernel:"
        )
        lines.append(f"   Current kernel: {os.uname().release}")
        lines.append(f'   Next boot kernel: {nvidia_status["next_boot_kernel"]}')
    elif not nvidia_status["next_boot_kernel"]:
        lines.append(
            "   ⚠️  Could not determine next boot kernel (try again with `sudo ubuntu-drivers`)"
        )
    else:
        lines.append("   ✓ Current and next boot kernel versions match")

    return "\n".join(lines)


def show_welcome_page() -> None:
    """Display a welcome page showing installed OEM and NVIDIA drivers."""
    data = UbuntuDrivers.detect.gather_welcome_page_data()
    output = format_welcome_page(data)
    print(output)


#
# main
#


@click.group(context_settings=CONTEXT_SETTINGS, invoke_without_command=True)
@pass_config
def greet(config: Config, **kwargs: Any) -> None:
    # Show welcome page if no subcommand is provided
    if not click.get_current_context().invoked_subcommand:
        show_welcome_page()


@greet.command()
@click.argument("driver", nargs=-1)  # add the name argument
@click.option(
    "--gpgpu",
    is_flag=True,
    help="Install “general-purpose computing” drivers for use in a headless server environment. This installs a server (ERD) flavor of the driver (which is required for compatibility with some server applications, such as nvidia-fabricmanager), and also results in a smaller installation footprint by not installing packages that are only useful in graphical environments.",
)
@click.option(
    "--recommended", is_flag=True, help="Only show the recommended driver packages"
)
@click.option("--free-only", is_flag=True, help="Only consider free packages")
@click.option(
    "--package-list",
    nargs=1,
    metavar="PATH",
    help="Create file with list of installed packages (in install mode)",
)
@click.option(
    "--no-oem",
    is_flag=True,
    metavar="install_oem_meta",
    help="Do not include OEM enablement packages (these enable an external archive)",
)
@click.option("--include-dkms", is_flag=True, help="Also consider DKMS packages")
@pass_config
def install(config: Config, **kwargs: Any) -> None:
    """Install a driver [driver[:version][,driver[:version]]]"""

    # Require root
    if os.geteuid() != 0:
        print(
            "Error: 'ubuntu-drivers install' must be run as root. Try using 'sudo'.",
            file=sys.stderr,
        )
        sys.exit(1)

    if kwargs.get("gpgpu"):
        config.gpgpu = True
    if kwargs.get("free_only"):
        config.free_only = True
    if kwargs.get("include_dkms"):
        config.include_dkms = True

    # if kwargs.get('package_list'):
    #     config.package_list = kwargs.get('package_list')
    if kwargs.get("package_list"):
        config.package_list = "".join(kwargs.get("package_list"))  # type: ignore[arg-type]
    if kwargs.get("no_oem"):
        config.install_oem_meta = False

    if kwargs.get("driver"):
        config.driver_string = "".join(kwargs.get("driver"))  # type: ignore[arg-type]

    if config.gpgpu:
        install_gpgpu(config)
    else:
        command_install(config)


@greet.command()
@click.argument("list", nargs=-1)
@click.option(
    "--gpgpu",
    is_flag=True,
    help="Install “general-purpose computing” drivers for use in a headless server environment. This installs a server (ERD) flavor of the driver (which is required for compatibility with some server applications, such as nvidia-fabricmanager), and also results in a smaller installation footprint by not installing packages that are only useful in graphical environments.",
)
@click.option(
    "--recommended", is_flag=True, help="Only show the recommended driver packages"
)
@click.option("--free-only", is_flag=True, help="Only consider free packages")
@click.option("--include-dkms", is_flag=True, help="Also consider DKMS packages")
@pass_config
def list(config: Config, **kwargs: Any) -> Optional[int]:
    """Show all driver packages which apply to the current system."""
    apt_pkg.init_config()
    apt_pkg.init_system()
    include_dkms: bool = kwargs.get("include_dkms", False)

    try:
        cache = apt_pkg.Cache(None)
    except Exception as ex:
        print(ex)
        return 1

    # First check if kernel needs updating
    kernel_detector = kerneldetection.KernelDetection(cache)

    should_exit = kernel_detector.get_kernel_update_warning(include_dkms)

    if should_exit:
        return 1

    if kwargs.get("gpgpu"):
        packages = UbuntuDrivers.detect.system_gpgpu_driver_packages(cache, sys_path)
        sort_func = UbuntuDrivers.detect._cmp_gfx_alternatives_gpgpu
    else:
        packages = UbuntuDrivers.detect.system_driver_packages(
            apt_cache=cache,
            sys_path=sys_path,
            freeonly=config.free_only,
            include_oem=config.install_oem_meta,
        )
        sort_func = UbuntuDrivers.detect._cmp_gfx_alternatives

    for package, info in sorted(packages.items(), key=cmp_to_key(lambda left, right: sort_func(left[0], right[0])), reverse=True):  # type: ignore[index]
        try:
            linux_modules = UbuntuDrivers.detect.get_linux_modules_metapackage(
                cache, package
            )
            if not linux_modules and "dkms" in package and include_dkms:
                linux_modules = package

            if linux_modules:
                if not include_dkms and "dkms" in linux_modules:
                    continue
                if kwargs.get("recommended"):
                    # This is just a space separated two item line
                    # Such as "nvidia-headless-no-dkms-470-server linux-modules-nvidia-470-server-generic"
                    print("%s %s" % (package, linux_modules))
                    break
                else:
                    print(
                        "%s, (kernel modules provided by %s)" % (package, linux_modules)
                    )
            else:
                print(package)
        except KeyError:
            print(package)

    return 0


@greet.command()
@click.argument("list-oem", nargs=-1)
@click.option(
    "--package-list",
    nargs=1,
    metavar="PATH",
    help="Create file with a list of the available packages",
)
@pass_config
def list_oem(config: Config, **kwargs: Any) -> None:
    """Show all OEM enablement packages which apply to this system"""
    if kwargs.get("package_list"):
        config.package_list = "".join(kwargs.get("package_list"))  # type: ignore[arg-type]

    command_list_oem(config)


@greet.command()
@click.argument("debug", nargs=-1)  # add the name argument
@pass_config
def debug(config: Config, **kwargs: Any) -> None:
    """Print all available information and debug data about drivers."""
    command_debug(config)


@greet.command()
@click.argument("devices", nargs=-1)  # add the name argument
@click.option("--free-only", is_flag=True, help="Only consider free packages")
@pass_config
def devices(config: Config, **kwargs: Any) -> None:
    """Show all devices which need drivers, and which packages apply to them."""
    if kwargs.get("free_only"):
        config.free_only = True
    command_devices(config)


if __name__ == "__main__":
    greet()
