#!/usr/bin/python
#
# Copyright (C) 2013 by Red Hat, Inc.  All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Author(s): Brian C. Lane <bcl@brianlane.com>
#
"""
Driver Update Disk UI

/tmp/dd_modules is a copy of /proc/modules at startup time
/tmp/dd_args is a parsed list of the inst.dd= cmdline args, and may include
             'dd' or 'inst.dd' if it was specified without arguments

Pass a path and it will install the driver rpms from the path before checking
for new OEMDRV devices.

Repositories for installed drivers are copied into /run/install/DD-X where X
starts at 1 and increments for each repository.

Selected driver package names are saved in /run/install/dd_packages

Anaconda uses the repository and package list to install the same set of drivers
to the target system.
"""
import logging
from logging.handlers import SysLogHandler
import sys
import os
import subprocess
import time

log = logging.getLogger("DD")


class RunCmdError(Exception):
    """ Raised when run_cmd gets a non-zero returncode
    """
    pass


def run_cmd(cmd):
    """ Run a command, collect stdout and the returncode. stderr is ignored.

        :param cmd: command and arguments to run
        :type cmd:  list of strings
        :returns:   exit code and stdout from the command
        :rtype:     (int, string)
        :raises:    OSError if the cmd doesn't exist, RunCmdError if the rc != 0
    """
    try:
        with open("/dev/null", "w") as fd_null:
            log.debug(" ".join(cmd))
            proc = subprocess.Popen(cmd,
                                    stdout=subprocess.PIPE,
                                    stderr=fd_null)
            out = proc.communicate()[0]
            if out:
                for line in out.splitlines():
                    log.debug(line)
    except OSError as e:
        log.error("Error running %s: %s" % (cmd[0], e.strerror))
        raise
    if proc.returncode:
        log.debug("%s returned %s" % (cmd[0], proc.returncode))
        raise RunCmdError()
    return (proc.returncode, out)


def oemdrv_list():
    """ Get a list of devices labeled as OEMDRV

        :returns: list of devices
        :rtype:   list
    """
    try:
        outlines = run_cmd(["blkid", "-t", "LABEL=OEMDRV", "-o", "device"])[1]
    except (OSError, RunCmdError):
        # Nothing with that label
        return []
    else:
        return outlines.splitlines()


def is_dd(dd_path):
    """ Determine if a path is a valid DD for the current arch

        :param dd_path: Path to the DD directory to test
        :type dd_path:  string
        :returns:       True if it is a valid DD
        :rtype:         bool
    """
    arch = os.uname()[4]
    return os.path.exists(dd_path+"/rhdd3") and \
       not os.path.isdir(dd_path+"/rpms/%s" % arch)


def get_dd_args():
    """ Get the dd arguments from /tmp/dd_args

        :returns: List of arguments
        :rtype:   list of strings
    """
    try:
        dd_args = open("/tmp/dd_args", "r").readline().split()
    except IOError:
        return []

    # skip dd args that need networking
    net_protocols = ["http", "https", "ftp", "nfs", "nfs4"]
    return filter(lambda x: x.split(":")[0].lower() not in net_protocols, dd_args)


def is_interactive():
    """ Determine if the user requested interactive driver selection

        :returns: True if 'dd' or 'inst.dd' included in /tmp/dd_args False if not
        :rtype:   bool
    """
    dd_args = get_dd_args()
    if "dd" in dd_args or "inst.dd" in dd_args:
        return True
    else:
        return False


def umount(device):
    """ Unmount the device

        :param device: Device to unmount
        :type device:  string
        :returns:      None
    """
    try:
        run_cmd(["umount", device])
    except (OSError, RunCmdError):
        pass


def mount_device(device, mnt="/media/DD/"):
    """ Mount a device and check to see if it really is a driver disk

        :param device: path to device to mount
        :type device:  string
        :param mnt:    path to mount the device on
        :type mnt:     string
        :returns:      True if it is a DD, False if not
        :rtype:        bool

        It is unmounted if it is not a DD and left mounted if it is.
    """
    try:
        run_cmd(["mount", device, mnt])
    except (OSError, RunCmdError):
        return False
    return True


def copy_repo(dd_path, dest_prefix):
    """ Copy the current arch's repository to a unique destination

        :param dd_path:     Path to the driver repo directory
        :type dd_path:      string
        :param dest_prefix: Destination directory prefix, a number is added
        :type dest_prefix:  string
        :returns:           None

        The destination directory names are in the order that the drivers
        were loaded, starting from 1
    """
    suffix = 1
    while os.path.exists(dest_prefix+str(suffix)):
        suffix += 1
    dest = dest_prefix+str(suffix)
    os.makedirs(dest)
    try:
        run_cmd(["cp", "-ar", dd_path, dest])
    except (OSError, RunCmdError):
        pass


def copy_file(src, dest):
    """ Copy a file

        :param src:  Source file
        :type src:   string
        :param dest: Destination file
        :type dest:  string
        :returns:    None
    """
    try:
        run_cmd(["cp", "-a", src, dest])
    except (OSError, RunCmdError):
        pass


def move_file(src, dest):
    """ Move a file

        :param src:  Source file
        :type src:   string
        :param dest: Destination file
        :type dest:  string
        :returns:    None
    """
    try:
        run_cmd(["mv", "-f", src, dest])
    except (OSError, RunCmdError):
        pass


def find_dd(mnt="/media/DD"):
    """ Find all suitable DD repositories under a path

        :param mnt: Top of the directory tree to search
        :type mnt:  string
        :returns:   list of DD repositories
        :rtype:     list
    """
    dd_repos = []
    arch = os.uname()[4]
    for root, dirs, files in os.walk(mnt, followlinks=True):
        if "rhdd3" in files and "rpms" in dirs and \
          os.path.exists(root+"/rpms/"+arch):
            dd_repos.append(root+"/rpms/"+arch)
    log.debug("Found repos - %s" % " ".join(dd_repos))
    return dd_repos


def get_module_set(fname):
    """ Read a module list and return a set of the names

        :param fname: Full path to filename
        :type fname:  string
        :returns:     set of the module names
    """
    modules = set()
    if os.path.exists(fname):
        with open(fname, "r") as f:
            for line in f:
                mod_args = line.strip().split()
                if mod_args:
                    modules.update([mod_args[0]])
    return modules


def reload_modules():
    """ Reload new module versions from /lib/modules/<kernel>/updates/
    """
    try:
        run_cmd(["depmod", "-a"])
    except (OSError, RunCmdError):
        pass

    # Make a list of modules added since startup
    startup_modules = get_module_set("/tmp/dd_modules")
    startup_modules.update(["virtio_blk", "virtio_net"])
    current_modules = get_module_set("/proc/modules")
    new_modules = current_modules.difference(startup_modules)
    log.debug("new_modules = %s" % " ".join(new_modules))

    # I think we can just iterate once using modprobe -r to remove unused deps
    for module in new_modules:
        try:
            run_cmd(["modprobe", "-r", module])
        except (OSError, RunCmdError):
            pass

    time.sleep(2)

    # Reload the modules, using the new versions from /lib/modules/<kernel>/updates/
    try:
        run_cmd(["udevadm", "trigger"])
    except (OSError, RunCmdError):
        pass


class Driver(object):
    def __init__(self):
        self.source = ""
        self.name = ""
        self.flags = ""
        self.description = []
        self.selected = False

    @property
    def args(self):
        return ["--%s" % a for a in self.flags.split()]

    @property
    def rpm(self):
        return self.source


def fake_drivers(num):
    """ Generate a number of fake drivers for testing
    """
    drivers = []
    for i in xrange(0, num):
        d = Driver()
        d.source = "driver-%d" % i
        d.flags = "modules"
        drivers.append(d)
    return drivers


def dd_list(dd_path, kernel_ver=None, anaconda_ver=None):
    """ Build a list of the drivers in the directory

        :param dd_path: Path to the driver repo
        :type dd_path:  string
        :returns:       list of drivers
        :rtype:         Driver object

        By default none of the drivers are selected
    """
    if not kernel_ver:
        kernel_ver = os.uname()[2]
    if not anaconda_ver:
        anaconda_ver = "19.0"

    try:
        outlines = run_cmd(["dd_list", "-k", kernel_ver, "-a", anaconda_ver, "-d", dd_path])[1]
    except (OSError, RunCmdError):
        return []

    # Output format is:
    #   source rpm\n
    #   name\n
    #   flags\n
    #   description (multi-line)\n
    #   ---\n
    drivers = []
    new_driver = Driver()
    line_idx = 0
    for line in outlines.splitlines():
        log.debug(line)
        if line == "---":
            drivers.append(new_driver)
            new_driver = Driver()
            line_idx = 0
        elif line_idx == 0:
            new_driver.source = line
            line_idx += 1
        elif line_idx == 1:
            new_driver.name = line
            line_idx += 1
        elif line_idx == 2:
            new_driver.flags = line
            line_idx += 1
        elif line_idx == 3:
            new_driver.description.append(line)

    return drivers


def dd_extract(driver, dest_path="/updates/", kernel_ver=None):
    """ Extract a driver rpm to a destination path

        :param driver:    Driver to extract
        :type driver:     Driver object
        :param dest_path: Top directory of the destination path
        :type dest_path:  string
        :returns:         None

        This extracts the driver's files into 'dest_path' (which defaults
        to /updates/ so that the normal live updates handling will overlay
        any binary or library updates onto the initrd automatically.
    """
    if not kernel_ver:
        kernel_ver = os.uname()[2]

    cmd = ["dd_extract", "-k", kernel_ver]
    cmd += driver.args
    cmd += ["--rpm", driver.rpm, "--directory", dest_path]
    log.info("Extracting files from %s" % driver.rpm)
    try:
        run_cmd(cmd)
    except (OSError, RunCmdError):
        log.error("dd_extract failed, skipped %s" % driver.rpm)
        return

    # Create the destination directories
    initrd_updates = "/lib/modules/" + os.uname()[2] + "/updates/"
    ko_updates = dest_path + initrd_updates
    initrd_firmware = "/lib/firmware/updates/"
    firmware_updates = dest_path + initrd_firmware
    for d in (initrd_updates, ko_updates, initrd_firmware, firmware_updates):
        if not os.path.exists(d):
            os.makedirs(d)

    # Copy *.ko files over to /updates/lib/modules/<kernel>/updates/
    for root, dirs, files in os.walk(dest_path+"/lib/modules/"):
        if root.endswith("/updates") and os.path.isdir(root):
            continue
        for f in (f for f in files if f.endswith(".ko")):
            src = root+"/"+f
            copy_file(src, ko_updates)
            move_file(src, initrd_updates)

    # Copy the firmware updates
    for root, dirs, files in os.walk(dest_path+"/lib/firmware/"):
        if root.endswith("/updates") and os.path.isdir(root):
            continue
        for f in (f for f in files):
            src = root+"/"+f
            copy_file(src, firmware_updates)
            move_file(src, initrd_firmware)


def select_drivers(drivers):
    """ Display pages of drivers to be loaded.

        :param drivers: Drivers to be selected by the user
        :type drivers:  list of Driver objects
        :returns:       None
    """
    page_length = 20
    page = 1
    while True:
        # show a page of drivers
        print("\nPage %d of %d" % (page, 1 + (len(drivers) / page_length)))
        print("Select drivers to install")
        for i in xrange(0, min(len(drivers), page_length)):
            driver_idx = ((page-1) * page_length) + i
            if drivers[driver_idx].selected:
                selected = "x"
            else:
                selected = " "
            print("%3d) [%s] %s" % (i+1, selected, drivers[driver_idx].source))

        # Select driver to toggle, continue or change pages
        idx = raw_input("\n# to toggle selection, 'n'-next page, 'p'-previous page or 'c'-continue: ")
        if idx.isdigit():
            if int(idx) < 1 or int(idx) > min(len(drivers), page_length):
                print("Invalid selection")
                continue
            driver_idx = ((page-1) * page_length) + int(idx) - 1
            drivers[driver_idx].selected = not drivers[driver_idx].selected
        elif idx.lower() == 'n':
            if page < 1 + (len(drivers) / page_length):
                page += 1
            else:
                print("Last page")
        elif idx.lower() == 'p':
            if page > 1:
                page -= 1
            else:
                print("First page")
        elif idx.lower() == 'c':
            return


def process_dd(dd_path):
    """ Handle installing modules, firmware, enhancements from the dd repo

        :param dd_path: Path to the driver repository
        :type dd_path:  string
        :returns:       None
    """
    drivers = dd_list(dd_path)
    log.debug("drivers = %s" % " ".join([d.rpm for d in drivers]))

    # If interactive mode or rhdd3.rules pass flag to deselect by default?
    if os.path.exists(dd_path+"/rhdd3.rules") or is_interactive():
        select_drivers(drivers)
        if not any((d.selected for d in drivers)):
            return
    else:
        map(lambda d: setattr(d, "selected", True), drivers)

    # Copy the repository for Anaconda to use during install
    if os.path.isdir(dd_path+"/repodata"):
        copy_repo(dd_path, "/updates/run/install/DD-")

    for driver in filter(lambda d: d.selected, drivers):
        dd_extract(driver, "/updates/")

        # Write the package names for all modules and firmware for Anaconda
        if os.path.isdir(dd_path+"/repodata") \
           and ("modules" in driver.flags or "firmwares" in driver.flags):
            with open("/run/install/dd_packages", "a") as f:
                f.write("%s\n" % driver.name)

    reload_modules()


def select_dd(device):
    """ Mount a device and check it for Driver Update repos

        :param device: Path to the device to mount and check
        :type device:  string
        :returns:      None
    """
    mnt = "/media/DD/"
    if not os.path.isdir(mnt):
        os.makedirs(mnt)
    if not mount_device(device, mnt):
        return

    dd_repos = find_dd(mnt)
    for repo in dd_repos:
        log.info("Processing DD repo %s on %s" % (repo, device))
        process_dd(repo)

    # TODO - does this need to be done before module reload?
    umount(device)


def network_driver(dd_path):
    """ Handle network driver download, then scan for new OEMDRV devices.

        :param dd_path: Path to the downloaded driver rpms
        :type dd_path:  string
        :returns:       None
    """
    skip_dds = set(oemdrv_list())

    log.info("Processing Network Drivers from %s" % dd_path)
    process_dd(dd_path)

    # TODO: May need to add new drivers to /tmp/dd_modules to prevent them from being unloaded

    # Scan for new OEMDRV devices
    dd_scan(skip_dds)


def dd_scan(skip_dds=None):
    """ Scan the system for OEMDRV devices and and specified by dd=/dev/<device>

        :param skip_dds: devices to skip when checking for OEMDRV label
        :type skip_dds:  set()
        :returns:        None
    """
    if skip_dds:
        dd_finished = skip_dds
        dd_todo = set(oemdrv_list()).difference(dd_finished)
        if dd_todo:
            log.info("Found new OEMDRV device(s) - %s" % ", ".join(dd_todo))
    else:
        dd_finished = set()
        dd_todo = set(oemdrv_list())

        # Add the user specified devices
        dd_devs = get_dd_args()
        dd_devs = [dev for dev in dd_devs if dev not in ("dd", "inst.dd")]
        dd_todo.update(dd_devs)
        log.info("Checking devices %s" % ", ".join(dd_todo))

    # Process each Driver Disk, checking for new disks after each one
    while dd_todo:
        device = dd_todo.pop()
        log.info("Checking device %s" % device)
        select_dd(device)
        dd_finished.update([device])
        new_oemdrv = set(oemdrv_list()).difference(dd_finished, dd_todo)
        if new_oemdrv:
            log.info("Found new OEMDRV device(s) - %s" % ", ".join(new_oemdrv))
        dd_todo.update(new_oemdrv)


if __name__ == '__main__':
    log.setLevel(logging.DEBUG)
    handler = SysLogHandler(address="/dev/log")
    log.addHandler(handler)
    handler = logging.StreamHandler()
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter("DD: %(message)s")
    handler.setFormatter(formatter)
    log.addHandler(handler)

    if len(sys.argv) > 1:
        network_driver(sys.argv[1])
    else:
        dd_scan()

