题目
给你一个由正整数组成的 m x n 矩阵 grid。你的任务是判断是否可以通过 一条水平或一条垂直分割线 将矩阵分割成两部分,使得:
- 分割后形成的每个部分都是 非空
的。 - 两个部分中所有元素的和 相等 ,或者总共 最多移除一个单元格 (从其中一个部分中)的情况下可以使它们相等。
- 如果移除某个单元格,剩余部分必须保持 连通 。
如果存在这样的分割,返回 true;否则,返回 false。
注意: 如果一个部分中的每个单元格都可以通过向上、向下、向左或向右移动到达同一部分中的其他单元格,则认为这一部分是 连通 的。
示例 1:
输入: grid = [[1,4],[2,3]]
输出: true
解释:

- 在第 0 行和第 1 行之间进行水平分割,结果两部分的元素和为
1 + 4 = 5和2 + 3 = 5,相等。因此答案是true。
示例 2:
输入: grid = [[1,2],[3,4]]
输出: true
解释:

- 在第 0 列和第 1 列之间进行垂直分割,结果两部分的元素和为
1 + 3 = 4和2 + 4 = 6。 - 通过从右侧部分移除
2(6 - 2 = 4),两部分的元素和相等,并且两部分保持连通。因此答案是true。
示例 3:
输入: grid = [[1,2,4],[2,3,5]]
输出: false
解释:

- 在第 0 行和第 1 行之间进行水平分割,结果两部分的元素和为
1 + 2 + 4 = 7和2 + 3 + 5 = 10。 - 通过从底部部分移除
3(10 - 3 = 7),两部分的元素和相等,但底部部分不再连通(分裂为[2]和[5])。因此答案是false。
示例 4:
输入: grid = [[4,1,8],[3,2,6]]
输出: false
解释:
不存在有效的分割,因此答案是 false。
提示:
1 <= m == grid.length <= 10^51 <= n == grid[i].length <= 10^52 <= m * n <= 10^51 <= grid[i][j] <= 10^5
解题思路
如果我们在每次切分矩阵时都去重新算和、重新找元素,时间开销会极大。因此,解题的第一步是“空间换时间”的全局扫描。
在仅遍历一次矩阵的过程中,我们建立两套核心数据:
线性和计算(Prefix Sum):记录每一行的元素总和
row_sum,以及每一列的元素总和col_sum。这样在后续“切蛋糕”时,只需通过简单的加减法,就能在 $O(1)$ 时间内得出任意切分线下方的两块区域的总和。全局极值映射(Bounding Box):对于矩阵中出现的每一个数值 $x$,我们记录下它出现的最上边界 (
min_row)、最下边界 (max_row)、最左边界 (min_col) 和最右边界 (max_col)。- 物理意义:我们不需要知道 $x$ 具体出现在哪些坐标,我们只需要知道它“势力范围”的极限在哪里。
题目要求移除一个元素后,剩余部分必须连通。这个条件乍一看需要用到复杂的图论算法(比如求割点),但其实可以通过几何直觉进行降维打击:
大块区域(行数 $\ge 2$ 且列数 $\ge 2$):在一个宽度和高度都至少为 2 的网格中,你挖掉任意一个格子,剩下的格子总能连在一起(哪怕挖掉正中心,周围依然是一圈连通的)。因此,在这类区域中,连通性约束可以直接忽略,问题退化为“区域内是否存在目标值”。
退化区域(单行或单列):如果切割后某一部分只有一条线($1 \times C$ 或 $R \times 1$),从中间抽走元素会把线切断。因此,这类区域只能移除两端的元素(即矩阵四个角附近的元素)。
有了前面的铺垫,我们就可以开始逐行(或逐列)模拟切割了。
假设我们在第 cut_row 行下方切了一刀(水平分割),上半部分的和为 top_sum,下半部分的和为 bottom_sum。 如果 top_sum > bottom_sum,我们需要从上半部分中精准剔除一个值为 d = top_sum - bottom_sum 的格子。
怎么知道上半部分到底有没有 d? 不用遍历查找,直接利用第一阶段的极值数据进行逻辑判断:
查字典:数值
d必须在全局极值表里存在(即整个矩阵得有这个数)。看上限:直接判断
min_row[d] <= cut_row是否成立。- 逻辑推演:
min_row[d]记录的是数值d在全矩阵中最靠上的一次出场位置。如果这个位置都在切分线cut_row的上方(或刚好在切分线上),那就说明上半部分绝对包含至少一个d!
- 逻辑推演:
过边界:最后套用第二阶段的连通性规则,如果上半部分是一条单线,就特判一下两端的值是不是
d。如果不是单线,直接判定可以安全移除,返回true。
具体代码
class Solution:
def canPartitionGrid(self, grid: List[List[int]]) -> bool:
m, n = len(grid), len(grid[0])
total = 0
row_sum = [0] * m
col_sum = [0] * n
# 记录每个值出现的最小/最大行列
min_row = {}
max_row = {}
min_col = {}
max_col = {}
for i in range(m):
for j in range(n):
x = grid[i][j]
total += x
row_sum[i] += x
col_sum[j] += x
if x not in min_row:
min_row[x] = max_row[x] = i
min_col[x] = max_col[x] = j
else:
if i < min_row[x]:
min_row[x] = i
if i > max_row[x]:
max_row[x] = i
if j < min_col[x]:
min_col[x] = j
if j > max_col[x]:
max_col[x] = j
# 横切:判断是否能从上半部分删一个值为 d 的格子
def can_remove_from_top(cut_row: int, d: int) -> bool:
h, w = cut_row + 1, n
if h == 1 and w == 1:
return False
if h == 1: # 只有一行,只能删两端
return grid[0][0] == d or grid[0][n - 1] == d
if w == 1: # 只有一列,只能删两端
return grid[0][0] == d or grid[cut_row][0] == d
return d in min_row and min_row[d] <= cut_row
# 横切:判断是否能从下半部分删一个值为 d 的格子
def can_remove_from_bottom(cut_row: int, d: int) -> bool:
h, w = m - cut_row - 1, n
if h == 1 and w == 1:
return False
if h == 1: # 只有一行,只能删两端
return grid[m - 1][0] == d or grid[m - 1][n - 1] == d
if w == 1: # 只有一列,只能删两端
return grid[cut_row + 1][0] == d or grid[m - 1][0] == d
return d in max_row and max_row[d] >= cut_row + 1
# 竖切:判断是否能从左半部分删一个值为 d 的格子
def can_remove_from_left(cut_col: int, d: int) -> bool:
h, w = m, cut_col + 1
if h == 1 and w == 1:
return False
if h == 1: # 只有一行,只能删两端
return grid[0][0] == d or grid[0][cut_col] == d
if w == 1: # 只有一列,只能删两端
return grid[0][0] == d or grid[m - 1][0] == d
return d in min_col and min_col[d] <= cut_col
# 竖切:判断是否能从右半部分删一个值为 d 的格子
def can_remove_from_right(cut_col: int, d: int) -> bool:
h, w = m, n - cut_col - 1
if h == 1 and w == 1:
return False
if h == 1: # 只有一行,只能删两端
return grid[0][cut_col + 1] == d or grid[0][n - 1] == d
if w == 1: # 只有一列,只能删两端
return grid[0][n - 1] == d or grid[m - 1][n - 1] == d
return d in max_col and max_col[d] >= cut_col + 1
# 枚举横切
top_sum = 0
for i in range(m - 1):
top_sum += row_sum[i]
bottom_sum = total - top_sum
if top_sum == bottom_sum:
return True
if top_sum > bottom_sum:
d = top_sum - bottom_sum
if can_remove_from_top(i, d):
return True
else:
d = bottom_sum - top_sum
if can_remove_from_bottom(i, d):
return True
# 枚举竖切
left_sum = 0
for j in range(n - 1):
left_sum += col_sum[j]
right_sum = total - left_sum
if left_sum == right_sum:
return True
if left_sum > right_sum:
d = left_sum - right_sum
if can_remove_from_left(j, d):
return True
else:
d = right_sum - left_sum
if can_remove_from_right(j, d):
return True
return False