Leetcode 2021-12-03

301. Remove Invalid Parentheses (Hard)

Given a string s that contains parentheses and letters, remove the minimum number of invalid parentheses to make the input string valid.
Return all the possible results. You may return the answer in any order.

class Solution(object):
    def removeInvalidParentheses(self, s):
        """
        :type s: str
        :rtype: List[str]
        """
        
        res=[] 
        if not s: return [s]
        visited=set()
        q=[]
        
        q.append(s)
        visited.add(s)
        
      
        
        def isValid(s):
            c=0
            for e in s:
                if e=='(':c+=1
                if e==')':c-=1
                
                if c<0: return False
            return c==0

        found=False
        
        while q:
            s=q.pop(0)
            
            if isValid(s):
                res.append(s)
                found=True
            
            if found: continue
            
            for i in range(len(s)):
                if s[i]!='(' and s[i]!=')': continue
                t=s[:i]+s[i+1:]
                
                if not t in visited:
                    q.append(t)
                    visited.add(t)
        
        return res
    
#         def valid(s):
#             stack=[]
#             ss=[ i for i in s if (i in '()') ]
#             for e in ss:
#                 if e==')':
#                     if (not stack) or (not stack.pop()=='('):
#                         return False
#                 else:
#                     stack.append(e)
#             return not stack

#ANSWER
class Solution:
    
  
    
    def removeInvalidParentheses(self, s: str) -> List[str]:
        #based on hint 
        #1) figure out how many misplaced left and right parenthesis
        left=0
        right=0
        for i,p in enumerate(s):
            if p=='(':
                left+=1
            elif p==')':
                if left>0:
                    left-=1
                else:
                    right+=1
        
        #2)recursion process 
        res =set()
        def bt(left_rem,right_rem,left,right,tmp,index):
            #left_rem is the left error position ( remained
            #right_rem is the right error position ) remained
            #left is # of ( in current expression tmp
            #right is # of ) in current expression tmp
            #tmp is current expression
            #index is the index of char in original string s
            
            #BASE CASE
            if index==len(s):     
                #print(left_rem,right_rem,left,right,tmp)
                if left_rem==0 and right_rem==0:
                    res.add("".join(tmp))
                return
            
            #discard current
            if (s[index]=='(' and left_rem>0) or (s[index]==')' and right_rem>0):
                bt(left_rem-(s[index]=='('),right_rem-(s[index]==')'),left,right,tmp,index+1)

            #add current
            tmp.append(s[index])

            #if current is not in {()}
            if s[index] not in ["(",")"]:
                bt(left_rem,right_rem,left,right,tmp,index+1)
            elif s[index]=='(':
                #consider an opening bracket
                bt(left_rem,right_rem,left+1,right,tmp,index+1)
            elif s[index]==')' and left>right:
                # consider a closing bracket
                bt(left_rem,right_rem,left,right+1,tmp,index+1)

            tmp.pop()                 

         
        bt(left,right,0,0,[],0)
        return list(res)

思路1 最快最吊炸天) 类似bfs,用queue 做,验证每一个可能的结果。但为了找到最少的变动就能valid的string,所以一旦找到remove后可以valid的string就不在queue中添加后续元素了。 2)backtracking,先计算错位的left right 括号个数。需要track的, left_remain,right_remain,left_counter,right_counter, tmp=[], index. 比较复杂。

302. Smallest Rectangle Enclosing Black Pixels (Hard)

You are given an m x n binary matrix image where 0 represents a white pixel and 1 represents a black pixel.
The black pixels are connected (i.e., there is only one black region). Pixels are connected horizontally and vertically.
Given two integers x and y that represents the location of one of the black pixels, return the area of the smallest (axis-aligned) rectangle that encloses all black pixels.
You must write an algorithm with less than O(mn) runtime complexity

class Solution:
    minx=float('inf')
    miny=float('inf')
    maxx=float('-inf')
    maxy=float('-inf')
    def minArea(self, image: List[List[str]], x: int, y: int) -> int:
        
        m=len(image)
        n=len(image[0])
        def dfs(x,y):
            self.minx=min(self.minx,x)
            self.miny=min(self.miny,y)
            self.maxx=max(self.maxx,x)
            self.maxy=max(self.maxy,y)
            image[x][y]='#'
            for xx,yy in [(x+1,y),(x-1,y),(x,y+1),(x,y-1)]:
                if xx>=0 and xx<m and yy>=0 and yy<n and image[xx][yy]=='1':
                    dfs(xx,yy)
        dfs(x,y)
        return (self.maxx-self.minx+1)*(self.maxy-self.miny+1)
#ANSWER WAY of WRITING
class Solution:
    def minArea(self, image: List[List[str]], x: int, y: int) -> int:
        m, n = len(image), len(image[0])
        def has_one(i, is_row=True):
            return any([(image[i][j] if is_row else image[j][i]) == "1" for j in range(n if is_row else m)])

        top = bisect.bisect_left(range(x+1), 1, key=lambda i: has_one(i))
        bottom = bisect.bisect_left(range(x, m), 1, key=lambda i: not has_one(i))+x
        left = bisect.bisect_left(range(y+1), 1, key=lambda i: has_one(i, False))
        right = bisect.bisect_left(range(y, n), 1, key=lambda i: not has_one(i, False))+y

        return (bottom-top)*(right-left)

除了bfs ,dfs外 答案给出了一种project 图片到1D,然后用binary search 找上下界,来算面积。

303. Range Sum Query - Immutable (Easy)

class NumArray:

    def __init__(self, nums: List[int]):
        
        for i,n in enumerate(nums):
            if i==0: continue
            nums[i]+=nums[i-1]
        self.cumsum=nums

    def sumRange(self, left: int, right: int) -> int:
        # 1 2 3
        # 1 3 6
        return self.cumsum[right]-self.cumsum[left-1] if left-1>=0 else self.cumsum[right]


# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# param_1 = obj.sumRange(left,right)

#OR ANSWER WAY OF WRITTEN
class NumArray(object):

    def __init__(self, nums):
        """
        :type nums: List[int]
        """
        
        nums = [0] + nums
        
        for i in range(1,len(nums)):
            nums[i] +=nums[i-1]
        
        self.nums=nums
     

    def sumRange(self, i, j):
        """
        :type i: int
        :type j: int
        :rtype: int
        """
        return self.nums[j+1]-self.nums[i]

304. Range Sum Query 2D - Immutable (Medium)

Given a 2D matrix matrix, handle multiple queries of the following type:
Calculate the sum of the elements of matrix inside the rectangle defined by its upper left corner (row1, col1) and lower right corner (row2, col2).

class NumMatrix:

    def __init__(self, matrix: List[List[int]]):
        # original matrix
        # 1 2 3
        # 4 5 6
        # 7 8 9
        #
        #cumsum matrix
        #             c1
        #        0 0  0    0
        #        0 1  3    6
        # row1   0 5  #12  21
        # row2   0 12 27  #45 
        #                 c2
        # do cumsum on row
        # do cumsum on col
        #
        #  row2,col2 -(row1-1)col2 - row2(col1-1) + (row1-1)(col1-1)
        self.matrix = [[0]*(len(matrix[0])+1)]
        for row in matrix:
            self.matrix.append([0]+row)
        
    
        for row in range(1,len(matrix)+1):
            for col in range(1,len(matrix[0])+1):
                self.matrix[row][col]+=self.matrix[row][col-1]
        
        for row in range(1,len(matrix)+1):
            for col in range(1,len(matrix[0])+1):
                self.matrix[row][col]+=self.matrix[row-1][col]
        
        
        
        
    def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int:
        return self.matrix[row2+1][col2+1]-self.matrix[row1][col2+1]-self.matrix[row2+1][col1]+self.matrix[row1][col1]
        


# Your NumMatrix object will be instantiated and called as such:
# obj = NumMatrix(matrix)
# param_1 = obj.sumRegion(row1,col1,row2,col2)

305. Number of Islands II (Hard)

You are given an empty 2D binary grid grid of size m x n. The grid represents a map where 0's represent water and 1's represent land. Initially, all the cells of grid are water cells (i.e., all the cells are 0's).
We may perform an add land operation which turns the water at position into a land. You are given an array positions where positions[i] = [ri, ci] is the position (ri, ci) at which we should operate the ith operation.
Return an array of integers answer where answer[i] is the number of islands after turning the cell (ri, ci) into a land.

class Solution:
    
    class unionfind:
        
        def __init__(self,m,n):
            self.m=m
            self.n=n
            self.parent=[-1]*(m*n) 
            self.rank = [0]*(m*n)
            self.counter=0
           
        def isvalid(self,row,col):
            ind=row*self.n+col
            return self.parent[ind]>=0
        
        def setparent(self,row,col):
            ind=row*self.n+col
            if self.parent[ind]==-1:
                self.parent[ind]=ind
                self.counter+=1
            
        def find(self,row,col):
            ind= row*self.n+col
            if self.parent[ind]!=ind:
                new_col = self.parent[ind]%self.n
                new_row = (self.parent[ind]-new_col)//self.n
                self.parent[ind]=self.find(new_row,new_col)
            return self.parent[ind]
            
            
        def union(self,p1,p2):
            x1,y1=p1
            x2,y2=p2
            root1=self.find(x1,y1)
            root2=self.find(x2,y2)
            if root1!=root2:
                self.counter-=1
                if self.rank[root1]>self.rank[root2]:
                    self.parent[root2]=root1
                elif self.rank[root1]<self.rank[root2]:
                    self.parent[root1]=root2
                else:
                    self.parent[root1]=root2
                    self.rank[root2]+=1
                    
        def getcount(self):
            return self.counter
         
                    
                
            
    
    def numIslands2(self, m: int, n: int, positions: List[List[int]]) -> List[int]:
        
       
        res=[]
        
        uf=self.unionfind(m,n)
    
        def get_nei(pos,uf):
            nei = []
            x,y=pos
            for xx,yy in [(x+1,y),(x-1,y),(x,y+1),(x,y-1)]:
                if xx>=0 and xx<m and yy>=0 and yy<n:
                    if uf.isvalid(xx,yy):
                        nei.append((xx,yy))
            return nei
    
        for pos in positions:
            uf.setparent(*pos)
            for nei in get_nei(pos,uf):
                uf.union(nei,pos)
            res.append(uf.getcount())
            
        return res
            

#######################MY ANSWER
class Solution:
    def numIslands2(self, m: int, n: int, positions: List[List[int]]) -> List[int]:
        class UF:
            def __init__(self,size):
                self.rank = [0]*size
                self.parent = [i for i in range(size)]
                self.c=0
            
            def index(self,i,j):
                return i*n+j
            
            def revindex(self,num):
                
                i = num//n
                j = num-i*n
                return i,j
            
            def find(self, i,j ):
                idx = self.index(i,j)
                #print(i,j,idx,self.parent[idx])
                if self.parent[idx]!=idx:
                    self.parent[idx] = self.find(*self.revindex(self.parent[idx]))
                return self.parent[idx]

            def union(self, i,j, ii,jj):
                rootA = self.find(i,j)
                rootB = self.find(ii,jj)
                if rootA!=rootB:
                    if self.rank[rootA]>self.rank[rootB]:
                        self.parent[rootB] = rootA
                    elif self.rank[rootB] > self.rank[rootA]:
                        self.parent[rootA] = rootB
                    else:
                        self.parent[rootA] = rootB
                        self.rank[rootB]+=1
                    self.c-=1
            
            def getc(self):
                return self.c
        

        uf = UF(m*n)
        res = []
        c = 0

        visited = set()
        for i, (x,y) in enumerate(positions):
            if (x,y) in visited:
                res.append(res[-1])
                continue
            
            c+=1
            for xx,yy in [(x+1,y),(x-1,y),(x,y+1),(x,y-1)]:
                if m>xx>=0 and n>yy>=0 and (xx,yy) in visited:
                    uf.union(x,y,xx,yy)
                   
                    
                 
            res.append(c+uf.getc())
            print(c+uf.getc())
            visited.add((x,y))
        return res

         


试过用neighbor_list做,但是无法分辨新加入的点是否会引起2个或3个不同island合并。 思路不正确, 感觉是用UNIONFIND做。还是不能顺畅写出UNIONFIND calss。这次的unionfind class是个变种。 注意增加相同点的时候,setparent做check,必须self.parent[i]==-1 才能做self.conter+=1 。 思路: 每次增加一个点,做setparent,counter++,找这个点所有valid neighbors,然后做union。没发现以个union则counter--。

306. Additive Number

An additive number is a string whose digits can form an additive sequence.
A valid additive sequence should contain at least three numbers. Except for the first two numbers, each subsequent number in the sequence must be the sum of the preceding two.
Given a string containing only digits, return true if it is an additive number or false otherwise.

class Solution:
    def isAdditiveNumber(self, num: str) -> bool:
        if not num or len(num)<3: 
            return False
        
        mem=dict()
        def helper(num):
            if not num: return True
            if num in mem: return mem[num]
            Found=False
            for len_a in range(1,len(num)):
                if len_a>1 and num[0]=='0': continue
                for len_b in range(1,len(num)):
                    if len_b>1 and num[len_a]=='0': continue
                    temp_flag=False
                    if len_a+len_b>=len(num):continue
                    a = num[:len_a]
                    b = num[len_a:len_a+len_b]
                    c = str(int(a)+int(b))
                    #print(a,b,c)
                    len_c= len(c)
                    if len_a+len_b+len_c>len(num):continue
                    if len_a+len_b+len_c==len(num) and num[len_a+len_b:len_a+len_b+len_c]==c:
                        return True
                    elif num[len_a+len_b:len_a+len_b+len_c]==c:
                        temp_flag = True and helper(num[len_a:])
                    Found=Found or temp_flag
            mem[num]=Found
            print(num,Found)
            return Found
        
        return helper(num)
        
#ANSWER WAY
def isAdditiveNumber(self, num):
    n = len(num)
    for i, j in itertools.combinations(range(1, n), 2):
        a, b = num[:i], num[i:j]
        if a != str(int(a)) or b != str(int(b)):
            continue
        while j < n:
            c = str(int(a) + int(b))
            if not num.startswith(c, j):
                break
            j += len(c)
            a, b = b, c
        if j == n:
            return True
    return False                

#MY SOLUTION
class Solution:
     
    def isAdditiveNumber(self, num: str) -> bool:
        
        res=False
        def helper(A,B,rest):
            nonlocal res
            if res:
                return
            AB = str(int(A)+int(B)) if not (len(A)>1 and A[0]=='0' or len(B)>1 and B[0]=='0') else None
            if AB is None: return 
            if len(AB)<=len(rest) and rest[:len(AB)]==AB:
                A = B
                B = AB
                rest = rest[len(AB):]
                if not rest:
                    res=True
                else:
                    helper(A,B,rest)  
        

        for i in range(1,len(num)):
            A = num[:i]
            for j in range(i+1,len(num)+1):
                B = num[i:j]
                rest = num[j:]
                helper(A,B,rest)

        return res

自己通过recursion with mem解决的,头部为0问题用2个continue 判断去解决。
答案大神用了itertools.combinatons i 是a的长度,j是 a+b的长度,确定了a,b可以算出来c,如果不满足条件跳出while,如果满足条件,则j+len(c),前进一位做下一次判断。 推出时候如果j==n说明找到了。

307. Range Sum Query - Mutable (Medium)

Given an integer array nums, handle multiple queries of the following types:
Update the value of an element in nums.
Calculate the sum of the elements of nums between indices left and right inclusive where left <= right.

 
#METHOD 1  SQRT DECOMPOSTION
class NumArray:

    def __init__(self, nums: List[int]):
        self.nums=nums
        self.len =  len(nums)//int(math.sqrt(len(nums)))+1 
        self.b=[0]*self.len
        for i,n in enumerate(nums):
            self.b[i//self.len]+=n
         

    def update(self, index: int, val: int) -> None:
        block_index = index//self.len
        self.b[block_index] = self.b[block_index] - self.nums[index]+val
        self.nums[index]=val

    def sumRange(self, left: int, right: int) -> int:
        res=0
        startblock=left//self.len
        endblock=right//self.len
        if startblock==endblock:
            for ind in range(left,right+1):
                res+=self.nums[ind]
        else:
            for ind in range(left,(startblock+1)*self.len):
                res+=self.nums[ind]
            for block_ind in range(startblock+1,endblock):
                res+=self.b[block_ind]
            for ind in range(endblock*self.len,right+1):
                res+=self.nums[ind]
        return res
        
# METHOD 2 SEGMENT TREE
class ST:
    def __init__(self,n):
        self.size = n
        self.tree = [0]*(2*self.size)
    
    def update(self,ind,val):
        #       1
        #     2   3
        #    4 5 6 7
        #  self.tree is:
        #  @ 1 2 3 4 5 6 7
        
        #offset index by size, the leave save value node save summation
        ind+=self.size
        self.tree[ind] += val
        while ind>0:
            left=ind
            right=ind
            if ind%2==0:
                right+=1
            else:
                left-=1
            if ind//2>0:
                self.tree[ind//2]=self.tree[left]+self.tree[right]
            ind //=2


    def query(self,left,right):
        left+=self.size
        right+=self.size
        res=0
        while left<=right:
            if right%2==0:
                res+=self.tree[right]
                right-=1
            if left%2==1:
                res+=self.tree[left]
                left+=1
            left//=2
            right//=2
        return res
class NumArray:

    def __init__(self, nums: List[int]):
        self.st = ST(len(nums))
        for i,n in enumerate(nums):
            self.st.update(i,n)
        
    def update(self, index: int, val: int) -> None:
        old_val = self.st.query(index,index)
        self.st.update(index,-old_val+val)

    def sumRange(self, left: int, right: int) -> int:
        return self.st.query(left,right)

        

#MY ANSWER
class BIT:
    def __init__(self,n):
        self.size = n
        self.bit = [0]*(self.size+1)
    
    def update(self,index,val):
        while index<=self.size:
            self.bit[index]+=val
            index += index&-index
    def query(self,index):
        res = 0 
        while index>0:
            res+=self.bit[index]
            index-= index&-index
        return res

class NumArray:

    def __init__(self, nums: List[int]):
        self.bit = BIT(len(nums))
        for i,n in enumerate(nums):
            self.bit.update(i+1,n)
        
    def update(self, index: int, val: int) -> None:
        ind = index+1
        old_val = self.bit.query(ind)-self.bit.query(ind-1)
        self.bit.update(ind,-old_val+val)

    def sumRange(self, left: int, right: int) -> int:
        return self.bit.query(right+1)-self.bit.query(left)


# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(index,val)
# param_2 = obj.sumRange(left,right)

思路1) Sqrt Decomposition,把nums分成 sqrt(len)个block求和。 思路2) Segment Tree。 这个没见过,是最优解。 构造方法很巧妙。可能需要记住.

308. Range Sum Query 2D - Mutable (Hard)

1D to 2D of rang sum Query

class NumMatrix:

    def __init__(self, matrix: List[List[int]]):
        self.matrix=matrix
        self.m=len(matrix)
        self.n=len(matrix[0])
        self.trees=[self.buildtree(nums) for nums in matrix]
        
        #for tree in self.trees:
        #    print(tree)
    
    def buildtree(self,nums):
        #   1
        #  2  3
        # 4 5 6 7
        # 1 2 3 4 5 6 7
        tree=['#']*self.n*2
        for i,n in enumerate(nums):
            ind=i+self.n
            tree[ind]=n
        
        for ind in range(self.n-1,0,-1):
            tree[ind]=tree[2*ind]+tree[2*ind+1]
        return tree
    
    def update(self, row: int, col: int, val: int) -> None:
        tree=self.trees[row]
        #   1
        #  2  3
        # 4 5 6 7
        # 1   2  3  4 5 6 7
        # 22  9  13 4 5 6 7
        ind = self.n+col
        self.matrix[row][col]=val
        tree[ind]=val
        while ind>0:
            left=ind
            right=ind
            if ind%2==0:
                right+=1
            else:
                left-=1
            if ind//2>0:
                tree[ind//2]=tree[left]+tree[right]
            ind=ind//2
        self.trees[row]=tree
      

    def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int:
        res = 0
        for row in range(row1,row2+1):
            tree=self.trees[row]
            #   1
            #  2  3
            # 4 5 6 7
            # 1   2  3  4 5 6 7
            # 22  9  13 4 5 6 7
            l=col1+self.n
            r=col2+self.n
            while l<=r:
                if l%2==1:
                    res+=tree[l]
                    l+=1
                if r%2==0:
                    res+=tree[r]
                    r-=1
                l//=2
                r//=2
        return res
            
# Your NumMatrix object will be instantiated and called as such:
# obj = NumMatrix(matrix)
# obj.update(row,col,val)
# param_2 = obj.sumRegion(row1,col1,row2,col2)

#MY ANSWER 。。。。。。。。。。。。。。
class BIT:
    def __init__(self,m,n):
        self.m= m
        self.n = n
        self.bit = [[0]*(self.n+1) for _ in range(self.m+1)]
    
    def update(self,r,c,val):
        i = r
        while i<=self.m:
            j = c
            while j<=self.n:
                self.bit[i][j]+=val
                j+= j&-j
            i+= i&-i

    def query(self,r,c):
        res = 0
        i=r
        while i>0:
            j = c
            while j>0:
                res+=self.bit[i][j]
                j-= j&-j
            i-= i&-i
        return res
        

class NumMatrix:

    def __init__(self, matrix: List[List[int]]):
        m = len(matrix)
        n = len(matrix[0])
        self.bit = BIT(m,n)
        for i in range(m):
            for j in range(n):
                ii = i+1
                jj = j+1
                self.bit.update(ii,jj,matrix[i][j])
        
    def update(self, row: int, col: int, val: int) -> None:
        i = row+1
        j = col+1
        old_val = self.sumRegion(row,col,row,col)
        self.bit.update(i,j,-old_val+val)

    def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int:
        i1=row1+1
        j1=col1+1
        i2=row2+1
        j2=col2+1
        a = self.bit.query(i2,j2)
        b = self.bit.query(i1-1,j1-1)
        c = self.bit.query(i2,j1-1)
        d = self.bit.query(i1-1,j2)
        return a-c-d+b

自己套用之前的1d segment tree用在2d问题,可以pass。 但不是最优解。 答案用了 Fenwick Tree (Binary Index Tree) https://www.youtube.com/watch?v=uSFzHCZ4E-8 没见过。而且最终还是用的2D版本的 Binary Index Tree。bit操作下,X,Y不影响,bit位置是bit,保存的值是partial sum。

309. Best Time to Buy and Sell Stock with Cooldown (Medium)

You are given an array prices where prices[i] is the price of a given stock on the ith day.
Find the maximum profit you can achieve. You may complete as many transactions as you like (i.e., buy one and sell one share of the stock multiple times) with the following restrictions:
After you sell your stock, you cannot buy stock on the next day (i.e., cooldown one day).

class Solution:
    def maxProfit(self, prices: List[int]) -> int: 
        
    
        """
        state machine  
                  O held    -sell->  sold 
                    /\                |
                    |_ buy_ O reset <-rest
       
       O is self loop
       
       DP:
         sold[i]=held[i-1]+price[i]
         held[i]=max(held[i-1],reset[i-1]-price[i])
         reset[i]=max(rest[i-1],sold[i-1])
         
         return max(reset[n],sold[n]
         """
        sold, held, reset = float('-inf'), float('-inf'), 0
        for price in prices:
            pre_sold = sold
            sold = held + price               #sell operation
            held = max(held, reset - price)   #buy operation
            reset = max(reset, pre_sold)      # do nothing operation
            
        return max(sold, reset)
 
        #309 直接看答案了
        # profit on ith day kth transanction 0 means 0 stock at our hand after the end of ith day, 1 means 1 stock at our hand after the end of ith day
        #T[i][k][0] = max(T[i-1][k][0], T[i-1][k][1] + prices[i]) #sell
        #T[i][k][1] = max(T[i-1][k][1], T[i-2][k][0] - prices[i])  # buy 
        #Base cases:
        # T[-1][k][0] = 0, T[-1][k][1] = -Infinity
        # T[i][0][0] = 0, T[i][0][1] = -Infinity
        #constain to cooldown
        
        T_ik0=0
        T_ik0_pre=0
        T_ik1=-float('inf')
        
        for p in prices:
            T_ik0_old=T_ik0
            T_ik0 = max(T_ik0, T_ik1+p)
            T_ik1 = max(T_ik1, T_ik0_pre-p)
            T_ik0_pre=T_ik0_old
        
        return T_ik0


这是一类问题,很多变种。。。。https://leetcode.com/problems/best-time-to-buy-and-sell-stock-with-cooldown/discuss/75924/Most-consistent-ways-of-dealing-with-the-series-of-stock-problems 答案的state machine 思维方法很有意思。

310. Minimum Height Trees (Medium)

A tree is an undirected graph in which any two vertices are connected by exactly one path. In other words, any connected graph without simple cycles is a tree.

Given a tree of n nodes labelled from 0 to n - 1, and an array of n - 1 edges where edges[i] = [ai, bi] indicates that there is an undirected edge between the two nodes ai and bi in the tree, you can choose any node of the tree as the root. When you select a node x as the root, the result tree has height h. Among all possible rooted trees, those with minimum height (i.e. min(h)) are called minimum height trees (MHTs).

Return a list of all MHTs' root labels. You can return the answer in any order.

class Solution:
    def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
        #remove those leafves what left is root
        dic= {i:[] for i in range(n)}
        q = []
        for n1,n2 in edges:
            dic[n1].append(n2)
            dic[n2].append(n1)
        for i in range(n):
            if len(dic[i])==1:
                q.append(i)
        
        res = q[:]
        while q:
            l = len(q)
            res = q[:]
            for _ in range(l):
                cur = q.pop(0)
                for nei in dic[cur]:
                    dic[nei].remove(cur)
                    if len(dic[nei])==1:
                        q.append(nei)
        
        return res if res else [0]
    

看一下hint才做出来,思路topological sort,leaves 直接砍掉。