Understanding Merge Sort
Merge Sort is a divide and conquer algorithm. It works by recursively dividing the input array into smaller subarrays, sorting them individually, and then merging the sorted subarrays to produce the final sorted output. It’s time complexity of .
Visual Example
graph TD
A(27, 10, 12, 20, 25, 13, 15, 22) --> B1(27, 10, 12, 20)
A --> B2(25, 13, 15, 22)
B1 --> C1(27, 10)
B1 --> C2(12, 20)
B2 --> C3(25, 13)
B2 --> C4(15, 22)
C1 --> D1(27)
C1 --> D2(10)
C2 --> D3(12)
C2 --> D4(20)
C3 --> D5(25)
C3 --> D6(13)
C4 --> D7(15)
C4 --> D8(22)
D1 --> E1(10, 27)
D2 --> E1
D3 --> E2(12, 20)
D4 --> E2
D5 --> E3(13, 25)
D6 --> E3
D7 --> E4(15, 22)
D8 --> E4
E1 --> F1(10, 12, 20, 27)
E2 --> F1
E3 --> F2(13, 15, 22, 25)
E4 --> F2
F1 --> G(10, 12, 13, 15, 20, 22, 25, 27)
F2 --> G
The Algorithm
void mergeSort(std::vector<int> &list, int start, int end) {
if (start >= end) return ;
int mid = start + (end - start) / 2;
mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
}
Recursion Explanation
Consider the following list:
In the mergeSort()
function, we are going to pass this list alongside with the indices for the start and the end of the list.
mergeSort(list, 0, arr.size() - 1);
In the case of our list , the initial values for start
and end
are 0
and 7
respectively.
if (start >= end) return ;
int mid = (start + end) / 2;
mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
Let’s go now step by step while keeping track of the variables:
We are going to call this layer as (B for base):
start = 0
end = 3
mid = 1
if (start >= end) return ;
int mid = (start + end) / 2;
> mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
graph LR
B1("B1: start=0, end=3, mid=1")
L1("L1: start=0, end=1")
R1("R1")
B1 --calls--> L1
B1 -.-> R1
This is the layer that is responsible for sorting the left portion of the original array. Let’s call this layer :
start = 0
end = 1
mid = 0
if (start >= end) return ;
int mid = (start + end) / 2;
> mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
’s first call for mergeSort()
immediatly returns as :
start = 0
end = 0
> if (start >= end) return ;
int mid = (start + end) / 2;
mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
graph LR
L1("L1: start=0, end=1, mid=0")
L1_1("L1_L: start=0, end=0 (Base Case)")
L1 --calls--> L1_1
L1_1 --returns--> L1
Now, we are back to , it calls mergeSort()
for the second time:
start = 0
end = 1
mid = 0
if (start >= end) return ;
int mid = (start + end) / 2;
mergeSort(list, start, mid);
> mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
’s second call for mergeSort()
immediatly returns as :
start = 1
end = 1
> if (start >= end) return ;
int mid = (start + end) / 2;
mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
graph LR
L1("L1: start=0, end=1, mid=0")
L1_2("L1_R: start=1, end=1 (Base Case)")
L1 --calls--> L1_2
L1_2 --returns--> L1
Back to , the merge()
function gets called, sorting the left portion of the original list:
start = 0
end = 1
mid = 0
if (start >= end) return ;
int mid = (start + end) / 2;
mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
> merge(list, start, mid, end);
graph LR
L1("L1: start=0, end=1, mid=0")
L1_1("L1_L: start=0, end=0 (Base Case)")
L1_2("L1_R: start=1, end=1 (Base Case)")
L1_Merge("L1_Merge: start=0, end=1, mid=0 (Merge)")
L1 --calls--> L1_1
L1 --calls--> L1_2
L1 --calls--> L1_Merge
L1_1 --returns--> L1
L1_2 --returns--> L1
L1_Merge --returns--> L1
After running merge()
, the list gets sorted on the first segment:
List Before:
List After:
After merging, returns to that consequentially calls :
start = 0
end = 3
mid = 1
if (start >= end) return ;
int mid = (start + end) / 2;
mergeSort(list, start, mid);
> mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
graph LR
B1("B1: start=0, end=3, mid=1")
L1("L1: start=0, end=1")
R1("R1")
B1 --calls--> L1
L1 --returns--> B1
B1 ==calls==> R1
calls mergeSort()
to sort its left portion:
start = 2
end = 3
mid = 2
if (start >= end) return ;
int mid = (start + end) / 2;
> mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
It immediatly returns as :
start = 2
end = 2
> if (start >= end) return ;
int mid = (start + end) / 2;
mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
graph LR
R1("R1: start=2, end=3, mid=2")
R1_L("R1_L: start=2, end=2 (Base Case)")
R1 --calls--> R1_L
R1_L --returns--> R1
calls mergeSort()
to sort its right portion:
start = 2
end = 3
mid = 2
if (start >= end) return ;
int mid = (start + end) / 2;
mergeSort(list, start, mid);
> mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
It also immediatly returns as :
start = 3
end = 3
> if (start >= end) return ;
int mid = (start + end) / 2;
mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
merge(list, start, mid, end);
graph LR
R1("R1: start=2, end=3, mid=2")
R1_R["R1_R: start=3, end=3 (Base Case)"]
R1 --calls--> R1_R
R1_R --returns--> R1
calls merge()
to sort the right portion of the original list:
start = 2
end = 3
mid = 2
if (start >= end) return ;
int mid = (start + end) / 2;
mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
> merge(list, start, mid, end);
graph LR
R1("R1: start=2, end=3, mid=2")
R1_1("R1_1: start=2, end=2 (Base Case)")
R1_2("R1_2: start=3, end=3 (Base Case)")
R1_Merge("R1_Merge: start=2, end=3, mid=2 (Merge)")
R1 --calls--> R1_1
R1_1 --returns--> R1
R1 --calls--> R1_2
R1_2 --returns--> R1
R1 --calls--> R1_Merge
R1_Merge --returns--> R1
List Before:
List After:
Now, back to , the merge()
function gets called:
start = 0
end = 3
mid = 1
if (start >= end) return ;
int mid = (start + end) / 2;
mergeSort(list, start, mid);
mergeSort(list, mid + 1, end);
> merge(list, start, mid, end);
The whole list gets sorted.
graph LR
B1_Merge("B1_Merge: start=0, end=3, mid=1 (Merge)")
L1_Sorted("L1_Sorted: start=0, end=1 (Sorted: [10, 27])")
R1_Sorted("R1_Sorted: start=2, end=3 (Sorted: [12, 20])")
B1_Merge --> L1_Sorted
B1_Merge --> R1_Sorted
List Before:
List After:
How does the merge function work?
To understand how the merge()
function works, let’s consider the following unsorted list and merge it.
start = 0
end = 3
mid = 1
First, two vectors are created. The left
vector contains the values , while the right
vector contains the values .
std::vector<int> left(list.begin() + start, list.begin() + mid + 1);
std::vector<int> right(list.begin() + mid + 1, list.begin() + end + 1);
- 10 gets compared with 12, as , 10 becomes the first element of the list.
i
gets incremented and 27 gets compared with 12. As , theelse
block gets executed, inserting 12 into the list and incrementingj
.- Now, 27 gets compared with 20. As , the
else
block gets executed again, putting 20 into the list and incrementingj
.
while (i < left.size() && j < right.size()) {
if (left[i] <= right[j])
list[k++] = left[i++];
else
list[k++] = right[j++];
}
As, 27 is still remaining outside of the final sorted list, we loop one more time through both lists to make sure the contents of all lists gets inserted:
while (i < left.size())
list[k++] = left[i++];
while (j < right.size())
list[k++] = right[j++];
Complete function:
void merge(std::vector<int> &list, int start, int mid, int end) {
std::vector<int> left(list.begin() + start, list.begin() + mid + 1);
std::vector<int> right(list.begin() + mid + 1, list.begin() + end + 1);
int i = 0, j = 0, k = start;
while (i < left.size() && j < right.size()) {
if (left[i] <= right[j])
list[k++] = left[i++];
else
list[k++] = right[j++];
}
while (i < left.size())
list[k++] = left[i++];
while (j < right.size())
list[k++] = right[j++];
}