πŸ“ Problem Details

input array nums the range of a subarray of nums is the difference between the largest and smallest element in the subarray return the sum of all subarray ranges

πŸ’­What Were My Initial Thoughts?

- brute force would be to generate all subarrays and compute the min/max for each
- infeasible for the problem constraints

πŸ’‘ Explanation of Solution

- Instead of looking at all subarrays, we flip the perspective:
  For each element in the array, calculate how many subarrays it is the maximum in, and how many it is the minimum in.

- The contribution of an element to the final sum is:
    (Number of subarrays where it is the max) * value
  minus
    (Number of subarrays where it is the min) * value

- To compute this efficiently, we use monotonic stacks to find:
    - Previous and next smaller elements for min contribution
    - Previous and next greater elements for max contribution

- The number of subarrays in which `nums[i]` is the max is:
    (i - prev_greater_index) * (next_greater_index - i)

- Similarly, for min:
    (i - prev_smaller_index) * (next_smaller_index - i)

- The final result is the sum of each element's max contribution minus min contribution.
 

βŒ› Complexity Analysis

Time Complexity: O(n)
- We go through the array a few times using monotonic stacks to find previous and next greater/smaller elements.

Space Complexity: O(n)
- We use stacks and arrays to store previous/next boundaries.

πŸ’» Implementation of Solution

class Solution {
public:
    long long subArrayRanges(vector<int>& nums) {
        int n = nums.size();
        long long result = 0;
 
        // Contribution as max
        stack<int> stk;
        for (int i = 0; i <= n; ++i) {
            while (!stk.empty() && (i == n || nums[stk.top()] < nums[i])) {
                int mid = stk.top(); stk.pop();
                int left = stk.empty() ? -1 : stk.top();
                int right = i;
                long long count = (mid - left) * (long long)(right - mid);
                result += nums[mid] * count;
            }
            stk.push(i);
        }
 
        // Contribution as min
        while (!stk.empty()) stk.pop();
        for (int i = 0; i <= n; ++i) {
            while (!stk.empty() && (i == n || nums[stk.top()] > nums[i])) {
                int mid = stk.top(); stk.pop();
                int left = stk.empty() ? -1 : stk.top();
                int right = i;
                long long count = (mid - left) * (long long)(right - mid);
                result -= nums[mid] * count;
            }
            stk.push(i);
        }
 
        return result;
    }
};