Java Visitor Pattern

  • + 0 comments

    class SumInLeavesVisitor extends TreeVis { private int sum = 0;

    public int getResult() {
        return sum;
    }
    
    public void visitNode(TreeNode node) {
        // Do nothing for non-leaf nodes
    }
    
    public void visitLeaf(TreeLeaf leaf) {
        sum += leaf.getValue();
    }
    

    }

    class ProductOfRedNodesVisitor extends TreeVis { private long product = 1; private final int MOD = 1000000007;

    public int getResult() {
        return (int) (product % MOD);
    }
    
    public void visitNode(TreeNode node) {
        if (node.getColor() == Color.RED) {
            product = (product * node.getValue()) % MOD;
        }
    }
    
    public void visitLeaf(TreeLeaf leaf) {
        if (leaf.getColor() == Color.RED) {
            product = (product * leaf.getValue()) % MOD;
        }
    }
    

    }

    class FancyVisitor extends TreeVis { private int sumNonLeafEvenDepth = 0; private int sumGreenLeaves = 0;

    public int getResult() {
        return Math.abs(sumNonLeafEvenDepth - sumGreenLeaves);
    }
    
    public void visitNode(TreeNode node) {
        if (node.getDepth() % 2 == 0) {
            sumNonLeafEvenDepth += node.getValue();
        }
    }
    
    public void visitLeaf(TreeLeaf leaf) {
        if (leaf.getColor() == Color.GREEN) {
            sumGreenLeaves += leaf.getValue();
        }
    }
    

    }

    public class Solution { private static int[] values; private static Color[] colors; private static Map> edges = new HashMap<>();

    public static Tree solve() {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt();
    
        values = new int[n];
        colors = new Color[n];
    
        for (int i = 0; i < n; i++) {
            values[i] = scanner.nextInt();
        }
    
        for (int i = 0; i < n; i++) {
            colors[i] = (scanner.nextInt() == 0) ? Color.RED : Color.GREEN;
        }
    
        for (int i = 0; i < n - 1; i++) {
            int u = scanner.nextInt() - 1;
            int v = scanner.nextInt() - 1;
    
            if (!edges.containsKey(u)) {
                edges.put(u, new HashSet<Integer>());
            }
            edges.get(u).add(v);
    
            if (!edges.containsKey(v)) {
                edges.put(v, new HashSet<Integer>());
            }
            edges.get(v).add(u);
        }
        scanner.close();
    
        return buildTree(0, 0);
    }
    
    private static Tree buildTree(int nodeIndex, int depth) {
        Set<Integer> neighbors = edges.get(nodeIndex);
    
        if (neighbors.isEmpty()) {
            return new TreeLeaf(values[nodeIndex], colors[nodeIndex], depth);
        }
    
        TreeNode node = new TreeNode(values[nodeIndex], colors[nodeIndex], depth);
    
        for (int neighbor : neighbors) {
            edges.get(neighbor).remove(nodeIndex);
            node.addChild(buildTree(neighbor, depth + 1));
        }
    
        return node;
    }