beautifulremi / 3318. 计算子数组的 x-sum

Created Tue, 04 Nov 2025 20:40:34 +0800 Modified Mon, 23 Mar 2026 05:26:54 +0000
3739 Words

题目

给你一个由 n 个整数组成的数组 nums,以及两个整数 k 和 x

数组的 x-sum 计算按照以下步骤进行:

  • 统计数组中所有元素的出现次数。
  • 仅保留出现次数最多的前 x 个元素的每次出现。如果两个元素的出现次数相同,则数值 较大 的元素被认为出现次数更多。
  • 计算结果数组的和。

注意,如果数组中的不同元素少于 x 个,则其 x-sum 是数组的元素总和。

返回一个长度为 n - k + 1 的整数数组 answer,其中 answer[i] 是 子数组 nums[i..i + k - 1] 的 x-sum

子数组 是数组内的一个连续 非空 的元素序列。

示例 1:

输入:nums = [1,1,2,2,3,4,2,3], k = 6, x = 2

输出:[6,10,12]

解释:

  • 对于子数组 [1, 1, 2, 2, 3, 4],只保留元素 1 和 2。因此,answer[0] = 1 + 1 + 2 + 2
  • 对于子数组 [1, 2, 2, 3, 4, 2],只保留元素 2 和 4。因此,answer[1] = 2 + 2 + 2 + 4。注意 4 被保留是因为其数值大于出现其他出现次数相同的元素(3 和 1)。
  • 对于子数组 [2, 2, 3, 4, 2, 3],只保留元素 2 和 3。因此,answer[2] = 2 + 2 + 2 + 3 + 3

示例 2:

输入:nums = [3,8,7,8,7,5], k = 2, x = 2

输出:[11,15,15,15,12]

解释:

由于 k == xanswer[i] 等于子数组 nums[i..i + k - 1] 的总和。

提示:

  • 1 <= n == nums.length <= 10^5
  • 1 <= nums[i] <= 10^5
  • 1 <= x <= k <= nums.length

解题思路

这是一个经典的“滑动窗口”问题,但其核心计算(x-sum)比较复杂,导致朴素的解法(对每个窗口都重新计算)时间复杂度过高。

一个朴素的解法是:

  1. 遍历 n - k + 1 个窗口。

  2. 对于每个窗口(长度为 k),用一个哈希表统计频率。

  3. 将哈希表转为一个列表,并根据“x-sum”规则(频率降序,值降序)排序。

  4. 取出前 x 个元素,计算它们的总和。 这种方法的时间复杂度大约是 O(N * k log k)(其中 Nn-k+1),在 Nk 都很大时(例如 N=10^5, k=10^5)会超时。

我们需要一个更高效的解法,能够在窗口滑动时,动态更新 x-sum,而不是从头计算。

核心解题思路:滑动窗口 + 动态维护

这个问题的关键在于,当窗口从 [i...i+k-1] 滑动到 [i+1...i+k] 时,我们只移除了一个元素 nums[i] 并增加了一个元素 nums[i+k]。我们希望用 O(log k)O(1) 的代价来更新 x-sum。

我们将使用两个数据结构来动态维护“被保留的”和“被移除的”元素,以及一个哈希表来跟踪频率。

优化的关键点

  1. 特殊情况:k <= x

    • 根据规则:“如果数组中的不同元素少于 x 个,则其 x-sum 是数组的元素总和。”

    • 一个长度为 k 的子数组,最多有 k 个不同的元素。

    • 如果 k <= x,那么不同元素的数量 d 必然满足 d <= k <= x

    • 因此,在这种情况下,x-sum 始终等于子数组的总和

    • 这退化成了一个简单的滑动窗口求和问题,可以在 O(N) 时间内解决。

  2. 一般情况:k > x

    • 这种情况才需要复杂的处理。

    • 我们需要维护窗口内所有元素的频率

    • 我们还需要一种方法来快速知道哪些是“Top x”元素。

数据结构

我们将使用以下数据结构来跟踪窗口状态:

  1. freq_map (unordered_map<int, int>):一个哈希表,用于存储当前窗口内每个元素 value 及其出现次数 frequency

  2. total_sum (long long):当前窗口所有元素的总和。

  3. x_sum (long long):当前窗口“Top x”元素的总和(即我们要求的 x-sum)。

  4. kept_set (set<pair<int, int>, Comparator>):一个有序集合(如 C++ 的 std::set),用于存放**“Top x”**的元素。它存储 (frequency, value) 对。

  5. removed_set (set<pair<int, int>, Comparator>):一个有序集合,用于存放**“非 Top x”**的元素。它也存储 (frequency, value) 对。

Comparator 是一个自定义比较器,它严格按照题目的排序规则:

  • 优先比较 frequency(降序)。

  • 如果 frequency 相同,则比较 value(降序)。

算法步骤

  1. 处理特殊情况

    • 如果 k <= x,则初始化一个 total_sum,然后用 O(N) 的滑动窗口计算每个窗口的总和,并存入 answer 数组。直接返回。
  2. 初始化 (k > x)

    • 处理第一个窗口 nums[0...k-1]

    • 遍历这 k 个元素:

      • 更新 freq_maptotal_sum
    • 遍历 freq_map,将所有 (frequency, value) 对插入到一个临时的有序列表或 removed_set 中。

    • x_sum 初始化为 0。

    • 调用一个 balance() 辅助函数,将 removed_set 中“最好”的 x 个元素移动到 kept_set 中,并同步更新 x_sum

  3. 计算 answer[0]

    • 获取当前窗口的不同元素总数 d = freq_map.size()

    • 如果 d <= xanswer[0] = total_sum

    • 否则,answer[0] = x_sum

  4. 滑动窗口

    • i = 1 循环到 n - k

    • 在每一步,我们处理 remove_val = nums[i-1]add_val = nums[i+k-1]

    • 更新 total_sum = total_sum - remove_val + add_val

    • 处理 remove(remove_val)

      1. 获取 old_freq = freq_map[remove_val]

      2. new_freq = old_freq - 1

      3. kept_setremoved_set移除 (old_freq, remove_val)

      4. 如果它在 kept_set 中被移除,则 x_sum -= remove_val * old_freq

      5. 如果 new_freq > 0,将 (new_freq, remove_val) 插入removed_set(暂时)。

      6. 如果 new_freq == 0,从 freq_map 中擦除 remove_val。否则更新 freq_map[remove_val] = new_freq

    • 处理 add(add_val)

      1. 获取 old_freq = freq_map[add_val] (如果不存在则为 0)。

      2. new_freq = old_freq + 1

      3. 如果 old_freq > 0,从 kept_setremoved_set移除 (old_freq, add_val)

      4. 如果它在 kept_set 中被移除,则 x_sum -= add_val * old_freq

      5. (new_freq, add_val) 插入removed_set(暂时)。

      6. 更新 freq_map[add_val] = new_freq

    • 重新平衡 (balance())

      • addremove 操作之后,kept_setremoved_set 的状态可能不平衡。

      • 情况 Akept_set.size() > x。将 kept_set最差的元素(排序最后的元素)移到 removed_set,并从 x_sum 中减去它的贡献 (freq * val)。

      • 情况 Bkept_set.size() < xremoved_set 不为空。将 removed_set最好的元素(排序最前的元素)移到 kept_set,并向 x_sum 添加它的贡献。

      • 重复此过程直到 kept_set.size() == x (或者 removed_set 变空)。

    • 记录结果 answer[i]

      • 获取 d = freq_map.size()

      • 如果 d <= xanswer[i] = total_sum

      • 否则,answer[i] = x_sum

  5. 返回 answer

复杂度分析

  • 特殊情况 (k <= x):时间 O(N),空间 O(1)

  • 一般情况 (k > x)

    • std::set(平衡二叉树)的插入和删除操作都是 O(log d),其中 d 是不同元素的数量(d <= k)。

    • balance() 函数每次移动一个元素,也是 O(log k)

    • 初始化第一个窗口:O(k log k)

    • 滑动窗口循环:N - k 次。

    • 每次滑动(add, remove, balance):每个操作都是 O(log k)

    • 总时间复杂度O(k log k + N log k),在 k 接近 N 时,最坏为 O(N log N)

    • 空间复杂度O(k),用于存储 freq_map, kept_setremoved_set

具体代码

class Solution {
private:
    // 定义我们要存储的元素类型:(frequency, value)
    using Elem = pair<int, int>;

    // 自定义比较器,严格按照题目的排序规则
    // 1. 频率(frequency)降序
    // 2. 如果频率相同,则数值(value)降序
    struct Comparator {
        bool operator()(const Elem& a, const Elem& b) const {
            if (a.first != b.first) {
                return a.first > b.first;
            }
            return a.second > b.second;
        }
    };

    // `kept_set` 存储 Top x 元素
    set<Elem, Comparator> kept_set;
    // `removed_set` 存储所有其他元素
    set<Elem, Comparator> removed_set;
    // `freq_map` 跟踪窗口内每个元素的当前频率
    unordered_map<int, int> freq_map;
    
    // `x_sum` 是 kept_set 中元素的总和
    long long x_sum = 0;
    // `total_sum` 是整个窗口所有元素的总和
    long long total_sum = 0;
    
    // 存储 x 的值
    int x_val;
    // 实例化比较器
    Comparator comp;

    /**
     * @brief 更新元素的频率和其在两个 set 中的位置。
     * @param val 要更新的元素值。
     * @param old_freq 更新前的频率。
     * @param new_freq 更新后的频率。
     */
    void update_sets(int val, int old_freq, int new_freq) {
        // 1. 如果旧频率 > 0,说明该元素已存在,需要先移除旧条目
        if (old_freq > 0) {
            Elem old_elem = {old_freq, val};
            // 检查它在哪个 set 中
            if (kept_set.count(old_elem)) {
                kept_set.erase(old_elem);
                // 如果它在 kept_set 中,需要从 x_sum 中减去它的贡献
                x_sum -= (long long)old_freq * val;
            } else {
                removed_set.erase(old_elem);
            }
        }

        // 2. 如果新频率 > 0,说明元素仍然存在,需要添加新条目
        if (new_freq > 0) {
            Elem new_elem = {new_freq, val};
            // *重要*:我们总是先将其添加到 removed_set。
            // 稍后的 balance() 函数会决定它是否应该被移到 kept_set。
            removed_set.insert(new_elem);
        }
    }

    /**
     * @brief 重新平衡 kept_set 和 removed_set。
     * 确保 kept_set 包含的
     * 始终是当前窗口中“最好”的 x 个元素。
     */
    void balance() {
        // 1. 如果 kept_set 不足 x 个,从 removed_set 移动最好的元素过来
        while (kept_set.size() < x_val && !removed_set.empty()) {
            Elem to_move = *removed_set.begin();
            removed_set.erase(removed_set.begin());
            kept_set.insert(to_move);
            // 增加它对 x_sum 的贡献
            x_sum += (long long)to_move.first * to_move.second;
        }

        // 2. 如果 kept_set 超过 x 个,移动最差的元素到 removed_set
        while (kept_set.size() > x_val) {
            // .rbegin() 指向最后一个(即最差的)元素
            Elem to_move = *kept_set.rbegin(); 
            // C++ set 擦除末尾元素的标准方法
            kept_set.erase(std::prev(kept_set.end())); 
            removed_set.insert(to_move);
            // 减去它对 x_sum 的贡献
            x_sum -= (long long)to_move.first * to_move.second;
        }

        // 3. 交换:如果 removed_set 中最好的元素 > kept_set 中最差的元素
        //    *comp(a, b)* 返回 true 意味着 a "优于" b
        while (!kept_set.empty() && !removed_set.empty() && 
               comp(*removed_set.begin(), *kept_set.rbegin())) {
            
            Elem to_add = *removed_set.begin();
            Elem to_remove = *kept_set.rbegin();

            // 执行交换
            removed_set.erase(removed_set.begin());
            kept_set.erase(std::prev(kept_set.end()));
            kept_set.insert(to_add);
            removed_set.insert(to_remove);

            // 更新 x_sum
            x_sum += (long long)to_add.first * to_add.second;
            x_sum -= (long long)to_remove.first * to_remove.second;
        }
    }

public:
    vector<int> findXSum(vector<int>& nums, int k, int x) {
        int n = nums.size();
        this->x_val = x; // 存储 x
        vector<int> answer;

        // --- 优化:特殊情况 k <= x ---
        // 根据规则,如果不同元素少于 x 个,x-sum 等于总和。
        // 窗口长度为 k,不同元素 <= k。
        // 如果 k <= x,那么不同元素 <= x,因此 x-sum 总是等于窗口总和。
        if (k <= x) {
            long long current_sum = 0;
            for (int i = 0; i < n; ++i) {
                current_sum += nums[i];
                // 窗口开始滑动后,移除左侧元素
                if (i >= k) {
                    current_sum -= nums[i - k];
                }
                // 当窗口形成(即 i >= k-1)时,记录答案
                if (i >= k - 1) {
                    // 结果要求是 int,但计算过程用 long long 防止溢出
                    answer.push_back((int)current_sum);
                }
            }
            return answer;
        }

        // --- 常规情况:k > x ---

        // 1. 初始化第一个窗口 (nums[0...k-1])
        for (int i = 0; i < k; ++i) {
            freq_map[nums[i]]++;
            total_sum += nums[i];
        }

        // 将第一个窗口的频率统计放入 removed_set
        for (auto const& [val, freq] : freq_map) {
            removed_set.insert({freq, val});
        }

        // 第一次平衡,填满 kept_set
        balance();

        // 记录第一个窗口的答案
        // 规则:如果不同元素 <= x,使用 total_sum
        if (freq_map.size() <= x) {
            answer.push_back((int)total_sum);
        } else {
            answer.push_back((int)x_sum);
        }

        // 2. 开始滑动窗口
        // i 是新加入窗口的元素索引
        for (int i = k; i < n; ++i) {
            int val_to_add = nums[i];
            int val_to_remove = nums[i - k];

            // 更新总和
            total_sum = total_sum - val_to_remove + val_to_add;

            // --- 处理移除 ---
            int old_freq_rem = freq_map[val_to_remove];
            int new_freq_rem = old_freq_rem - 1;
            freq_map[val_to_remove] = new_freq_rem;
            if (new_freq_rem == 0) {
                freq_map.erase(val_to_remove); // 移除map中的键
            }
            // 更新 set
            update_sets(val_to_remove, old_freq_rem, new_freq_rem);

            // --- 处理添加 ---
            int old_freq_add = freq_map.count(val_to_add) ? freq_map[val_to_add] : 0;
            int new_freq_add = old_freq_add + 1;
            freq_map[val_to_add] = new_freq_add;
            // 更新 set
            update_sets(val_to_add, old_freq_add, new_freq_add);

            // --- 重新平衡 ---
            balance();

            // 记录当前窗口的答案
            if (freq_map.size() <= x) {
                answer.push_back((int)total_sum);
            } else {
                answer.push_back((int)x_sum);
            }
        }

        return answer;
    }
};