😯327 Count of Range Sum

https://leetcode.com/problems/count-of-range-sum/

class Solution {
    // sum(i, j) = prefixsum(0, j)-prefixsum(0, i-1)
    // lower <= sum(i, j) <= upper 
    // => lower <= prefixsum(0, j)-prefixsum(0, i-1) <= upper
    // => prefixsum(0,j)-upper <= prefixsum(0, i-1) <= prefixsum(0, j)-lower
    // 因此给定j和i,j>i,只需要知道之前见过的i对应的prefix有几个在
    // [prefixsum(0,j)-upper, prefixsum(0,j)-lower]的范围内就可以找到所有满足条件并且大于1的i
    // 然后再检查一下当前prefisum(0, j)是否在[lower, upper]内就可以cover掉所有情况
    // 考虑到prefixsum可能很稀疏,可以采取收集排序去重的方式整理
    public int countRangeSum(int[] nums, int lower, int upper) {
        int ans = 0;
        long[] prefixsum = new long[nums.length];
        prefixsum[0] = nums[0];
        for (int i = 1; i < nums.length; i++) {
            prefixsum[i] = prefixsum[i-1] + nums[i];
        }
        Arrays.sort(prefixsum);
        // dedup
        int n = 0;
        for (int i = 0; i < nums.length; i++) {
            if(prefixsum[n] != prefixsum[i]) {
                prefixsum[++n] = prefixsum[i];
            }
        }
        long sum = 0;
        
        int[] cnts = new int[(n+2)*4];
        for (int i = 0; i < nums.length; i++) {
            sum += nums[i];
            if (upper >= sum && sum >= lower) {
                ans++;
            }
            // only consider the val has overlap with the largest range
            if (i > 0 && !(sum-lower < prefixsum[0] || sum-upper > prefixsum[n])) {
                int left = search(prefixsum, sum-upper, 0, n, true);
                int right = search(prefixsum, sum-lower, 0, n, false);
                int cnt = query(cnts, left, right, 0, n, 1);
                ans += cnt;
            }
            int pos = search(prefixsum, sum, 0, n, true);
            increase(pos, cnts, 0, n, 1);
        }
        
        return ans;
    }
    
    private void increase(int pos, int[] cnts, int left, int right, int idx) {
        if (pos == left && pos == right) {
            cnts[idx]++;
            return;
        }
        if (pos < left || pos > right) {
            return;
        }
        int mid = (left + right) / 2;
        if (pos <= mid) {
            increase(pos, cnts, left, mid, idx*2);
        } else {
            increase(pos, cnts, mid+1, right, idx*2+1);
        }
        cnts[idx] = cnts[idx*2] + cnts[idx*2+1];
    }
    
    private int query(int[] cnts, int jobl, int jobr, int left, int right, int idx) {
        if (left >= jobl && right <= jobr) {
            return cnts[idx];
        }
        if (jobl > right || jobr < left) {
            return 0;
        }
        int mid = (left + right) / 2;
        int cnt = 0;
        if (jobl <= mid) {
            cnt += query(cnts, jobl, jobr, left, mid, idx*2);
        }
        if (jobr > mid) {
            cnt += query(cnts, jobl, jobr, mid+1, right, idx*2+1);
        }
        return cnt;
    }
    
    private int search(long[] arr, long val, int start, int end, boolean l) {
        int left = start;
        int right = end;
        while (left + 1 < right) {
            int mid = (left + right) / 2;
            if (val == arr[mid]) {
                return mid;
            }
            if (val < arr[mid]) {
                right = mid;
            } else {
                left = mid;
            }
        }
        if (l) {
            // get the smallest val larger
            if (arr[left] >= val) {
                return left;
            } 
            // 主函数里已经保证了查找的值在范围内,所以如果left不满足条件,right一定满足
            return right;
        } else {
            // get the largest val smaller
            if (arr[right] <= val) {
                return right;
            } 
            return left;
        }
    }
}

Last updated