- Merge Sort: Counting Inversions
- Discussions
Merge Sort: Counting Inversions
Merge Sort: Counting Inversions
+ 20 comments Make sure that you use a long to track the swaps! If you use an int, you'll get the wrong answer due to integer overflow in the last third or so of the cases.
+ 43 comments In case you're struggling with Python timing out on the last 3 test cases. Just switch to PyPy as the language, you can use the same code but it runs faster. Assuming you've implemented the O(n Log n) algorithm then all the tests will pass.
+ 11 comments I don't realy understand the example. Why do we have to do 4 swaps in case of this array: 2 1 3 1 2? If we swap
a[0] <-> a[3] ==> 1 1 3 2 2, and
a[2] <-> a[4] ==> 1 1 2 2 3,
we can get the sorted array.
+ 3 comments I don't understand how this problem is a mergesort. It is more of a bubblesort, as it is doing adjacent element swaps if not in order. For mergesort, while merging two halves, we compare the numbers at any given index and pick the lowest number to reconstruct the array. Which is not a swap operation explained in the problem statement. I understand mergesort has better runtime but the problem is confusing and conflicting.
+ 6 comments A great explanation by Tim Roughgarden can be found on Coursera: https://www.coursera.org/learn/algorithm-design-analysis
Check: Week 1, O(n log n) Algorithm for Counting Inversions I, II
We can avoid allocating and copying multiple arrays by using a single
aux
array of sizen
(wheren
is the size of the original array). Both arrays are switched on each recursive call.import java.io.*; import java.util.*; import java.text.*; import java.math.*; import java.util.regex.*; public class Solution { private static long countInversions(int[] arr) { int[] aux = arr.clone(); return countInversions(arr, 0, arr.length - 1, aux); } private static long countInversions(int[] arr, int lo, int hi, int[] aux) { if (lo >= hi) return 0; int mid = lo + (hi - lo) / 2; long count = 0; count += countInversions(aux, lo, mid, arr); count += countInversions(aux, mid + 1, hi, arr); count += merge(arr, lo, mid, hi, aux); return count; } private static long merge(int[] arr, int lo, int mid, int hi, int[] aux) { long count = 0; int i = lo, j = mid + 1, k = lo; while (i <= mid || j <= hi) { if (i > mid) { arr[k++] = aux[j++]; } else if (j > hi) { arr[k++] = aux[i++]; } else if (aux[i] <= aux[j]) { arr[k++] = aux[i++]; } else { arr[k++] = aux[j++]; count += mid + 1 - i; } } return count; } public static void main(String[] args) { Scanner in = new Scanner(System.in); int t = in.nextInt(); for(int a0 = 0; a0 < t; a0++){ int n = in.nextInt(); int arr[] = new int[n]; for(int arr_i=0; arr_i < n; arr_i++){ arr[arr_i] = in.nextInt(); } System.out.println(countInversions(arr)); } } }
Sort 474 Discussions, By:
Please Login in order to post a comment