Kitty's Calculations on a Tree

  • + 0 comments

    Java 8 code:

    import java.io.; import java.util.;

    public class Solution { static final long MOD = 1_000_000_007; static List[] tree; static int[] parent, depth; static int LOG; static int[][] up;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] first = br.readLine().split(" ");
        int n = Integer.parseInt(first[0]);
        int q = Integer.parseInt(first[1]);
    
        tree = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++) tree[i] = new ArrayList<>();
    
        for (int i = 0; i < n - 1; i++) {
            String[] edge = br.readLine().split(" ");
            int a = Integer.parseInt(edge[0]);
            int b = Integer.parseInt(edge[1]);
            tree[a].add(b);
            tree[b].add(a);
        }
    
        LOG = 1;
        while ((1 << LOG) <= n) LOG++;
        parent = new int[n + 1];
        depth = new int[n + 1];
        up = new int[n + 1][LOG];
    
        dfs(1, 1);
    
        // Precompute binary lifting table
        for (int j = 1; j < LOG; j++) {
            for (int i = 1; i <= n; i++) {
                up[i][j] = up[up[i][j - 1]][j - 1];
            }
        }
    
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < q; i++) {
            int k = Integer.parseInt(br.readLine());
            String[] nodesStr = br.readLine().split(" ");
            int[] nodes = new int[k];
            for (int j = 0; j < k; j++) {
                nodes[j] = Integer.parseInt(nodesStr[j]);
            }
            sb.append(processSet(nodes)).append("\n");
        }
    
        System.out.print(sb);
    }
    
    static void dfs(int v, int p) {
        parent[v] = p;
        up[v][0] = p;
        for (int nxt : tree[v]) {
            if (nxt != p) {
                depth[nxt] = depth[v] + 1;
                dfs(nxt, v);
            }
        }
    }
    
    static int lca(int a, int b) {
        if (depth[a] < depth[b]) {
            int tmp = a; a = b; b = tmp;
        }
        int diff = depth[a] - depth[b];
        for (int i = 0; i < LOG; i++) {
            if (((diff >> i) & 1) != 0) {
                a = up[a][i];
            }
        }
        if (a == b) return a;
        for (int i = LOG - 1; i >= 0; i--) {
            if (up[a][i] != up[b][i]) {
                a = up[a][i];
                b = up[b][i];
            }
        }
        return parent[a];
    }
    
    static long dist(int a, int b) {
        int l = lca(a, b);
        return depth[a] + depth[b] - 2L * depth[l];
    }
    
    static long processSet(int[] nodes) {
        if (nodes.length < 2) return 0;
        Arrays.sort(nodes);
        long res = 0;
        for (int i = 0; i < nodes.length; i++) {
            for (int j = i + 1; j < nodes.length; j++) {
                long d = dist(nodes[i], nodes[j]);
                res = (res + (nodes[i] * 1L * nodes[j] % MOD) * d % MOD) % MOD;
            }
        }
        return res;
    }
    

    }