• + 4 comments

    O(N) solution in ~30 lines

    struct node{
        int64_t s;
        bool v1,v2;
        vector<int> e;
        node(int x) : v1(0),v2(0), s(x) {}
    };
    
    vector<node> t;
    unordered_set<int64_t> s,q;
    int64_t ans,sum;
    inline int64_t m(int64_t x, int64_t y){ return (y>=0?(x<y?x:y):x); }
    inline bool find(unordered_set<int64_t> &x, int64_t y) { return x.find(y)!=x.end(); }
    
    int64_t dfs(int p){
        if(t[p].v1++) return 0;
        for(int i=0;i<t[p].e.size();i++) t[p].s+=dfs(t[p].e[i]);
        return t[p].s;
    }
    
    void solve(int p){
        if(t[p].v2++) return;
        int64_t x[3] = {t[p].s<<1, (sum<<1)-(t[p].s<<2), sum-t[p].s};
        int64_t y[2] = {3*t[p].s-sum, (x[2]>>1)-t[p].s};
        for(int i=0;i<(x[2]&1?2:3);i++) 
        if(find(s,x[i]>>1)||find(q,(x[i]+x[0])>>1)) ans=m(ans,y[i>>1]);
    
        q.insert(t[p].s);
        for(int i=0;i<t[p].e.size();i++) solve(t[p].e[i]);
        q.erase(t[p].s); s.insert(t[p].s);
    }
    
    int64_t balancedForest(vector<int> &c, vector<vector<int>> &e) {
        t.clear(); s.clear();
        for(int i=0;i<c.size();i++) t.push_back(c[i]);
        for(int i=0;i<e.size();i++) t[e[i][0]].e.push_back(e[i][1]), t[e[i][1]].e.push_back(e[i][0]);
        ans=sum=dfs(0);
        solve(0);
        return (ans==sum?-1:ans);
    }
    

    Also, here's a python version which is actually even shorter:

    class Node:
        def __init__(self, c):
            self.s=c
            self.children = []
            
        def calc_sums(self):
            [x.children.remove(self) for x in self.children if self in x.children]  # make edges directed
            self.s+=sum([x.calc_sums() for x in self.children])
            return self.s
    
        def dfs(self, root, sums, root_sums, ans):
            checks = [self.s, root.s-2*self.s]
            if 2*self.s<=root.s<=3*self.s and any(x in sums or x+self.s in root_sums for x in checks):
                ans.add(3*self.s-root.s)
            if root.s>3*self.s and ((root.s-self.s)/2 in sums or (root.s+self.s)/2 in root_sums):
                ans.add((root.s-3*self.s)//2)
    
            root_sums.add(self.s)
            for i in self.children:
                i.dfs(root, sums, root_sums, ans)
            root_sums.remove(self.s)
            sums.add(self.s)
            return ans
    
    def balancedForest(c, edges):
        tree = [Node(i) for i in c]
        for i,j in edges:
            tree[i-1].children.append(tree[j-1])
            tree[j-1].children.append(tree[i-1])
        tree[0].calc_sums()
        answers = tree[0].dfs(tree[0], set(), set(), set())
        return min(answers) if answers else -1