"""IntegerPartitions.py Generate and manipulate partitions of integers into sums of integers. D. Eppstein, August 2005. """ import unittest def mckay(n): """ Integer partitions of n, in reverse lexicographic order. Note that the generated output consists of the same list object, repeated the correct number of times; the caller must leave this list unchanged, and must make a copy of any partition that is intended to last longer than the next call into the generator. The algorithm follows Knuth v4 fasc3 p38 in rough outline. """ if n == 0: yield [] if n <= 0: return partition = [n] last_nonunit = (n > 1) - 1 while True: yield partition if last_nonunit < 0: return if partition[last_nonunit] == 2: partition[last_nonunit] = 1 partition.append(1) last_nonunit -= 1 continue replacement = partition[last_nonunit] - 1 total_replaced = replacement + len(partition) - last_nonunit reps,rest = divmod(total_replaced,replacement) partition[last_nonunit:] = reps*[replacement] if rest: partition.append(rest) last_nonunit = len(partition) - (partition[-1]==1) - 1 def revlex_partitions(n): """ Integer partitions of n, in reverse lexicographic order. The output and asymptotic runtime are the same as mckay(n), but the algorithm is different: it involves no division, and is simpler than mckay, but uses O(n) extra space for a recursive call stack. """ if n == 0: yield [] if n <= 0: return for p in revlex_partitions(n-1): if len(p) == 1 or (len(p) > 1 and p[-1] < p[-2]): p[-1] += 1 yield p p[-1] -= 1 p.append(1) yield p p.pop() def lex_partitions(n): """Similar to revlex_partitions, but in lexicographic order.""" if n == 0: yield [] if n <= 0: return for p in lex_partitions(n-1): p.append(1) yield p p.pop() if len(p) == 1 or (len(p) > 1 and p[-1] < p[-2]): p[-1] += 1 yield p p[-1] -= 1 partitions = revlex_partitions # default partition generating algorithm def binary_partitions(n): """ Generate partitions of n into powers of two, in revlex order. Knuth exercise 7.2.1.4.64. The average time per output is constant. But this doesn't really solve the exercise, because it isn't loopless... """ # Generate the binary representation of n if n < 0: return pow = 1 sum = 0 while pow <= n: pow <<= 1 partition = [] while pow: if sum+pow <= n: partition.append(pow) sum += pow pow >>= 1 # Find all partitions of numbers up to n into powers of two > 1, # in revlex order, by repeatedly splitting the smallest nonunit power, # and replacing the following sequence of 1's by the first revlex # partition with maximum power less than the result of the split. # Time analysis: # # Each outer iteration increases len(partition) by at most one # (only if the power being split is a 2) and each inner iteration # in which some ones are replaced by x decreases len(partition), # so the number of those inner iterations is less than one per # output. # # Each time a power 2^k is split, it creates two or more 2^{k-1}'s, # all of which must eventually be split as well. So, it S_k denotes # the number of times a 2^k is split, and X denotes the total # number of outputs generated, then S_k <= X/2^{k-1}. # On an outer iteration in which 2^k is split, there will be k # inner iterations in which x is halved, so the total number # of such inner iterations is <= sum_k k*X/2^{k-1} = O(X). # # Therefore the overall average time per output is constant. last_nonunit = len(partition) - 1 - (n&1) while True: yield partition if last_nonunit < 0: return if partition[last_nonunit] == 2: partition[last_nonunit] = 1 partition.append(1) last_nonunit -= 1 continue partition.append(1) x = partition[last_nonunit] = partition[last_nonunit+1] = \ partition[last_nonunit] >> 1 # make the split! last_nonunit += 1 while x > 1: if len(partition) - last_nonunit - 1 >= x: del partition[-x+1:] last_nonunit += 1 partition[last_nonunit] = x else: x >>= 1 def fixed_length_partitions(n,L): """ Integer partitions of n into L parts, in colex order. The algorithm follows Knuth v4 fasc3 p38 in rough outline; Knuth credits it to Hindenburg, 1779. """ # guard against special cases if L == 0: if n == 0: yield [] return if L == 1: if n > 0: yield [n] return if n < L: return partition = [n - L + 1] + (L-1)*[1] while True: yield partition if partition[0] - 1 > partition[1]: partition[0] -= 1 partition[1] += 1 continue j = 2 s = partition[0] + partition[1] - 1 while j < L and partition[j] >= partition[0] - 1: s += partition[j] j += 1 if j >= L: return partition[j] = x = partition[j] + 1 j -= 1 while j > 0: partition[j] = x s -= x j -= 1 partition[0] = s def conjugate(p): """ Find the conjugate of a partition. E.g. len(p) = max(conjugate(p)) and vice versa. """ result = [] j = len(p) if j <= 0: return result while True: result.append(j) while len(result) >= p[j-1]: j -= 1 if j == 0: return result # If run standalone, perform unit tests class PartitionTest(unittest.TestCase): counts = [1,1,2,3,5,7,11,15,22,30,42,56,77,101,135] def testCounts(self): """Check that each generator has the right number of outputs.""" for n in range(len(self.counts)): self.assertEqual(self.counts[n],len(list(mckay(n)))) self.assertEqual(self.counts[n],len(list(lex_partitions(n)))) self.assertEqual(self.counts[n],len(list(revlex_partitions(n)))) def testSums(self): """Check that all outputs are partitions of the input.""" for n in range(len(self.counts)): for p in mckay(n): self.assertEqual(n,sum(p)) for p in revlex_partitions(n): self.assertEqual(n,sum(p)) for p in lex_partitions(n): self.assertEqual(n,sum(p)) def testRevLex(self): """Check that the revlex generators' outputs are in revlex order.""" for n in range(len(self.counts)): last = [n+1] for p in mckay(n): self.assert_(last > p) last = list(p) # make less-mutable copy last = [n+1] for p in revlex_partitions(n): self.assert_(last > p) last = list(p) # make less-mutable copy def testLex(self): """Check that the lex generator's outputs are in lex order.""" for n in range(1,len(self.counts)): last = [] for p in lex_partitions(n): if not (last < p): print "last:",last,"p:",p self.assert_(last < p) last = list(p) # make less-mutable copy def testRange(self): """Check that all numbers in output partitions are in range.""" for n in range(len(self.counts)): for p in mckay(n): for x in p: self.assert_(0 < x <= n) for p in lex_partitions(n): for x in p: self.assert_(0 < x <= n) for p in revlex_partitions(n): for x in p: self.assert_(0 < x <= n) def testFixedLength(self): """Check that the fixed length partition outputs are correct.""" for n in range(len(self.counts)): pn = [list(p) for p in revlex_partitions(n)] pn.sort() np = 0 for L in range(n+1): pnL = [list(p) for p in fixed_length_partitions(n,L)] pnL.sort() np += len(pnL) self.assertEqual(pnL,[p for p in pn if len(p) == L]) self.assertEqual(np,len(pn)) def testConjugatePartition(self): """Check that conjugating a partition forms another partition.""" for n in range(len(self.counts)): for p in partitions(n): c = conjugate(p) for x in c: self.assert_(0 < x <= n) self.assertEqual(sum(c),n) def testConjugateInvolution(self): """Check that double conjugation returns the same partition.""" for n in range(len(self.counts)): for p in partitions(n): self.assertEqual(p,conjugate(conjugate(p))) def testConjugateMaxLen(self): """Check the max-length reversing property of conjugation.""" for n in range(1,len(self.counts)): for p in partitions(n): self.assertEqual(len(p),max(conjugate(p))) def testBinary(self): """Test that the binary partitions are generated correctly.""" for n in range(len(self.counts)): binaries = [] for p in partitions(n): for x in p: if x & (x - 1): break else: binaries.append(list(p)) self.assertEqual(binaries,[list(p) for p in binary_partitions(n)]) if __name__ == "__main__": unittest.main()