308.Range Sum Query 2D - Mutable

Tags: [matrix], [segment_tree], [left_sum], [pre_sum], [prefix_sum], [binary_indexed_tree], [fenwick_tree]

Com: {g}

Link: https://leetcode.com/problems/range-sum-query-2d-mutable/#/description

Given a 2D matrix matrix, find the sum of the elements inside the rectangle defined by its upper left corner (row1, col1) and lower right corner (row2, col2).

Range Sum Query 2D
The above rectangle (with the red border) is defined by (row1, col1) = (2, 1) and (row2, col2) = (4, 3), which contains sum = 8.

Example:
Given matrix = [
  [3, 0, 1, 4, 2],
  [5, 6, 3, 2, 1],
  [1, 2, 0, 1, 5],
  [4, 1, 0, 1, 7],
  [1, 0, 3, 0, 5]
]

sumRegion(2, 1, 4, 3) -> 8
update(3, 2, 2)
sumRegion(2, 1, 4, 3) -> 10
Note:
The matrix is only modifiable by the update function.
You may assume the number of calls to update and sumRegion function is distributed evenly.
You may assume that row1 ≤ row2 and col1 ≤ col2.

Shorter Solution: Binary Indexed Tree

class NumMatrix(object):
    def __init__(self, matrix):
        """
        :type matrix: List[List[int]]
        """
        self.matrix = matrix
        if not matrix or not len(matrix[0]):
            return

        self.num_of_rows = len(matrix)
        self.num_of_cols = len(matrix[0])

        self.nums = [[0 for _ in xrange(self.num_of_cols)] for _ in xrange(self.num_of_rows)]
        self.bi_tree = [[0 for _ in xrange(self.num_of_cols + 1)] for _ in xrange(self.num_of_rows + 1)]
        for row in xrange(self.num_of_rows):
            for col in xrange(self.num_of_cols):
                self.update(row, col, matrix[row][col])

        print self.nums


    def update(self, row, col, val):
        """
        :type row: int
        :type col: int
        :type val: int
        :rtype: void
        """
        if not self.matrix or not len(self.matrix[0]):
            return

        diff = val - self.nums[row][col]
        self.nums[row][col] = val

        bi_row = row + 1
        bi_col = col + 1
        while bi_row <= self.num_of_rows:
            bi_col = col + 1
            while bi_col <= self.num_of_cols:
                self.bi_tree[bi_row][bi_col] += diff
                bi_col += bi_col & (-bi_col)
            bi_row += bi_row & (-bi_row)


    def sumRegion(self, row1, col1, row2, col2):
        """
        :type row1: int
        :type col1: int
        :type row2: int
        :type col2: int
        :rtype: int
        """
        if not self.bi_tree:
            return 0

        return self.sum_helper(row2, col2) -\
               self.sum_helper(row1 - 1, col2) -\
               self.sum_helper(row2, col1 - 1) +\
               self.sum_helper(row1 - 1, col1 - 1)

    def sum_helper(self, row, col):
        result = 0

        bi_row = row + 1
        bi_col = col + 1
        while bi_row > 0:
            bi_col = col + 1
            while bi_col > 0:
                result += self.bi_tree[bi_row][bi_col]
                bi_col -= bi_col & (-bi_col)
            bi_row -= bi_row & (-bi_row)

        return result


# Your NumMatrix object will be instantiated and called as such:
# obj = NumMatrix(matrix)
# obj.update(row,col,val)
# param_2 = obj.sumRegion(row1,col1,row2,col2)

Solution: Binary Indexed Tree

class NumMatrix(object):

    def __init__(self, matrix):
        """
        :type matrix: List[List[int]]
        """
        self.matrix = matrix
        if not matrix or not len(matrix[0]):
            return

        self.num_of_rows = len(matrix)
        self.num_of_cols = len(matrix[0])
        self.bi_tree = self.build_binary_indexed_tree(matrix)


    def build_binary_indexed_tree(self, matrix):
        tree = [[0 for _ in xrange(self.num_of_cols + 1)] for _ in xrange(self.num_of_rows + 1)]
        for row in xrange(self.num_of_rows):
            for col in xrange(self.num_of_cols):
                bi_row = row + 1
                while bi_row <= self.num_of_rows:
                    bi_col = col + 1
                    while bi_col <= self.num_of_cols:
                        tree[bi_row][bi_col] += matrix[row][col]
                        bi_col += bi_col & (-bi_col)
                    bi_row += bi_row & (-bi_row)

        return tree


    def update(self, row, col, val):
        """
        :type row: int
        :type col: int
        :type val: int
        :rtype: void
        """
        if not self.matrix or not len(self.matrix[0]):
            return

        diff = val - self.matrix[row][col]
        self.matrix[row][col] = val

        bi_row = row + 1
        while bi_row <= self.num_of_rows:
            bi_col = col + 1
            while bi_col <= self.num_of_cols:
                self.bi_tree[bi_row][bi_col] += diff
                bi_col += bi_col & (-bi_col)
            bi_row += bi_row & (-bi_row)



    def sumRegion(self, row1, col1, row2, col2):
        """
        :type row1: int
        :type col1: int
        :type row2: int
        :type col2: int
        :rtype: int
        """
        if not self.matrix or not len(self.matrix[0]):
            return 0

        return self.sum_helper(row2, col2) -\
               self.sum_helper(row1 - 1, col2) -\
               self.sum_helper(row2, col1 - 1) +\
               self.sum_helper(row1 - 1, col1 - 1)


    def sum_helper(self, row, col):
        result = 0
        bi_row = row + 1
        while bi_row > 0:
            bi_col = col + 1
            while bi_col > 0:
                result += self.bi_tree[bi_row][bi_col]
                bi_col -= bi_col & (-bi_col)
            bi_row -= bi_row & (-bi_row)

        return result


# Your NumMatrix object will be instantiated and called as such:
# obj = NumMatrix(matrix)
# obj.update(row,col,val)
# param_2 = obj.sumRegion(row1,col1,row2,col2)

Revelation:

  • Binary indexed tree 是 segment tree 的一种特例,能用 segment tree 解决的问题,都能用 binary indexed tree 解决,但反过来不一定.
  • 一定记住每次内层的 while 开始前要重设 bi_col.
  • build tree 时,只有按照matrix一个cell一个cell的迭代才能保证答案的正确性.

Note:

  • Time complexity of initialization = O(n), n is the number of elements of matrix.
  • Time complexity of update = O(lgn), n is the number of elements of matrix.
  • Time complexity of sum range = O(lgn), n is the number of elements of matrix.

Solution: Left Sum

class NumMatrix(object):
    def __init__(self, matrix):
        """
        :type matrix: List[List[int]]
        """
        self.matrix = matrix
        if not self.matrix or not len(self.matrix[0]):
            return

        for row in xrange(len(self.matrix)):
            for col in xrange(1, len(self.matrix[0])):
                self.matrix[row][col] += self.matrix[row][col - 1]


    def update(self, row, col, val):
        """
        :type row: int
        :type col: int
        :type val: int
        :rtype: void
        """
        if not self.matrix or not len(self.matrix[0]):
            return

        original_val = self.matrix[row][col]
        if col - 1 >= 0:
            original_val -= self.matrix[row][col - 1]

        diff = val - original_val
        for c in xrange(col, len(self.matrix[0])):
            self.matrix[row][c] += diff


    def sumRegion(self, row1, col1, row2, col2):
        """
        :type row1: int
        :type col1: int
        :type row2: int
        :type col2: int
        :rtype: int
        """
        if not self.matrix or not len(self.matrix[0]):
            return 0

        curr_sum = 0
        for row in xrange(row1, row2 + 1):
            curr_sum += self.matrix[row][col2]
            if col1 - 1 >= 0:
                curr_sum -= self.matrix[row][col1 - 1]

        return curr_sum



# Your NumMatrix object will be instantiated and called as such:
# obj = NumMatrix(matrix)
# obj.update(row,col,val)
# param_2 = obj.sumRegion(row1,col1,row2,col2)

Note:

  • Time complexity of initialization = O(n * m), n is the number of rows, m is the number of cols.
  • Time complexity of update = O(m), m is the number of cols.
  • Time complexity of sum range = O(n), n is the number of rows.

Time Limited Exceeded: Segment Tree

class NumMatrix(object):

    class SegmentTreeNode(object):
        def __init__(self, start, end):
            # start including (row, col)
            # end including (row, col)

            self.start = start
            self.end = end
            self.sum = 0
            self.children = [None for _ in xrange(4)]


    def __init__(self, matrix):
        """
        :type matrix: List[List[int]]
        """
        self.matrix = matrix
        if not self.matrix or not len(self.matrix[0]):
            return

        self.root = self.build_segment_tree(self.matrix, 
                                            (0, 0), (len(self.matrix) - 1, len(self.matrix[0]) - 1))


    def build_segment_tree(self, matrix, start, end):
        # base case
        if start[0] > end[0] or start[1] > end[1]:
            return None
        if start == end:
            root = self.SegmentTreeNode(start, end)
            root.sum = matrix[start[0]][start[1]]
            return root

        root = self.SegmentTreeNode(start, end)
        mid = (start[0] + (end[0] - start[0]) / 2,
               start[1] + (end[1] - start[1]) / 2)

        upper_left = self.build_segment_tree(matrix, start, mid)
        upper_right = self.build_segment_tree(matrix, (start[0], mid[1] + 1), (mid[0], end[1]))
        lower_left = self.build_segment_tree(matrix, (mid[0] + 1, start[1]), (end[0], mid[1]))
        lower_right = self.build_segment_tree(matrix, (mid[0] + 1, mid[1] + 1), end)

        if upper_left:
            root.children[0] = upper_left
            root.sum += upper_left.sum

        if upper_right:
            root.children[1] = upper_right
            root.sum += upper_right.sum

        if lower_left:
            root.children[2] = lower_left
            root.sum += lower_left.sum

        if lower_right:
            root.children[3] = lower_right
            root.sum += lower_right.sum

        return root


    def update(self, row, col, val):
        """
        :type row: int
        :type col: int
        :type val: int
        :rtype: void
        """
        if not self.matrix or not len(self.matrix[0]):
            return

        diff = val - self.matrix[row][col]
        self.matrix[row][col] = val
        self.update_helper(self.root, row, col, diff)

    def update_helper(self, root, row, col, diff):
        # base case
        if not root:
            return
        if root.start == root.end == (row, col):
            root.sum += diff
            return

        root.sum += diff
        area_index = self.get_area_index(root.start, root.end, row, col)
        if area_index < 0:
            raise ValueError('Fail to update')

        self.update_helper(root.children[area_index], row, col, diff)    


    def get_area_index(self, start, end, row, col):
        mid = (start[0] + (end[0] - start[0]) / 2,
               start[1] + (end[1] - start[1]) / 2)

        if start[0] <= row <= mid[0] and start[1] <= col <= mid[1]:
            return 0
        elif start[0] <= row <= mid[0] and mid[1] + 1 <= col <= end[1]:
            return 1
        elif mid[0] + 1 <= row <= end[0] and start[1] <= col <= mid[1]:
            return 2
        elif mid[0] + 1 <= row <= end[0] and mid[1] + 1 <= col <= end[1]:
            return 3
        else:
            return -1


    def sumRegion(self, row1, col1, row2, col2):
        """
        :type row1: int
        :type col1: int
        :type row2: int
        :type col2: int
        :rtype: int
        """
        if not self.matrix or not len(self.matrix[0]):
            return 0

        return self.sum_range_helper(self.root, (row1, col1), (row2, col2))


    def sum_range_helper(self, root, start, end):
        # base case
        if root.start == start and root.end == end:
            return root.sum

        area_index_1 = self.get_area_index(root.start, root.end, start[0], start[1])
        area_index_2 = self.get_area_index(root.start, root.end, end[0], end[1])

        if area_index_1 == area_index_2:
            return self.sum_range_helper(root.children[area_index_1], start, end)
        else:
            mid = (root.start[0] + (root.end[0] - root.start[0]) / 2,
                   root.start[1] + (root.end[1] - root.start[1]) / 2)

            row1, col1 = start
            row2, col2 = end

            if area_index_1 == 0 and area_index_2 == 3:
                # the query area crosses four areas
                return self.sum_range_helper(root.children[0], (row1, col1), mid) +\
                       self.sum_range_helper(root.children[1], (row1, mid[1] + 1), (mid[0], col2)) +\
                       self.sum_range_helper(root.children[2], (mid[0] + 1, col1), (row2, mid[1])) +\
                       self.sum_range_helper(root.children[3], (mid[0] + 1, mid[1] + 1), (row2, col2))
            elif area_index_1 == 0 and area_index_2 == 2:
                # the query area crosses two areas on the left side.
                return self.sum_range_helper(root.children[0], (row1, col1), (mid[0], col2)) +\
                       self.sum_range_helper(root.children[2], (mid[0] + 1, col1), (row2, col2))
            elif area_index_1 == 1 and area_index_2 == 3:
                # query area crosses two areas on the right side.
                return self.sum_range_helper(root.children[1], (row1, col1), (mid[0], col2)) +\
                       self.sum_range_helper(root.children[3], (mid[0] + 1, col1), (row2, col2))
            elif area_index_1 == 0 and area_index_2 == 1:
                # query area crosses two areas on top.
                return self.sum_range_helper(root.children[0], (row1, col1), (row2, mid[1])) +\
                       self.sum_range_helper(root.children[1], (row1, mid[1] + 1), (row2, col2))
            else:
                # query area crosses two areas on bottom.
                return self.sum_range_helper(root.children[2], (row1, col1), (row2, mid[1])) +\
                       self.sum_range_helper(root.children[3], (row1, mid[1] + 1), (row2, col2))



# Your NumMatrix object will be instantiated and called as such:
# obj = NumMatrix(matrix)
# obj.update(row,col,val)
# param_2 = obj.sumRegion(row1,col1,row2,col2)

Revelation:

  • The above algorithm is correct, the same idea in Java can pass Leetcode.

  • For range querying, we can think about using segment tree to solve the problem, especially, when we need to update the update the value.

  • When we update the tree, do not forget to update the matrix itself.

  • When we update the tree, we just need to pass the diff (diff = new__val - old__val).

Note:

  • Time complexity of initialization = O(n), n is the number of all cells of the given matrix. Because there 2*n - 1 nodes in the built segment tree.
  • Time complexity of update = O(lgn), n is the number of all cells of the given matrix.
  • Time complexity of sum range = O(lgn), n is the number of all cells of the given matrix. Explanation: https://zhaonanli.gitbooks.io/leetcode/content/307range-sum-query-mutable.html

results matching ""

    No results matching ""