#!/usr/bin/python
#
#       Test fragments of example code to ensure they behave as
#       advertised.
#
#       Functionally, this is similar to doctest, however our input
#       is more regular than that seen by doctest, and therefore our
#       parsing is more relaxed.
#

import sys, os, traceback, types

PS1 = '>>>'
PS2 = '...'

class DummyFile:
    """
    A simple file-like class that captures output writen to it and
    regurgitates it on demand.
    """

    def __init__(self):
        self.buf = []

    def write(self, s):
        self.buf.append(s)
    
    def flush(self):
        pass

    def get(self):
        result = "".join(self.buf)
        self.buf = []
        return result


class ExampleFailure(Exception):
    pass


class ExampleTest:
    def __init__(self, filename):
        self.filename = filename
        self.result = ""
        self.example = []

    def report(self):
        if self.result:
            print "=" * 70
            print "Testing \"%s\" failed," % self.filename,
            print self.result
            print

    def test(self):
        try:
            if not self.example:
                self.load_and_parse()
            self.result = ""
            self.execute_verify_output()
        except ExampleFailure, msg:
            self.result = msg

        return self.result == ""


    def load_and_parse(self):
        try:
            lines = open(self.filename).readlines()
        except IOError, (eno, estr):
            raise ExampleFailure("could not load: %s" % estr)

        src, want = "", ""
        cmd_lineno = 0
        cur_lineno = 0
        while lines:
            line = lines.pop(0)
            cur_lineno += 1

            if line.startswith(PS1):
                if src:
                    self.example.append((src, cmd_lineno, want))
                cmd_lineno = cur_lineno
                src = line[4:]
                want = ""
            elif line.startswith(PS2):
                if not src:
                    raise ExampleFailure("Line %d: %s line with no preceeding %s line" % (cur_lineno, PS2, PS1))
                src += line[4:]
            else:
                if not src:
                    raise ExampleFailure("Line %s: output with no preceeding %s line" % (cur_lineno, PS1))
                want += line
        if src:
            self.example.append((src, cmd_lineno, want))


    def execute_verify_output(self):
        saved_stdout = sys.stdout
        sys.stdout = dummyfile = DummyFile()
        namespace = {}
        for src, lineno, want in self.example:
            got = None
            try:
                code = compile(src, '<string>', 'single')
                exec code in namespace
                got = dummyfile.get()
            except:
                # Is it a wanted exception?
                if want.find("Traceback (innermost last):\n") == 0 or \
                   want.find("Traceback (most recent call last):\n") == 0:
                    # Only compare exception type and value - the rest of
                    # the traceback isn't necessary.
                    exc_type, exc_val = sys.exc_info()[:2]
                    got = traceback.format_exception_only(exc_type, exc_val)[-1]
                    want = want.split('\n')[-2].strip() + '\n'
                else:
                    exc_type, exc_val, exc_tb = sys.exc_info()
                    if type(exc_type) == types.ClassType:
                        exc_type = exc_type.__name__
                    exc_tb = exc_tb.tb_next
                    psrc = PS1 + " " + src.rstrip().replace("\n", PS2 + " ")
                    pexp = "".join(traceback.format_exception(exc_type, exc_val, exc_tb)).rstrip()

                    self.result = "Line %d:\n%s\n%s" % (lineno, psrc, pexp)
                    del exc_tb
                    break

            if got != want:
                self.result = "Line %d: output does not match example\nExpected:\n%s\nGot:\n%s" % (lineno, want.rstrip(), got.rstrip())
                break

        sys.stdout = saved_stdout


def files_from_directory(dir):
    filenames = []
    for fn in os.listdir(dir):
        filename = os.path.join(dir, fn)
        if fn[0] != '.' and os.path.isfile(filename) \
               and not filename.endswith('.py'):
            filenames.append(filename)
    filenames.sort()
    return filenames


def main():
    if len(sys.argv) > 1:
        filenames = sys.argv[1:]
    else:
        filenames = files_from_directory("doctest")

    fail_cnt = total_cnt = 0
    failures = []

    for filename in filenames:
        if os.path.isfile(filename):
            total_cnt += 1
            e = ExampleTest(filename)
            if not e.test():
                failures.append(e)
                fail_cnt += 1
                sys.stdout.write("!")
            else:
                sys.stdout.write(".")
            sys.stdout.flush()
    sys.stdout.write("\n")

    for e in failures:
        e.report()

    print "%d of %d tests failed" % (fail_cnt, total_cnt)

    sys.exit(fail_cnt > 0)


if __name__ == '__main__':
    main()

