307.Range Sum Query - Mutable

Tags: [tree], [segment_tree], [data_structure], [dp]

Link: https://leetcode.com/problems/range-sum-query-mutable/\#/description

Given an integer array nums, find the sum of the elements between indices i and j (i≤j), inclusive.

The update(i, val) function modifies nums by updating the element at index i to val.

Example:

Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8

Note:

  1. The array is only modifiable by the update function.
  2. You may assume the number of calls to update and sumRange function is distributed evenly.

Solution: Segment Tree

class NumArray(object):

    class SegmentTreeNode(object):
        def __init__(self, start, end):
            self.start = start
            self.end = end
            self.sum = 0
            self.left = None
            self.right = None

    def __init__(self, nums):
        """
        :type nums: List[int]
        """
        self.root = self.build_segment_tree(nums, 0, len(nums) - 1)

    def build_segment_tree(self, nums, start, end):
        # base case
        if start > end:
            return None
        if start == end:
            root = self.SegmentTreeNode(start, end)
            root.sum = nums[start]
            return root

        root = self.SegmentTreeNode(start, end)

        mid = start + (end - start) / 2
        root.left = self.build_segment_tree(nums, start, mid)
        root.right = self.build_segment_tree(nums, mid + 1, end)

        if root.left:
            root.sum += root.left.sum
        if root.right:
            root.sum += root.right.sum

        return root

    def update(self, i, val):
        """
        :type i: int
        :type val: int
        :rtype: void
        """
        self.update_segment_tree(self.root, i, val)

    def update_segment_tree(self, root, position, val):
        # base case.
        if not root:
            return
        if position == root.start == root.end:
            root.sum = val
            return

        mid = root.start + (root.end - root.start) / 2
        if position <= mid:
            # update the left subtree
            self.update_segment_tree(root.left, position, val)
        else:
            # udpate the right subtree
            self.update_segment_tree(root.right, position, val)

        root.sum = 0
        if root.left:
            root.sum += root.left.sum
        if root.right:
            root.sum += root.right.sum

    def sumRange(self, i, j):
        """
        :type i: int
        :type j: int
        :rtype: int
        """
        return self.sum_range_helper(self.root, i, j)

    def sum_range_helper(self, root, start, end):
        # base case
        if start > end:
            raise ValueError('the start should be less than or equal to end')
        if start == root.start and end == root.end:
            return root.sum

        mid = root.start + (root.end - root.start) / 2
        if end <= mid:
            # Query in the root.left
            return self.sum_range_helper(root.left, start, end)
        elif start >= mid + 1:
            # Query in the root.right
            return self.sum_range_helper(root.right, start, end)
        else:
            # Query both in root.left and root.right, then sum the result together
            return self.sum_range_helper(root.left, start, mid) + self.sum_range_helper(root.right, mid + 1, end)

# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(i,val)
# param_2 = obj.sumRange(i,j)

Revelation:

  • When we want to search the values in the range (two dimension or multi-dimension), we can think to build a segment tree.
  • mid = start + (end - start) / 2, and mid = (start + end) / 2. Both way to calculate the mid is correct. But the first way is better, because think about if start and end is very big, (start + end) may be integer overflow, but if we use start + (end - start) / 2, which can avoid a little bit the integer overflow when start and end is big. But if the start and end is very big, neither way cannot avoid the integer overflow.

Note:

  • Time complexity of initialization = O(n), n is the number of elements in the given nums. Because each element we only visit once.
  • Time complexity of update = O(lgn), n is the number of elements in the given nums. Because each time, we only go through one branch of the tree.
  • Time complexity of sumRange = O(lgn), n is the number of elements in the given nums. Because the worst case is that the root range is [start, end], and the query range is [mid, mid + 1], mid = start + (end - start) / 2. At root level, we need search both left and right branches. But under the root level. In the left subtree each time, only search one branch, and in the right subtree, each time only search one branch. So the time complexity = O(lgn + lgn) = O(lgn). For example root range is [0, 10], and the query range is [5, 6], we can draw this segment tree to see how the search work in the tree.

Time Limited Exceeded: DP

class NumArray(object):

    def __init__(self, nums):
        """
        :type nums: List[int]
        """
        self.total_sum = sum(nums)
        self.sum_left = self.calculate_sum_left(nums)
        self.sum_right = self.calculate_sum_right(nums)

    def calculate_sum_left(self, nums):
        if not nums:
            return []

        sum_left = [0 for _ in xrange(len(nums))]
        sum_left[0] = nums[0]
        for i in xrange(1, len(nums)):
            sum_left[i] = sum_left[i - 1] + nums[i]

        return sum_left

    def calculate_sum_right(self, nums):
        if not nums:
            return []

        sum_right = [0 for _ in xrange(len(nums))]
        sum_right[-1] = nums[-1]
        for i in xrange(len(nums) - 2, -1, -1):
            sum_right[i] = sum_right[i + 1] + nums[i]

        return sum_right

    def update(self, i, val):
        """
        :type i: int
        :type val: int
        :rtype: void
        """
        if i < 0 or i >= len(self.sum_left):
            raise ValueError('the input i is invalid')

        original_val = (self.sum_left[i] + self.sum_right[i]) - self.total_sum

        # update total_sum
        self.total_sum = self.total_sum - original_val + val

        # update sum_left
        for k in xrange(i, len(self.sum_left)):
            self.sum_left[k] = self.sum_left[k] - original_val + val

        # update sum_right
        for k in xrange(0, i + 1):
            self.sum_right[k] = self.sum_right[k] - original_val + val

    def sumRange(self, i, j):
        """
        :type i: int
        :type j: int
        :rtype: int
        """
        result = self.total_sum
        if i - 1 >= 0:
            result -= self.sum_left[i - 1]
        if j + 1 < len(self.sum_right):
            result -= self.sum_right[j + 1]

        return result

# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(i,val)
# param_2 = obj.sumRange(i,j)

Note:

  • Time complexity of initialization = O(n), n is the number of elements of the given nums.
  • Time complexity of update = O(n), n is the number of elements of the given nums.
  • Time complexity of sumRange = O(1).

results matching ""

    No results matching ""