• + 5 comments

    O(N) time and O(N) space solution, passes all test cases and does not use recursion:

    The basic idea is that we compute a minimum spanning subtree, e.g. the smallest subtree that spans all the cities that we need to deliver letters to. This can be done in O(N) time by putting all non-letter degree-1 cities in a queue, then repeatedly "pruning" them (e.g. ticking them off in a bit mask), then whenever we detect a parent that is also non-letter and has degree 1, we add it to the queue. The final resulting subtree has no leaves which are non-letter cities, and also contains all letter cities, so by definition is the "minimum spanning subtree". Note this is not related to the standard minimum spanning tree algorithm.

    Once we have a bitmask telling us which nodes are in the subtree, we find the diameter of the subtree using two Breadth first searches, as explained here (though that thread is for weight-1 edges, it is straightforward to generalize to arbitrarily weighted edges: simply keep track of all distances from the source while you are BFSing). Finally, the answer, e.g. the minimum distance needed to travel to cover the entire subtree, is just twice the total weight of that subtree minus the diameter.

    import java.util.Scanner;
    import java.util.Queue;
    import java.util.ArrayDeque;
    
    class Solution{
        static boolean[] prune(int[][][] adj, boolean[] isLett){
    	int n=adj.length;
    	int[] degree=new int[n];
    	for(int i=0;i<n;++i) degree[i]=adj[i].length;
    	boolean[] rem=new boolean[n];
    	Queue<Integer> q=new ArrayDeque<>();
    	for(int i=0;i<n;++i){
    	    if(!isLett[i] && degree[i]==1) q.add(i);
    	}
    	while(!q.isEmpty()){
    	    int leaf=q.remove();
    	    rem[leaf]=true;
    	    for(int[] edge: adj[leaf]){
    		int node=edge[0];
    		if(isLett[node]) break;
    		else if(!rem[node]){
    		    if(--degree[node] == 1){
    			q.add(node);
    			break;
    		    }
    		}
    	    }
    	}
    	return rem;
        }
        static int[] bfs(int[][][] adj, boolean[] rem, int source){
    	int n=adj.length, unvis=-1;
    	int[] dist=new int[n];
    	for(int i=0;i<n;++i) dist[i]=unvis;
    	Queue<Integer> q=new ArrayDeque<>();
    	q.add(source);
    	dist[source]=0;
    	int best=0, total=0;
    	while(!q.isEmpty()){
    	    int x=q.remove();
    	    for(int[] edge: adj[x]){
    		int to=edge[0];
    		if(!rem[to] && dist[to]==unvis){
    		    int weight=edge[1];
    		    total+=weight;
    		    q.add(to);
    		    dist[to]=dist[x]+weight;
    		    if(dist[to]>dist[best]) best=to;
    		}
    	    }
    	}
    	int[] result={total,dist[best],best};
    	return result;
        }
        static int solve(int[][][] adj, int[] lett){
    	boolean[] isLett=new boolean[adj.length];
    	for(int i: lett) isLett[i]=true;	
    	boolean[] rem=prune(adj,isLett);
    	int[] result=bfs(adj,rem,lett[0]);
    	int totalWeight=result[0], sink=result[2];
    	result=bfs(adj,rem,sink);
    	int diameter=result[1];
    	return 2*totalWeight-diameter;
        }
        static int[][][] weightedAdjacency(int n, int[] from, int[] to, int[] d){
    	int[] count=new int[n];
    	for(int f: from) ++count[f];
    	for(int t: to) ++count[t];
    	int[][][] adj=new int[n][][];
    	for(int i=0;i<n;++i) adj[i]=new int[count[i]][];
    	for(int i=0;i<from.length;++i){
    	    adj[from[i]][--count[from[i]]]=new int[]{to[i],d[i]};
    	    adj[to[i]][--count[to[i]]]=new int[]{from[i],d[i]};
    	}
    	return adj;
        }
        public static void main(String[] args){
    	Scanner sc=new Scanner(System.in);
    	int n=sc.nextInt(), k=sc.nextInt();
    	int[] lett=new int[k];
    	for(int i=0;i<k;++i) lett[i]=sc.nextInt()-1;
    	int[] from=new int[n-1], to=new int[n-1], d=new int[n-1];
    	for(int i=0;i<n-1;++i){
    	    from[i]=sc.nextInt()-1;
    	    to[i]=sc.nextInt()-1;
    	    d[i]=sc.nextInt();
    	}
    	sc.close();
    	int[][][] adj=weightedAdjacency(n,from,to,d);
    	System.out.println(solve(adj,lett));
        }
    }