#!/usr/libexec/platform-python

import re
import subprocess
import sys

class Change(object):
    def __init__(self):
        self.commit = ''
        self.author_name = ''
        self.author_email = ''
        self.author_date = 0
        self.subject = ''
        self.body = ''
        self.number = 0
        self.change_id = ''
        self.reviewed_on = ''
        self.lustre_commit = ''
        self.lustre_change = ''
        self.lustre_change_number = 0
        self.cray_bug_id = ''
        self.hpe_bug_id = ''
        self._parent = self
        self._rank = 0

    def _find(self):
        if self._parent != self:
            self._parent = self._parent._find()

        return self._parent

    def _union(self, c2):
        r1 = self._find()
        r2 = c2._find()
        if r1._rank > r2._rank:
            r2._parent = r1
        elif r1._rank < r2._rank:
            r1._parent = r2
        elif r1 != r2:
            r2._parent = r1
            r1._rank += 1


GIT_LOG_FIELDS = ['commit', 'author_name', 'author_email', 'author_date', 'subject', 'body']
GIT_LOG_KEYS = ['%H', '%an', '%ae', '%at', '%s', '%b']
GIT_LOG_FORMAT = '%x1f'.join(GIT_LOG_KEYS) + '%x1e'

def _change_from_record(rec):
    change = Change()
    change.__dict__.update(dict(list(zip(GIT_LOG_FIELDS, rec.split('\x1f')))))
    change.author_date = int(change.author_date)
    for line in change.body.splitlines():
        # Sometimes we have 'key : value' so we strip both sides.
        lis = line.split(':', 1)
        if len(lis) == 2:
            key = lis[0].strip()
            val = lis[1].strip()
            if key in ['Change-Id', 'Reviewed-on', 'Lustre-commit', 'Lustre-change', 'Cray-bug-id', 'HPE-bug-id']:
                change.__dict__[key.replace('-', '_').lower()] = val

    obj = re.match(r'[A-Za-z]+://[\w\.]+/(\d+)$', change.reviewed_on)
    if obj:
        change.number = int(obj.group(1))

    obj = re.match(r'[A-Za-z]+://[\w\.]+/(\d+)$', change.lustre_change)
    if obj:
        change.lustre_change_number = int(obj.group(1))

    return change


def _head(lis):
    if lis:
        return lis[0]
    else:
        return None


class Branch(object):
    def __init__(self, name, paths):
        self.name = name
        self.paths = paths
        self.log = [] # Oldest commit is first.
        self.by_commit = {} # str -> change
        self.by_subject = {} # str -> list of changes
        self.by_change_id = {} # str -> list of changes
        self.by_number = {} # str -> list of changes

    def _add_change_from_record(self, rec):
        # TODO Handle reverted commits.
        change = _change_from_record(rec)
        self.log.append(change)
        assert change.commit
        assert change.commit not in self.by_commit
        self.by_commit[change.commit] = change

        assert change.subject
        lis = self.by_subject.setdefault(change.subject, [])
        # XXX Do we want this?
        # if lis:
        #    lis[0]._union(change)
        lis.append(change)

        for bug_id in (change.cray_bug_id, change.hpe_bug_id):
            if bug_id and (' ' in change.subject):
                # Split subject in to issue and rest.
                issue, rest = change.subject.split(None, 1)
                # Make new subject using external bug id
                subject = ' '.join((bug_id, rest))
                lis = self.by_subject.setdefault(subject, [])
                lis.append(change)

        # Equivalate by change_id.
        if change.change_id:
            lis = self.by_change_id.setdefault(change.change_id, [])
            if lis:
                lis[0]._union(change)
            lis.append(change)

        # Equivalate by number (from reviewed_on).
        if change.number:
            lis = self.by_number.setdefault(change.number, [])
            if lis:
                lis[0]._union(change)
            lis.append(change)

    def load(self):
        self.log = []
        self.by_commit = {}
        self.by_subject = {}
        self.by_change_id = {}
        self.by_number = {}

        git_base = ['git'] # [, '--git-dir=' + self.path + '/.git']
        # rc = subprocess.call(git_base + ['fetch', 'origin'])
        # assert rc == 0

        pipe = subprocess.Popen(git_base + ['log',
                                            '--format=' + GIT_LOG_FORMAT,
                                            '--reverse',
                                            self.name
                                            ] + self.paths,
                                stdout=subprocess.PIPE,
                                text=True)
        out, _ = pipe.communicate()
        rc = pipe.wait()
        assert rc == 0

        for rec in out.split('\x1e\n'):
            if rec:
                self._add_change_from_record(rec)

    def find_port(self, change):
        # Try to find a port of change in this branch. change may or
        # may not belong to branch.
        #
        # TODO Return oldest member of equivalence class.
        port = (self.by_commit.get(change.commit) or
                self.by_commit.get(change.lustre_commit) or
                self.by_commit.get(change.lustre_change) or # Handle misuse.
                _head(self.by_change_id.get(change.change_id)) or
                _head(self.by_change_id.get(change.lustre_commit)) or # ...
                _head(self.by_change_id.get(change.lustre_change)) or
                _head(self.by_number.get(change.number)) or # Do we need this?
                _head(self.by_number.get(change.lustre_change_number)) or
                _head(self.by_subject.get(change.subject))) # Do we want this?
        if port:
            return port._find()
        else:
            return None


def branch_comm(b1, b2):
    n1 = len(b1.log)
    n2 = len(b2.log)
    i1 = 0
    i2 = 0
    printed = set() # commits

    def change_is_printed(c):
        return (c.commit in printed) or (c.lustre_commit in printed)

    def change_set_printed(c):
        printed.add(c.commit)
        if c.lustre_commit:
            printed.add(c.lustre_commit)

    # Suppress initial common commits.
    while i1 < n1 and i2 < n2:
        # XXX Should we use _find() on c1 and c2 here?
        # XXX Or c1 = b1.find_port(c1)?
        c1 = b1.log[i1]
        c2 = b2.log[i2]
        if c1.commit == c2.commit:
            i1 += 1
            i2 += 1
            continue
        else:
            break

    while i1 < n1 and i2 < n2:
        c1 = b1.log[i1]
        if change_is_printed(c1):
            i1 += 1
            continue

        c2 = b2.log[i2]
        if change_is_printed(c2):
            i2 += 1
            continue

        p1 = b1.find_port(c2)
        if p1 and change_is_printed(p1):
            change_set_printed(c2)
            i2 += 1
            continue

        p2 = b2.find_port(c1)
        if p2 and change_is_printed(p2):
            change_set_printed(c1)
            i1 += 1
            continue

        # Neither of c1 and c2 has been printed, nor has any port or either.

        # XXX Do we need c1._find() here?
        if c1 == p1 or c2 == p2:
            # c1 and c2 are ports of the same change.
            change_set_printed(c1)
            change_set_printed(c2)
            if p1:
                change_set_printed(p1)
            if p2:
                change_set_printed(p2)
            i1 += 1
            i2 += 1
            # c1 is common to both branches.
            print('\t\t%s\t%s' % (c1.commit, c1.subject)) # TODO Add a '*' if subjects different...
            continue

        if p1 and not p2:
            # b1 has c2, b2 does not have c1, (port of c2 must be after c1).
            change_set_printed(c1)
            i1 += 1
            # c1 is unique to b1.
            print('%s\t\t\t%s' % (c1.commit, c1.subject))
            continue

        if p2 and not p1:
            # b2 has c1, b1 does not have c2, (port of c1 must be after c2).
            change_set_printed(c2)
            i2 += 1
            # c2 is unique to b2.
            print('\t%s\t\t%s' % (c2.commit, c2.subject))
            continue

        # Now neither is ported or both are ported (and the order is weird).
        if p2:
            change_set_printed(c1)
            change_set_printed(p2)
            i1 += 1
            # c1 is common to both branches.
            print('\t\t%s\t%s' % (c1.commit, c1.subject))
            continue
        else:
            change_set_printed(c1)
            i1 += 1
            # c1 is unique to b1.
            print('%s\t\t\t%s' % (c1.commit, c1.subject))
            continue

    for c1 in b1.log[i1:]:
        if change_is_printed(c1):
            continue

        assert i2 == n2
        # All commits from b2 have been printed. Therefore if c1 has
        # been ported to b2 then the port has already been printed. So
        # c1 is unique to b1 and must be printed.

        change_set_printed(c1)
        print('%s\t\t\t%s' % (c1.commit, c1.subject))

    for c2 in b2.log[i2:]:
        if change_is_printed(c2):
            continue

        assert i1 == n1
        # ...
        change_set_printed(c2)
        print('\t%s\t\t%s' % (c2.commit, c2.subject))


USAGE = """usage: '_PROGNAME_ BRANCH1 BRANCH2 [PATH]...'

Compare commits to Lustre branches.

Prints commits unique to BRANCH1 in column 1.
Prints commits unique to BRANCH2 in column 2.
Prints commits common to both branches in column 3.
Prints commit subject in column 4.
Skips initial common commits.

The output format is inspired by comm(1). To filter commits by branch,
pipe the output to awk. For example:
  $ ... | awk -F'\\t' '$1 != ""' # only commits unique to BRANCH1
  $ ... | awk -F'\\t' '$2 != ""' # only commits unique to BRANCH2
  $ ... | awk -F'\\t' '$3 != ""' # only common commits
  $ ... | awk -F'\\t' '$3 == ""' # exclude common commmits

This assumes that both branches are in the repository that contains
the current directory. To compare branches from different upstream
repositories (for example 'origin/master' and 'other/b_post_cmd3') do:

  $ cd fs/lustre-release
  $ git fetch origin
  $ git remote add other ...
  $ git fetch other
  $ _PROGNAME_ origin/master other/b_post_cmd3"""


def main():
    if len(sys.argv) < 3:
        print(USAGE.replace('_PROGNAME_', sys.argv[0]), file=sys.stderr)
        sys.exit(1)

    paths = sys.argv[3:]

    b1 = Branch(sys.argv[1], paths)
    b1.load()

    b2 = Branch(sys.argv[2], paths)
    b2.load()

    branch_comm(b1, b2)


if __name__ == '__main__':
    main()

