"""Permutations.py Efficient recursive version of the Steinhaus-Johnson-Trotter algorithm for listing all permutations of a set of items. D. Eppstein, October 2011. NOTE: The generators in this module work by making a sequence of small changes to a permutation stored in a Python list object, and then yielding that list. That means that the objects that they generate (the list objects) must not be modified by their callers, because that would break subsequent generation steps. And it also means that each permutation yielded by the generator only has its correct value up to the time the generator is called again, after which this value will change. If you want to keep a persistent copy of a permutation, or change it, you need to copy it into a separate object. E.g. RIGHT: [list(p) for p in SteinhausJohnsonTrotter(n)] # makes a list of all the permutations of order n WRONG: list(SteinhausJohnsonTrotter(n)) [p for p in SteinhausJohnsonTrotter(n)] # either way, these make a list of length n! all of whose # elements point to the same list object as each other The Steinhaus-Johnson-Trotter implementation given here sets up a sequence of recursive simple generators, each taking constant space, for a total space of O(n), where n is the number of items being permuted. The number of recursive calls to generate a swap that moves the item originally in position n of the input permutation is O(n-i+1), so all but a 1/n fraction of the swaps take no recursion and the rest always take O(n) time, for an average time per swap of O(1) and an average time per generated permutation of O(1). The other generators are similar. """ import unittest # 2to3 compatibility try: xrange except: xrange = range def PlainChanges(n): """Generate the swaps for the Steinhaus-Johnson-Trotter algorithm.""" if n < 1: return up = xrange(n-1) down = xrange(n-2,-1,-1) recur = PlainChanges(n-1) try: while True: for x in down: yield x yield next(recur) + 1 for x in up: yield x yield next(recur) except StopIteration: pass def SteinhausJohnsonTrotter(x): """Generate all permutations of x. If x is a number rather than an iterable, we generate the permutations of range(x).""" # set up the permutation and its length try: perm = list(x) except: perm = list(range(x)) n = len(perm) # run through the sequence of swaps yield perm for x in PlainChanges(n): perm[x],perm[x+1] = perm[x+1],perm[x] yield perm def DoublePlainChanges(n): """Generate the swaps for double permutations.""" if n < 1: return up = xrange(1,2*n-1) down = xrange(2*n-2,0,-1) recur = DoublePlainChanges(n-1) try: while True: for x in up: yield x yield next(recur) + 1 for x in down: yield x yield next(recur) + 2 except StopIteration: pass def DoubleSteinhausJohnsonTrotter(n): """Generate all double permutations of the range 0 through n-1""" perm = [] for i in range(n): perm += [i,i] # run through the sequence of swaps yield perm for x in DoublePlainChanges(n): perm[x],perm[x+1] = perm[x+1],perm[x] yield perm def StirlingChanges(n): """Variant Steinhaus-Johnson-Trotter for Stirling permutations. A Stirling permutation is a double permutation in which each pair of values has only larger values between them. The algorithm is to sweep the largest pair of values through the sequence of smaller values, recursing when it reaches the ends of the sequence, exactly as in the standard Steinhaus-Johnson-Trotter algorithm. However, it differs in swapping items two positions apart instead of adjacent items.""" if n <= 1: return up = xrange(2*n-2) down = xrange(2*n-3,-1,-1) recur = StirlingChanges(n-1) try: while True: for x in down: yield x yield next(recur) + 2 for x in up: yield x yield next(recur) except StopIteration: pass def StirlingPermutations(n): """Generate all Stirling permutations of order n.""" perm = [] for i in range(n): perm += [i,i] # run through the sequence of swaps yield perm for x in StirlingChanges(n): perm[x],perm[x+2] = perm[x+2],perm[x] yield perm def InvolutionChanges(n): """Generate change sequence for involutions on n items. Uses a variation of the Steinhaus-Johnson-Trotter idea, in which we first recurse for n-1, generating involutions in which the last item is fixed, and then we the match for the last item back and forth over a recursively generated sequence for n-2.""" if n <= 3: for c in [[],[],[0],[0,1,0]][n]: yield c return for c in InvolutionChanges(n-1): yield c yield n-2 for i in range(n-4,-1,-1): yield i ic = InvolutionChanges(n-2) up = range(0,n-2) down = range(n-3,-1,-1) try: while True: yield next(ic) + 1 for i in up: yield i yield next(ic) for i in down: yield i except StopIteration: yield n-4 def Involutions(n): """Generate involutions on n items. The first involution is always the one in which all items are mapped to themselves, and the last involution is the one in which only the final two items are swapped. Each two involutions differ by a change that either adds or removes an adjacent pair of swapped items, moves a swap target by one, or swaps two adjacent swap targets.""" p = list(range(n)) yield p for c in InvolutionChanges(n): x,y = p[c],p[c+1] # current partners of c and c+1 if x == c and y != c+1: x = c+1 if x != c and y == c+1: y = c p[x],p[y],p[c],p[c+1] = c+1, c, y, x # swap partners yield p # If run standalone, perform unit tests class PermutationTest(unittest.TestCase): def testChanges(self): """Do we get the expected sequence of changes for n=3?""" self.assertEqual(list(PlainChanges(3)),[1,0,1,0,1]) def testLengths(self): """Are the lengths of the generated sequences factorial?""" f = 1 for i in range(2,7): f *= i self.assertEqual(f,len(list(SteinhausJohnsonTrotter(i)))) def testDistinct(self): """Are all permutations in the sequence different from each other?""" for i in range(2,7): s = set() n = 0 for x in SteinhausJohnsonTrotter(i): s.add(tuple(x)) n += 1 self.assertEqual(len(s),n) def testAdjacent(self): """Do consecutive permutations in the sequence differ by a swap?""" for i in range(2,7): last = None for p in SteinhausJohnsonTrotter(i): if last: diffs = [j for j in range(i) if p[j] != last[j]] self.assertEqual(len(diffs),2) self.assertEqual(p[diffs[0]],last[diffs[1]]) self.assertEqual(p[diffs[1]],last[diffs[0]]) last = list(p) def testListInput(self): """If given a list as input, is it the first output?""" for L in ([1,3,5,7], list('zyx'), [], [[]], list(range(20))): self.assertEqual(L,next(SteinhausJohnsonTrotter(L))) def testInvolutions(self): """Are these involutions and do we have the right number of them?""" telephone = [1,1,2,4,10,26,76,232,764] for n in range(len(telephone)): count = 0 sorted = list(range(n)) invs = set() for p in Involutions(n): self.assertEqual([p[i] for i in p],sorted) invs.add(tuple(p)) count += 1 self.assertEqual(len(invs),count) self.assertEqual(len(invs),telephone[n]) if __name__ == "__main__": unittest.main()