Kth smallest element
Problem Statement:
find the kth smallest element in the array.
Algorithm
1. We modify quick sort algorithm to get the kth smallest element.
2. once we get the pivot element using quick_partition() , we check if pivot is k then kth smallest si a[pivot]
3. if k is less than pivot then call quick_select on subarray a[0...pivot] since kth smallest is in this array.
4. if k is greater than pivot then call quick_select on subarray a[pivot...end] with k = k-pivot (measuring k w.r.t pivot ) since kth element is in right side of pivot.
Code :
// using quick select
#include <bits/stdc++.h>
using namespace std;
int quick_partition(int a[], int start, int end)
{
int pivot = end;
int i = start;
int j = i;
/* for every element a[j] if a[j] <= a[pivot]
then swap it with position a[i] and increse both i and j by 1
else increase j by 1*/
while (j < end)
{
if (a[j] <= a[pivot])
{
swap(a[i], a[j]);
i++;
}
j++;
}
/* finally i will be the first position where a[i]>a[pivot]
so we swap these and pivot will be ith position*/
swap(a[i], a[pivot]);
return i;
}
int quick_select_kth(int a[], int start, int end, int k)
{
int pivot;
while (start <= end)
{
pivot = quick_partition(a, start, end);
int pivot_distance = pivot - start + 1;
// cout << "start = " << start << " end = " << end << " pivot = " << pivot << "\n";
if (pivot_distance == k)
{
return a[pivot];
}
else if (pivot_distance > k)
{
end = pivot - 1;
}
else
{
k -= pivot_distance;
start = pivot + 1;
}
}
return INT_MAX;
}
int main()
{
int n, k;
cout << "Enter n: ";
cin >> n;
int a[n];
for (int i = 0; i < n; i++)
{
cin >> a[i];
}
cout << "Enter k: ";
cin >> k;
cout << "\n";
int element = quick_select_kth(a, 0, n - 1, k );
cout << k << "th smallest = " << element << "\n";
// for (int i = 0; i < n; i++)
// {
// cout << a[i] << " ";
// }
// printf("\n");
}
Complexity Analysis :
best case : \(O(n)\)
worst case : \(O(n^2)\)
Algorithm
1. In this algoithm we use median of medians as pivot in quick_partion. Everything else is same.
Code :
// kth smallest using median of medians as pivot
#include <bits/stdc++.h>
using namespace std;
int find_median_of_medians(int a[], int start, int end);
int quick_partition(int a[], int start, int end);
int quick_select_kth(int a[], int start, int end, int k);
int main()
{
int n, k;
cout << "Enter n: ";
cin >> n;
int a[n];
for (int i = 0; i < n; i++)
{
cin >> a[i];
}
cout << "Enter k: ";
cin >> k;
cout << "\n";
int element = quick_select_kth(a, 0, n - 1, k );
cout << k << "th smallest = " << element << "\n";
// for (int i = 0; i < n; i++)
// {
// cout << a[i] << " ";
// }
// printf("\n");
}
int find_median_of_medians(int a[], int start, int end)
{
int n = end - start + 1;
int median[(int)ceil((float)n / 5)];
int i = 0;
int *ptr = a + start;
for (; i < n / 5; i++)
{
sort(ptr, ptr + 5);
median[i] = *(ptr + 3);
ptr += 5;
}
if (i * 5 < n)
{
sort(ptr, ptr + n % 5);
median[i] = *(ptr + (n % 5) / 2);
i++;
}
int median_of_medians;
if (i == 1)
{
median_of_medians = median[i - 1];
}
else
{
median_of_medians = quick_select_kth(median, 0, i - 1, i / 2);
}
return median_of_medians;
}
int quick_partition(int a[], int start, int end)
{
// move median to end
int median_of_medians = find_median_of_medians(a, start, end);
int x = start;
for (; x <= end; x++)
{
if (a[x] == median_of_medians)
{
break;
}
}
swap(a[x], a[end]);
//////////////////////////////////////////////////
int pivot = end;
int i = start;
int j = i;
/* for every element a[j] if a[j] <= a[pivot]
then swap it with position a[i] and increse both i and j by 1
else increase j by 1*/
while (j < end)
{
if (a[j] <= a[pivot])
{
swap(a[i], a[j]);
i++;
}
j++;
}
/* finally i will be the first position where a[i]>a[pivot]
so we swap these and pivot will be ith position*/
swap(a[i], a[pivot]);
return i;
}
int quick_select_kth(int a[], int start, int end, int k)
{
int pivot;
while (start <= end)
{
pivot = quick_partition(a, start, end);
int pivot_distance = pivot - start + 1;
// cout << "start = " << start << " end = " << end << " pivot = " << pivot << "\n";
if (pivot_distance == k)
{
return a[pivot];
}
else if (pivot_distance > k)
{
end = pivot - 1;
}
else
{
k -= pivot_distance;
start = pivot + 1;
}
}
return INT_MAX;
}
Complexity Analysis :
Time complexity : \(O(n)\)