题目
给你一个整数数组 nums。
特殊三元组 定义为满足以下条件的下标三元组 (i, j, k):
0 <= i < j < k < n,其中n = nums.lengthnums[i] == nums[j] * 2nums[k] == nums[j] * 2
返回数组中 特殊三元组 的总数。
由于答案可能非常大,请返回结果对 109 + 7 取余数后的值。
示例 1:
输入: nums = [6,3,6]
输出: 1
解释:
唯一的特殊三元组是 (i, j, k) = (0, 1, 2),其中:
nums[0] = 6,nums[1] = 3,nums[2] = 6nums[0] = nums[1] * 2 = 3 * 2 = 6nums[2] = nums[1] * 2 = 3 * 2 = 6
示例 2:
输入: nums = [0,1,0,0]
输出: 1
解释:
唯一的特殊三元组是 (i, j, k) = (0, 2, 3),其中:
nums[0] = 0,nums[2] = 0,nums[3] = 0nums[0] = nums[2] * 2 = 0 * 2 = 0nums[3] = nums[2] * 2 = 0 * 2 = 0
示例 3:
输入: nums = [8,4,2,8,4]
输出: 2
解释:
共有两个特殊三元组:
(i, j, k) = (0, 1, 3)nums[0] = 8,nums[1] = 4,nums[3] = 8nums[0] = nums[1] * 2 = 4 * 2 = 8nums[3] = nums[1] * 2 = 4 * 2 = 8
(i, j, k) = (1, 2, 4)nums[1] = 4,nums[2] = 2,nums[4] = 4nums[1] = nums[2] * 2 = 2 * 2 = 4nums[4] = nums[2] * 2 = 2 * 2 = 4
提示:
3 <= n == nums.length <= 10^50 <= nums[i] <= 10^5
解题思路
题目要求的条件是:
下标 $i < j < k$
$nums[i] = nums[j] \times 2$
$nums[k] = nums[j] \times 2$
如果我们遍历每一个元素作为中间的元素 $nums[j]$,问题就简化为:
在 $j$ 的左边有多少个元素等于 $nums[j] \times 2$?记为 $L$。
在 $j$ 的右边有多少个元素等于 $nums[j] \times 2$?记为 $R$。
根据乘法原理,以 $j$ 为中心组成的特殊三元组数量就是 $L \times R$。
最终答案就是所有位置 $j$ 的 $L \times R$ 之和。
为了在 $O(N)$ 时间内完成,我们需要快速知道当前元素左边和右边的数字频率。我们可以使用哈希表(HashMap)或频率数组。
初始化右侧频率表 (right_cnt):
首先遍历一遍整个数组,统计所有数字出现的次数,存入 right_cnt。此时,它代表了“所有元素都在当前指针右侧(包括自身)”的状态。
初始化左侧频率表 (left_cnt):
创建一个空的哈希表或数组,用于在遍历过程中动态记录当前元素左侧的数字频率。
遍历数组(枚举 $j$):
从左到右遍历数组中的每一个元素 $x$(即 $nums[j]$):
更新右侧表:因为当前的 $x$ 已经遍历到了,所以它不再属于“右侧”,将
right_cnt[x]减 1。计算目标值:我们需要找的值是 $target = x \times 2$。
计算贡献:
从
left_cnt中获取 $target$ 的数量(即 $L$)。从
right_cnt中获取 $target$ 的数量(即 $R$)。如果是特殊三元组,则 $count = L \times R$。
将 $count$ 加入总结果(注意取余 $10^9 + 7$)。
更新左侧表:将当前 $x$ 放入
left_cnt中(left_cnt[x]加 1),因为它将成为后续元素的“左侧”。
时间复杂度:$O(N)$。我们只需要遍历数组两次(一次初始化,一次计算)。哈希表或数组的查找是 $O(1)$ 的。
空间复杂度:$O(M)$,其中 $M$ 是数组中数值的范围(本题中 $nums[i] \le 10^5$)。我们需要两个哈希表或数组来存储频率。
具体代码
func specialTriplets(nums []int) int {
const mod = 1e9 + 7
// 初始化两个 map,分别代表左侧和右侧的计数器
rightCnt := make(map[int]int)
leftCnt := make(map[int]int)
// 1. 预处理:先将所有元素都统计到右侧 map 中
for _, x := range nums {
rightCnt[x]++
}
ans := 0
// 2. 遍历数组,枚举中间元素 x
for _, x := range nums {
// Step A: 当前元素 x 正在作为中间节点处理,所以从右侧计数中减去
rightCnt[x]--
target := x * 2
// Step B: 计算贡献
// 如果 target 在 leftCnt 或 rightCnt 中不存在,Go 会返回 0
// 所以直接相乘即可,0 * n = 0,不会影响结果
if c := leftCnt[target] * rightCnt[target]; c > 0 {
ans = (ans + c) % mod
}
// Step C: 当前元素 x 处理完毕,加入左侧计数,作为后续元素的“左边”
leftCnt[x]++
}
return ans
}