Kitty's Calculations on a Tree

  • + 0 comments

    import java.io.; import java.util.; import java.text.; import java.math.; import java.util.regex.*;

    public class Solution { static final long MOD = 1000000007L; static int n, q, LOG; static ArrayList[] tree; static int[][] up; static int[] depth; static int[] tin, tout; static int timer = 0;

    // Preprocessing DFS for tin/tout and binary lifting
    static void dfs(int v, int p) {
        tin[v] = ++timer;
        up[v][0] = p;
        for (int i = 1; i < LOG; i++) {
            int mid = up[v][i - 1];
            up[v][i] = (mid == -1) ? -1 : up[mid][i - 1];
        }
        for (int to : tree[v]) {
            if (to == p) continue;
            depth[to] = depth[v] + 1;
            dfs(to, v);
        }
        tout[v] = ++timer;
    }
    
    static boolean isAncestor(int u, int v) {
        return tin[u] <= tin[v] && tout[v] <= tout[u];
    }
    
    static int lca(int a, int b) {
        if (isAncestor(a, b)) return a;
        if (isAncestor(b, a)) return b;
        for (int i = LOG - 1; i >= 0; i--) {
            if (up[a][i] != -1 && !isAncestor(up[a][i], b)) {
                a = up[a][i];
            }
        }
        return up[a][0];
    }
    
    static int dist(int a, int b) {
        int l = lca(a, b);
        return depth[a] + depth[b] - 2 * depth[l];
    }
    
    // Build virtual tree from a list of nodes (nodes must be unique)
    // Returns list of nodes in virtual tree order (sorted by tin),
    // and fills adjacency map adj for the virtual tree edges parent -> children
    static List<Integer> buildVirtualTree(List<Integer> nodes, Map<Integer, ArrayList<Integer>> adj) {
        Collections.sort(nodes, new Comparator<Integer>() {
            public int compare(Integer a, Integer b) {
                return Integer.compare(tin[a], tin[b]);
            }
        });
    
        // add LCAs of consecutive nodes
        ArrayList<Integer> extra = new ArrayList<Integer>();
        for (int i = 0; i + 1 < nodes.size(); i++) {
            int l = lca(nodes.get(i), nodes.get(i + 1));
            extra.add(l);
        }
        for (int x : extra) nodes.add(x);
    
        // unique and sort again by tin
        Collections.sort(nodes, new Comparator<Integer>() {
            public int compare(Integer a, Integer b) {
                return Integer.compare(tin[a], tin[b]);
            }
        });
        ArrayList<Integer> vs = new ArrayList<Integer>();
        int last = -1;
        for (int x : nodes) {
            if (x != last) {
                vs.add(x);
                last = x;
            }
        }
    
        // clear adj lists for nodes (we will only fill for nodes in vs)
        for (int v : vs) {
            adj.put(v, new ArrayList<Integer>());
        }
    
        // build edges using stack
        Stack<Integer> st = new Stack<Integer>();
        st.push(vs.get(0));
        for (int i = 1; i < vs.size(); i++) {
            int v = vs.get(i);
            while (!st.isEmpty() && !isAncestor(st.peek(), v)) {
                st.pop();
            }
            int parent = st.peek();
            adj.get(parent).add(v);
            st.push(v);
        }
    
        return vs;
    }
    
    // Post-order traversal on virtual tree: returns sum of labels (node indices)
    // presentSet marks which nodes are actual query nodes (not just LCAs)
    static long dfsVirtual(int v, Map<Integer, ArrayList<Integer>> adj, Set<Integer> presentSet, long totalSum, long[] ansRef) {
        long sumSub = presentSet.contains(v) ? v : 0L;
        ArrayList<Integer> children = adj.get(v);
        if (children != null) {
            for (int c : children) {
                long childSum = dfsVirtual(c, adj, presentSet, totalSum, ansRef);
                // contribution from edge (v -> c)
                long len = depth[c] - depth[v];
                long contrib = ((len % MOD) * (childSum % MOD)) % MOD;
                contrib = (contrib * ((totalSum - childSum) % MOD + MOD) % MOD) % MOD;
                ansRef[0] = (ansRef[0] + contrib) % MOD;
                sumSub += childSum;
            }
        }
        return sumSub;
    }
    
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;
    
        st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        int Q = Integer.parseInt(st.nextToken());
        // build original tree
        tree = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++) tree[i] = new ArrayList<Integer>();
        for (int i = 0; i < n - 1; i++) {
            st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());
            tree[u].add(v);
            tree[v].add(u);
        }
    
        // LCA preprocess
        LOG = 1;
        while ((1 << LOG) <= n) LOG++;
        up = new int[n + 1][LOG];
        depth = new int[n + 1];
        tin = new int[n + 1];
        tout = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            for (int j = 0; j < LOG; j++) up[i][j] = -1;
        }
        depth[1] = 0;
        dfs(1, 1);
    
        StringBuilder out = new StringBuilder();
        for (int qi = 0; qi < Q; qi++) {
            // read k
            st = new StringTokenizer(br.readLine());
            int k = Integer.parseInt(st.nextToken());
            // read k nodes (may be on one or several lines)
            ArrayList<Integer> queryNodes = new ArrayList<Integer>(k);
            int read = 0;
            while (read < k) {
                st = new StringTokenizer(br.readLine());
                while (st.hasMoreTokens() && read < k) {
                    queryNodes.add(Integer.parseInt(st.nextToken()));
                    read++;
                }
            }
    
            // prepare
            long totalSum = 0L;
            Set<Integer> presentSet = new HashSet<Integer>(k * 2);
            for (int v : queryNodes) {
                totalSum += v;
                totalSum %= MOD;
                presentSet.add(v);
            }
    
            // Build virtual tree
            Map<Integer, ArrayList<Integer>> adj = new HashMap<Integer, ArrayList<Integer>>();
            List<Integer> vs = buildVirtualTree(queryNodes, adj);
    
            // compute answer by DFS on virtual tree from root vs[0]
            long[] ansRef = new long[1];
            ansRef[0] = 0L;
            dfsVirtual(vs.get(0), adj, presentSet, totalSum, ansRef);
    
            out.append(ansRef[0] % MOD).append('\n');
        }
    
        System.out.print(out.toString());
    }
    

    }