// KBest - Data structure for selecting k best objects from a set of n
// David Eppstein, UC Irvine, 5 June 1999
//
// We maintain a value k, and a list of Objects, subject to the following two operations:
//
// put(Object,key): Add the item to the list, with the given keys
//
// Object get(): If the list is empty or k=0, return null. Otherwise, find one of the k
//    smallest keys in the list, remove it, decrement k, and return the object we found.
//
// The amortized expected time per operation is O(1), and the space used is O(k).

public class KBest {
	static private final int LOAD_FACTOR = 3;
		// Keep 3k objects in working set.
		// Higher load factors would make the calls from puts() to cut() more likely to succeed
		// and have a bigger expected number of items cut.  However this would not help much,
		// because the expected fraction of items cut per call converges to 1/2, and anyway
		// the unsuccessful cuts increase nSmall, either speeding future cuts or leading to
		// a future free cut depending on whether new puts are greater or less than smallPivot.

	KBest(int kk) {
		int aSize = (int) (LOAD_FACTOR * kk);
		k = kk;
		objs = new Object[aSize];
		keys = new int[aSize];
		nObjs = nSmall = nEq = smallPivot = 0;
	}
	
	public final void put(Object o, int key) {
		if (k <= 0 || o == null) return;		// stop putting if we've found all k items already
		while (nObjs == objs.length) cut();	// make room! make room!

		if (nSmall >= k) {						// already have enough small values?
			if (key >= smallPivot) return;	// yes, and this one is too big
			objs[nSmall - nEq] = o;
			keys[nSmall - nEq] = key;
			if (nEq > 0) nEq--;					// can make room in nSmall by dropping one from nEq?
			else {
				nObjs++;								// nSmall overflowed k, reset to zero
				nSmall = nEq = 0;
			}
			return;
		}
		
		if (nSmall == 0 || key > smallPivot) {
			objs[nObjs] = o;						// big key, add to end of object list
			keys[nObjs] = key;
		} else {
			objs[nObjs] = objs[nSmall];		// small key, add between small and eq
			keys[nObjs] = keys[nSmall];			// after shifting what was there out of the way
			objs[nSmall] = objs[nSmall - nEq];
			keys[nSmall] = keys[nSmall - nEq];
			objs[nSmall - nEq] = o;
			keys[nSmall - nEq] = key;
			nSmall++;
			if (key == smallPivot) nEq++;
			if (nSmall == k) nObjs = k;		// if nSmall grows to k, flush unneeded big objs
		}
		nObjs++;
	}
	
	public final Object get() {
		if (nObjs <= 0) return null;
		while (nSmall <= 0) cut();	// find someone small
		Object retval = objs[--nSmall];
		if (nEq > 0) nEq--;
		if (--k <= 0) {				// last one?
			nObjs = nSmall = 0;		// yes, disable further gets
			objs = null;				// and free up array for garbage collector
			keys = null;
		} else {
			objs[nSmall] = objs[--nObjs]; // move last obj to fill gap
			keys[nSmall] = keys[nObjs];
		}
		return retval;
	}

	private Object objs[];
	private int keys[];
	private int k;
	private int nObjs;	// how many entries of objs[] are valid
	private int nSmall;  // how many of them are known to be among current k best
	private int smallPivot; // key value >= all nSmall values and <= all other values
	private int nEq;		// how many nSmall values are equal to smallPivot
	static private java.util.Random rand = new java.util.Random();

	// do all the work: pivot around a random element, improving either nObjs or nSmall
	private void cut() {
		if (nObjs <= k) {		// so few items left that we can avoid a cut?
			nSmall = nObjs;
			nEq = 0;
			smallPivot = Integer.MAX_VALUE;	// yes, but force a cut later if we get another put
			return;
		}

		// choose a pivot randomly among the non-nSmall values
		// deal w/stupid choice of rounding direction in Java definition of division operators
		int nUnknowns = nObjs - nSmall;
		int modulus = rand.nextInt() % nUnknowns;
		if (modulus < 0) modulus += nUnknowns; // true moduli aren't negative
		int pivot = keys[modulus + nSmall];

		// partition array according to pivot and analyze cases according to partition size
		int firstEqual = split(nSmall,pivot);
		if (firstEqual >= k) nObjs = firstEqual; // easy case: >=k items <pivot, cut all >=pivot
		else {									// harder case: <k items <pivot, need to adjust nSmall
			int firstGreater = split(firstEqual,pivot+1);	// but first continue to partition
			if (firstGreater >= k) {
				nObjs = nSmall = k;			// if kth item is =pivot, both cut and adjust nSmall
				nEq = k - firstEqual;
			}
			else {
				nSmall = firstGreater;		// <k items <=pivot, only adjust nSmall
				nEq = firstGreater - firstEqual;
			}
			smallPivot = pivot;
		}
	}
	
	// do quicksort-like partition into sets less than and geq given pivot value
	// returned value is index of first value geq pivot
	private int split(int start, int pivot) {
		int end = nObjs;
		while (start < end) {
			while (start < end && keys[start] < pivot) start++;
			while (start < end && keys[end-1] >= pivot) end--;
			if (start < end) {
				Object tempObj = objs[start];
				objs[start] = objs[end-1];
				objs[end-1] = tempObj;
				int tempKey = keys[start];
				keys[start] = keys[end-1];
				keys[end-1] = tempKey;
				start++;
				end--;
			}
		}
		return start;
	}
}
