如果k/2处的值相等,则可以排除任意一边的前k/2元素,因为就算有可能在排除掉的k/2个里,那也只能是第k/2个,否则数量加起来不够,但是另一边没有排除掉的里仍然有这个元素,因此不影响。
class Solution {
public double findMedianSortedArrays(int[] nums1, int[] nums2) {
int len = nums1.length + nums2.length;
if (len % 2 == 0) {
return ((double)find(nums1, 0, nums2, 0, len / 2) +
find(nums1, 0, nums2, 0, len / 2 + 1)) / 2;
}
return find(nums1, 0, nums2, 0, len / 2 + 1);
}
public int find(int[] nums1, int start1, int[] nums2, int start2, int k){
// 注意这三个base case的条件,要一个array完全不剩了为止,就算只有一个也可能在这一个里
// 因此不能用start1+k-1 >= nums1.length来判断
if (start1 >= nums1.length) {
return nums2[start2 + k - 1];
}
if (start2 >= nums2.length) {
return nums1[start1 + k - 1];
}
//注意k为1的base case
if (k == 1) {
return Math.min(nums1[start1], nums2[start2]);
}
int midIndex1 = start1 + k / 2 - 1;
int midIndex2 = start2 + k / 2 - 1;
// 当array1不够k/2时设成无穷大
// 淘汰midVal小的 不是因为kth不可能在这个array里,只是kth不可能在它的前k/2个里
int mid1 = midIndex1 >= nums1.length ? Integer.MAX_VALUE : nums1[midIndex1];
int mid2 = midIndex2 >= nums2.length ? Integer.MAX_VALUE : nums2[midIndex2];
if (mid1 < mid2) {
return find(nums1, midIndex1 + 1, nums2, start2, k - k / 2);
} else {
return find(nums1, start1, nums2, midIndex2 + 1, k - k / 2);
}
}
}
class Solution {
public double findMedianSortedArrays(int[] nums1, int[] nums2) {
int len = nums1.length + nums2.length;
if (len % 2 == 0) {
return ((double)find(len / 2, nums1, 0, nums2, 0) + find(len / 2 + 1, nums1, 0, nums2, 0)) / 2;
}
return find(len / 2 + 1, nums1, 0, nums2, 0);
}
private double find(int k, int[] nums1, int start1, int[] nums2, int start2) {
// 二分法,每次排除k/2,主要就是判断从start1到start1 + k/2和start2到start2_k/2
// 第kth不可能在哪一段里
if(start1 >= nums1.length) {
return nums2[start2 + k - 1];
}
if (start2 >= nums2.length) {
return nums1[start1 + k - 1];
}
if (k == 1) {
// 注意是第k小的,取小的那个
return Math.min(nums1[start1], nums2[start2]);
}
int midIdx1 = start1 + k / 2 - 1;
int midIdx2 = start2 + k / 2 - 1;
// 之所以没有k/2个就设成最大值,不是因为第kth只能在这一段里,而是第kth不能在另一段的前k/2里
// 因为如果在的话加上这不到k/2个怎么也达不到k个的总数
int midVal1 = midIdx1 >= nums1.length ? Integer.MAX_VALUE : nums1[midIdx1];
int midVal2 = midIdx2 >= nums2.length ? Integer.MAX_VALUE : nums2[midIdx2];
// 第kth不可能在midVal比较小的那个array的前k/2里,不然加上另外那k/2也不够k个的
if (midVal1 > midVal2) {
return find(k - k/2, nums1, start1, nums2, midIdx2 + 1);
}
else {
return find(k - k/2, nums1, midIdx1 + 1, nums2, start2);
}
}
}