Kitty's Calculations on a Tree

Sort by

recency

|

67 Discussions

|

  • + 0 comments

    import java.io.; import java.util.; import java.text.; import java.math.; import java.util.regex.*;

    public class Solution { static final long MOD = 1000000007L; static int n, q, LOG; static ArrayList[] tree; static int[][] up; static int[] depth; static int[] tin, tout; static int timer = 0;

    // Preprocessing DFS for tin/tout and binary lifting
    static void dfs(int v, int p) {
        tin[v] = ++timer;
        up[v][0] = p;
        for (int i = 1; i < LOG; i++) {
            int mid = up[v][i - 1];
            up[v][i] = (mid == -1) ? -1 : up[mid][i - 1];
        }
        for (int to : tree[v]) {
            if (to == p) continue;
            depth[to] = depth[v] + 1;
            dfs(to, v);
        }
        tout[v] = ++timer;
    }
    
    static boolean isAncestor(int u, int v) {
        return tin[u] <= tin[v] && tout[v] <= tout[u];
    }
    
    static int lca(int a, int b) {
        if (isAncestor(a, b)) return a;
        if (isAncestor(b, a)) return b;
        for (int i = LOG - 1; i >= 0; i--) {
            if (up[a][i] != -1 && !isAncestor(up[a][i], b)) {
                a = up[a][i];
            }
        }
        return up[a][0];
    }
    
    static int dist(int a, int b) {
        int l = lca(a, b);
        return depth[a] + depth[b] - 2 * depth[l];
    }
    
    // Build virtual tree from a list of nodes (nodes must be unique)
    // Returns list of nodes in virtual tree order (sorted by tin),
    // and fills adjacency map adj for the virtual tree edges parent -> children
    static List<Integer> buildVirtualTree(List<Integer> nodes, Map<Integer, ArrayList<Integer>> adj) {
        Collections.sort(nodes, new Comparator<Integer>() {
            public int compare(Integer a, Integer b) {
                return Integer.compare(tin[a], tin[b]);
            }
        });
    
        // add LCAs of consecutive nodes
        ArrayList<Integer> extra = new ArrayList<Integer>();
        for (int i = 0; i + 1 < nodes.size(); i++) {
            int l = lca(nodes.get(i), nodes.get(i + 1));
            extra.add(l);
        }
        for (int x : extra) nodes.add(x);
    
        // unique and sort again by tin
        Collections.sort(nodes, new Comparator<Integer>() {
            public int compare(Integer a, Integer b) {
                return Integer.compare(tin[a], tin[b]);
            }
        });
        ArrayList<Integer> vs = new ArrayList<Integer>();
        int last = -1;
        for (int x : nodes) {
            if (x != last) {
                vs.add(x);
                last = x;
            }
        }
    
        // clear adj lists for nodes (we will only fill for nodes in vs)
        for (int v : vs) {
            adj.put(v, new ArrayList<Integer>());
        }
    
        // build edges using stack
        Stack<Integer> st = new Stack<Integer>();
        st.push(vs.get(0));
        for (int i = 1; i < vs.size(); i++) {
            int v = vs.get(i);
            while (!st.isEmpty() && !isAncestor(st.peek(), v)) {
                st.pop();
            }
            int parent = st.peek();
            adj.get(parent).add(v);
            st.push(v);
        }
    
        return vs;
    }
    
    // Post-order traversal on virtual tree: returns sum of labels (node indices)
    // presentSet marks which nodes are actual query nodes (not just LCAs)
    static long dfsVirtual(int v, Map<Integer, ArrayList<Integer>> adj, Set<Integer> presentSet, long totalSum, long[] ansRef) {
        long sumSub = presentSet.contains(v) ? v : 0L;
        ArrayList<Integer> children = adj.get(v);
        if (children != null) {
            for (int c : children) {
                long childSum = dfsVirtual(c, adj, presentSet, totalSum, ansRef);
                // contribution from edge (v -> c)
                long len = depth[c] - depth[v];
                long contrib = ((len % MOD) * (childSum % MOD)) % MOD;
                contrib = (contrib * ((totalSum - childSum) % MOD + MOD) % MOD) % MOD;
                ansRef[0] = (ansRef[0] + contrib) % MOD;
                sumSub += childSum;
            }
        }
        return sumSub;
    }
    
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;
    
        st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        int Q = Integer.parseInt(st.nextToken());
        // build original tree
        tree = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++) tree[i] = new ArrayList<Integer>();
        for (int i = 0; i < n - 1; i++) {
            st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());
            tree[u].add(v);
            tree[v].add(u);
        }
    
        // LCA preprocess
        LOG = 1;
        while ((1 << LOG) <= n) LOG++;
        up = new int[n + 1][LOG];
        depth = new int[n + 1];
        tin = new int[n + 1];
        tout = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            for (int j = 0; j < LOG; j++) up[i][j] = -1;
        }
        depth[1] = 0;
        dfs(1, 1);
    
        StringBuilder out = new StringBuilder();
        for (int qi = 0; qi < Q; qi++) {
            // read k
            st = new StringTokenizer(br.readLine());
            int k = Integer.parseInt(st.nextToken());
            // read k nodes (may be on one or several lines)
            ArrayList<Integer> queryNodes = new ArrayList<Integer>(k);
            int read = 0;
            while (read < k) {
                st = new StringTokenizer(br.readLine());
                while (st.hasMoreTokens() && read < k) {
                    queryNodes.add(Integer.parseInt(st.nextToken()));
                    read++;
                }
            }
    
            // prepare
            long totalSum = 0L;
            Set<Integer> presentSet = new HashSet<Integer>(k * 2);
            for (int v : queryNodes) {
                totalSum += v;
                totalSum %= MOD;
                presentSet.add(v);
            }
    
            // Build virtual tree
            Map<Integer, ArrayList<Integer>> adj = new HashMap<Integer, ArrayList<Integer>>();
            List<Integer> vs = buildVirtualTree(queryNodes, adj);
    
            // compute answer by DFS on virtual tree from root vs[0]
            long[] ansRef = new long[1];
            ansRef[0] = 0L;
            dfsVirtual(vs.get(0), adj, presentSet, totalSum, ansRef);
    
            out.append(ansRef[0] % MOD).append('\n');
        }
    
        System.out.print(out.toString());
    }
    

    }

  • + 0 comments

    Java 8 code:

    import java.io.; import java.util.;

    public class Solution { static final long MOD = 1_000_000_007; static List[] tree; static int[] parent, depth; static int LOG; static int[][] up;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] first = br.readLine().split(" ");
        int n = Integer.parseInt(first[0]);
        int q = Integer.parseInt(first[1]);
    
        tree = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++) tree[i] = new ArrayList<>();
    
        for (int i = 0; i < n - 1; i++) {
            String[] edge = br.readLine().split(" ");
            int a = Integer.parseInt(edge[0]);
            int b = Integer.parseInt(edge[1]);
            tree[a].add(b);
            tree[b].add(a);
        }
    
        LOG = 1;
        while ((1 << LOG) <= n) LOG++;
        parent = new int[n + 1];
        depth = new int[n + 1];
        up = new int[n + 1][LOG];
    
        dfs(1, 1);
    
        // Precompute binary lifting table
        for (int j = 1; j < LOG; j++) {
            for (int i = 1; i <= n; i++) {
                up[i][j] = up[up[i][j - 1]][j - 1];
            }
        }
    
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < q; i++) {
            int k = Integer.parseInt(br.readLine());
            String[] nodesStr = br.readLine().split(" ");
            int[] nodes = new int[k];
            for (int j = 0; j < k; j++) {
                nodes[j] = Integer.parseInt(nodesStr[j]);
            }
            sb.append(processSet(nodes)).append("\n");
        }
    
        System.out.print(sb);
    }
    
    static void dfs(int v, int p) {
        parent[v] = p;
        up[v][0] = p;
        for (int nxt : tree[v]) {
            if (nxt != p) {
                depth[nxt] = depth[v] + 1;
                dfs(nxt, v);
            }
        }
    }
    
    static int lca(int a, int b) {
        if (depth[a] < depth[b]) {
            int tmp = a; a = b; b = tmp;
        }
        int diff = depth[a] - depth[b];
        for (int i = 0; i < LOG; i++) {
            if (((diff >> i) & 1) != 0) {
                a = up[a][i];
            }
        }
        if (a == b) return a;
        for (int i = LOG - 1; i >= 0; i--) {
            if (up[a][i] != up[b][i]) {
                a = up[a][i];
                b = up[b][i];
            }
        }
        return parent[a];
    }
    
    static long dist(int a, int b) {
        int l = lca(a, b);
        return depth[a] + depth[b] - 2L * depth[l];
    }
    
    static long processSet(int[] nodes) {
        if (nodes.length < 2) return 0;
        Arrays.sort(nodes);
        long res = 0;
        for (int i = 0; i < nodes.length; i++) {
            for (int j = i + 1; j < nodes.length; j++) {
                long d = dist(nodes[i], nodes[j]);
                res = (res + (nodes[i] * 1L * nodes[j] % MOD) * d % MOD) % MOD;
            }
        }
        return res;
    }
    

    }

  • + 0 comments

    Calculations on a tree involve determining aspects like height, diameter, and volume, often using mathematical formulas or measurement tools. These calculations are essential for managing tree health and planning landscaping work. When trees are removed, stump grinding services Hampton can efficiently handle leftover stumps, ensuring safety and aesthetics. Accurate tree calculations also assist arborists in diagnosing structural stability and growth potential.

  • + 0 comments

    The case 17 contains way higher number of tree nodes than specified in the question. 2*(10**5) .

    190509 1157

  • + 2 comments
    import sys, math
    
    # Pretty important line here.
    sys.setrecursionlimit(10**6)
    
    def preprocess_original_tree(node, parent):
        # "Jumping" in the current node.
        global timer
        tin[node] = timer
        timer += 1
    
        # Building a chart of the 2^i ancestors of each node.
        up[node][0] = parent
    
        # Setting each node's (which is the current child) 2^i ancestors' value.
        for i in range(1, LOG):
            if up[node][i - 1] != -1:
                up[node][i] = up[up[node][i - 1]][i - 1]
    
        # Perform DFS to set current node's children's depth recursively.
        for neighbor in adj[node]:
            if neighbor == parent:
                continue
            depth[neighbor] = depth[node] + 1
            preprocess_original_tree(neighbor, node)
    
        # "Jumping" out of the current node.
        tout[node] = timer
        timer += 1
    
    def lift_node(node, k):
        # Jumping to the greatest 2^i ancestor each time.
        for i in range(LOG - 1, -1, -1):
            if k & (1 << i):
                node = up[node][i]
        return node
    
    def get_lca(u, v):
        # Equalizing both node's depths.
        if depth[u] < depth[v]:
            u, v = v, u
        u = lift_node(u, depth[u] - depth[v])
        if u == v:
            return u
    
        # Jumping to the greatest 2^i ancestor each time.
        for i in range(LOG - 1, -1, -1):
            if up[u][i] != up[v][i]:
                u = up[u][i]
                v = up[v][i]
        return up[u][0]
    
    def get_distance(u, v):
        ancestor = get_lca(u, v)
    
        # It uses the original tree's preprocessed depths.
        return depth[u] + depth[v] - 2 * depth[ancestor]
    
    def build_virtual_tree(nodes):
        # Adding relevant nodes to virtual tree.
        nodes.sort(key=lambda x: tin[x])
        m = len(nodes)
        vt_nodes = nodes[:]
        for i in range(m - 1):
            vt_nodes.append(get_lca(nodes[i], nodes[i + 1]))
        vt_nodes = list(set(vt_nodes))
        vt_nodes.sort(key=lambda x: tin[x])
    
        # Connecting nodes in virtual tree.
        tree = {node: [] for node in vt_nodes}
        stack = []
        
        # All virtual tree nodes are stored in the order in which they were found, thus preserving their hierarchy from left to right.
        for node in vt_nodes:
            # Validating if the last ancestor in the stack is the ancestor of the current node.
            while stack and tout[stack[-1]] < tin[node]:
                stack.pop()
            if stack:
                tree[stack[-1]].append(node)
            stack.append(node)
        return tree, vt_nodes
    
    def solve_query(query_nodes):
        # Traversing query nodes (virtual tree's nodes)
        def dp(u):
            nonlocal res
            s = query_val.get(u, 0)
    
            # Performing DFS.
            for v in vt[u]:
                sub = dp(v)
                # Since 
                # sum(u in sub) * sum(v not in sub) 
                # = (sum(u in sub)) * (sum(v not in sub)) 
                # = sub * (S_total - sub)
                res = (res + sub * (S_total - sub) % MOD * get_distance(u, v)) % MOD
                s += sub
    
            # Returning the total sum of the current subtree.
            return s
    
        if len(query_nodes) < 2:
            return 0
        S_total = sum(query_nodes)
        query_val = {node: node for node in query_nodes}
        vt, vt_nodes = build_virtual_tree(query_nodes)
    
        res = 0
        dp(vt_nodes[0])
        return res
    
    MOD = 10**9 + 7
    timer = 0
    
    data = sys.stdin.read().split()
    it = iter(data)
    n = int(next(it))
    q = int(next(it))
    
    LOG = int(math.log2(n)) + 1
    up = [[-1] * LOG for _ in range(n + 1)]
    depth = [0] * (n + 1)
    tin = [0] * (n + 1)
    tout = [0] * (n + 1)
    
    # Building original tree.
    adj = [[] for _ in range(n + 1)]
    for _ in range(n - 1):
        u, v = int(next(it)), int(next(it))
        adj[u].append(v)
        adj[v].append(u)
    
    preprocess_original_tree(1, -1)
    
    res = []
    for _ in range(q):
        k = int(next(it))
        query_nodes = [int(next(it)) for _ in range(k)]
        res.append(str(solve_query(query_nodes)))
    
    sys.stdout.write("\n".join(res))