Java Visitor Pattern

  • + 0 comments

    I made a slightly shorter version, but your answer was definitely pointing me to the right direction, so thank you! :) This challenge was way over the top for the original task.

    class SumInLeavesVisitor extends TreeVis {
        private int result;
    
        public int getResult() {
            return result;
        }
    
        public void visitNode(TreeNode node) {
            // Do nothing here
        }
    
        public void visitLeaf(TreeLeaf leaf) {
            result += leaf.getValue();
        }
    }
    
    class ProductOfRedNodesVisitor extends TreeVis {
        private static int M = 1000000007;
    
        private long result = 1;
    
        public int getResult() {
            return (int)result;
        }
    
        public void visitNode(TreeNode node) {
            if (node.getColor() == Color.RED) {
                result = (result * node.getValue()) % M;
            }
        }
    
        public void visitLeaf(TreeLeaf leaf) {
            if (leaf.getColor() == Color.RED) {
                result = (result * leaf.getValue()) % M;
            }
        }
    }
    
    class FancyVisitor extends TreeVis {
    
        int greenLeafSum = 0;
        int evenNodeSum = 0;
    
        public int getResult() {
            return Math.abs(evenNodeSum - greenLeafSum);
        }
    
        public void visitNode(TreeNode node) {
            if (node.getDepth() % 2 == 0) {
                evenNodeSum += node.getValue();
            }
        }
    
        public void visitLeaf(TreeLeaf leaf) {
            if (leaf.getColor() == Color.GREEN) {
                greenLeafSum += leaf.getValue();
            }
        }
    }
    
    public class Solution {
    
        public static Tree solve() {
            Scanner scanner = new Scanner(System.in);
            int numNodes = scanner.nextInt();
    
            int[] values = new int[numNodes];
            for (int i=0; i<numNodes; i++) {
                values[i] = scanner.nextInt();   
            }
    
            Color[] colors = new Color[numNodes];
            for (int i=0; i<numNodes; i++) {
                colors[i] = (scanner.nextInt() == 0) ? Color.RED : Color.GREEN;   
            }
    
            HashMap<Integer, HashSet<Integer>> map = new HashMap<Integer, HashSet<Integer>>();
            for (int i=0; i<numNodes-1; i++) {
                int u = scanner.nextInt()-1;
                int v = scanner.nextInt()-1;
    
                HashSet<Integer> uNeighbors = map.get(u);
                if (uNeighbors == null) {
                    uNeighbors = new HashSet<>();
                    map.put(u, uNeighbors);
                }
                uNeighbors.add(v);
    
                HashSet<Integer> vNeighbors = map.get(v);
                if (vNeighbors == null) {
                    vNeighbors = new HashSet<>();
                    map.put(v, vNeighbors);
                }
                vNeighbors.add(u);
            }
    
            // Construct the root
            Tree root = createSubTree(0, 0, map, values, colors);
            return root;
        }
    
        private static Tree createSubTree(int nodeIdx, int depth, HashMap<Integer, HashSet<Integer>> map, int[] values, Color[] colors) {
            HashSet<Integer> neighbors = map.get(nodeIdx);
            if (neighbors.isEmpty()) {
                return new TreeLeaf(values[nodeIdx], colors[nodeIdx], depth);
            } else {
                TreeNode node = new TreeNode(values[nodeIdx], colors[nodeIdx], depth);
                for (int neighbor: neighbors) {
                    HashSet<Integer> neighboursOfNeighbour = map.get(neighbor);
                    neighboursOfNeighbour.remove(nodeIdx); // Remove the backward edge, so only childs remain
                    node.addChild(createSubTree(neighbor, depth+1, map, values, colors));
                }
                return node;
            }
        }