308.Range Sum Query 2D - Mutable
Tags: [matrix], [segment_tree], [left_sum], [pre_sum], [prefix_sum], [binary_indexed_tree], [fenwick_tree]
Com: {g}
Link: https://leetcode.com/problems/range-sum-query-2d-mutable/#/description
Given a 2D matrix matrix, find the sum of the elements inside the rectangle defined by its upper left corner (row1, col1) and lower right corner (row2, col2).
Range Sum Query 2D
The above rectangle (with the red border) is defined by (row1, col1) = (2, 1) and (row2, col2) = (4, 3), which contains sum = 8.
Example:
Given matrix = [
[3, 0, 1, 4, 2],
[5, 6, 3, 2, 1],
[1, 2, 0, 1, 5],
[4, 1, 0, 1, 7],
[1, 0, 3, 0, 5]
]
sumRegion(2, 1, 4, 3) -> 8
update(3, 2, 2)
sumRegion(2, 1, 4, 3) -> 10
Note:
The matrix is only modifiable by the update function.
You may assume the number of calls to update and sumRegion function is distributed evenly.
You may assume that row1 ≤ row2 and col1 ≤ col2.
Shorter Solution: Binary Indexed Tree
class NumMatrix(object):
def __init__(self, matrix):
"""
:type matrix: List[List[int]]
"""
self.matrix = matrix
if not matrix or not len(matrix[0]):
return
self.num_of_rows = len(matrix)
self.num_of_cols = len(matrix[0])
self.nums = [[0 for _ in xrange(self.num_of_cols)] for _ in xrange(self.num_of_rows)]
self.bi_tree = [[0 for _ in xrange(self.num_of_cols + 1)] for _ in xrange(self.num_of_rows + 1)]
for row in xrange(self.num_of_rows):
for col in xrange(self.num_of_cols):
self.update(row, col, matrix[row][col])
print self.nums
def update(self, row, col, val):
"""
:type row: int
:type col: int
:type val: int
:rtype: void
"""
if not self.matrix or not len(self.matrix[0]):
return
diff = val - self.nums[row][col]
self.nums[row][col] = val
bi_row = row + 1
bi_col = col + 1
while bi_row <= self.num_of_rows:
bi_col = col + 1
while bi_col <= self.num_of_cols:
self.bi_tree[bi_row][bi_col] += diff
bi_col += bi_col & (-bi_col)
bi_row += bi_row & (-bi_row)
def sumRegion(self, row1, col1, row2, col2):
"""
:type row1: int
:type col1: int
:type row2: int
:type col2: int
:rtype: int
"""
if not self.bi_tree:
return 0
return self.sum_helper(row2, col2) -\
self.sum_helper(row1 - 1, col2) -\
self.sum_helper(row2, col1 - 1) +\
self.sum_helper(row1 - 1, col1 - 1)
def sum_helper(self, row, col):
result = 0
bi_row = row + 1
bi_col = col + 1
while bi_row > 0:
bi_col = col + 1
while bi_col > 0:
result += self.bi_tree[bi_row][bi_col]
bi_col -= bi_col & (-bi_col)
bi_row -= bi_row & (-bi_row)
return result
# Your NumMatrix object will be instantiated and called as such:
# obj = NumMatrix(matrix)
# obj.update(row,col,val)
# param_2 = obj.sumRegion(row1,col1,row2,col2)
Solution: Binary Indexed Tree
class NumMatrix(object):
def __init__(self, matrix):
"""
:type matrix: List[List[int]]
"""
self.matrix = matrix
if not matrix or not len(matrix[0]):
return
self.num_of_rows = len(matrix)
self.num_of_cols = len(matrix[0])
self.bi_tree = self.build_binary_indexed_tree(matrix)
def build_binary_indexed_tree(self, matrix):
tree = [[0 for _ in xrange(self.num_of_cols + 1)] for _ in xrange(self.num_of_rows + 1)]
for row in xrange(self.num_of_rows):
for col in xrange(self.num_of_cols):
bi_row = row + 1
while bi_row <= self.num_of_rows:
bi_col = col + 1
while bi_col <= self.num_of_cols:
tree[bi_row][bi_col] += matrix[row][col]
bi_col += bi_col & (-bi_col)
bi_row += bi_row & (-bi_row)
return tree
def update(self, row, col, val):
"""
:type row: int
:type col: int
:type val: int
:rtype: void
"""
if not self.matrix or not len(self.matrix[0]):
return
diff = val - self.matrix[row][col]
self.matrix[row][col] = val
bi_row = row + 1
while bi_row <= self.num_of_rows:
bi_col = col + 1
while bi_col <= self.num_of_cols:
self.bi_tree[bi_row][bi_col] += diff
bi_col += bi_col & (-bi_col)
bi_row += bi_row & (-bi_row)
def sumRegion(self, row1, col1, row2, col2):
"""
:type row1: int
:type col1: int
:type row2: int
:type col2: int
:rtype: int
"""
if not self.matrix or not len(self.matrix[0]):
return 0
return self.sum_helper(row2, col2) -\
self.sum_helper(row1 - 1, col2) -\
self.sum_helper(row2, col1 - 1) +\
self.sum_helper(row1 - 1, col1 - 1)
def sum_helper(self, row, col):
result = 0
bi_row = row + 1
while bi_row > 0:
bi_col = col + 1
while bi_col > 0:
result += self.bi_tree[bi_row][bi_col]
bi_col -= bi_col & (-bi_col)
bi_row -= bi_row & (-bi_row)
return result
# Your NumMatrix object will be instantiated and called as such:
# obj = NumMatrix(matrix)
# obj.update(row,col,val)
# param_2 = obj.sumRegion(row1,col1,row2,col2)
Revelation:
- Binary indexed tree 是 segment tree 的一种特例,能用 segment tree 解决的问题,都能用 binary indexed tree 解决,但反过来不一定.
- 一定记住每次内层的 while 开始前要重设 bi_col.
- build tree 时,只有按照matrix一个cell一个cell的迭代才能保证答案的正确性.
Note:
- Time complexity of initialization = O(n), n is the number of elements of matrix.
- Time complexity of update = O(lgn), n is the number of elements of matrix.
- Time complexity of sum range = O(lgn), n is the number of elements of matrix.
Solution: Left Sum
class NumMatrix(object):
def __init__(self, matrix):
"""
:type matrix: List[List[int]]
"""
self.matrix = matrix
if not self.matrix or not len(self.matrix[0]):
return
for row in xrange(len(self.matrix)):
for col in xrange(1, len(self.matrix[0])):
self.matrix[row][col] += self.matrix[row][col - 1]
def update(self, row, col, val):
"""
:type row: int
:type col: int
:type val: int
:rtype: void
"""
if not self.matrix or not len(self.matrix[0]):
return
original_val = self.matrix[row][col]
if col - 1 >= 0:
original_val -= self.matrix[row][col - 1]
diff = val - original_val
for c in xrange(col, len(self.matrix[0])):
self.matrix[row][c] += diff
def sumRegion(self, row1, col1, row2, col2):
"""
:type row1: int
:type col1: int
:type row2: int
:type col2: int
:rtype: int
"""
if not self.matrix or not len(self.matrix[0]):
return 0
curr_sum = 0
for row in xrange(row1, row2 + 1):
curr_sum += self.matrix[row][col2]
if col1 - 1 >= 0:
curr_sum -= self.matrix[row][col1 - 1]
return curr_sum
# Your NumMatrix object will be instantiated and called as such:
# obj = NumMatrix(matrix)
# obj.update(row,col,val)
# param_2 = obj.sumRegion(row1,col1,row2,col2)
Note:
- Time complexity of initialization = O(n * m), n is the number of rows, m is the number of cols.
- Time complexity of update = O(m), m is the number of cols.
- Time complexity of sum range = O(n), n is the number of rows.
Time Limited Exceeded: Segment Tree
class NumMatrix(object):
class SegmentTreeNode(object):
def __init__(self, start, end):
# start including (row, col)
# end including (row, col)
self.start = start
self.end = end
self.sum = 0
self.children = [None for _ in xrange(4)]
def __init__(self, matrix):
"""
:type matrix: List[List[int]]
"""
self.matrix = matrix
if not self.matrix or not len(self.matrix[0]):
return
self.root = self.build_segment_tree(self.matrix,
(0, 0), (len(self.matrix) - 1, len(self.matrix[0]) - 1))
def build_segment_tree(self, matrix, start, end):
# base case
if start[0] > end[0] or start[1] > end[1]:
return None
if start == end:
root = self.SegmentTreeNode(start, end)
root.sum = matrix[start[0]][start[1]]
return root
root = self.SegmentTreeNode(start, end)
mid = (start[0] + (end[0] - start[0]) / 2,
start[1] + (end[1] - start[1]) / 2)
upper_left = self.build_segment_tree(matrix, start, mid)
upper_right = self.build_segment_tree(matrix, (start[0], mid[1] + 1), (mid[0], end[1]))
lower_left = self.build_segment_tree(matrix, (mid[0] + 1, start[1]), (end[0], mid[1]))
lower_right = self.build_segment_tree(matrix, (mid[0] + 1, mid[1] + 1), end)
if upper_left:
root.children[0] = upper_left
root.sum += upper_left.sum
if upper_right:
root.children[1] = upper_right
root.sum += upper_right.sum
if lower_left:
root.children[2] = lower_left
root.sum += lower_left.sum
if lower_right:
root.children[3] = lower_right
root.sum += lower_right.sum
return root
def update(self, row, col, val):
"""
:type row: int
:type col: int
:type val: int
:rtype: void
"""
if not self.matrix or not len(self.matrix[0]):
return
diff = val - self.matrix[row][col]
self.matrix[row][col] = val
self.update_helper(self.root, row, col, diff)
def update_helper(self, root, row, col, diff):
# base case
if not root:
return
if root.start == root.end == (row, col):
root.sum += diff
return
root.sum += diff
area_index = self.get_area_index(root.start, root.end, row, col)
if area_index < 0:
raise ValueError('Fail to update')
self.update_helper(root.children[area_index], row, col, diff)
def get_area_index(self, start, end, row, col):
mid = (start[0] + (end[0] - start[0]) / 2,
start[1] + (end[1] - start[1]) / 2)
if start[0] <= row <= mid[0] and start[1] <= col <= mid[1]:
return 0
elif start[0] <= row <= mid[0] and mid[1] + 1 <= col <= end[1]:
return 1
elif mid[0] + 1 <= row <= end[0] and start[1] <= col <= mid[1]:
return 2
elif mid[0] + 1 <= row <= end[0] and mid[1] + 1 <= col <= end[1]:
return 3
else:
return -1
def sumRegion(self, row1, col1, row2, col2):
"""
:type row1: int
:type col1: int
:type row2: int
:type col2: int
:rtype: int
"""
if not self.matrix or not len(self.matrix[0]):
return 0
return self.sum_range_helper(self.root, (row1, col1), (row2, col2))
def sum_range_helper(self, root, start, end):
# base case
if root.start == start and root.end == end:
return root.sum
area_index_1 = self.get_area_index(root.start, root.end, start[0], start[1])
area_index_2 = self.get_area_index(root.start, root.end, end[0], end[1])
if area_index_1 == area_index_2:
return self.sum_range_helper(root.children[area_index_1], start, end)
else:
mid = (root.start[0] + (root.end[0] - root.start[0]) / 2,
root.start[1] + (root.end[1] - root.start[1]) / 2)
row1, col1 = start
row2, col2 = end
if area_index_1 == 0 and area_index_2 == 3:
# the query area crosses four areas
return self.sum_range_helper(root.children[0], (row1, col1), mid) +\
self.sum_range_helper(root.children[1], (row1, mid[1] + 1), (mid[0], col2)) +\
self.sum_range_helper(root.children[2], (mid[0] + 1, col1), (row2, mid[1])) +\
self.sum_range_helper(root.children[3], (mid[0] + 1, mid[1] + 1), (row2, col2))
elif area_index_1 == 0 and area_index_2 == 2:
# the query area crosses two areas on the left side.
return self.sum_range_helper(root.children[0], (row1, col1), (mid[0], col2)) +\
self.sum_range_helper(root.children[2], (mid[0] + 1, col1), (row2, col2))
elif area_index_1 == 1 and area_index_2 == 3:
# query area crosses two areas on the right side.
return self.sum_range_helper(root.children[1], (row1, col1), (mid[0], col2)) +\
self.sum_range_helper(root.children[3], (mid[0] + 1, col1), (row2, col2))
elif area_index_1 == 0 and area_index_2 == 1:
# query area crosses two areas on top.
return self.sum_range_helper(root.children[0], (row1, col1), (row2, mid[1])) +\
self.sum_range_helper(root.children[1], (row1, mid[1] + 1), (row2, col2))
else:
# query area crosses two areas on bottom.
return self.sum_range_helper(root.children[2], (row1, col1), (row2, mid[1])) +\
self.sum_range_helper(root.children[3], (row1, mid[1] + 1), (row2, col2))
# Your NumMatrix object will be instantiated and called as such:
# obj = NumMatrix(matrix)
# obj.update(row,col,val)
# param_2 = obj.sumRegion(row1,col1,row2,col2)
Revelation:
The above algorithm is correct, the same idea in Java can pass Leetcode.
For range querying, we can think about using segment tree to solve the problem, especially, when we need to update the update the value.
When we update the tree, do not forget to update the matrix itself.
When we update the tree, we just need to pass the diff (diff = new__val - old__val).
Note:
- Time complexity of initialization = O(n), n is the number of all cells of the given matrix. Because there 2*n - 1 nodes in the built segment tree.
- Time complexity of update = O(lgn), n is the number of all cells of the given matrix.
- Time complexity of sum range = O(lgn), n is the number of all cells of the given matrix. Explanation: https://zhaonanli.gitbooks.io/leetcode/content/307range-sum-query-mutable.html