• + 0 comments
    import sys
    
    class DisjointSet:
        def __init__(self, n):
            self.parent = list(range(n + 1))
            self.size = [1] * (n + 1)
    
        def find(self, x):
            if self.parent[x] != x:
                self.parent[x] = self.find(self.parent[x])
            return self.parent[x]
    
        def union(self, x, y):
            root_x = self.find(x)
            root_y = self.find(y)
    
            if root_x != root_y:
                # Union by size
                if self.size[root_x] < self.size[root_y]:
                    root_x, root_y = root_y, root_x
    
                self.parent[root_y] = root_x
                self.size[root_x] += self.size[root_y]
        
    def solve(N, queries):
        ds = DisjointSet(N)
        output = []
    
        for query in queries:        
            if query[0] == 'M':
                a, b = query[1], query[2]
                ds.union(a, b)
            elif query[0] == 'Q':
                x = query[1]
                output.append(ds.size[ds.find(x)])
    
        return output