Java Visitor Pattern

  • + 0 comments
    
    

    import java.io.; import java.util.;

    enum Color { RED, GREEN }

    abstract class Tree { private int value; private Color color; private int depth;

    public Tree(int value, Color color, int depth) {
        this.value = value;
        this.color = color;
        this.depth = depth;
    }
    
    public int getValue() {
        return value;
    }
    
    public Color getColor() {
        return color;
    }
    
    public int getDepth() {
        return depth;
    }
    
    public abstract void accept(TreeVis visitor);
    

    }

    class TreeNode extends Tree { private final List children = new ArrayList<>();

    public TreeNode(int value, Color color, int depth) {
        super(value, color, depth);
    }
    
    public void accept(TreeVis visitor) {
        visitor.visitNode(this);
        for (Tree child : children) {
            child.accept(visitor);
        }
    }
    
    public void addChild(Tree child) {
        children.add(child);
    }
    

    }

    class TreeLeaf extends Tree { public TreeLeaf(int value, Color color, int depth) { super(value, color, depth); }

    public void accept(TreeVis visitor) {
        visitor.visitLeaf(this);
    }
    

    }

    abstract class TreeVis { public abstract int getResult(); public abstract void visitNode(TreeNode node); public abstract void visitLeaf(TreeLeaf leaf); }

    // Visitor 1: Sum of all leaf values class SumInLeavesVisitor extends TreeVis { private int sum = 0;

    public int getResult() {
        return sum;
    }
    
    public void visitNode(TreeNode node) {}
    
    public void visitLeaf(TreeLeaf leaf) {
        sum += leaf.getValue();
    }
    

    }

    // Visitor 2: Product of red node values class ProductOfRedNodesVisitor extends TreeVis { private long product = 1; private final int MOD = 1_000_000_007;

    public int getResult() {
        return (int) product;
    }
    
    public void visitNode(TreeNode node) {
        if (node.getColor() == Color.RED) {
            product = (product * node.getValue()) % MOD;
        }
    }
    
    public void visitLeaf(TreeLeaf leaf) {
        if (leaf.getColor() == Color.RED) {
            product = (product * leaf.getValue()) % MOD;
        }
    }
    

    }

    // Visitor 3: Difference between even depth node sum and green leaf sum class FancyVisitor extends TreeVis { private int sumEvenDepthNonLeaf = 0; private int sumGreenLeaf = 0;

    public int getResult() {
        return Math.abs(sumEvenDepthNonLeaf - sumGreenLeaf);
    }
    
    public void visitNode(TreeNode node) {
        if (node.getDepth() % 2 == 0) {
            sumEvenDepthNonLeaf += node.getValue();
        }
    }
    
    public void visitLeaf(TreeLeaf leaf) {
        if (leaf.getColor() == Color.GREEN) {
            sumGreenLeaf += leaf.getValue();
        }
    }
    

    }

    public class Solution {

    private static int[] values;
    private static Color[] colors;
    private static Map<Integer, List<Integer>> treeMap = new HashMap<>();
    
    public static void main(String[] args) {
        Tree root = solve();
    
        TreeVis vis1 = new SumInLeavesVisitor();
        TreeVis vis2 = new ProductOfRedNodesVisitor();
        TreeVis vis3 = new FancyVisitor();
    
        root.accept(vis1);
        root.accept(vis2);
        root.accept(vis3);
    
        System.out.println(vis1.getResult());
        System.out.println(vis2.getResult());
        System.out.println(vis3.getResult());
    }
    
    public static Tree solve() {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
    
        values = new int[n];
        for (int i = 0; i < n; i++) {
            values[i] = sc.nextInt();
        }
    
        colors = new Color[n];
        for (int i = 0; i < n; i++) {
            colors[i] = sc.nextInt() == 0 ? Color.RED : Color.GREEN;
        }
    
        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt() - 1;
            int v = sc.nextInt() - 1;
    
            treeMap.computeIfAbsent(u, x -> new ArrayList<>()).add(v);
            treeMap.computeIfAbsent(v, x -> new ArrayList<>()).add(u);
        }
    
        return buildTree(0, 0, -1);
    }
    
    private static Tree buildTree(int node, int depth, int parent) {
        List<Integer> children = treeMap.getOrDefault(node, new ArrayList<>());
        children.remove((Integer) parent);
    
        if (children.isEmpty()) {
            return new TreeLeaf(values[node], colors[node], depth);
        } else {
            TreeNode treeNode = new TreeNode(values[node], colors[node], depth);
            for (int child : children) {
                treeNode.addChild(buildTree(child, depth + 1, node));
            }
            return treeNode;
        }
    }
    

    }