same as regular mergeSort
only when left[i] > right[j]
that means that all elements in range [i, n1-1] > right[j]
so
count += n1-i;
7 8 9 9
0 1 2 3
and i = 1
so I know that 8 and all elements after it is greater than current right[j]
so increment count by (n1-1-i+1) -> (n1-i)
class Solution {
public:
int count = 0;
void merge(vector<int>& arr, int l, int m, int r) {
int n1 = m-l+1;
int n2 = r-m;
vector<int> left(n1);
vector<int> right(n2);
for (int i = l; i <= r; i++) {
if (i <= m) {
left[i-l] = arr[i];
} else {
right[i-m-1] = arr[i];
}
}
int i = 0;
int j = 0;
int k = l;
while(i < n1 && j < n2) {
if (left[i] > right[j]) {
arr[k] = right[j];
count += (n1-i);
j++;
} else {
arr[k] = left[i];
i++;
}
k++;
}
while(i < n1) {
arr[k] = left[i];
i++;
k++;
}
while(j < n2) {
arr[k] = right[j];
j++;
k++;
}
}
int mergeSort(vector<int>& arr, int l, int r) {
if (l < r) {
int m = (l+r)/2;
mergeSort(arr, l, m);
mergeSort(arr, m+1, r);
merge(arr, l, m, r);
}
return count;
}
int inversionCount(vector<int> &arr) {
return mergeSort(arr, 0, arr.size()-1);
}
};