這題真的蠻難的 , 要去leetcode討論區看解答, 有說明很清楚
我們先探討提議, 再來逐行慢慢描述
首先,會給你兩個以排序的陣列, 要你找到這兩個陣列的中位數
因為題目要求log(m+n)的關係, 你不能一個一個慢慢找, 看到要求log我們自動就猜是二分搜尋了
根據discuss
我們可以先定義, 何謂中位數
我們可以將一個陣列分成兩個等長的陣列, 其中一個陣列的元素必定大於另一個陣列的
我們有一個陣列A, 長度為m , 可以分成left, 與right
left_A | right_A
A[0], A[1], ..., A[i-1] | A[i], A[i+1], ..., A[m-1]
從上面可以知道m 個元素存在m+1種切法, 怎麼切就是我們最大課題, 我們的目標就是找到i
好現在我們有另一個陣列B, 依樣畫葫蘆
left_B | right_B
B[0], B[1], ..., B[j-1] | B[j], B[j+1], ..., B[n-1]
一樣我們要找j, 現在兩個陣列一起看
left_part | right_part
A[0], A[1], ..., A[i-1] | A[i], A[i+1], ..., A[m-1]
B[0], B[1], ..., B[j-1] | B[j], B[j+1], ..., B[n-1]
我們知道,
len(left_part) == len(right_part)
min(left_part) <= max(right_part)
另外這邊我們統一讓m>n, 以減少我們要面對的許多條件,(兩邊看哪邊大即可對調)
從上面, 我們可以有兩個條件
i + j == ((m - i) + (n - j)) / 2 (len(left_part) == len(right_part))
B[j - 1] <= A[i], A[i-1] <= B[j]
有了上述第一式, 我們就可以推算i j的關係,
第二式, 就是我們二分搜尋的條件
另外上述第一次會改成
i + j == ((m - i) + (n - j) + 1) / 2
為什麼加1
當有一組case為
[1, 3] [2] => [1, 2, 3] => [1, 2 | 2, 3]
奇數總數會有一個問題, 他的右半邊會少一個數, 所以我們先在這邊加一, 好讓兩邊的總數先一致
我們後面會在處理這邊做的事情
偶數怎麼辦呢, 統一加一除2其實對偶數沒有影響, 他的i + j 值依然一樣
接著我們開始切A, 切A的動作會連帶影響到B, 我們專心切A就好
我們會從A的中間開始切, 最小就是Amin:0, 最大就是Amax:m
當i 決定後, j也會知道, 我們就要看第二式來做決定
B[j - 1] <= A[i], A[i-1] <= B[j]
B[j - 1] <= A[i]:
Amax--, 因為A[i]太大了, 相對切太右邊, 往左切一點, 且 i < m 且 j > 0 j >0的原因是因為 若j ==0,沒跳過這判別式會出錯(B[j-1]), i亦同(A[i])A[i-1] <= B[j]:
Amin++ 因為A[i-1]太小了, 相對切太左邊, 往右切一點, 且 i >0且 j < n
最後, 當有極端的狀況出現,(i =0, i = m, j =0, j = m),
我們要另外幫忙解決掉,
為什麼, 因為我們在找的時候, 都是用B[j-1], A[i-1]A[i]B[i]等等的式子, 若是遇到上面的四個特殊狀況, 程式會out of memory
所以 當
i == 0, left_max = B[j -1]
j ==0, left_max = A[i-1]
i = m, right_min = B[j]
j = n, right_min = A[i]
最後如果是奇數, 找完left_max即可回傳, 因為我們剛剛有有加一個數, 這邊如果不即時回傳, 後面right_min的值會是錯的
code:
class Solution {
public:
double findMedianSortedArrays(vector<int> &nums1, vector<int> &nums2) {
int m = nums1.size();
int n = nums2.size();
int iMin = 0;
int iMax = m;
int left_max;
int right_min;
if (n > m) {
return findMedianSortedArrays(nums2, nums1);
}
if (n == 0) {
return (nums1[(m - 1) / 2] + nums1[m / 2]) / 2.0;
}
while (iMin <= iMax) {
int i = (iMin + iMax) / 2;
int j = (m + n + 1) / 2 - i;
if (j > 0 && i < m && nums1[i] < nums2[j - 1]) {
iMin++;
} else if (i > 0 && j < n && nums1[i - 1] > nums2[j]) {
iMax--;
} else {
if (i == 0) {
left_max = nums2[j - 1];
} else if (j == 0) {
left_max = nums1[i - 1];
} else {
left_max = max(nums1[i - 1], nums2[j - 1]);
}
if ((m + n) % 2 == 1) {
return left_max;
}
if (i == m) {
right_min = nums2[j];
} else if (j == n) {
right_min = nums1[i];
} else {
right_min = min(nums1[i], nums2[j]);
}
return (left_max + right_min) / 2.0;
}
}return -1;
}
};