#!/usr/bin/python3

import glob
import os
import string
import subprocess
import sys
import unittest
import dataclasses
import errno
import time
import socket
import tempfile

import importlib.machinery
import importlib.util

import parent
import testlib
import testvm

sys.dont_write_bytecode = True
os.environ['PYTHONUNBUFFERED'] = '1'


@dataclasses.dataclass
class Test:
    test_id: int
    command: list
    process: subprocess.Popen = None
    retries: int = 0
    output: bytes = b""


def print_test(test, print_tap=True):
    for line in test.output.splitlines(keepends=True):
        sys.stdout.buffer.write(line)
    sys.stdout.flush()

    if not print_tap:
        return

    if test.process.returncode == 0:
        print("ok {0} {1} {2}".format(test.test_id, test.command[0], test.command[-1]))
    elif test.process.returncode == 77 or b"# SKIP " in test.output:
        # If the test was skipped, add the last line (which contains the reason
        # for the skip) to the result
        print("ok {0} {1} {2} {3}".format(test.test_id, test.command[0], test.command[-1],
                                          test.output.splitlines()[-1].strip().decode()
                                          if test.process.returncode == 77 else ""))
    else:
        print("not ok {0} {1} {2}".format(test.test_id, test.command[0], test.command[-1]))
    sys.stdout.flush()

def finish_test(opts, test):
    """Returns if a test should retry or not

    Call test-policy on the test's output, print if needed.
    """

    if test.process.returncode != 1:
        print_test(test, not opts.list)
        return False, 0

    if not opts.thorough:
        cmd = ["tests-policy", testvm.DEFAULT_IMAGE]
        try:
            test.output += "not ok {0} {1} {2}".format(test.test_id, test.command[0], test.command[-1]).encode()
            proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
            changed = proc.communicate(test.output)[0]
            if proc.returncode == 0:
                if test.output != changed:
                    changed += b"\n"
                test.output = changed
        except OSError as ex:
            if ex.errno != errno.ENOENT:
                sys.stderr.write("Couldn't run tests-policy: {0}\n".format(str(ex)))

        if b"# SKIP" in test.output:
            print_test(test, print_tap=False)
            return False, 0

    if test.retries < 2 or b"# RETRY" in test.output:
        test.retries += 1
        test.output += " # RETRY {0}\n".format(test.retries).encode()
        print_test(test, print_tap=opts.thorough)
        return True, 0
    else:
        print_test(test, print_tap=opts.thorough)
    return False, 1

def check_valid(filename):
    name = os.path.basename(filename)
    allowed = string.ascii_letters + string.digits + '-_'
    if not all(c in allowed for c in name):
        return None
    return name.replace("-", "_")

def build_command(filename, test, opts):
    cmd = [filename]
    if opts.sit:
        cmd.append("-s")
    if opts.trace:
        cmd.append("-t")
    if opts.verbosity:
        cmd.append("-v")
    if not opts.fetch:
        cmd.append("--nonet")
    if opts.list:
        cmd.append("-l")
    cmd.append(test)
    return cmd

def run(opts):
    # Build the list of tests we'll parallellize and the ones we'll run serially
    test_loader = unittest.TestLoader()
    parallel_tests = []
    serial_tests = []
    test_id = 1
    result = 0
    jobs = 1 if opts.list else opts.jobs
    start_time = time.time()

    for filename in glob.glob(os.path.join(os.path.dirname(__file__), "check-*")):
        name = check_valid(filename)
        if not name or not os.path.isfile(filename):
            continue
        loader = importlib.machinery.SourceFileLoader(name, filename)
        module = importlib.util.module_from_spec(importlib.util.spec_from_loader(loader.name, loader))
        loader.exec_module(module)
        for test_suite in test_loader.loadTestsFromModule(module):
            for test in test_suite:
                test_method = getattr(test.__class__, test._testMethodName)
                test_str = "{0}.{1}".format(test.__class__.__name__, test._testMethodName)
                if opts.tests and not any([t in test_str for t in opts.tests]):
                    continue
                test = Test(test_id, build_command(filename, test_str, opts))
                if getattr(test_method, "_testlib__non_destructive", False):
                    serial_tests.append(test)
                else:
                    parallel_tests.append(test)
                test_id += 1

    print("1..{0}".format(len(parallel_tests) + len(serial_tests)))
    sys.stdout.flush()

    if serial_tests and not opts.list:
        testlib.MachineCase.get_global_machine()
        for test in serial_tests:
            test.command.insert(-2, "--machine")
            test.command.insert(-2, "{0}:{1}".format(testlib.MachineCase.get_global_machine().ssh_address,
                                                 testlib.MachineCase.get_global_machine().ssh_port))
            test.command.insert(-2, "--browser")
            test.command.insert(-2, "{0}:{1}".format(testlib.MachineCase.get_global_machine().web_address,
                                                 testlib.MachineCase.get_global_machine().web_port))

    running_tests = []
    serial_test = None
    while serial_tests or parallel_tests or running_tests:
        made_progress = False
        if len(running_tests) < jobs:
            test = None
            if serial_tests and serial_test is None:
                test = serial_tests.pop(0)
                serial_test = test
            elif parallel_tests:
                test = parallel_tests.pop(0)

            if test:
                made_progress = True
                test.outfile = tempfile.TemporaryFile()
                test.process = subprocess.Popen(test.command, stdout=test.outfile, stderr=subprocess.STDOUT)
                running_tests.append(test)


        for test in running_tests.copy():
            poll_result = test.process.poll()
            if poll_result is not None:
                made_progress = True
                test.outfile.seek(0)
                test.output = test.outfile.read()
                test.outfile.close()
                running_tests.remove(test)
                retry, test_result = finish_test(opts, test)
                result += test_result

                # run again if needed
                if retry:
                    test.output = None
                    test.process = None
                    if test is serial_test:
                        serial_tests.insert(0, test)
                    else:
                        parallel_tests.insert(0, test)

                if test is serial_test:
                    serial_test = None

        # Sleep if we didn't make progress
        if not made_progress and not opts.list:
            time.sleep(0.5)

    testlib.MachineCase.kill_global_machine()

    duration = int(time.time() - start_time)
    hostname = socket.gethostname().split(".")[0]
    details = "[{0}s on {1}]".format(duration, hostname)
    print()
    if result > 0:
        print("# {0} TESTS FAILED {1}".format(result, duration))
    else:
        print("# TESTS PASSED {0}".format(details))
    print("Test run finished, return code: {0}".format(result))

    return result


def main():
    parser = testlib.arg_parser()
    parser.add_argument('--publish', action='store', help="Unused")
    parser.add_argument('-j', '--jobs', dest="jobs", type=int,
                        default=os.environ.get("TEST_JOBS", 1), help="Number of concurrent jobs")
    parser.add_argument('--thorough', dest='thorough', action='store_true',
                        help='Thorough mode, no skipping known issues')
    opts = parser.parse_args()

    if opts.sit and opts.jobs > 1:
        parser.error("the -s or --sit argument not avalible with multiple jobs")

    image = testvm.DEFAULT_IMAGE
    revision = os.environ.get("TEST_REVISION")
    test_browser = os.environ.get("TEST_BROWSER", "chromium")
    if not revision:
        revision = subprocess.check_output(["git", "rev-parse", "HEAD"],
                                           universal_newlines=True).strip()

    # Tell any subprocesses what we are testing
    os.environ["TEST_REVISION"] = revision
    os.environ["TEST_BROWSER"] = test_browser
    testvm.DEFAULT_IMAGE = image
    os.environ["TEST_OS"] = image

    return run(opts)


if __name__ == '__main__':
    sys.exit(main())
