327 Count of Range Sum
https://leetcode.com/problems/count-of-range-sum/
class Solution {
// sum(i, j) = prefixsum(0, j)-prefixsum(0, i-1)
// lower <= sum(i, j) <= upper
// => lower <= prefixsum(0, j)-prefixsum(0, i-1) <= upper
// => prefixsum(0,j)-upper <= prefixsum(0, i-1) <= prefixsum(0, j)-lower
// 因此给定j和i,j>i,只需要知道之前见过的i对应的prefix有几个在
// [prefixsum(0,j)-upper, prefixsum(0,j)-lower]的范围内就可以找到所有满足条件并且大于1的i
// 然后再检查一下当前prefisum(0, j)是否在[lower, upper]内就可以cover掉所有情况
// 考虑到prefixsum可能很稀疏,可以采取收集排序去重的方式整理
public int countRangeSum(int[] nums, int lower, int upper) {
int ans = 0;
long[] prefixsum = new long[nums.length];
prefixsum[0] = nums[0];
for (int i = 1; i < nums.length; i++) {
prefixsum[i] = prefixsum[i-1] + nums[i];
}
Arrays.sort(prefixsum);
// dedup
int n = 0;
for (int i = 0; i < nums.length; i++) {
if(prefixsum[n] != prefixsum[i]) {
prefixsum[++n] = prefixsum[i];
}
}
long sum = 0;
int[] cnts = new int[(n+2)*4];
for (int i = 0; i < nums.length; i++) {
sum += nums[i];
if (upper >= sum && sum >= lower) {
ans++;
}
// only consider the val has overlap with the largest range
if (i > 0 && !(sum-lower < prefixsum[0] || sum-upper > prefixsum[n])) {
int left = search(prefixsum, sum-upper, 0, n, true);
int right = search(prefixsum, sum-lower, 0, n, false);
int cnt = query(cnts, left, right, 0, n, 1);
ans += cnt;
}
int pos = search(prefixsum, sum, 0, n, true);
increase(pos, cnts, 0, n, 1);
}
return ans;
}
private void increase(int pos, int[] cnts, int left, int right, int idx) {
if (pos == left && pos == right) {
cnts[idx]++;
return;
}
if (pos < left || pos > right) {
return;
}
int mid = (left + right) / 2;
if (pos <= mid) {
increase(pos, cnts, left, mid, idx*2);
} else {
increase(pos, cnts, mid+1, right, idx*2+1);
}
cnts[idx] = cnts[idx*2] + cnts[idx*2+1];
}
private int query(int[] cnts, int jobl, int jobr, int left, int right, int idx) {
if (left >= jobl && right <= jobr) {
return cnts[idx];
}
if (jobl > right || jobr < left) {
return 0;
}
int mid = (left + right) / 2;
int cnt = 0;
if (jobl <= mid) {
cnt += query(cnts, jobl, jobr, left, mid, idx*2);
}
if (jobr > mid) {
cnt += query(cnts, jobl, jobr, mid+1, right, idx*2+1);
}
return cnt;
}
private int search(long[] arr, long val, int start, int end, boolean l) {
int left = start;
int right = end;
while (left + 1 < right) {
int mid = (left + right) / 2;
if (val == arr[mid]) {
return mid;
}
if (val < arr[mid]) {
right = mid;
} else {
left = mid;
}
}
if (l) {
// get the smallest val larger
if (arr[left] >= val) {
return left;
}
// 主函数里已经保证了查找的值在范围内,所以如果left不满足条件,right一定满足
return right;
} else {
// get the largest val smaller
if (arr[right] <= val) {
return right;
}
return left;
}
}
}
Last updated