#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright (c) 2016 Mikkel Schubert <MSchubert@snm.ku.dk>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import print_function

import argparse
import bz2
import cStringIO
import difflib
import gzip
import json
import os
import re
import subprocess
import sys
import tempfile
import types


#############################################################################
_COLORS_ENABLED = True


def _do_print_color(*vargs, **kwargs):
    """Utility function: Prints using shell colors."""
    colorcode = kwargs.pop("colorcode")
    destination = kwargs.pop("file", sys.stdout)

    # No colors if output is redirected (e.g. less, file, etc.)
    if _COLORS_ENABLED and destination.isatty():
        vargs = list(vargs)
        for (index, varg) in enumerate(vargs):
            varg_lines = []
            # Newlines terminate the color-code for e.g. 'less', so ensure that
            # each line is color-coded, while preserving the list of arguments
            for line in str(varg).split("\n"):
                varg_lines.append("\033[00;%im%s\033[00m" % (colorcode, line))
            vargs[index] = "\n".join(varg_lines)

    print(*vargs, file=destination, **kwargs)

    if '\n' in kwargs.get('end', '\n'):
        destination.flush()


def print_ok(*vargs, **kwargs):
    """Equivalent to print, but prints using shell colorcodes (green)."""
    _do_print_color(*vargs, colorcode=32, **kwargs)


def print_err(*vargs, **kwargs):
    """Equivalent to print, but prints using shell colorcodes (red)."""
    _do_print_color(*vargs, colorcode=31, **kwargs)


#############################################################################
UNCOMPRESSED, GZIP, BZIP2 = "raw", "gz", "bz2"


def compress(value, compression):
    fileobj = cStringIO.StringIO()

    if compression == GZIP:
        handle = gzip.GzipFile('', 'w', 9, fileobj)
        handle.write(value)
        handle.close()

        return fileobj.getvalue()
    elif compression == BZIP2:
        return bz2.compress(value)
    else:
        assert False, compression


def decompress(filename):
    with open(filename) as handle:
        value = handle.read()

    if value and filename.endswith(".bz2"):
        if not value.startswith("BZ"):
            raise TestError("Expected bz2 file at %r, but header is %r"
                            % (filename, value[:2]))

        value = bz2.decompress(value)
    elif value and filename.endswith(".gz"):
        if not value.startswith("\x1f\x8b"):
            raise TestError("Expected gzip file at %r, but header is %r"
                            % (filename, value[:2]))

        fileobj = cStringIO.StringIO(value)
        handle = gzip.GzipFile('', 'r', 9, fileobj)

        value = handle.read()

    return cStringIO.StringIO(value).readlines()


#############################################################################
_EXEC = './build/AdapterRemoval'
_INFO_FILE = "info.json"
_INFO_FIELDS = {
    'arguments': types.ListType,
    'return_code': types.IntType,
    'stderr': types.ListType,
}


def pretty_output(s, padding=0, max_lines=float("inf")):
    padding = " " * padding
    lines = s.split("\n")
    if len(lines) > max_lines:
        lines = lines[:max_lines]
        lines.append("...")

    result = []
    for line in lines:
        result.append("%s>  %s" % (padding, line))

    return "\n".join(result)


def interleave(text_1, text_2):
    lines_1 = text_1.split("\n")
    lines_2 = text_2.split("\n")
    iters = (iter(lines_1).next, iter(lines_2).next)

    assert len(lines_1) == len(lines_2)

    result = []
    while True:
        try:
            for it in iters:
                for _ in xrange(4):
                    result.append(it())
        except StopIteration:
            break

    return "\n".join(result)


class TestError(StandardError):
    pass


class TestCase(object):
    def __init__(self, root, path):
        self.root = root
        self.path = path
        self.name = " :: ".join(path)
        self._files = self._collect_files(root)
        self._info = self._read_info(os.path.join(root, _INFO_FILE))

    def __repr__(self):
        return "TestCase(%r)" % ({'root': self.root,
                                  'name': self.name,
                                  'info': self._info,
                                  'files': self._files})

    def run(self, root):
        root = os.path.join(root, *self.path)

        interleaved_tests = [False]
        if self._is_paired():
            interleaved_tests.append(True)

        for in_compression in (UNCOMPRESSED, GZIP, BZIP2):
            for out_compression in (UNCOMPRESSED, GZIP, BZIP2):
                for interleaved in interleaved_tests:
                    yield "%s>%s%s" % (in_compression,
                                       out_compression,
                                       ",intl" if interleaved else "")

                    postfix = "%s_%s%s" % (in_compression,
                                           out_compression,
                                           "_intl" if interleaved else "")

                    self._do_run(os.path.join(root, postfix),
                                 in_compression, out_compression, interleaved)

    def _do_run(self, root, in_compression=UNCOMPRESSED,
                out_compression=UNCOMPRESSED, interleaved=False):
        assert in_compression in (UNCOMPRESSED, BZIP2, GZIP)
        assert out_compression in (UNCOMPRESSED, BZIP2, GZIP)
        os.makedirs(root)

        input_1, input_2 = self._setup_input(root, in_compression, interleaved)
        self._do_call(root, input_1, input_2, out_compression, interleaved)

        self._check_file_creation(root, input_1, input_2, out_compression)
        self._check_file_contents(root, out_compression)

    def _setup_input(self, root, compression, interleaved):
        input_files = {}
        for key in ("input_1", "input_2"):
            if self._files[key] is not None:
                input_files[key] = open(self._files[key]).read()

        if interleaved:
            input_files = {"input_1": interleave(input_files["input_1"],
                                                 input_files["input_2"])}

        final_files = {}
        for key, value in input_files.iteritems():
            filename = key + ".fastq"
            if compression != UNCOMPRESSED:
                filename += "." + compression
                value = compress(value, compression)

            with open(os.path.join(root, filename), "w") as handle:
                handle.write(value)

            final_files[key] = filename

        if 'barcodes' in self._files:
            with open(os.path.join(root, 'barcodes.txt'), 'w') as handle:
                handle.writelines(self._files['barcodes'])

        return final_files.get("input_1"), final_files.get("input_2")

    def _do_call(self, root, input_1, input_2, compression, interleaved):
        command = self._build_command(root, input_1, input_2, compression, interleaved)
        with open(os.devnull, "w") as dev_null:
            proc = subprocess.Popen(command,
                                    stdin=dev_null,
                                    stdout=subprocess.PIPE,
                                    stderr=subprocess.PIPE,
                                    close_fds=True,
                                    preexec_fn=os.setsid,
                                    cwd=root)

            stdout, stderr = proc.communicate()
            if stdout:
                raise TestError("Unexpected output to STDOUT: %r" % (stdout,))

            for value in self._info["stderr"]:
                if re.search(value, stderr) is None:
                    raise TestError("Expected value not found in output:\n"
                                    "  Searching for:\n%s\n  STDERR:\n%s"
                                    % (pretty_output(value, 4),
                                       pretty_output(stderr, 4, 5)))

            if proc.returncode != self._info["return_code"]:
                raise TestError("ERROR: Expected return-code %i, but "
                                "AdapterRemoval returned %i:\n%s"
                                % (self._info["return_code"],
                                   proc.returncode,
                                   pretty_output(stderr, 2, 5)))

    def _check_file_creation(self, root, input_1, input_2, compression):
        expected_files = set(self._files["output"])
        if 'barcodes' in self._files:
            expected_files.add('barcodes.txt')

        if compression != UNCOMPRESSED:
            expected_files_ = set()
            for value in expected_files:
                if not value.endswith(".settings") and value != 'barcodes.txt':
                    expected_files_.add(value + "." + compression)
                else:
                    expected_files_.add(value)
            expected_files = expected_files_

        observed_files = frozenset(os.listdir(root)) - frozenset((input_1,
                                                                  input_2))

        if expected_files - observed_files:
            raise TestError("ERROR: Expected output file(s) not created:\n"
                            "  Expected: %r\n  Observed: %r"
                            % (sorted(expected_files), sorted(observed_files)))
        elif observed_files - expected_files:
            raise TestError("ERROR: Unexpected output file(s) created: %r"
                            % (sorted(observed_files - expected_files)))

    def _check_file_contents(self, root, compression):
        for filename, exp_data in sorted(self._files["output"].iteritems()):
            obs_filename = os.path.join(root, filename)
            if compression != UNCOMPRESSED:
                if not filename.endswith(".settings"):
                    obs_filename += "." + compression

            exp_data = exp_data
            obs_data = decompress(obs_filename)

            if filename.endswith(".settings"):
                exp_data = self._mangle_settings(exp_data)
                obs_data = self._mangle_settings(obs_data)

            self._diff_file_pair_contents(os.path.join(self.root, filename),
                                          os.path.join(root, filename),
                                          exp_data, obs_data)

    def _diff_file_pair_contents(self, exp_filename, obs_filename, exp_data, obs_data):
        if exp_data != obs_data:
            lines = "".join(difflib.unified_diff(exp_data, obs_data,
                                                 "expected", "observed"))

            raise TestError("ERROR: Output file(s) differ:\n"
                            "  Expected: %r\n  Observed: %r\n  Diff:\n%s"
                            % (exp_filename, obs_filename,
                               pretty_output(lines, 4)))

    def _build_command(self, root, input_1, input_2, compression, interleaved):
        command = [os.path.abspath(_EXEC)]
        if interleaved:
            command.append("--interleaved-input")

        if 'barcodes' in self._files:
            command.extend(('--barcode-list', 'barcodes.txt'))

        if input_1 is not None or input_2 is not None:
            if compression == BZIP2:
                command.append("--bzip2")
            elif compression == GZIP:
                command.append("--gzip")

        if input_1 is not None:
            command.extend(('--file1', input_1))
        if input_2 is not None:
            command.extend(('--file2', input_2))

        command.extend(self._info['arguments'])

        return [field % {"ROOT": root} for field in command]

    def _is_paired(self):
        return (self._files["input_1"] is not None) \
            and (self._files["input_2"] is not None)

    @classmethod
    def _mangle_settings(cls, handle):
        result = []
        for line in handle:
            line = re.sub(r"RNG seed: [0-9]+", "RNG seed: NA", line)
            line = re.sub(r"AdapterRemoval ver. [0-9]+.[0-9]+.[0-9]+",
                          "AdapterRemoval ver. X.Y.Z", line)
            # Removed keyword to allow automatic interleaved tests
            line = re.sub(r" interleaved paired-end reads",
                          r" paired-end reads",
                          line)

            result.append(line)

        return result

    @classmethod
    def _collect_files(cls, root):
        result = {
            "input_1": None,
            "input_2": None,
            "output": {},
        }

        def read_lines(root, filename):
            return open(os.path.join(root, filename)).readlines()

        for filename in os.listdir(root):
            if filename.startswith("input_1.fastq"):
                if result["input_1"] is not None:
                    raise TestError('Duplicate input_1.fastq* files in %r'
                                    % (root,))
                result["input_1"] = os.path.abspath(os.path.join(root,
                                                                 filename))
            elif filename.startswith("input_2.fastq"):
                if result["input_2"] is not None:
                    raise TestError('Duplicate input_2.fastq* files in %r'
                                    % (root,))
                result["input_2"] = os.path.abspath(os.path.join(root,
                                                                 filename))
            elif filename not in ('info.json', 'README', 'barcodes.txt'):
                result["output"][filename] = read_lines(root, filename)
            elif filename == 'barcodes.txt':
                result["barcodes"] = read_lines(root, filename)

        return result

    @classmethod
    def _read_info(cls, filename):
        with open(filename) as handle:
            text = handle.read()
            raw_info = json.loads(text)

        if not isinstance(raw_info, types.DictType):
            raise TestError('\'info.json\' does not contain dictionary.')
        elif set(raw_info) - set(_INFO_FIELDS):
            raise TestError('\'info.json\' does contains unexpected fields; '
                            'expected keys %r, but found keys unknown keys %r.'
                            % (sorted(_INFO_FIELDS),
                               sorted(set(raw_info) - set(_INFO_FIELDS))))

        info = {"arguments": [],
                "return_code": 0,
                "stderr": []}
        info.update(raw_info)

        for key, expected_type in _INFO_FIELDS.items():
            if not isinstance(info[key], expected_type):
                raise TestError('Type of %r in \'info.json\' is %s, not a %s.'
                                % (key, type(info[key]), expected_type))
            elif isinstance(info[key], types.ListType):
                for value in info[key]:
                    if not isinstance(value, types.StringTypes):
                        raise TestError('Type of value in %r in \'info.json\' '
                                        'is %s, not a string.'
                                        % (key, type(value)))

        return info

    @classmethod
    def collect(cls, root, path=()):
        tests = []

        for filename in sorted(os.listdir(root)):
            filepath = os.path.join(root, filename)

            if os.path.isdir(filepath):
                tests.extend(cls.collect(filepath, path + (filename,)))
            elif filename == 'info.json':
                try:
                    test = TestCase(root, path)
                except StandardError, error:
                    print_err("ERROR: %s reading test %r: %s"
                              % (error.__class__.__name__,
                                 ' :: '.join(path), error))
                    continue

                tests.append(test)

        return tests


def parse_args(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('work_dir',
                        help="Directory in which to run test-cases.")
    parser.add_argument('source_dir',
                        help="Directory containing test-cases.")

    return parser.parse_args(argv)


def main(argv):
    args = parse_args(argv)
    if not os.path.exists(_EXEC):
        print_err("ERROR: Executable does not exist: %r" % (_EXEC,))
        return 1

    print('Reading test-cases from %r' % (args.source_dir,))
    tests = TestCase.collect(args.source_dir)

    args.work_dir = tempfile.mkdtemp(dir=args.work_dir)
    print('Writing test-cases results to %r' % (args.work_dir,))

    n_failures = 0
    print('\nRunning tests:')
    for idx, test in enumerate(tests, start=1):
        print("  %i of %i: %s " % (idx, len(tests), test.name), end='')
        label = 'unknown'

        try:
            for label in test.run(args.work_dir):
                print_ok(".", end="")
                sys.stdout.flush()
        except TestError, error:
            n_failures += 1
            message = "\n    ".join(str(error).split("\n"))
            print_err(" %s for %s:\n    %s" % (error.__class__.__name__,
                                               label, message))
        else:
            print_ok(" [OK]")

    if n_failures:
        print_err("\n%i of %i tests failed .." % (n_failures, len(tests)))
    else:
        print_ok("\nAll %i tests succeeded .." % (len(tests),))

    return 0


if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))
