LeetCode_378: Kth Smallest Element in a Sorted Matrix

        经典的数据结构问题,在算法导论堆排序的思考题中出现过这种结构的排序求解问题,该问题解法多样,很适合学习。

题目

Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

Example:

matrix = [ [ 1, 5, 9], [10, 11, 13], [12, 13, 15] ], k = 8, return 13.

Note:
You may assume k is always valid, 1 ≤ k ≤ n2.

二叉树解法

代码中的Solution类,运行时间是39ms,解法里面的代码之所以比较快,还是得益于使用了STL中的upper_bound函数。

此外,算法本身也是二分查找的思路,使用upper_bound计算每行区间中间值的位置(利用了原先数列已经是排好序的,可以直接送入),来计算小于中间值min的元素个数,如果k小于此值,说明k在上半段,缩减搜索范围继续进行。

STL中的upper_bound和lower_bound

upper_bound与lower_bound通过二分查找法,找出一个非递减序列中第一个大于等于(lower)或者等于(upper)val的值。这部分二分查找的细节参见这篇博客

优先队列解法

代码中的 SolutionPQ 类,算法复杂度为 O(klg(n)),52ms

因为原先矩阵每一行列都有固定的大小关系,通过一个优先队列按列进行搜索,优先队列就是搜索边界,不断提取出边界上的最小值。来k此就得到了第k小的值。这也算是一个图搜索问题了,以最上面为小边,逐步往下拓展边界的搜索策略。

堆排序解法

SolutionHS,算法复杂度为O(n^2),博主最开始的解法,速度比较慢,400+ms,只是记录一下,不提供任何参考。

思路就是将这个矩阵构成一个特殊的堆,每回堆顶最小,提取出来就可以了,此外还设计了这个堆的heapify操作。

代码

#include <iostream>
#include <cstdio>
#include <vector>
#include <queue>
#include <functional>

using namespace std;

class Solution // BST 39ms
{
public:
	int kthSmallest(vector<vector<int>>& matrix, int k){
		int n = matrix.size();
		int le = matrix[0][0], ri = matrix[n - 1][n - 1];
		int mid = 0;
		while (le < ri) {
			mid = le + (ri - le) / 2;
			int num = 0;
			for (int i = 0; i < n; i++){ // 统计小于min的元素个数
				int pos = upper_bound(matrix[i].begin(), matrix[i].end(), mid) - matrix[i].begin();
				num += pos;
			}
			if (num < k){
				le = mid + 1; // left search boundry changed
			} else{
				ri = mid;
			}
		}
		return le;
	}
};

class SolutionPQ{ // 52ms O(klog(n))
public:

	int kthSmallest(vector<vector<int>>& arr, int k){
		int n = arr.size(), m = arr[0].size();
		priority_queue < pair<int, pair<int, int>>,
			vector<pair<int, pair<int, int>>>,
			greater < pair<int, pair<int, int>> >> pq;

		for (int i = 0; i < n; ++i){
			pq.push({ arr[i][0], {i, 0} });
		}

		int x = k, ans;
		while (x--){
			const auto t = pq.top(); pq.pop();
			ans = t.first;
			const int i = t.second.first;
			const int j = t.second.second;
			if (j != m - 1){
				pq.push({ arr[i][j + 1], {i, j + 1} });
			}
		}

		return ans;
	}
};

class SolutionHS { // 400+ms O(n^2)
public:
	int heapify(vector<vector<int>>& matrix, vector<vector<bool>>& mask){
		const int n = matrix.size();
		int i = 0, j = 0;
		while (true){
			int vp = matrix[i][j];
			
			bool havelc = i < n - 1 && mask[i + 1][j];
			bool haverc = j < n - 1 && mask[i][j + 1];
			if ((!havelc) && (!haverc)){	// no lc no rc
				return 0;
			} else if(havelc && (!haverc)){ // only have lc
				if (matrix[i+1][j] < matrix[i][j]){
					swap(matrix[i+1][j], matrix[i][j]);
					i = i + 1;
				} else{
					return 0;
				}
			} else if((!havelc) && haverc){ // only have rc
				if (matrix[i][j + 1] < matrix[i][j]){
					swap(matrix[i][j+1], matrix[i][j]);
					j = j + 1;
				} else{
					return 0;
				}
			} else{			// choose largest
				int minst = matrix[i][j];
				pair<int, int> minid{ i, j };
				if (matrix[i][j + 1] < matrix[i][j]){
					minst = matrix[i][j + 1];
					minid = { i, j + 1 };
				} 
				if (matrix[i + 1][j] < minst){
					minst = matrix[i + 1][j];
					minid = { i + 1, j };
				}
				if (minid != pair<int, int>{i, j}){
					swap(matrix[i][j], matrix[minid.first][minid.second]);
					i = minid.first;
					j = minid.second;
				} else{
					return 0;
				}
			}
		}
	}

	int kthSmallest(vector<vector<int>>& matrix, int k) {
		const int n = matrix.size();
		if (n == 0 || k <= 0 || k > n*n){
			return 0;
		}
		
		vector<vector<bool>> mask(n, vector<bool>(n, true)); // true in avalible
		int i = n-1, j = n-1, res = 0, ii = 1;
		while (ii < k){
			ii++;
			swap(matrix[i][j], matrix[0][0]);
			mask[i][j] = false;
			heapify(matrix, mask);
			if (j == 0){
				j = n - 1;
				i -= 1;
				continue;
			}
			j -= 1;
		}

		return matrix[0][0];
	}
};

int main(int argc, char *argv[]){
	Solution s;
	vector<vector<int>> v{ { 1, 5, 9 }, { 10, 11, 13 }, { 12, 13, 15 } };
	for (int k = 1; k < 10; ++k)
		cout << s.kthSmallest(v, k) << endl;
	system("pause");
	return 0;
}

 

发表评论