// KBest - Data structure for selecting k best objects from a set of n // David Eppstein, UC Irvine, 5 June 1999 // // We maintain a value k, and a list of Objects, subject to the following two operations: // // put(Object,key): Add the item to the list, with the given keys // // Object get(): If the list is empty or k=0, return null. Otherwise, find one of the k // smallest keys in the list, remove it, decrement k, and return the object we found. // // The amortized expected time per operation is O(1), and the space used is O(k). public class KBest { static private final int LOAD_FACTOR = 3; // Keep 3k objects in working set. // Higher load factors would make the calls from puts() to cut() more likely to succeed // and have a bigger expected number of items cut. However this would not help much, // because the expected fraction of items cut per call converges to 1/2, and anyway // the unsuccessful cuts increase nSmall, either speeding future cuts or leading to // a future free cut depending on whether new puts are greater or less than smallPivot. KBest(int kk) { int aSize = (int) (LOAD_FACTOR * kk); k = kk; objs = new Object[aSize]; keys = new int[aSize]; nObjs = nSmall = nEq = smallPivot = 0; } public final void put(Object o, int key) { if (k <= 0 || o == null) return; // stop putting if we've found all k items already while (nObjs == objs.length) cut(); // make room! make room! if (nSmall >= k) { // already have enough small values? if (key >= smallPivot) return; // yes, and this one is too big objs[nSmall - nEq] = o; keys[nSmall - nEq] = key; if (nEq > 0) nEq--; // can make room in nSmall by dropping one from nEq? else { nObjs++; // nSmall overflowed k, reset to zero nSmall = nEq = 0; } return; } if (nSmall == 0 || key > smallPivot) { objs[nObjs] = o; // big key, add to end of object list keys[nObjs] = key; } else { objs[nObjs] = objs[nSmall]; // small key, add between small and eq keys[nObjs] = keys[nSmall]; // after shifting what was there out of the way objs[nSmall] = objs[nSmall - nEq]; keys[nSmall] = keys[nSmall - nEq]; objs[nSmall - nEq] = o; keys[nSmall - nEq] = key; nSmall++; if (key == smallPivot) nEq++; if (nSmall == k) nObjs = k; // if nSmall grows to k, flush unneeded big objs } nObjs++; } public final Object get() { if (nObjs <= 0) return null; while (nSmall <= 0) cut(); // find someone small Object retval = objs[--nSmall]; if (nEq > 0) nEq--; if (--k <= 0) { // last one? nObjs = nSmall = 0; // yes, disable further gets objs = null; // and free up array for garbage collector keys = null; } else { objs[nSmall] = objs[--nObjs]; // move last obj to fill gap keys[nSmall] = keys[nObjs]; } return retval; } private Object objs[]; private int keys[]; private int k; private int nObjs; // how many entries of objs[] are valid private int nSmall; // how many of them are known to be among current k best private int smallPivot; // key value >= all nSmall values and <= all other values private int nEq; // how many nSmall values are equal to smallPivot static private java.util.Random rand = new java.util.Random(); // do all the work: pivot around a random element, improving either nObjs or nSmall private void cut() { if (nObjs <= k) { // so few items left that we can avoid a cut? nSmall = nObjs; nEq = 0; smallPivot = Integer.MAX_VALUE; // yes, but force a cut later if we get another put return; } // choose a pivot randomly among the non-nSmall values // deal w/stupid choice of rounding direction in Java definition of division operators int nUnknowns = nObjs - nSmall; int modulus = rand.nextInt() % nUnknowns; if (modulus < 0) modulus += nUnknowns; // true moduli aren't negative int pivot = keys[modulus + nSmall]; // partition array according to pivot and analyze cases according to partition size int firstEqual = split(nSmall,pivot); if (firstEqual >= k) nObjs = firstEqual; // easy case: >=k items =pivot else { // harder case: = k) { nObjs = nSmall = k; // if kth item is =pivot, both cut and adjust nSmall nEq = k - firstEqual; } else { nSmall = firstGreater; // = pivot) end--; if (start < end) { Object tempObj = objs[start]; objs[start] = objs[end-1]; objs[end-1] = tempObj; int tempKey = keys[start]; keys[start] = keys[end-1]; keys[end-1] = tempKey; start++; end--; } } return start; } }