307.Range Sum Query - Mutable
Tags: [tree], [segment_tree], [data_structure], [dp]
Link: https://leetcode.com/problems/range-sum-query-mutable/\#/description
Given an integer array nums, find the sum of the elements between indices i and j (i≤j), inclusive.
The update(i, val) function modifies nums by updating the element at index i to val.
Example:
Given nums = [1, 3, 5]
sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
Note:
- The array is only modifiable by the update function.
- You may assume the number of calls to update and sumRange function is distributed evenly.
Solution: Segment Tree
class NumArray(object):
class SegmentTreeNode(object):
def __init__(self, start, end):
self.start = start
self.end = end
self.sum = 0
self.left = None
self.right = None
def __init__(self, nums):
"""
:type nums: List[int]
"""
self.root = self.build_segment_tree(nums, 0, len(nums) - 1)
def build_segment_tree(self, nums, start, end):
# base case
if start > end:
return None
if start == end:
root = self.SegmentTreeNode(start, end)
root.sum = nums[start]
return root
root = self.SegmentTreeNode(start, end)
mid = start + (end - start) / 2
root.left = self.build_segment_tree(nums, start, mid)
root.right = self.build_segment_tree(nums, mid + 1, end)
if root.left:
root.sum += root.left.sum
if root.right:
root.sum += root.right.sum
return root
def update(self, i, val):
"""
:type i: int
:type val: int
:rtype: void
"""
self.update_segment_tree(self.root, i, val)
def update_segment_tree(self, root, position, val):
# base case.
if not root:
return
if position == root.start == root.end:
root.sum = val
return
mid = root.start + (root.end - root.start) / 2
if position <= mid:
# update the left subtree
self.update_segment_tree(root.left, position, val)
else:
# udpate the right subtree
self.update_segment_tree(root.right, position, val)
root.sum = 0
if root.left:
root.sum += root.left.sum
if root.right:
root.sum += root.right.sum
def sumRange(self, i, j):
"""
:type i: int
:type j: int
:rtype: int
"""
return self.sum_range_helper(self.root, i, j)
def sum_range_helper(self, root, start, end):
# base case
if start > end:
raise ValueError('the start should be less than or equal to end')
if start == root.start and end == root.end:
return root.sum
mid = root.start + (root.end - root.start) / 2
if end <= mid:
# Query in the root.left
return self.sum_range_helper(root.left, start, end)
elif start >= mid + 1:
# Query in the root.right
return self.sum_range_helper(root.right, start, end)
else:
# Query both in root.left and root.right, then sum the result together
return self.sum_range_helper(root.left, start, mid) + self.sum_range_helper(root.right, mid + 1, end)
# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(i,val)
# param_2 = obj.sumRange(i,j)
Revelation:
- When we want to search the values in the range (two dimension or multi-dimension), we can think to build a segment tree.
- mid = start + (end - start) / 2, and mid = (start + end) / 2. Both way to calculate the mid is correct. But the first way is better, because think about if start and end is very big, (start + end) may be integer overflow, but if we use start + (end - start) / 2, which can avoid a little bit the integer overflow when start and end is big. But if the start and end is very big, neither way cannot avoid the integer overflow.
Note:
- Time complexity of initialization = O(n), n is the number of elements in the given nums. Because each element we only visit once.
- Time complexity of update = O(lgn), n is the number of elements in the given nums. Because each time, we only go through one branch of the tree.
- Time complexity of sumRange = O(lgn), n is the number of elements in the given nums. Because the worst case is that the root range is [start, end], and the query range is [mid, mid + 1], mid = start + (end - start) / 2. At root level, we need search both left and right branches. But under the root level. In the left subtree each time, only search one branch, and in the right subtree, each time only search one branch. So the time complexity = O(lgn + lgn) = O(lgn). For example root range is [0, 10], and the query range is [5, 6], we can draw this segment tree to see how the search work in the tree.
Time Limited Exceeded: DP
class NumArray(object):
def __init__(self, nums):
"""
:type nums: List[int]
"""
self.total_sum = sum(nums)
self.sum_left = self.calculate_sum_left(nums)
self.sum_right = self.calculate_sum_right(nums)
def calculate_sum_left(self, nums):
if not nums:
return []
sum_left = [0 for _ in xrange(len(nums))]
sum_left[0] = nums[0]
for i in xrange(1, len(nums)):
sum_left[i] = sum_left[i - 1] + nums[i]
return sum_left
def calculate_sum_right(self, nums):
if not nums:
return []
sum_right = [0 for _ in xrange(len(nums))]
sum_right[-1] = nums[-1]
for i in xrange(len(nums) - 2, -1, -1):
sum_right[i] = sum_right[i + 1] + nums[i]
return sum_right
def update(self, i, val):
"""
:type i: int
:type val: int
:rtype: void
"""
if i < 0 or i >= len(self.sum_left):
raise ValueError('the input i is invalid')
original_val = (self.sum_left[i] + self.sum_right[i]) - self.total_sum
# update total_sum
self.total_sum = self.total_sum - original_val + val
# update sum_left
for k in xrange(i, len(self.sum_left)):
self.sum_left[k] = self.sum_left[k] - original_val + val
# update sum_right
for k in xrange(0, i + 1):
self.sum_right[k] = self.sum_right[k] - original_val + val
def sumRange(self, i, j):
"""
:type i: int
:type j: int
:rtype: int
"""
result = self.total_sum
if i - 1 >= 0:
result -= self.sum_left[i - 1]
if j + 1 < len(self.sum_right):
result -= self.sum_right[j + 1]
return result
# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(i,val)
# param_2 = obj.sumRange(i,j)
Note:
- Time complexity of initialization = O(n), n is the number of elements of the given nums.
- Time complexity of update = O(n), n is the number of elements of the given nums.
- Time complexity of sumRange = O(1).