#!/usr/local/bin/python2.7
#
# Compare git repositories or fast-import streams for differences.
#
# Requires git and tar(1).
#
# SPDX-License-Identifier: BSD-2-Clause

# This code runs under both Python 2 and Python 3: preserve this property!
from __future__ import print_function

import sys, os, getopt, subprocess, tempfile, hashlib, shutil, time, calendar, re

DEBUG_GENERAL  = 1
DEBUG_COMMANDS = 2
DEBUG_LOWLEVEL = 3

class Fatal(Exception):
    "Unrecoverable error."
    def __init__(self, msg):
        Exception.__init__(self)
        self.msg = msg

def do_or_die(dcmd, legend=""):
    "Either execute a command or raise a fatal exception."
    if legend:
        legend = " "  + legend
    if verbose >= DEBUG_COMMANDS:
        sys.stdout.write("repodiffer: executing '%s'%s\n" % (dcmd, legend))
    try:
        retcode = subprocess.call(dcmd, shell=True)
        if retcode < 0:
            raise Fatal("repodiffer: child was terminated by signal %d.\n" % -retcode)
        elif retcode != 0:
            raise Fatal("repodiffer: child returned %d.\n" % retcode)
    except (OSError, IOError) as e:
        raise Fatal("repodiffer: execution of %s%s failed: %s\n" % (dcmd, legend, e))

def capture_or_die(dcmd, legend=""):
    "Either execute a command and capture its output or die."
    if legend:
        legend = " "  + legend
    if verbose >= DEBUG_COMMANDS:
        sys.stdout.write("repodiffer: executing '%s'%s\n" % (dcmd, legend))
    try:
        return subprocess.check_output(dcmd, shell=True).decode('ascii')
    except subprocess.CalledProcessError as e:
        if e.returncode < 0:
            raise Fatal("repodiffer: child was terminated by signal %d." % -e.returncode)
        elif e.returncode != 0:
            sys.stderr.write("repodiffer: child returned %d." % e.returncode)
        sys.exit(1)

class directory_context:
    def __init__(self, target):
        self.target = target
        self.source = None
    def __enter__(self):
        if verbose >= DEBUG_COMMANDS:
            sys.stdout.write("repodiffer: in %s...\n" % self.target)
        self.source = os.getcwd()
        if os.path.isdir(self.target):
            os.chdir(self.target)
        else:
            enclosing = os.path.dirname(self.target)
            if enclosing:
                os.chdir()
    def __exit__(self, extype, value_unused, traceback_unused):
        os.chdir(self.source)

class Baton:
    "Ship progress indications to stdout."
    def __init__(self, prompt, endmsg='done', enable=False):
        self.prompt = prompt
        self.endmsg = endmsg
        self.countfmt = None
        self.counter = 0
        if enable:
            self.stream = sys.stdout
        else:
            self.stream = None
        self.count = 0
        self.time = 0
    def __enter__(self):
        if self.stream:
            self.stream.write(self.prompt + "...")
            if os.isatty(self.stream.fileno()):
                self.stream.write(" \b")
            self.stream.flush()
        self.count = 0
        self.time = time.time()
        return self
    def startcounter(self, countfmt, initial=1):
        self.countfmt = countfmt
        self.counter = initial
    def bumpcounter(self):
        if self.stream is None:
            return
        if os.isatty(self.stream.fileno()):
            if self.countfmt:
                update = self.countfmt % self.counter
                self.stream.write(update + ("\b" * len(update)))
                self.stream.flush()
            else:
                self.twirl()
        self.counter = self.counter + 1
    def endcounter(self):
        if self.stream:
            w = len(self.countfmt % self.count)
            self.stream.write((" " * w) + ("\b" * w))
            self.stream.flush()
        self.countfmt = None
    def twirl(self, ch=None):
        "One twirl of the baton."
        if self.stream is None:
            return
        if os.isatty(self.stream.fileno()):
            if ch:
                self.stream.write(ch)
                self.stream.flush()
                return
            else:
                update = "-/|\\"[self.count % 4]
                self.stream.write(update + ("\b" * len(update)))
                self.stream.flush()
        self.count = self.count + 1
    def __exit__(self, extype, value_unused, traceback_unused):
        if extype == KeyboardInterrupt:
            self.endmsg = "interrupted"
        if extype == Fatal:
            self.endmsg = "aborted by error"
        if self.stream:
            self.stream.write("...(%2.2f sec) %s.\n" \
                              % (time.time() - self.time, self.endmsg))
        return False

class Commit(object):
    __slots__ = ("repo", "text", "commit", "tree", "author", "committer",
                 "comment", "authordate", "commitdate", "matched",
                 "parents", "mark")
    def __init__(self, repo, text):
        self.repo = repo
        self.text = text
        self.commit = None
        self.tree = self.author = self.committer = self.comment =  None
        self.authordate = self.commitdate = None
        self.matched = False
        self.parents = []
        for line in text.split("\n"):
            if line.startswith("commit "):
                self.commit = line.strip().split()[1]
            elif line.startswith("parent "):
                self.parents.append(line.strip().split()[1])
            elif line.startswith("tree "):
                self.tree = line.strip().split()[1]
            elif line.startswith("author "):
                (self.author, self.authordate) = Commit.date_extract(line[7:])
            elif line.startswith("committer "):
                (self.committer, self.commitdate) = Commit.date_extract(line[10:])
            elif not line:
                if self.comment is None:
                    self.comment = ""
            else:
                if self.comment is not None:
                    self.comment += line + "\n"
            assert self.commit and self.parts()
        self.mark = None
    def timestamp(self):
        "UTC timestamp of commit."
        return int(self.commitdate.split()[0])
    @staticmethod
    def idname(fld):
        "Human part of user name."
        if "email" in ignores:
            return fld.split("<")[0].strip()
        else:
            return fld
    @staticmethod
    def email(fld):
        "Email part of user name."
        return fld.split("<")[1].split(">")[0]
    def name_stamp(self):
        "A timestamp/committer pair using full name.  Should be unique!"
        return(Commit.idname(self.committer), self.timestamp())
    def action_stamp(self):
        "A timestamp/committer pair using email name.  Should be unique!"
        return(Commit.email(self.committer), str(self.timestamp()))
    @staticmethod
    def date_extract(text):
        "Extract a git-format date"
        ee = text.rindex(">")
        return (text[:ee+1], text[ee+2:])
    @staticmethod
    def time_not_equal(a, b):
        "Time equality check, which sometimes wants to ignore timezone skew."
        if "timezone" in ignores:
            return int(a.split()[0]) % 360 != int(b.split()[0]) % 360
        else:
            return a != b
    def __str__(self):
        "Generate an ID for this commit that is useful to humans."
        if self.repo.streamed:
            return self.mark
        elif fullhash:
            return self.commit
        else:
            return self.commit[:7]
    def __repr__(self):
        r = "commit %s\n" % str(self)
        r += "committer %s %s\n" % (self.committer, self.commitdate)
        r += "author %s %s\n" % (self.author, self.authordate)
        for parent in self.parents:
            r += "parent %s\n" % parent
        r += "tree %s\n" % self.tree
        r += self.comment
        return r
    def parts(self):
        return (self.tree,
                self.author, self.authordate,
                self.committer, self.commitdate,
                self.comment)
    def difference(self, other):
        res = []
        # Alas, there's no better way to do this than with an ignores global
        if "author" not in ignores and \
               Commit.idname(self.author) != Commit.idname(other.author):
            res.append('author')
        if "authordate" not in ignores and \
               Commit.time_not_equal(self.authordate, other.authordate):
            res.append('authordate')
        if "committer" not in ignores and \
               Commit.idname(self.committer) != Commit.idname(other.committer):
            res.append('committer')
        if "commitdate" not in ignores \
               and Commit.time_not_equal(self.commitdate, other.commitdate):
            res.append('commitdate')
        if "comment" not in ignores and self.comment != other.comment:
            res.append('comment')
        if "tree" not in ignores and self.tree != other.tree:
            res.append('tree')
        if "parent" not in ignores:
            if len(self.parents) != len(other.parents):
                res.append("parents")
            else:
                for (a_id, b_id) in zip(self.parents, other.parents):
                    a = self.repo.fastlookup[a_id]
                    b = other.repo.fastlookup[b_id]
                    if a.matched is None or b.matched is None:
                        res.append("parents")
                        break
                    elif id(a.matched) != id(b) or id(b.matched) != id(a):
                        res.append("parents")
                        break
        return res
    def file_tree(self):
        "Return a temporary directory holding a checkout of this commit."
        scratchdir = tempfile.mkdtemp(prefix="rpds", suffix=str(os.getpid()))
        with directory_context(self.repo.directory):
            do_or_die("git archive --format=tar %s | (cd %s && tar xf -)" \
                      % (self.commit, scratchdir))
        return scratchdir
    def manifest(self):
        "Return a file manifest for this revision."
        with directory_context(self.repo.directory):
            paths = capture_or_die("git ls-tree -r --full-name --name-only %s" % self.commit)
        # Because older versions of git-archive barf on empty trees.
        if not paths.strip():
            return {}
        filemap = {}
        try:
            scratchdir = self.file_tree()
            with directory_context(self.repo.directory):
                do_or_die("git archive --format=tar %s | (cd %s && tar xf -)" \
                          % (self.commit, scratchdir))
                for path in paths.split():
                    if path:
                        if path.endswith(".gitignore") and "gitignore" in ignores:
                            continue
                        elif path.endswith(".cvsignore") and "gitignore" in ignores:
                            continue
                        target = os.path.join(scratchdir, path)
                        if os.path.islink(target):
                            filemap[path] = os.readlink(target)
                        elif os.path.isfile(target):
                            with open(target, "rb") as fp:
                                filemap[path] = hashlib.md5(fp.read()).digest()
                        else:
                            sys.stderr.write("repodiffer: ignoring non-file non-symlink path %s\n" % path)
        finally:
            shutil.rmtree(scratchdir)
        return filemap
    def __hash__(self):
        return hash(self.parts())
    def __eq__(self, other):
        return not self.difference(other)

class Tag(object):
    __slots__ = ("repo", "name", "ref", "comment", "tagger")
    def __init__(self, repo, name, ref, comment=None):
        self.repo = repo
        self.name = name
        self.ref = ref
        self.comment = comment
        self.tagger = None
    def metadata(self):
        return (self.name, self.comment, self.tagger)
    def difference(self, other):
        diff = []
        if "comment" not in ignores and self.comment != other.comment:
            diff.append("comment")
        # Implement zone-skew logic here
        if "tagger" not in ignores and self.tagger != other.tagger:
            diff.append("tagger")
        sc = self.referent()
        oc = other.referent()
        if not sc.matched or not oc.matched or oc.matched != self or sc.matched != other:
            diff.append("referent")
        return diff
    def referent(self):
        return self.repo.fastlookup[self.ref]
    def __str__(self):
        return str(self.__dict__)
    def __eq__(self, other):
        return not self.difference(other)

class Repository:
    "A repository, for comparison purposes."
    def __init__(self, sourcepath):
        self.path = sourcepath
        # Accept either a repository directory or an import stream
        if os.path.isdir(self.path):
            self.streamed = False
            self.directory = self.path
        elif os.path.isfile(self.path):
            self.streamed = True
            self.directory = tempfile.mkdtemp(prefix="rpd", suffix=str(os.getpid()))
            source = os.path.abspath(self.path)
            with directory_context(self.directory):
                do_or_die("git init --quiet; git fast-import --quiet --export-marks=.git/marks <%s" % source)
        else:
            raise Fatal("repodiffer: I don't know how to turn %s into a repo.")
        # Get a list of metadata and tree signatures for this repo
        with directory_context(self.directory):
            log = capture_or_die("git log --all --reverse --date-order --format=raw")
            stanzas = ["commit " + chunk for chunk in log.split("\ncommit ")]
            stanzas[0] = stanzas[0][7:]
            self.signature = [Commit(self, stanza) for stanza in stanzas]
            if verbose >= DEBUG_LOWLEVEL:
                print("Commits in %s" % self.path)
                for commit in self.signature:
                    print(commit.commit)
            self.mark_to_sha1 = {}
            self.sha1_to_mark = {}
            try:
                for line in open(".git/marks"):
                    (mark, sha1) = line.strip().split()
                    self.sha1_to_mark[sha1] = mark
                    self.mark_to_sha1[mark] = sha1
                for item in self.signature:
                    item.mark = self.sha1_to_mark[item.commit]
                    item.parents = [self.sha1_to_mark[h] for h in item.parents]
            except (IOError, OSError):
                pass
            self.fastlookup = {}
            for item in self.signature:
                self.fastlookup[item.commit] = item
                self.fastlookup[item.mark] = item
            # Get information about tags
            self.tags = {}
            for tagname in capture_or_die("git tag -l").split():
                tagdata = capture_or_die("git show --format=raw %s" % tagname)
                if tagdata.startswith("commit"):
                    refers_to = tagdata.split("\n")[0].split()[1]
                    self.tags[tagname] = Tag(self, tagname, refers_to)
                elif tagdata.startswith("tag"):
                    for line in tagdata.split('\n'):
                        if line.startswith("commit"):
                            refers_to = line.strip().split()[1]
                            break
                    else:
                        raise Fatal("no commit line in annotated tag")
                    tag = Tag(self, tagname, refers_to, comment="")
                    stanza = capture_or_die("git cat-file -p " + tagname)
                    body_latch = False
                    for line in stanza.split("\n")[:-1]:
                        if not body_latch:
                            if line == "":
                                body_latch = True
                            elif line.startswith("tagger"):
                                tag.tagger = line[7:]
                        else:
                            tag.comment += line + "\n"
                    self.tags[tagname] = tag
            # And about heads too
            for line in capture_or_die("git show-ref --head").split("\n"):
                if line:
                    (hashid, ref) = line.strip().split()
                    if ref.startswith("refs/heads"):
                        self.tags[ref] = Tag(self, ref, hashid)
    def lookup(self, name):
        "Look up an object by name."
        for commit in self.signature:
            if commit.commit.startswith(name) or commit.mark == name:
                return commit
        return None
    def sort(self):
        "Sort by Committer-Date in case commits were generated in topo order."
        self.signature.sort(key=lambda c: c.timestamp())
    def tagnames(self):
        return set(self.tags.keys())
    def __str__(self):
        return "<%s: %s>" % (self.path, self.signature)
    def __del__(self):
        if self.streamed:
            os.system("rm -fr " + self.directory)

class Difference:
    def __init__(self, a, b):
        self.difference_types = a.difference(b)
        self.left_only = self.right_only = self.different = None
        if 'tree' in self.difference_types and 'tree' not in ignores:
            a_files = a.manifest()
            b_files = b.manifest()
            a_fileset = set(a_files.keys())
            b_fileset = set(b_files.keys())
            self.left_only = tuple(a_fileset - b_fileset)
            self.right_only = tuple(b_fileset - a_fileset)
            self.different = tuple([fn for fn in (a_fileset & b_fileset)
                              if a_files[fn] != b_files[fn]])
            if not (self.left_only or self.right_only or self.different):
                self.difference_types.remove('tree')
    def __hash__(self):
        return hash((tuple(self.difference_types), self.different))
    def __eq__(self, other):
        return self.__dict__ == other.__dict__

def read_legacymap(fn):
    legacydict = {}
    for line in open(fn):
        (legacy, stamp) = line.split()
        (timefield, person) = stamp.split('!')
        if ':' in person:
            (person, seq) = person.split(':')
            seq = int(seq) - 1
        else:
            seq = 0
        assert legacy and timefield and person
        rfc3339date = time.strptime(timefield, "%Y-%m-%dT%H:%M:%SZ")
        timestamp = calendar.timegm(rfc3339date)
        # Presently, ignore CVS cookies as they can't be mapped to changesets
        if not legacy.startswith("CVS"):
            legacydict[(person, str(timestamp))] = legacy
    return legacydict

diffdict = {}

def report_pair(a, b, legacies):
    global n_equal, n_insertions, n_deletions, n_changed, spancount
    tell = ""
    if a is None:
        n_insertions += 1
        if not changes_only:
            if spancount and not showequal:
                tell += "%d equal pairs.\n" % spancount
                spancount = 0
            tell += "R only: %s\n" % b
        return tell
    if b is None:
        n_deletions += 1
        if not changes_only:
            if spancount and not showequal:
                tell += "%d equal pairs.\n" % spancount
                spancount = 0
            tell += "L only: %s\n" % a
        return tell
    diff = Difference(a, b)
    if not diff.difference_types:
        n_equal += 1
        spancount += 1
        if showequal:
            tell += "equal: %s == %s" % (a, b)
            legacy = legacies and legacies.get(a.action_stamp())
            if legacy:
                tell += " (" + legacy + ")"
            tell += "\n"
        return tell
    if spancount and not showequal:
        tell += "%d equal pairs.\n" % spancount
        spancount = 0
    n_changed += 1
    legacy = legacies and legacies.get(a.action_stamp())
    if legacy:
        legacyid = "(" + legacy + ") "
    else:
        legacyid = ""
    if diff in diffdict:
        tell += "%s%s -> %s: same differences as for %s.\n" % (legacyid, a, b, diffdict[diff])
    else:
        tell += "changed: %s%s -> %s in %s.\n" % (legacyid, a, b, ",".join(diff.difference_types))
        if "parents" in diff.difference_types:
            def parent_show(c, p):
                commit = c.repo.fastlookup[p]
                legacy = legacies and legacies.get(commit.action_stamp())
                if legacy:
                    legacyid = "(" + legacy + ") "
                else:
                    legacyid = ""
                return legacyid + str(commit)
            tell += " parents: "
            tell += ", ".join([parent_show(a, parent) for parent in a.parents])
            tell += " -> "
            tell += ", ".join([parent_show(b, parent) for parent in b.parents])
            tell += "\n"
        else:
            diffdict[diff] = "%s -> %s" % (a, b)
        if diff.left_only:
            tell += "L only:\n"
            for fn in diff.left_only:
                tell += ("  " + fn + "\n")
        if diff.right_only:
            tell += "R only:\n"
            for fn in diff.right_only:
                tell += ("  " + fn + "\n")
        if diff.different:
            tell += "Differing files:\n"
            if not treediff:
                for fn in diff.different:
                    tell += ("  " + fn + "\n")
            else:
                try:
                    a_tree = a.file_tree()
                    b_tree = b.file_tree()
                    command = "diff -r  -u %s %s || :" % (a_tree, b_tree)
                    difftext = capture_or_die(command)
                    difftext = difftext.replace("--- " + a_tree + "/", "--- ")
                    difftext = difftext.replace("+++ " + b_tree + "/", "+++ ")
                    difftext = re.sub("\ndiff .*\n", "\n", "\n" + difftext)[1:]
                    tell += difftext
                finally:
                    shutil.rmtree(a_tree)
                    shutil.rmtree(b_tree)
    return tell

def diff_single(single, alpha, beta, legacies):
    "Examine the difference between two specified commits."
    global ignores
    try:
        (a, b) = single.split("=")
    except ValueError:
        sys.stderr.write("repodiffer: need hashes or marks separated by =\n")
        sys.exit(1)
    ac = alpha.lookup(a)
    if ac is None:
        sys.stderr.write("repodiffer: no match for %s\n" % a)
        sys.exit(1)
    bc = beta.lookup(b)
    if bc is None:
        sys.stderr.write("repodiffer: no match for %s\n" % b)
        sys.exit(1)
    ignores += "parents"
    sys.stdout.write(report_pair(ac, bc, legacies))
    return 1 if a and b and Difference(a, b).different else 0

def diff_all(alpha, beta, legacies):
    "Walk through time-ordered commit pairs, finding differences."
    alpha.sort()
    beta.sort()
    i = 0
    tied = []
    duplicates =  []
    content_differences = 0
    for a in alpha.signature:
        for b in beta.signature:
            if a.name_stamp() == b.name_stamp():
                # Throw out duplicates so all our pairs consist of
                # unique matches.
                if a.name_stamp() in duplicates:
                    continue
                elif a.matched or b.matched:
                    continue
                tied.append((a, b))
                a.matched = b
                b.matched = a
    i = j = 0
    # Build the pairing table.
    pairs = []
    while i < len(alpha.signature) and j < len(beta.signature):
        a = alpha.signature[i]
        b = beta.signature[j]
        if (a, b) in tied:
            pairs.append((a, b))
            i += 1
            j += 1
        elif a.matched and b.matched:
            sys.stderr.write("repodiffer: matches are out of order at %d:%d\n" \
                             % (i+1, j+1))
            sys.exit(1)
        if a.matched and not b.matched:
            pairs.append((None, b))
            j += 1
        if b.matched and not a.matched:
            pairs.append((a, None))
            i += 1
        if not a.matched and not b.matched:
            if a.timestamp() <= b.timestamp():
                pairs.append((a, None))
                pairs.append((None, b))
            else:
                pairs.append((a, None))
                pairs.append((None, b))
            i += 1
            j += 1
    with Baton("repodiffer:", enable=(not quiet and verbose==0)) as baton:
        baton.startcounter("%%d of %d" % len(pairs))
        # Report on the pair table
        say = ""
        for (a, b) in pairs:
            say += report_pair(a, b, legacies)
            if a and b and Difference(a, b).different:
                content_differences += 1
            baton.bumpcounter()
    sys.stdout.write(say)
    if n_changed or n_deletions or n_insertions:
        sys.stdout.write("statistics: %d equal, %d of %d inserted, %d of %d deleted, %d changed.\n"
                         % (n_equal,
                            n_insertions,
                            len(beta.signature),
                            n_deletions,
                            len(alpha.signature),
                            n_changed))
    if not changes_only:
        if alpha.tagnames() - beta.tagnames():
            sys.stdout.write("tags only in L: %s\n" % ",".join(list(alpha.tagnames() - beta.tagnames())))
        if beta.tagnames() - alpha.tagnames():
            sys.stdout.write("tags only in R: %s\n" % ",".join(list(beta.tagnames() - alpha.tagnames())))
    common = alpha.tagnames() & beta.tagnames()
    for tagname in common:
        a = alpha.tags[tagname]
        b = beta.tags[tagname]
        if a.metadata() != b.metadata() or \
               not alpha.fastlookup[a.ref].matched or \
               not beta.fastlookup[b.ref].matched:
            sys.stdout.write("Tag '%s': different in %s.\n" % (tagname, ",".join(a.difference(b))))
        elif showequal:
            sys.stdout.write("Tag '%s': same.\n" % tagname)
    return 1 if content_differences > 0 else 0

if __name__ == '__main__':
    if sys.hexversion < 0x02070000:
        sys.stderr.write("repodiffer: requires Python 2.7 or later.")
        sys.exit(1)
    (options, arguments) = getopt.getopt(sys.argv[1:], "cefi:m:qs:tv:", ["changes-only", "show-equal", "full-hash", "ignore=", "legacy-map=", "single=", "quiet", "tree-diff", "verbose="])
    alpha = None
    beta = None
    changes_only = False
    showequal = False
    fullhash = False
    ignores = ""
    legacymap = None
    single = None
    quiet = False
    treediff = False
    verbose = 0
    n_equal = n_insertions = n_deletions = n_changed = 0
    spancount = 0
    for (opt, val) in options:
        if opt == '-c' or opt == '--changes-only':
            changes_only = True
        elif opt == '-e' or opt == '--show-equal':
            showequal = True
        elif opt == '-f' or opt == '--full-hash':
            fullhash = True
        elif opt == '-i' or opt == '--ignore':
            ignores = val
        elif opt == '-m' or opt == '--legacy-map':
            legacymap = val
        elif opt == '-s' or opt == '--single':
            single = val
        elif opt == '-q' or opt == '--quiet':
            quiet = True
        elif opt == '-t' or opt == '--tree-diff':
            treediff = True
        elif opt == '-v' or opt == '--verbose':
            verbose = int(val)
    if len(arguments) < 2:
        sys.stderr.write("repodiffer: input and output directory or import-stream arguments are required.\n")
        sys.exit(1)
    else:
        (alpha, beta) = arguments[:2]
    if not os.path.exists(alpha):
        sys.stderr.write("repodiffer: %s must exist.\n" % alpha)
        sys.exit(1)
    if not os.path.exists(beta):
        sys.stderr.write("repodiffer: %s must exist.\n" % beta)
        sys.exit(1)
    try:
        alpha = Repository(alpha)
        beta  = Repository(beta)
        legacies = legacymap and read_legacymap(legacymap)
        if single:
            sys.exit(diff_single(single, alpha, beta, legacies))
        else:
            sys.exit(diff_all(alpha, beta, legacies))
    except Fatal as err:
        sys.stderr.write(err.msg + "\n")
        del alpha
        del beta
        sys.exit(1)
    except KeyboardInterrupt:
        pass

# The following sets edit modes for GNU EMACS
# Local Variables:
# mode:python
# End:
# end
