# Tree Pruning

# Tree Pruning

+ 0 comments I have the following algorithm. I run DFS on the tree and make an array, lets call it postOrder, of the nodes in a post order traversal. The root is the rightmost element of the array. I also build an array maintaining the sizes of all the subtrees, lets call it subtreeSize, such that subtreeSize[i] is the size of the subtree rooted at i. Then comes the dynamic programming part where I will be working on the postOrder array from left to right, such that root 1 is the last element to be handled. I initialize an array DP of dimensions n + 1 and k + 1 and set dp[0][j] = 0 for all j. The dp formula I use is: dp[i][j] = max { dp[i-1][j] + weight[i], dp[i-subtreeSize[i]][j-1]}. The rationale behind the formula is that either I don't remove the i-th element (dp[i-1][j] + weight[i]) or I do (dp[i-subtreeSize[i]][j-1]). Four of the test cases passes and the rest fail. When I debug my dp table at test case 3 I see that dp[n][k] = 16511421728 while the answer is said to be 16234296119, which happens to be the value inside my dp[n][k-1]. Any suggestions to what may be wrong?

EDIT: I just found out what the problem was. I had forgot one more step of the initialization part, namely to set dp[i][0] to sumweight[i-1] + weight[i]. Before I could do this initialization step I had to build the sumweight array. sumweight[i] stores the weights of all the i first nodes of the postorder array. When I added these changes to my code all the test cases passed. I had written everything down on a paper but I forgot to actually add all of it into the code. I hope this explanation will help others solve the problem.

+ 0 comments Interesting problem. I couldn't meet the time constraints for two of the tests in C++ and was banging my head against the wall. Changed a bunch of my STL containers to arrays and, though a little faster, still no luck. Then I realized I had my inner loop iterating over the first index and outer loop iterating over the second index in my 2d array. I simply switched the inner and outer loops and passed the time constraints easily. Will have to remember that going forward :)

+ 0 comments The testcases seemed to be weak and my O(n*k^2) passed all tests within the time limit.Admin,would you please add a few more stronger testcases?

+ 1 comment Is the input format correct? It has been said that the tree is rooted at vertex 1 But the sample input seems to have nodes directed to vertex 1 5 2 1 1 -1 -1 -1 1 2 2 3 4 1 4 5 As you can see below, if I construct the tree 4 will be the root Please clarify

+ 0 comments There you go, PyPy 3 solution :-)

# Enter your code here. Read input from STDIN. Print output to STDOUT import itertools class Graph: def __init__(self, weights): self.n = len(weights) self.weights = [0] + weights self.edges = [[] for _ in range(n + 1)] def diffs(seq): d = [] for i in range(len(seq) - 1): d.append(seq[i + 1] - seq[i]) return d def distributions(elems, picks): assignments = itertools.combinations_with_replacement(elems + [None], picks) def dist_for_assignment(assgn): picks_by_elem = dict(((elem, len(list(picks))) for (elem, picks) in itertools.groupby(assgn))) return dict(((elem, picks_by_elem.get(elem, 0)) for elem in elems)) return (dist_for_assignment(a) for a in assignments) def max_sum(g, k): visits = [] seen = [False for _ in range(g.n + 1)] stack = [1] while stack: v = stack.pop() visits.append(v) seen[v] = True i = 0 while i < len(g.edges[v]): nv = g.edges[v][i] if seen[nv]: g.edges[v].pop(i) else: stack.append(nv) i += 1 sums = [[] for _ in range(g.n + 1)] for v in reversed(visits): children = g.edges[v] if not children: # no edges to remove sums[v] = [g.weights[v]] else: child_sums = [sums[c] for c in children] total_cuts = sum((len(sums[c]) - 1 for c in children)) combined_sums = sums[children[0]] for child_idx in range(1, len(children)): child_sums = sums[children[child_idx]] new_combined_sums = [-(1 << 64) for _ in range(total_cuts + 1)] # best distribution of total_cuts among [:child_idx] children for cuts in range(total_cuts + 1): for cuts1 in range(cuts + 1): if cuts1 >= len(combined_sums) or cuts - cuts1 >=len(child_sums): continue new_combined_sums[cuts] = max(new_combined_sums[cuts], combined_sums[cuts1] + child_sums[cuts - cuts1]) combined_sums = new_combined_sums sums[v] = [s + g.weights[v] for s in combined_sums] if v > 1: # we could cut off this node if not root if len(sums[v]) == 1: if sums[v][0] < 0: sums[v].append(0) elif sums[v][1] < 0: sums[v][1] = 0 # more cuts for less never makes sense max_sum_idx = 1 for i in range(1, len(sums[v])): sums[v][i] = max(sums[v][i], 0) if sums[v][i] > sums[v][max_sum_idx]: max_sum_idx = i sums[v] = sums[v][:min(max_sum_idx,k) + 1] return sums[1][min(k, len(sums[1]) - 1)] n, k = map(int, input().split()) w = list(map(int, input().split())) g = Graph(w) for _ in range(n - 1): a, b = map(int, input().split()) g.edges[a].append(b) g.edges[b].append(a) print(max_sum(g, k))

Sort 26 Discussions, By:

Please Login in order to post a comment