/********************************************************************
created: 2014/04/29 11:35
filename: nth_element.cpp
author: Justme0 (http://blog.****.net/justme0) purpose: nth_element
*********************************************************************/ #include <cstdio>
#include <cstdlib>
#include <cstring> typedef int Type; template <class T>
inline T * copy_backward(const T *first, const T *last, T *result) {
const ptrdiff_t num = last - first;
memmove(result - num, first, sizeof(T) * num);
return result - num;
} /*
** 将 value 插到 last 前面(不包括 last)的区间
** 此函数保证不会越界(主调函数已判断),因此以 unguarded_ 开头
*/
template <class RandomAccessIterator, class T>
void unguarded_linear_insert(RandomAccessIterator last, T value) {
RandomAccessIterator next = last;
--next;
while(value < *next) {
*last = *next;
last = next;
--next;
}
*last = value;
} /*
** 将 last 处的元素插到[first, last)的有序区间
*/
template <class RandomAccessIterator>
void linear_insert(RandomAccessIterator first, RandomAccessIterator last) {
Type value = *last;
if (value < *first) { // 若尾比头小,就将整个区间一次性向后移动一个位置
copy_backward(first, last, last + );
*first = value;
} else {
unguarded_linear_insert(last, value);
}
} template <class RandomAccessIterator>
void insertion_sort(RandomAccessIterator first, RandomAccessIterator last) {
if (first == last) {
return ;
} for (RandomAccessIterator ite = first + ; ite != last; ++ite) {
linear_insert(first, ite);
}
} template <class T>
inline const T & median(const T &a, const T &b, const T&c) {
if (a < b) {
if (b < c) {
return b;
} else if (a < c) {
return c;
} else {
return a;
}
} else if (a < c) {
return a;
} else if (b < c) {
return c;
} else {
return b;
}
} template <class ForwardIterator1, class ForwardIterator2>
inline void iter_swap(ForwardIterator1 a, ForwardIterator2 b) {
Type tmp = *a; // 源码中的 T 由迭代器的 traits 得来,这里简化了
*a = *b;
*b = tmp;
} /*
** 设返回值为 mid,则[first, mid)中迭代器指向的值小于等于 pivot;
** [mid, last)中迭代器指向的值大于等于 pivot
** 这是 STL 内置的算法,会用于 nth_element, sort 中
** 笔者很困惑为什么不用 partition
*/
template <class RandomAccessIterator, class T>
RandomAccessIterator unguarded_partition(RandomAccessIterator first, RandomAccessIterator last, T pivot) {
while(true) {
while (*first < pivot) {
++first;
}
--last;
while (pivot < *last) { // 若 std::partition 的 pred 是 IsLess(pivot),这里将是小于等于
--last;
}
if (!(first < last)) { // 小于操作只适用于 random access iterator
return first;
}
iter_swap(first, last);
++first;
}
} template <class RandomAccessIterator>
void nth_element(RandomAccessIterator first, RandomAccessIterator nth, RandomAccessIterator last) {
while (last - first > ) {
RandomAccessIterator cut = unguarded_partition(first, last, Type(median(
*first,
*(first + (last - first) / ),
*(last - ))));
if (cut <= nth) {
first = cut;
} else {
last = cut;
}
}
insertion_sort(first, last);
} int main(int argc, char **argv) {
int arr[] = {, , , , , , , , , , };
int size = sizeof arr / sizeof *arr; nth_element(arr, arr + , arr + size); for (int i = ; i < size; ++i) {
printf("%d ", arr[i]); // 20 12 22 17 17 22 23 30 30 33 40
}
printf("\n"); system("PAUSE");
return ;
}