SUM相关数据结构总结

BIT , SEGMENT TREE, UnionFind

Binary Indexed Tree

class BIT:
    def __init__(self,n):
        self.size = n
        self.bit = [0]*(self.size+1)
    
    def update(self,index,val):
         #index in BIT is 1 more than original index
        while index<=self.size:
            self.bit[index]+=val
            index += index&-index
    def query(self,index):
         #index in BIT is 1 more than original index
        res = 0 
        while index>0:
            res+=self.bit[index]
            index-= index&-index
        return res

#2D
class BIT:
    def __init__(self,m,n):
        self.m= m
        self.n = n
        self.bit = [[0]*(self.n+1) for _ in range(self.m+1)]
    
    def update(self,r,c,val):
        i = r
        while i<=self.m:
            j = c
            while j<=self.n:
                self.bit[i][j]+=val
                j+= j&-j
            i+= i&-i

    def query(self,r,c):
        res = 0
        i=r
        while i>0:
            j = c
            while j>0:
                res+=self.bit[i][j]
                j-= j&-j
            i-= i&-i
        return res
        

update / \ query 注意index from 1

Segment Tree

class ST:
    def __init__(self,n):
        self.size = n
        self.tree = [0]*(2*self.size)
    
    def update(self,ind,val):
        #       1
        #     2   3
        #    4 5 6 7
        #  self.tree is:
        #  @ 1 2 3 4 5 6 7
        
        #offset index by size, the leave save value node save summation
        ind+=self.size
        self.tree[ind] += val
        while ind>0:
            left=ind
            right=ind
            if ind%2==0:
                right+=1
            else:
                left-=1
            if ind//2>0:
                self.tree[ind//2]=self.tree[left]+self.tree[right]
            ind //=2


    def query(self,left,right):
        left+=self.size
        right+=self.size
        res=0
        while left<=right:
            if right%2==0:
                res+=self.tree[right]
                right-=1
            if left%2==1:
                res+=self.tree[left]
                left+=1
            left//=2
            right//=2
        return res

原始值保存在叶子节点中,所以ind是要加slef.size的

UnionFind

     class UnionFind:
        def __init__(self,n):
            self.parent=[i for i in range(n)]
            self.rank=[0]*n
            self.n=n
        def find(self,x):
            if x!=self.parent[x]:
                self.parent[x]=self.find(self.parent[x])
            return self.parent[x]
        
        def union(self,x,y):
            px=self.find(x)
            py=self.find(y)
            if px!=py:
                self.n-=1
                if self.rank[px]<self.rank[py]:
                    self.parent[px]=py
                elif self.rank[px]>self.rank[py]:
                    self.parent[py]=px
                else:
                    self.parent[py]=px
                    self.rank[px]+=1