#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# boltd mocking
#
# Copyright © 2017 Red Hat, Inc
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library. If not, see <http://www.gnu.org/licenses/>.
# Authors:
#       Christian J. Kellner <christian@kellner.me>

from __future__ import print_function

import argparse
import atexit
import os
import readline
import shlex
import shutil
import subprocess
import sys
import tempfile
import time
import uuid

try:
    import gi
    from gi.repository import GLib
    from gi.repository import Gio

    gi.require_version("UMockdev", "1.0")
    from gi.repository import UMockdev

except ImportError as e:
    sys.stderr.write("Missing dependencies: %s\n" % str(e))
    sys.exit(1)

DBUS_NAME = "org.freedesktop.bolt"
DBUS_PATH = "/org/freedesktop/bolt"
DBUS_IFACE_PREFIX = "org.freedesktop.bolt1."
DBUS_IFACE_MANAGER = DBUS_IFACE_PREFIX + "Manager"
DBUS_IFACE_DEVICE = DBUS_IFACE_PREFIX + "Device"
DBUS_IFACE_DOMAIN = DBUS_IFACE_PREFIX + "Domain"
SERVICE_FILE = "/usr/share/dbus-1/system-services/org.freedesktop.bolt.service"


class SysfsDev:
    known_attrs = []

    def __init__(self, sysname, syspath):
        self.sysname = sysname
        self.syspath = syspath

    @property
    def is_connected(self):
        return self.syspath is not None

    def sysattr_path(self, name):
        return os.path.join(self.syspath, name)

    def have_sysattr(self, name):
        path = self.sysattr_path(name)
        return os.path.isfile(path)

    def read_sysattr(self, name):
        path = self.sysattr_path(name)
        try:
            with open(path, "r") as fd:
                raw = fd.read()
                return raw.strip()
        except FileNotFoundError:
            return None

    def dump(self, prefix="", file=sys.stdout):
        if self.sysname is None:
            print("%s disconnected" % prefix, file=file)
        else:
            print("%s%s @ %s" % (prefix, self.sysname, self.syspath), file=file)

    @staticmethod
    def bridge(klass):
        lst = getattr(klass, "known_attrs", [])

        def install(attr):
            def getter(self):
                return self.read_sysattr(attr)

            prop = property(getter)
            setattr(klass, attr, prop)

        [install(attr) for attr in lst]
        return klass


@SysfsDev.bridge
class Domain(SysfsDev):
    known_attrs = ["boot_acl", "security"]

    def __init__(self, index, sysname, syspath):
        super(Domain, self).__init__(sysname, syspath)
        self.index = index
        self.parent = None
        self.serial = 100
        self.host = None

    @property
    def uid(self):
        return self.host and self.host.uid

    def gen_serial(self):
        s = self.serial
        self.serial += 1
        return s

    def find_device(self, name_or_id):
        if self.host is None:
            return None
        return self.host.find_device(name_or_id)

    def show(self, prefix="", file=sys.stdout):
        name = self.sysname or "domain"
        uid = self.host.uid if self.host else ""
        print("%s%s %s" % (prefix, name, uid), file=file)
        pf = prefix + "  "
        super(Domain, self).dump(prefix=pf, file=file)
        security = self.security or "unknown"
        print("%ssecurity %s" % (pf, security), file=file)


@SysfsDev.bridge
class Device(SysfsDev):
    known_attrs = [
        "authorized",
        "boot",
        "device",
        "device_name",
        "generation",
        "key",
        "unique_id",
        "vendor",
        "vendor_name",
    ]

    search_fields = ("uid", "sysname", "device_name")

    def __init__(self, uid, sysname, syspath):
        super(Device, self).__init__(sysname, syspath)
        self.uid = uid
        self.domain = None
        self.is_host = False
        self.children = []

    def find_device(self, identifier, fields=search_fields):
        for i in [getattr(self, f) for f in fields]:
            if i == identifier:
                return self
        for c in self.children:
            d = c.find_device(identifier)
            if d is not None:
                return d
        return None

    def show(self, prefix="", file=sys.stdout):
        print("%sdevice %s" % (prefix, self.uid), file=file)
        pf = prefix + "  "
        super(Device, self).dump(prefix=pf, file=file)
        if not self.is_connected or self.is_host:
            return
        print("%sauthorized %s" % (pf, self.authorized), file=file)


class MockSysfs:

    def __init__(self):
        self.testbed = UMockdev.Testbed.new()
        self.domains = {}

    @property
    def root(self):
        return self.testbed.get_root_dir()

    def _gen_domain_id(self):
        for i in range(1024):
            if ("domain%d" % i) not in self.domains:
                return i
        raise ValueError("too many domains in use")

    def find_domain(self, name_or_id):
        for name, dom in self.sysfs.domains.items():
            if target in (name, dom.uid):
                return dom
        return None

    def find_device(self, name_or_id):
        for name, dom in self.domains.items():
            d = dom.find_device(name_or_id)
            if d is not None:
                return d
        return None

    def domain_add(self, security, bootacl, iommu):
        index = self._gen_domain_id()
        sysname = "domain%d" % index

        props = ["security", security]

        if isinstance(bootacl, int):
            bootacl = "," * (bootacl - 1)

        if bootacl is not None:
            props += ["boot_acl", bootacl]

        if iommu is not None:
            props += ["iommu_dma_protection", str(iommu) + "\n"]

        syspath = self.testbed.add_device(
            "thunderbolt", sysname, None, props, ["DEVTYPE", "thunderbolt_domain"]
        )
        domain = Domain(index, sysname, syspath)
        self.domains[sysname] = domain
        return domain

    def host_add(self, domain, uid, name, vendor, generation):
        sysname = "%d-0" % domain.index
        attributes = [
            "device_name",
            name,
            "device",
            "0x23",
            "vendor_name",
            vendor,
            "vendor",
            "0x23",
            "authorized",
            "1",
            "unique_id",
            uid,
        ]
        if generation is not None:
            attributes += ["generation", str(generation)]

        syspath = self.testbed.add_device(
            "thunderbolt",
            sysname,
            domain.syspath,
            attributes,
            ["DEVTYPE", "thunderbolt_device"],
        )
        host = Device(uid, sysname, syspath)
        host.is_host = True
        domain.host = host
        host.domain = domain
        return host

    def device_add(self, parent, uid, name, vendor, authorized, key, boot, gen):
        domain = parent.domain
        serial = domain.gen_serial()
        sysname = "%d-%d" % (domain.index, serial)

        props = [
            "device_name",
            name,
            "device",
            "0x23",
            "vendor_name",
            vendor,
            "vendor",
            "0x23",
            "authorized",
            "%d\n" % authorized,
            "unique_id",
            uid,
        ]

        if key is not None:
            # The kernel always returns the key with trailing `\n`
            if not key.endswith("\n"):
                key += "\n"
                props += ["key", key]

        if boot is not None:
            props += ["boot", str(boot)]

        if gen is not None:
            props += ["generation", str(gen)]

        syspath = self.testbed.add_device(
            "thunderbolt",
            sysname,
            parent.syspath,
            props,
            ["DEVTYPE", "thunderbolt_device"],
        )

        device = Device(uid, sysname, syspath)
        device.domain = domain
        parent.children.append(device)

        return device

    def remove(self, dev):
        self.testbed.uevent(dev.syspath, "remove")
        self.testbed.remove_device(dev.syspath)
        dev.syspath = None
        dev.sysname = None


class Store:
    def __init__(self, path=None):
        self.path = path
        self.mocked = False
        if self.path is None:
            self.make_mock_store()

    def __del__(self):
        if self.mocked:
            shutil.rmtree(self.path)

    def make_mock_store(self):
        path = tempfile.mkdtemp()
        os.makedirs(os.path.join(path, "devices"))
        os.makedirs(os.path.join(path, "domains"))
        os.makedirs(os.path.join(path, "keys"))
        self.path = path
        self.mocked = True


class Daemon:
    def __init__(self, store, sysfs):
        self._discover_binary()
        self.rundir = tempfile.mkdtemp()
        self.store = store
        self.sysfs = sysfs
        self.log = None
        self._boltd = None

    def __del__(self):
        if self.is_running:
            self.stop()
        shutil.rmtree(self.rundir)

    @staticmethod
    def path_from_service_file(sf):
        with open(sf) as f:
            for line in f:
                if not line.startswith("Exec="):
                    continue
                return line.split("=", 1)[1].strip()
        return None

    def _discover_binary(self):
        if "BOLT_BUILD_DIR" in os.environ:
            print("Using boltd from local build")
            build_dir = os.environ["BOLT_BUILD_DIR"]
            boltd = os.path.join(build_dir, "boltd")
            boltctl = os.path.join(build_dir, "boltctl")
        elif "UNDER_JHBUILD" in os.environ:
            print("Using boltd from JHBuild")
            jhbuild_prefix = os.environ["JHBUILD_PREFIX"]
            boltd = os.path.join(jhbuild_prefix, "libexec", "boltd")
            boltctl = os.path.join(jhbuild_prefix, "bin", "boltctl")
        elif os.path.exists(os.path.abspath("build/boltd")):
            build_dir = os.path.abspath("build")
            print("Using boltd from %s" % build_dir)
            boltd = os.path.join(build_dir, "boltd")
            boltctl = os.path.join(build_dir, "boltctl")
        else:
            print("Using boltd from system installation")
            boltd = Daemon.path_from_service_file(SERVICE_FILE)
            boltctl = shutil.which("boltctl")

        assert boltd is not None, "failed to find daemon"
        assert os.access(boltctl, os.X_OK), "could not execute @ " + boltctl

        self.paths = {"daemon": boltd, "boltctl": boltctl}
        self.dbus = Gio.bus_get_sync(Gio.BusType.SYSTEM, None)

    # dbus helper methods
    def _get_dbus_property(self, name, interface=DBUS_IFACE_MANAGER):
        proxy = Gio.DBusProxy.new_sync(
            self.dbus,
            Gio.DBusProxyFlags.DO_NOT_AUTO_START,
            None,
            DBUS_NAME,
            DBUS_PATH,
            "org.freedesktop.DBus.Properties",
            None,
        )
        return proxy.Get("(ss)", interface, name)

    def start(self, args):
        timeout = 10
        env = os.environ.copy()
        env["G_DEBUG"] = "fatal-criticals"
        env["UMOCKDEV_DIR"] = self.sysfs.root
        env["STATE_DIRECTORY"] = self.store.path
        env["RUNTIME_DIRECTORY"] = self.rundir
        argv = [self.paths["daemon"], "--replace"]
        if args.verbose:
            argv += ["-v"]
        self._boltd = subprocess.Popen(
            argv, env=env, stdout=self.log, stderr=subprocess.STDOUT
        )

        timeout_count = timeout * 10
        timeout_sleep = 0.1
        while timeout_count > 0:
            time.sleep(timeout_sleep)
            timeout_count -= 1
            try:
                self._get_dbus_property("Version")
                break
            except GLib.GError:
                pass
        else:
            timeout_time = timeout * 10 * timeout_sleep
            print("daemon did not start in %d seconds" % timeout_time)
            self._boltd = None
            return

        if self._boltd.poll() is not None:
            print("daemon crashed :(")
            self._boltd = None

    def stop(self):
        if self._boltd:
            try:
                self._boltd.terminate()
            except OSError:
                pass
            self._boltd.wait()
        self._boltd = None

    @property
    def is_running(self):
        return self._boltd is not None


class Command:
    class SubCommand:
        def __init__(self, name, desc, parser, handler):
            self.name = name
            self.desc = desc
            self.parser = parser
            self.handler = handler

        def __call__(self, obj, argv, parser):
            if argv.show_help:
                print(self.parser.format_help())
            else:
                self.handler(obj, argv, self.parser)

    def __init__(self, handler):
        self.handler = handler
        self.name = handler.__name__
        self.desc = handler.__doc__ or self.name
        self.parser = argparse.ArgumentParser(
            prog=self.name, description=self.desc, add_help=False
        )
        self.parser.add_argument("--help", dest="show_help", action="store_true")
        self.subparsers = None
        self.use_parse_known = False

    def __call__(self, obj, argv):
        if self.use_parse_known:
            args, rest = self.parser.parse_known_args(argv)
        else:
            args = self.parser.parse_args(argv)
            rest = None
        context = {"parser": self.parser, "unknown": rest}
        if args.show_help:
            print(self.parser.format_help())
        elif hasattr(args, "func"):
            args.func(obj, args, context)
        else:
            self.handler(obj, args, context)

    def subcommand(self, handler):
        if self.subparsers is None:
            self.subparsers = self.parser.add_subparsers()
        name = handler.__name__
        if not name.startswith(self.name + "_"):
            raise ValueError("Invalid naming scheme")
        name = name[len(self.name + "_") :]
        desc = handler.__doc__ or self.name
        p = self.subparsers.add_parser(name, help=desc, add_help=False)
        p.add_argument("--help", dest="show_help", action="store_true")
        cmd = Command.SubCommand(name, desc, p, handler)
        p.set_defaults(func=cmd)
        return cmd

    @staticmethod
    def subcommand_dispatch(obj, args):
        args.func(obj, args)


def command(func):
    cmd = Command(func)
    return cmd


def arg(name_or_flags, *args, **kwargs):
    def decorator(func):
        func.parser.add_argument(name_or_flags, *args, **kwargs)
        return func

    return decorator


def parse_known(func):
    func.use_parse_known = True
    return func


def collect_args(klass):
    methods = [getattr(klass, m) for m in dir(klass) if not m.startswith("__")]
    commands = [m for m in methods if isinstance(m, Command)]
    klass.commands = {m.name: m for m in commands}
    return klass


@collect_args
class World:
    def __init__(self, store, sysfs, daemon):
        self.store = store
        self.sysfs = sysfs
        self.daemon = daemon

    def handle_line(self, line):
        if not line:
            return True
        if line == "exit":
            return False
        if line.startswith("#"):
            return True

        argv = shlex.split(line)
        idx = [i for i, arg in enumerate(argv) if arg.startswith("#")]
        if len(idx):
            argv = argv[: idx[0]]

        cmd = self.commands.get(argv[0], None)
        if cmd is None:
            print("unknown command")
            return True
        try:
            cmd(self, argv[1:])
        except SystemExit:
            return True
        return True

    def loop(self):
        do_loop = True
        while do_loop:
            try:
                line = input("> ")
            except EOFError:
                print("Bye.")
                break

            do_loop = self.handle_line(line)

    @command
    def help(self, args, context):
        for name, cmd in self.commands.items():
            print("%s - %s " % (name, cmd.desc))

    @arg("--verbose", action="store_true")
    @command
    def start(self, args, context):
        """start the bolt daemon"""
        if self.daemon.is_running:
            print("boltd already running")
        else:
            print("starting boltd")
            self.daemon.start(args)

    @command
    def stop(self, args, context):
        """stops the bolt daemon"""
        if not self.daemon.is_running:
            print("boltd not running")
        else:
            print("stopping boltd")
            self.daemon.stop()

    @command
    def status(self, args, context):
        """Show overall status"""
        print("sysfs: %s" % self.sysfs.root)
        print("store: %s" % self.store.path)
        print("boltd: %s" % self.daemon.paths["daemon"])
        print("rundir: %s" % self.daemon.rundir)
        print("boltd %s" % ("running" if self.daemon.is_running else "stopped"))

    @arg("file", type=argparse.FileType("r"), default=None)
    @command
    def load(self, args, context):
        """Load and execute a commands from a file"""
        for line in args.file:
            line = line.strip()
            print("@ %s" % line)
            self.handle_line(line)

    @command
    def controller(self, args, context):
        """control controllers. hah!"""
        parser = context["parser"]
        print(parser.format_help())

    @arg("--security", type=str, default="secure")
    @arg("--uuid", type=str, default=None)
    @arg("--iommu", type=str, choices=[None, "0", "1"], default=None)
    @arg("--bootacl", type=str, default=None)
    @arg("--generation", type=int, default=3)
    @arg("--vendor", type=str, default="GNOME.org")
    @arg("--name", type=str, default="Laptop")
    @controller.subcommand
    def controller_new(self, args, context):
        """create a new domain+host combination"""
        security = args.security
        bootacl = args.bootacl
        iommu = args.iommu
        gen = args.generation
        uid = args.uuid or str(uuid.uuid4())
        print(uid)

        domain = self.sysfs.domain_add(security, bootacl, iommu)
        domain.show()

        name = args.name
        vendor = args.vendor

        host = self.sysfs.host_add(domain, uid, name, vendor, gen)
        host.show()

    @arg("id", type=str, help="name or uuid")
    @controller.subcommand
    def controller_rm(self, args, parser):
        """removes a new domain+host combination"""
        target = args.id
        domain = self.sysfs.find_domain(target)
        if domain is None:
            print("Could not find controller")
            return
        if domain.host is not None:
            self.sysfs.remove(domain.host)
        self.sysfs.remove(domain)

    @command
    def device(self, args, parser):
        """control devices"""
        print(parser.format_help())

    @arg("--key", type=str, default=None)
    @arg("--boot", type=int, choices=[None, 0, 1], default=None)
    @arg("--authorized", type=int, choices=[0, 1], default=0)
    @arg("--uuid", type=str, default=None)
    @arg("--vendor", type=str, default="GNOME.org")
    @arg("--generation", type=int, default=3)
    @arg("name", type=str)
    @arg("parent", type=str)
    @device.subcommand
    def device_new(self, args, context):
        """create a new device combination"""
        parent = self.sysfs.find_device(args.parent)
        if parent is None:
            print("unknown parent")
            return

        uid = args.uuid or str(uuid.uuid4())

        device = self.sysfs.device_add(
            parent,
            uid,
            args.name,
            args.vendor,
            args.authorized,
            args.key,
            args.boot,
            args.generation,
        )

        device.show()

    @arg("id", type=str, help="name or uuid")
    @device.subcommand
    def device_rm(self, args, context):
        """removes a device"""
        device = self.sysfs.find_device(args.id)
        if device is None:
            print("unknown device")
            return
        print("disconnecting %s" % device.uid)
        self.sysfs.remove(device)

    @parse_known
    @command
    def boltctl(self, args, context):
        """Invoke boltctl"""
        boltctl = self.daemon.paths["boltctl"]
        bc_args = context["unknown"] or []
        argv = [boltctl] + bc_args
        subprocess.call(argv)


def main():
    readline.set_history_length(1000)
    histfile = os.path.join(os.path.expanduser("~"), ".cache", "bolt_mock")

    try:
        readline.read_history_file(histfile)
    except FileNotFoundError:
        pass

    atexit.register(readline.write_history_file, histfile)

    store = Store()
    sysfs = MockSysfs()
    daemon = Daemon(store, sysfs)

    ctrl = World(store, sysfs, daemon)
    ctrl.loop()


if __name__ == "__main__":
    if "umockdev" not in os.environ.get("LD_PRELOAD", ""):
        wrapped = ["umockdev-wrapper"] + sys.argv
        os.execvp(wrapped[0], wrapped)

    main()
