summaryrefslogtreecommitdiff
path: root/gitMergeCommon.py
diff options
context:
space:
mode:
Diffstat (limited to 'gitMergeCommon.py')
-rw-r--r--gitMergeCommon.py275
1 files changed, 275 insertions, 0 deletions
diff --git a/gitMergeCommon.py b/gitMergeCommon.py
new file mode 100644
index 0000000000..fdbf9e4778
--- /dev/null
+++ b/gitMergeCommon.py
@@ -0,0 +1,275 @@
+#
+# Copyright (C) 2005 Fredrik Kuivinen
+#
+
+import sys, re, os, traceback
+from sets import Set
+
+def die(*args):
+ printList(args, sys.stderr)
+ sys.exit(2)
+
+def printList(list, file=sys.stdout):
+ for x in list:
+ file.write(str(x))
+ file.write(' ')
+ file.write('\n')
+
+import subprocess
+
+# Debugging machinery
+# -------------------
+
+DEBUG = 0
+functionsToDebug = Set()
+
+def addDebug(func):
+ if type(func) == str:
+ functionsToDebug.add(func)
+ else:
+ functionsToDebug.add(func.func_name)
+
+def debug(*args):
+ if DEBUG:
+ funcName = traceback.extract_stack()[-2][2]
+ if funcName in functionsToDebug:
+ printList(args)
+
+# Program execution
+# -----------------
+
+class ProgramError(Exception):
+ def __init__(self, progStr, error):
+ self.progStr = progStr
+ self.error = error
+
+ def __str__(self):
+ return self.progStr + ': ' + self.error
+
+addDebug('runProgram')
+def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True):
+ debug('runProgram prog:', str(prog), 'input:', str(input))
+ if type(prog) is str:
+ progStr = prog
+ else:
+ progStr = ' '.join(prog)
+
+ try:
+ if pipeOutput:
+ stderr = subprocess.STDOUT
+ stdout = subprocess.PIPE
+ else:
+ stderr = None
+ stdout = None
+ pop = subprocess.Popen(prog,
+ shell = type(prog) is str,
+ stderr=stderr,
+ stdout=stdout,
+ stdin=subprocess.PIPE,
+ env=env)
+ except OSError, e:
+ debug('strerror:', e.strerror)
+ raise ProgramError(progStr, e.strerror)
+
+ if input != None:
+ pop.stdin.write(input)
+ pop.stdin.close()
+
+ if pipeOutput:
+ out = pop.stdout.read()
+ else:
+ out = ''
+
+ code = pop.wait()
+ if returnCode:
+ ret = [out, code]
+ else:
+ ret = out
+ if code != 0 and not returnCode:
+ debug('error output:', out)
+ debug('prog:', prog)
+ raise ProgramError(progStr, out)
+# debug('output:', out.replace('\0', '\n'))
+ return ret
+
+# Code for computing common ancestors
+# -----------------------------------
+
+currentId = 0
+def getUniqueId():
+ global currentId
+ currentId += 1
+ return currentId
+
+# The 'virtual' commit objects have SHAs which are integers
+shaRE = re.compile('^[0-9a-f]{40}$')
+def isSha(obj):
+ return (type(obj) is str and bool(shaRE.match(obj))) or \
+ (type(obj) is int and obj >= 1)
+
+class Commit(object):
+ __slots__ = ['parents', 'firstLineMsg', 'children', '_tree', 'sha',
+ 'virtual']
+
+ def __init__(self, sha, parents, tree=None):
+ self.parents = parents
+ self.firstLineMsg = None
+ self.children = []
+
+ if tree:
+ tree = tree.rstrip()
+ assert(isSha(tree))
+ self._tree = tree
+
+ if not sha:
+ self.sha = getUniqueId()
+ self.virtual = True
+ self.firstLineMsg = 'virtual commit'
+ assert(isSha(tree))
+ else:
+ self.virtual = False
+ self.sha = sha.rstrip()
+ assert(isSha(self.sha))
+
+ def tree(self):
+ self.getInfo()
+ assert(self._tree != None)
+ return self._tree
+
+ def shortInfo(self):
+ self.getInfo()
+ return str(self.sha) + ' ' + self.firstLineMsg
+
+ def __str__(self):
+ return self.shortInfo()
+
+ def getInfo(self):
+ if self.virtual or self.firstLineMsg != None:
+ return
+ else:
+ info = runProgram(['git-cat-file', 'commit', self.sha])
+ info = info.split('\n')
+ msg = False
+ for l in info:
+ if msg:
+ self.firstLineMsg = l
+ break
+ else:
+ if l.startswith('tree'):
+ self._tree = l[5:].rstrip()
+ elif l == '':
+ msg = True
+
+class Graph:
+ def __init__(self):
+ self.commits = []
+ self.shaMap = {}
+
+ def addNode(self, node):
+ assert(isinstance(node, Commit))
+ self.shaMap[node.sha] = node
+ self.commits.append(node)
+ for p in node.parents:
+ p.children.append(node)
+ return node
+
+ def reachableNodes(self, n1, n2):
+ res = {}
+ def traverse(n):
+ res[n] = True
+ for p in n.parents:
+ traverse(p)
+
+ traverse(n1)
+ traverse(n2)
+ return res
+
+ def fixParents(self, node):
+ for x in range(0, len(node.parents)):
+ node.parents[x] = self.shaMap[node.parents[x]]
+
+# addDebug('buildGraph')
+def buildGraph(heads):
+ debug('buildGraph heads:', heads)
+ for h in heads:
+ assert(isSha(h))
+
+ g = Graph()
+
+ out = runProgram(['git-rev-list', '--parents'] + heads)
+ for l in out.split('\n'):
+ if l == '':
+ continue
+ shas = l.split(' ')
+
+ # This is a hack, we temporarily use the 'parents' attribute
+ # to contain a list of SHA1:s. They are later replaced by proper
+ # Commit objects.
+ c = Commit(shas[0], shas[1:])
+
+ g.commits.append(c)
+ g.shaMap[c.sha] = c
+
+ for c in g.commits:
+ g.fixParents(c)
+
+ for c in g.commits:
+ for p in c.parents:
+ p.children.append(c)
+ return g
+
+# Write the empty tree to the object database and return its SHA1
+def writeEmptyTree():
+ tmpIndex = os.environ.get('GIT_DIR', '.git') + '/merge-tmp-index'
+ def delTmpIndex():
+ try:
+ os.unlink(tmpIndex)
+ except OSError:
+ pass
+ delTmpIndex()
+ newEnv = os.environ.copy()
+ newEnv['GIT_INDEX_FILE'] = tmpIndex
+ res = runProgram(['git-write-tree'], env=newEnv).rstrip()
+ delTmpIndex()
+ return res
+
+def addCommonRoot(graph):
+ roots = []
+ for c in graph.commits:
+ if len(c.parents) == 0:
+ roots.append(c)
+
+ superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree())
+ graph.addNode(superRoot)
+ for r in roots:
+ r.parents = [superRoot]
+ superRoot.children = roots
+ return superRoot
+
+def getCommonAncestors(graph, commit1, commit2):
+ '''Find the common ancestors for commit1 and commit2'''
+ assert(isinstance(commit1, Commit) and isinstance(commit2, Commit))
+
+ def traverse(start, set):
+ stack = [start]
+ while len(stack) > 0:
+ el = stack.pop()
+ set.add(el)
+ for p in el.parents:
+ if p not in set:
+ stack.append(p)
+ h1Set = Set()
+ h2Set = Set()
+ traverse(commit1, h1Set)
+ traverse(commit2, h2Set)
+ shared = h1Set.intersection(h2Set)
+
+ if len(shared) == 0:
+ shared = [addCommonRoot(graph)]
+
+ res = Set()
+
+ for s in shared:
+ if len([c for c in s.children if c in shared]) == 0:
+ res.add(s)
+ return list(res)