Session 5
Divide and Conquer Technique
In Divide and conquer approach, the original problem is divided into two or more sub-problems recursively, till it is small enough to be solved easily. Each sub-problem is some fraction of the original problem. Next, the solutions of the sub-problems are combined together to generate the solution of the original problem.
Objectives
- Understand how divide and conquer splits a problem into smaller independent subproblems.
- Implement recursive binary search, merge sort, quick sort, and Strassen's matrix multiplication.
- Show dry runs, recursive call trees, loop counts, comparisons, exchanges, and operation counts where the lab manual asks for them.
Questions Covered
| Question | Requirement | Completion Notes |
|---|---|---|
| Q1 | Recursive binary search for 100 in a sorted array | Add implementation, trace, and recursive call tree. |
| Q2 | Maximum searches among 512 million items using binary search | Add logarithmic calculation. |
| Q3 | Merge sort for 200 150 50 100 75 25 10 5 | Add split/merge steps and recursion tree. |
| Q4 | Quick sort for 12 20 22 16 25 18 8 10 6 15 | Add partition steps and output. |
| Q5 | Quick sort performance for sorted list 6 8 10 12 15 16 18 20 22 25 | Add comparisons, exchanges, and loop iterations. |
| Q6 | Strassen's multiplication for n x n matrices where n is power of 2 | Add multiple instances and compare multiplications/additions. |
Question 1
Problem Statement
Implement a recursive binary search algorithm on your system to search for number 100 in the following array of integers. Show the process step by step and draw recursive calls.
10 35 40 45 50 55 60 65 70 100
Approach
Binary search compares the target with the middle element. If the target is larger, the search continues in the right half; if smaller, it continues in the left half. Because the array is sorted, every recursive call reduces the search range by half.
Step-by-Step Trace
| Call | Low | High | Mid | A[mid] | Decision |
|---|---|---|---|---|---|
| 1 | 0 | 9 | 4 | 50 | 100 > 50, search right half |
| 2 | 5 | 9 | 7 | 65 | 100 > 65, search right half |
| 3 | 8 | 9 | 8 | 70 | 100 > 70, search right half |
| 4 | 9 | 9 | 9 | 100 | Found |
Recursive Call Tree
binarySearch(0, 9)
-> binarySearch(5, 9)
-> binarySearch(8, 9)
-> binarySearch(9, 9)
-> found at index 9
Complexity
- Time complexity:
O(log n) - Space complexity:
O(log n)due to recursive call stack
Implementation
Python
def binary_search(arr, low, high, target):
"""
Recursive Binary Search in Python
:param arr: Sorted list of integers
:param low: Starting index
:param high: Ending index
:param target: Value to find
:return: Index of target, or -1 if not found
"""
# Base case: search range is invalid
if low > high:
return -1
# Integer division for mid calculation
mid = low + (high - low) // 2
# Step-by-step trace for lab submission
print(f"[Call] low={low}, high={high}, mid={mid}, arr[mid]={arr[mid]}")
# Base case: target found
if arr[mid] == target:
return mid
# Target is in the right half
elif arr[mid] < target:
return binary_search(arr, mid + 1, high, target)
# Target is in the left half
else:
return binary_search(arr, low, mid - 1, target)
# Driver code
if __name__ == "__main__":
arr = [10, 35, 40, 45, 50, 55, 60, 65, 70, 100]
target = 100
print("=== Recursive Binary Search in Python ===")
result = binary_search(arr, 0, len(arr) - 1, target)
if result != -1:
print(f"\n✅ Target {target} found at index {result}")
else:
print(f"\n❌ Target {target} not found in array")
C Language
#include <stdio.h>
/**
* Recursive Binary Search
* @param arr Sorted integer array
* @param low Starting index of current search range
* @param high Ending index of current search range
* @param target Value to search for
* @return Index of target, or -1 if not found
*/
int binarySearch(int arr[], int low, int high, int target) {
// Base case: invalid range means target is not in array
if (low > high) {
return -1;
}
// Calculate mid to avoid integer overflow
int mid = low + (high - low) / 2;
// Print step-by-step trace for lab documentation
printf("[Call] low=%d, high=%d, mid=%d, arr[mid]=%d\n", low, high, mid, arr[mid]);
// Check if target is present at mid
if (arr[mid] == target) {
return mid;
}
// If target is greater, ignore left half and search right
if (arr[mid] < target) {
return binarySearch(arr, mid + 1, high, target);
}
// If target is smaller, ignore right half and search left
return binarySearch(arr, low, mid - 1, target);
}
int main() {
int arr[] = {10, 35, 40, 45, 50, 55, 60, 65, 70, 100};
int n = sizeof(arr) / sizeof(arr[0]);
int target = 100;
printf("=== Recursive Binary Search in C ===\n");
int result = binarySearch(arr, 0, n - 1, target);
if (result != -1)
printf("\n✅ Target %d found at index %d\n", target, result);
else
printf("\n❌ Target %d not found in array\n", target);
return 0;
}
Rust
/// Recursive Binary Search in Rust
/// Returns Some(index) if found, None otherwise
fn binary_search(arr: &[i32], low: isize, high: isize, target: i32) -> Option<isize> {
// Base case: invalid range
if low > high {
return None;
}
// Safe mid calculation using isize to prevent underflow
let mid = low + (high - low) / 2;
// Trace output for lab documentation
println!(
"[Call] low={}, high={}, mid={}, arr[mid]={}",
low, high, mid, arr[mid as usize]
);
// Check if target is at mid
if arr[mid as usize] == target {
return Some(mid);
}
// Recurse right or left based on comparison
if arr[mid as usize] < target {
return binary_search(arr, mid + 1, high, target);
}
binary_search(arr, low, mid - 1, target)
}
fn main() {
let arr = [10, 35, 40, 45, 50, 55, 60, 65, 70, 100];
let target = 100;
println!("=== Recursive Binary Search in Rust ===");
// Pass initial bounds: 0 to length-1
match binary_search(&arr, 0, arr.len() as isize - 1, target) {
Some(idx) => println!("\n✅ Target {} found at index {}", target, idx),
None => println!("\n❌ Target {} not found in array", target),
}
}
Question 2
Problem Statement
Suppose we are required to search among 512 million items in a list using binary search. What is the maximum number of searches needed to find a given item or conclude that it is not present?
Solution
Binary search needs at most ceil(log2 n) comparisons.
n = 512,000,000
log2(512,000,000) is slightly less than 29
Maximum searches = 29
If the list is interpreted as exactly 512 * 2^20 = 536,870,912, then:
536,870,912 = 2^29
Maximum searches = 29
Question 3
Problem Statement
Implement Merge Sort algorithm to sort the following list and show the process step by step. Draw a tree of recursive calls.
200 150 50 100 75 25 10 5
Approach
Merge sort divides the list into two halves until each sublist has one element, then merges the sorted sublists.
Recursive Split Tree
[200,150,50,100,75,25,10,5]
├─ [200,150,50,100]
│ ├─ [200,150]
│ │ ├─ [200]
│ │ └─ [150]
│ └─ [50,100]
│ ├─ [50]
│ └─ [100]
└─ [75,25,10,5]
├─ [75,25]
│ ├─ [75]
│ └─ [25]
└─ [10,5]
├─ [10]
└─ [5]
Implementation
Python
def merge_sort(arr, left=0, right=None, depth=0):
"""
Recursive Merge Sort Implementation
:param arr: List to sort
:param left: Left index
:param right: Right index
:param depth: Recursion depth for indentation
:return: Sorted list
"""
if right is None:
right = len(arr) - 1
# Print current state with indentation
indent = " " * depth
print(f"{indent}mergeSort({left},{right}) → {arr[left:right+1]}")
if left < right:
# Calculate mid point (avoid overflow)
mid = left + (right - left) // 2
# Recursively sort first half
merge_sort(arr, left, mid, depth + 1)
# Recursively sort second half
merge_sort(arr, mid + 1, right, depth + 1)
# Merge the sorted halves
merge(arr, left, mid, right, depth)
else:
print(f"{indent} Base case: single element")
return arr
def merge(arr, left, mid, right, depth):
"""
Merge two sorted subarrays
:param arr: Original array
:param left: Start of left subarray
:param mid: End of left subarray
:param right: End of right subarray
:param depth: Indentation level
"""
indent = " " * depth
# Create temporary arrays
left_arr = arr[left:mid + 1]
right_arr = arr[mid + 1:right + 1]
print(f"{indent} Merge: Left={left_arr} Right={right_arr}")
i = j = 0 # Initial indices for left and right subarrays
k = left # Initial index of merged subarray
# Merge the temp arrays back
while i < len(left_arr) and j < len(right_arr):
if left_arr[i] <= right_arr[j]:
arr[k] = left_arr[i]
print(f"{indent} Compare {left_arr[i]} <= {right_arr[j]} → Take {left_arr[i]}")
i += 1
else:
arr[k] = right_arr[j]
print(f"{indent} Compare {left_arr[i]} > {right_arr[j]} → Take {right_arr[j]}")
j += 1
k += 1
# Copy remaining elements
while i < len(left_arr):
arr[k] = left_arr[i]
print(f"{indent} Append remaining left: {left_arr[i]}")
i += 1
k += 1
while j < len(right_arr):
arr[k] = right_arr[j]
print(f"{indent} Append remaining right: {right_arr[j]}")
j += 1
k += 1
print(f"{indent} Result: {arr[left:right+1]}\n")
# Driver code
if __name__ == "__main__":
arr = [200, 150, 50, 100, 75, 25, 10, 5]
print("=== Merge Sort Implementation in Python ===")
print(f"Original array: {arr}\n")
merge_sort(arr)
print(f"Sorted array: {arr}")
C Language
#include <stdio.h>
/**
* Merge Sort - Divide and Conquer Algorithm
* Time Complexity: O(n log n) in all cases
* Space Complexity: O(n) for temporary arrays
*/
/**
* Merge two sorted subarrays arr[left...mid] and arr[mid+1...right]
* @param arr Original array
* @param left Starting index of left subarray
* @param mid Ending index of left subarray
* @param right Ending index of right subarray
*/
void merge(int arr[], int left, int mid, int right) {
int n1 = mid - left + 1; // Size of left subarray
int n2 = right - mid; // Size of right subarray
// Create temporary arrays
int L[n1], R[n2];
// Copy data to temporary arrays
printf(" Merge: Left=[");
for (int i = 0; i < n1; i++) {
L[i] = arr[left + i];
printf("%d ", L[i]);
}
printf("] Right=[");
for (int j = 0; j < n2; j++) {
R[j] = arr[mid + 1 + j];
printf("%d ", R[j]);
}
printf("]\n");
// Merge the temp arrays back into arr[left..right]
int i = 0, j = 0, k = left;
while (i < n1 && j < n2) {
if (L[i] <= R[j]) {
arr[k] = L[i];
printf(" Compare %d <= %d → Take %d\n", L[i], R[j], L[i]);
i++;
} else {
arr[k] = R[j];
printf(" Compare %d > %d → Take %d\n", L[i], R[j], R[j]);
j++;
}
k++;
}
// Copy remaining elements of L[] (if any)
while (i < n1) {
arr[k] = L[i];
printf(" Append remaining left: %d\n", L[i]);
i++;
k++;
}
// Copy remaining elements of R[] (if any)
while (j < n2) {
arr[k] = R[j];
printf(" Append remaining right: %d\n", R[j]);
j++;
k++;
}
printf(" Result: [");
for (int idx = left; idx <= right; idx++) {
printf("%d ", arr[idx]);
}
printf("]\n\n");
}
/**
* Recursive merge sort function
* @param arr Array to sort
* @param left Left index
* @param right Right index
*/
void mergeSort(int arr[], int left, int right, int depth) {
// Print indentation for recursion depth
for (int i = 0; i < depth; i++) printf(" ");
printf("mergeSort(%d,%d) → [", left, right);
for (int i = left; i <= right; i++) printf("%d ", arr[i]);
printf("]\n");
if (left < right) {
// Find the middle point (avoid overflow)
int mid = left + (right - left) / 2;
// Sort first half
mergeSort(arr, left, mid, depth + 1);
// Sort second half
mergeSort(arr, mid + 1, right, depth + 1);
// Merge the sorted halves
merge(arr, left, mid, right);
} else {
for (int i = 0; i < depth; i++) printf(" ");
printf(" Base case: single element\n");
}
}
/**
* Print array elements
*/
void printArray(int arr[], int size) {
printf("[");
for (int i = 0; i < size; i++) {
printf("%d", arr[i]);
if (i < size - 1) printf(", ");
}
printf("]\n");
}
int main() {
int arr[] = {200, 150, 50, 100, 75, 25, 10, 5};
int n = sizeof(arr) / sizeof(arr[0]);
printf("=== Merge Sort Implementation in C ===\n");
printf("Original array: ");
printArray(arr, n);
printf("\n");
mergeSort(arr, 0, n - 1, 0);
printf("Sorted array: ");
printArray(arr, n);
return 0;
}
Rust
/// Merge Sort Implementation in Rust
/// Time Complexity: O(n log n)
/// Space Complexity: O(n)
/// Merge two sorted subarrays
///
/// # Arguments
/// * `arr` - Mutable slice to sort
/// * `left` - Starting index of left subarray
/// * `mid` - Ending index of left subarray
/// * `right` - Ending index of right subarray
/// * `depth` - Recursion depth for indentation
fn merge(arr: &mut [i32], left: usize, mid: usize, right: usize, depth: usize) {
let indent = " ".repeat(depth);
// Create temporary vectors
let left_arr: Vec<i32> = arr[left..=mid].to_vec();
let right_arr: Vec<i32> = arr[mid + 1..=right].to_vec();
println!(
"{} Merge: Left={:?} Right={:?}",
indent, left_arr, right_arr
);
let mut i = 0; // Index for left_arr
let mut j = 0; // Index for right_arr
let mut k = left; // Index for merged array
// Merge the temp arrays
while i < left_arr.len() && j < right_arr.len() {
if left_arr[i] <= right_arr[j] {
arr[k] = left_arr[i];
println!(
"{} Compare {} <= {} → Take {}",
indent, left_arr[i], right_arr[j], left_arr[i]
);
i += 1;
} else {
arr[k] = right_arr[j];
println!(
"{} Compare {} > {} → Take {}",
indent, left_arr[i], right_arr[j], right_arr[j]
);
j += 1;
}
k += 1;
}
// Copy remaining elements
while i < left_arr.len() {
arr[k] = left_arr[i];
println!("{} Append remaining left: {}", indent, left_arr[i]);
i += 1;
k += 1;
}
while j < right_arr.len() {
arr[k] = right_arr[j];
println!("{} Append remaining right: {}", indent, right_arr[j]);
j += 1;
k += 1;
}
println!("{} Result: {:?}\n", indent, &arr[left..=right]);
}
/// Recursive merge sort function
///
/// # Arguments
/// * `arr` - Mutable slice to sort
/// * `left` - Left index
/// * `right` - Right index
/// * `depth` - Recursion depth for indentation
fn merge_sort(arr: &mut [i32], left: usize, right: usize, depth: usize) {
let indent = " ".repeat(depth);
// Print current subarray
print!(
"{}mergeSort({},{}) → {:?}\n",
indent,
left,
right,
&arr[left..=right]
);
if left < right {
// Calculate mid point (avoid overflow)
let mid = left + (right - left) / 2;
// Sort first half
merge_sort(arr, left, mid, depth + 1);
// Sort second half
merge_sort(arr, mid + 1, right, depth + 1);
// Merge sorted halves
merge(arr, left, mid, right, depth);
} else {
println!("{} Base case: single element", indent);
}
}
fn main() {
let mut arr = [200, 150, 50, 100, 75, 25, 10, 5];
let n = arr.len();
println!("=== Merge Sort Implementation in Rust ===");
println!("Original array: {:?}\n", arr);
merge_sort(&mut arr, 0, n - 1, 0);
println!("Sorted array: {:?}", arr);
}
Final Merge Result
5 10 25 50 75 100 150 200
Question 4
Problem Statement
Implement Quick Sort's algorithm to sort the following list and show step-by-step processes.
12 20 22 16 25 18 8 10 6 15
Approach
Choose a pivot, partition the list so smaller elements move before the pivot and larger elements move after it, then recursively sort the two partitions.
Recursive Split Tree
Here, we are using array index for understanding how the loop will work for quicksort.
quickSort(0,9) pivot=15
│
├─ quickSort(0,3) pivot=6
│ ├─ quickSort(0,-1) → base case
│ └─ quickSort(1,3) pivot=12
│ ├─ quickSort(1,2) pivot=10
│ │ ├─ quickSort(1,1) → base case
│ │ └─ quickSort(3,2) → base case
│ └─ quickSort(4,3) → base case
│
└─ quickSort(5,9) pivot=25
├─ quickSort(5,8) pivot=16
│ ├─ quickSort(5,4) → base case
│ └─ quickSort(6,8) pivot=18
│ ├─ quickSort(6,5) → base case
│ └─ quickSort(7,8) pivot=20
│ ├─ quickSort(7,6) → base case
│ └─ quickSort(8,7) → base case
└─ quickSort(10,9) → base case
Python
def partition(arr, low, high):
"""
Lomuto Partition: Rearranges array so elements <= pivot are left,
and elements > pivot are right. Returns pivot's final index.
"""
pivot = arr[high]
i = low - 1 # Boundary for smaller elements
print(f" Partitioning [{low}..{high}] with pivot={pivot}")
for j in range(low, high):
print(f" Compare arr[{j}]={arr[j]} with pivot={pivot} → ", end="")
if arr[j] <= pivot:
i += 1
if i != j:
arr[i], arr[j] = arr[j], arr[i]
print(f"SWAP {arr[i]} ↔ {arr[j]}")
else:
print("No swap needed")
else:
print("Skip (greater)")
# Place pivot in correct sorted position
arr[i + 1], arr[high] = arr[high], arr[i + 1]
print(f" Final swap: place pivot at index {i + 1}")
return i + 1
def quick_sort(arr, low, high):
"""Recursive Quick Sort"""
if low < high:
print(f"quickSort({low}, {high})")
pi = partition(arr, low, high)
quick_sort(arr, low, pi - 1) # Sort left partition
quick_sort(arr, pi + 1, high) # Sort right partition
if __name__ == "__main__":
arr = [12, 20, 22, 16, 25, 18, 8, 10, 6, 15]
print("=== Quick Sort Implementation in Python ===")
print(f"Original: {arr}\n")
quick_sort(arr, 0, len(arr) - 1)
print(f"\nSorted: {arr}")
C Language
#include <stdio.h>
/**
* Swap two integers
*/
void swap(int* a, int* b) {
int temp = *a;
*a = *b;
*b = temp;
}
/**
* Lomuto Partition Scheme
* @param arr Array to partition
* @param low Starting index
* @param high Ending index (pivot is at high)
* @return Final pivot index
*/
int partition(int arr[], int low, int high) {
int pivot = arr[high]; // Select last element as pivot
int i = low - 1; // Index of smaller element
printf(" Partitioning [%d..%d] with pivot=%d\n", low, high, pivot);
for (int j = low; j < high; j++) {
printf(" Compare arr[%d]=%d with pivot=%d → ", j, arr[j], pivot);
if (arr[j] <= pivot) {
i++;
if (i != j) {
swap(&arr[i], &arr[j]);
printf("SWAP %d ↔ %d\n", arr[i], arr[j]);
} else {
printf("No swap needed\n");
}
} else {
printf("Skip (greater)\n");
}
}
// Place pivot in correct position
swap(&arr[i + 1], &arr[high]);
printf(" Final swap: place pivot at index %d\n", i + 1);
return i + 1;
}
/**
* Recursive Quick Sort
*/
void quickSort(int arr[], int low, int high) {
if (low < high) {
printf("quickSort(%d, %d)\n", low, high);
int pi = partition(arr, low, high);
// Recursively sort elements before and after partition
quickSort(arr, low, pi - 1);
quickSort(arr, pi + 1, high);
}
}
/**
* Print array
*/
void printArray(int arr[], int size) {
printf("[");
for (int i = 0; i < size; i++) {
printf("%d", arr[i]);
if (i < size - 1) printf(", ");
}
printf("]\n");
}
int main() {
int arr[] = {12, 20, 22, 16, 25, 18, 8, 10, 6, 15};
int n = sizeof(arr) / sizeof(arr[0]);
printf("=== Quick Sort Implementation in C ===\n");
printf("Original: "); printArray(arr, n);
printf("\n");
quickSort(arr, 0, n - 1);
printf("\nSorted: "); printArray(arr, n);
return 0;
}
Rust
/// Quick Sort Implementation in Rust
/// Uses Lomuto partition scheme with last element as pivot
/// Swaps two elements in a mutable slice
fn swap(arr: &mut [i32], i: usize, j: usize) {
arr.swap(i, j);
}
/// Partition function
fn partition(arr: &mut [i32], low: usize, high: usize) -> usize {
let pivot = arr[high];
let mut i = if low == 0 { 0 } else { low } - 1;
println!(" Partitioning [{low}..{high}] with pivot={pivot}");
for j in low..high {
print!(" Compare arr[{j}]={} with pivot={pivot} → ", arr[j]);
if arr[j] <= pivot {
i += 1;
if i != j {
swap(arr, i, j);
println!("SWAP {} ↔ {}", arr[i], arr[j]);
} else {
println!("No swap needed");
}
} else {
println!("Skip (greater)");
}
}
// Place pivot in correct position
swap(arr, i + 1, high);
println!(" Final swap: place pivot at index {}", i + 1);
i + 1
}
/// Recursive quick sort
fn quick_sort(arr: &mut [i32], low: usize, high: usize) {
if low < high {
println!("quickSort({}, {})", low, high);
let pi = partition(arr, low, high);
quick_sort(arr, low, if pi == 0 { 0 } else { pi - 1 });
if pi + 1 <= high {
quick_sort(arr, pi + 1, high);
}
}
}
fn main() {
let mut arr = [12, 20, 22, 16, 25, 18, 8, 10, 6, 15];
let n = arr.len();
println!("=== Quick Sort Implementation in Rust ===");
println!("Original: {:?}\n", arr);
quick_sort(&mut arr, 0, n - 1);
println!("\nSorted: {:?}", arr);
}
Final Sorted Output
6 8 10 12 15 16 18 20 22 25
Complexity
| Metric | Value | Explanation |
|---|---|---|
| Best Case Time | O(n log n) | Pivot divides array evenly every time |
| Average Case Time | O(n log n) | Random/pivot selection yields balanced partitions |
| Worst Case Time | O(n²) | Already sorted/reverse sorted + last/first pivot |
| Space Complexity | O(log n) | Recursion stack depth (average) |
| Stable Sort | No | Relative order of equal elements may change |
| In-Place | Yes | Sorts within original array (no extra arrays) |
Question 5
Problem Statement
Examine the performance of Quick Sort for the following list in terms of comparisons, exchange operations, and loop iterations.
6 8 10 12 15 16 18 20 22 25
Note
This input is already sorted. If the implementation always chooses the first or last element as pivot, this becomes the worst-case pattern for Quick Sort.
| Metric | Worst-Case Pattern |
|---|---|
| Time complexity | O(n^2) |
| Recursion depth | n |
| Comparisons | Approximately n(n-1)/2 |
For n = 10, the worst-case comparison count is:
10 * 9 / 2 = 45
Performance Summary
| Metric | Count Formula/Explanation |
|---|---|
| Total Comparisons | 45 |
| Meaningful Swaps (i≠j) | 0 |
| Self-Swaps (i=j) | 45 Each comparison where arr[j]≤pivot triggers swap with itself |
| Final Pivot Placements | 9 |
| Total Loop Iterations | 45 |
| Recursive Calls | 10 |
| Recursion Depth | 10 |
Mathematical Verification
For n = 10, sorted array, Lomuto partition:
Comparisons = (n-1) + (n-2) + ... + 1 + 0
= Σ(k) for k=0 to n-1
= n(n-1)/2
= 10×9/2 = 45 ✓
Swaps (meaningful) = 0 (since array is already partitioned)
Time Complexity = O(n²) ← Worst Case
Space Complexity = O(n) ← Recursion stack depth
Question 6
Problem Statement
Implement Strassen's multiplication algorithm for two n x n matrices, where n is a power of 2, on different problem instances and compare all instances in terms of number of multiplications and additions required.
Strassen’s Algorithm
Traditional matrix multiplication performs 8 multiplications for two 2 x 2 matrices. Strassen's method reduces this to 7 multiplications by using additional additions and subtractions.
M1 = (A11 + A22) × (B11 + B22)
M2 = (A21 + A22) × B11
M3 = A11 × (B12 - B22)
M4 = A22 × (B21 - B11)
M5 = (A11 + A12) × B22
M6 = (A21 - A11) × (B11 + B12)
M7 = (A12 - A22) × (B21 + B22)
Comparison
| Method | Multiplications | Additions/Subtractions | Time Complexity |
|---|---|---|---|
| Classical multiplication | 8 subproblems | Fewer additions | O(n^3) |
| Strassen multiplication | 7 subproblems | More additions/subtractions | O(n^2.807) |
Implementation
Python
# 6.py
# Strassen's Matrix Multiplication with Deterministic Operation Counting
# Language: Python 3.x | Complexity: O(n^2.807) mults, O(n^2.807) adds
# Run: python3 6.py
class OpCounter:
"""Tracks scalar multiplications and additions/subtractions."""
def __init__(self):
self.mults = 0
self.adds = 0
def add_mat(A, B, n, ctr):
"""Element-wise addition: C = A + B. n is the actual size of A and B."""
C = [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)]
ctr.adds += n * n
return C
def sub_mat(A, B, n, ctr):
"""Element-wise subtraction: C = A - B. n is the actual size of A and B."""
C = [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)]
ctr.adds += n * n
return C
def strassen(A, B, n, ctr):
"""Recursive Strassen multiplication. Base case: n=1."""
if n == 1:
ctr.mults += 1
return [[A[0][0] * B[0][0]]]
mid = n // 2
# Extract quadrants (all are mid x mid)
A11 = [row[:mid] for row in A[:mid]]
A12 = [row[mid:] for row in A[:mid]]
A21 = [row[:mid] for row in A[mid:]]
A22 = [row[mid:] for row in A[mid:]]
B11 = [row[:mid] for row in B[:mid]]
B12 = [row[mid:] for row in B[:mid]]
B21 = [row[:mid] for row in B[mid:]]
B22 = [row[mid:] for row in B[mid:]]
# 7 recursive products
# FIX: All add_mat/sub_mat calls now correctly use `mid`, not `n`
M1 = strassen(add_mat(A11, A22, mid, ctr), add_mat(B11, B22, mid, ctr), mid, ctr)
M2 = strassen(add_mat(A21, A22, mid, ctr), B11, mid, ctr)
M3 = strassen(A11, sub_mat(B12, B22, mid, ctr), mid, ctr)
M4 = strassen(A22, sub_mat(B21, B11, mid, ctr), mid, ctr)
M5 = strassen(add_mat(A11, A12, mid, ctr), B22, mid, ctr)
M6 = strassen(sub_mat(A21, A11, mid, ctr), add_mat(B11, B12, mid, ctr), mid, ctr)
M7 = strassen(sub_mat(A12, A22, mid, ctr), add_mat(B21, B22, mid, ctr), mid, ctr)
# Combine quadrants into result
# FIX: All helper calls correctly use `mid`
C11 = add_mat(sub_mat(add_mat(M1, M4, mid, ctr), M5, mid, ctr), M7, mid, ctr)
C12 = add_mat(M3, M5, mid, ctr)
C21 = add_mat(M2, M4, mid, ctr)
C22 = add_mat(sub_mat(add_mat(M1, M3, mid, ctr), M2, mid, ctr), M6, mid, ctr)
# Merge into full n x n matrix
C = [[0]*n for _ in range(n)]
for i in range(mid):
for j in range(mid):
C[i][j] = C11[i][j]
C[i][j+mid] = C12[i][j]
C[i+mid][j] = C21[i][j]
C[i+mid][j+mid] = C22[i][j]
return C
def naive_multiply(A, B, n, ctr):
"""Standard O(n^3) multiplication for baseline comparison."""
C = [[0]*n for _ in range(n)]
for i in range(n):
for j in range(n):
for k in range(n):
ctr.mults += 1
if k > 0: ctr.adds += 1
C[i][j] += A[i][k] * B[k][j]
return C
def init_matrix(n):
"""Deterministic initialization for reproducible grading."""
return [[(i * n + j + 1) for j in range(n)] for i in range(n)]
if __name__ == "__main__":
print(f"{'n':<4} | {'Strassen Mults':<14} | {'Strassen Adds':<13} | {'Naive Mults':<11} | {'Naive Adds':<10}")
print("-" * 85)
sizes = [2, 4, 8, 16, 32, 64]
for n in sizes:
A = init_matrix(n)
B = init_matrix(n)
ctr_str = OpCounter()
C_str = strassen(A, B, n, ctr_str)
ctr_naive = OpCounter()
C_naive = naive_multiply(A, B, n, ctr_naive)
print(f"{n:<4} | {ctr_str.mults:<14} | {ctr_str.adds:<13} | {ctr_naive.mults:<11} | {ctr_naive.adds:<10}")
C Language
/*
* strassen_complete.c
* Strassen's Matrix Multiplication with Deterministic Operation Counting
* Language: C11 | Layout: 1D Row-Major | Architecture: arm64/x86_64 compatible
*/
#include <stdio.h>
#include <stdlib.h>
/* Operation counter: tracks scalar arithmetic operations */
typedef struct { long mults; long adds; } OpCounter;
/* Row-major index helper: maps 2D (i,j) to 1D offset */
static inline size_t IDX(size_t i, size_t j, size_t n) { return i * n + j; }
/* Allocate zero-initialized n×n matrix (1D array) */
static double* alloc_mat(size_t n) {
double* m = (double*)calloc(n * n, sizeof(double));
if (!m) { fprintf(stderr, "Memory allocation failed\n"); exit(EXIT_FAILURE); }
return m;
}
/* Element-wise addition with operation counting */
static void mat_add(const double* a, const double* b, double* c, size_t n, OpCounter* ctr) {
for (size_t i = 0; i < n * n; i++) c[i] = a[i] + b[i];
ctr->adds += (long)(n * n);
}
/* Element-wise subtraction with operation counting */
static void mat_sub(const double* a, const double* b, double* c, size_t n, OpCounter* ctr) {
for (size_t i = 0; i < n * n; i++) c[i] = a[i] - b[i];
ctr->adds += (long)(n * n);
}
/* Extract quadrant from parent matrix */
static void copy_quad(const double* src, double* dst, size_t src_n, size_t dst_n, size_t r, size_t c) {
for (size_t i = 0; i < dst_n; i++)
for (size_t j = 0; j < dst_n; j++)
dst[IDX(i, j, dst_n)] = src[IDX(i + r, j + c, src_n)];
}
/* Merge 4 quadrants into full matrix */
static void merge_quads(double* C, const double* c11, const double* c12,
const double* c21, const double* c22, size_t n) {
size_t m = n / 2;
for (size_t i = 0; i < m; i++)
for (size_t j = 0; j < m; j++) {
C[IDX(i, j, n)] = c11[IDX(i, j, m)];
C[IDX(i, j + m, n)] = c12[IDX(i, j, m)];
C[IDX(i + m, j, n)] = c21[IDX(i, j, m)];
C[IDX(i + m, j + m, n)] = c22[IDX(i, j, m)];
}
}
/* Recursive Strassen implementation */
void strassen(const double* A, const double* B, double* C, size_t n, OpCounter* ctr) {
if (n == 1) { C[0] = A[0] * B[0]; ctr->mults++; return; }
size_t m = n / 2;
/* Allocate quadrants & temporaries */
double *a11=alloc_mat(m),*a12=alloc_mat(m),*a21=alloc_mat(m),*a22=alloc_mat(m);
double *b11=alloc_mat(m),*b12=alloc_mat(m),*b21=alloc_mat(m),*b22=alloc_mat(m);
double *m1=alloc_mat(m),*m2=alloc_mat(m),*m3=alloc_mat(m),*m4=alloc_mat(m);
double *m5=alloc_mat(m),*m6=alloc_mat(m),*m7=alloc_mat(m);
double *s1=alloc_mat(m),*s2=alloc_mat(m),*s3=alloc_mat(m),*s4=alloc_mat(m);
double *s5=alloc_mat(m),*s6=alloc_mat(m),*s7=alloc_mat(m),*s8=alloc_mat(m);
double *s9=alloc_mat(m),*s10=alloc_mat(m),*tmp=alloc_mat(m);
/* Split inputs into quadrants */
copy_quad(A,a11,n,m,0,0); copy_quad(A,a12,n,m,0,m);
copy_quad(A,a21,n,m,m,0); copy_quad(A,a22,n,m,m,m);
copy_quad(B,b11,n,m,0,0); copy_quad(B,b12,n,m,0,m);
copy_quad(B,b21,n,m,m,0); copy_quad(B,b22,n,m,m,m);
/* Strassen's 7 products */
mat_add(a11,a22,s1,m,ctr); mat_add(b11,b22,s2,m,ctr); strassen(s1,s2,m1,m,ctr);
mat_add(a21,a22,s3,m,ctr); strassen(s3,b11,m2,m,ctr);
mat_sub(b12,b22,s4,m,ctr); strassen(a11,s4,m3,m,ctr);
mat_sub(b21,b11,s5,m,ctr); strassen(a22,s5,m4,m,ctr);
mat_add(a11,a12,s6,m,ctr); strassen(s6,b22,m5,m,ctr);
mat_sub(a21,a11,s7,m,ctr); mat_add(b11,b12,s8,m,ctr); strassen(s7,s8,m6,m,ctr);
mat_sub(a12,a22,s9,m,ctr); mat_add(b21,b22,s10,m,ctr); strassen(s9,s10,m7,m,ctr);
/* Combine results into C quadrants */
double *c11=alloc_mat(m),*c12=alloc_mat(m),*c21=alloc_mat(m),*c22=alloc_mat(m);
mat_add(m1,m4,c11,m,ctr); mat_sub(c11,m5,tmp,m,ctr); mat_add(tmp,m7,c11,m,ctr);
mat_add(m3,m5,c12,m,ctr); mat_add(m2,m4,c21,m,ctr);
mat_sub(m1,m2,c22,m,ctr); mat_add(c22,m3,tmp,m,ctr); mat_add(tmp,m6,c22,m,ctr);
merge_quads(C,c11,c12,c21,c22,n);
/* Clean up all temporary allocations */
free(a11);free(a12);free(a21);free(a22);free(b11);free(b12);free(b21);free(b22);
free(m1);free(m2);free(m3);free(m4);free(m5);free(m6);free(m7);
free(s1);free(s2);free(s3);free(s4);free(s5);free(s6);free(s7);free(s8);free(s9);free(s10);
free(tmp);free(c11);free(c12);free(c21);free(c22);
}
/* Standard O(n^3) multiplication for baseline comparison */
void naive_mul(const double* A, const double* B, double* C, size_t n, OpCounter* ctr) {
for (size_t i = 0; i < n; i++) {
for (size_t j = 0; j < n; j++) {
double sum = 0.0;
for (size_t k = 0; k < n; k++) {
sum += A[IDX(i, k, n)] * B[IDX(k, j, n)];
ctr->mults += 1;
if (k > 0) ctr->adds += 1;
}
C[IDX(i, j, n)] = sum;
}
}
}
/* Initialize matrix with deterministic values */
void init_matrix(double* m, size_t n) {
for (size_t i = 0; i < n * n; i++) m[i] = (double)(i + 1);
}
/* ================= ENTRY POINT ================= */
int main(void) {
printf("%-4s | %-14s | %-13s | %-11s | %-10s\n", "n", "Strassen Mults", "Strassen Adds", "Naive Mults", "Naive Adds");
printf("%-70s\n", "----------------------------------------------------------------------");
size_t sizes[] = {2, 4, 8, 16, 32, 64};
size_t num_sizes = sizeof(sizes) / sizeof(sizes[0]);
for (size_t idx = 0; idx < num_sizes; idx++) {
size_t n = sizes[idx];
double* A = alloc_mat(n); init_matrix(A, n);
double* B = alloc_mat(n); init_matrix(B, n);
double* C_str = alloc_mat(n);
double* C_naive = alloc_mat(n);
OpCounter ctr_str = {0, 0};
strassen(A, B, C_str, n, &ctr_str);
OpCounter ctr_naive = {0, 0};
naive_mul(A, B, C_naive, n, &ctr_naive);
printf("%-4zu | %-14ld | %-13ld | %-11ld | %-10ld\n",
n, ctr_str.mults, ctr_str.adds, ctr_naive.mults, ctr_naive.adds);
free(A); free(B); free(C_str); free(C_naive);
}
return 0;
}
Rust
// strassen_complete.rs
// Strassen's Matrix Multiplication with Deterministic Operation Counting
// Language: Rust 2021 | Layout: Vec<f64> (Row-Major) | Complexity: O(n^2.807)
// Compile: rustc -O strassen_complete.rs -o strassen
// Run: ./strassen
struct OpCounter {
mults: u64,
adds: u64,
}
#[inline]
fn idx(i: usize, j: usize, n: usize) -> usize {
i * n + j
}
fn mat_add(a: &[f64], b: &[f64], n: usize, ctr: &mut OpCounter) -> Vec<f64> {
let mut c = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
c[idx(i, j, n)] = a[idx(i, j, n)] + b[idx(i, j, n)];
}
}
ctr.adds += (n * n) as u64;
c
}
fn mat_sub(a: &[f64], b: &[f64], n: usize, ctr: &mut OpCounter) -> Vec<f64> {
let mut c = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
c[idx(i, j, n)] = a[idx(i, j, n)] - b[idx(i, j, n)];
}
}
ctr.adds += (n * n) as u64;
c
}
fn extract_quad(src: &[f64], src_n: usize, dst_n: usize, r: usize, c: usize) -> Vec<f64> {
let mut dst = vec![0.0; dst_n * dst_n];
for i in 0..dst_n {
for j in 0..dst_n {
dst[idx(i, j, dst_n)] = src[idx(i + r, j + c, src_n)];
}
}
dst
}
fn merge_quads(c11: &[f64], c12: &[f64], c21: &[f64], c22: &[f64], n: usize) -> Vec<f64> {
let m = n / 2;
let mut c = vec![0.0; n * n];
for i in 0..m {
for j in 0..m {
c[idx(i, j, n)] = c11[idx(i, j, m)];
c[idx(i, j + m, n)] = c12[idx(i, j, m)];
c[idx(i + m, j, n)] = c21[idx(i, j, m)];
c[idx(i + m, j + m, n)] = c22[idx(i, j, m)];
}
}
c
}
fn strassen(a: &[f64], b: &[f64], n: usize, ctr: &mut OpCounter) -> Vec<f64> {
if n == 1 {
ctr.mults += 1;
return vec![a[0] * b[0]];
}
let m = n / 2;
let a11 = extract_quad(a, n, m, 0, 0);
let a12 = extract_quad(a, n, m, 0, m);
let a21 = extract_quad(a, n, m, m, 0);
let a22 = extract_quad(a, n, m, m, m);
let b11 = extract_quad(b, n, m, 0, 0);
let b12 = extract_quad(b, n, m, 0, m);
let b21 = extract_quad(b, n, m, m, 0);
let b22 = extract_quad(b, n, m, m, m);
let m1 = strassen(
&mat_add(&a11, &a22, m, ctr),
&mat_add(&b11, &b22, m, ctr),
m,
ctr,
);
let m2 = strassen(&mat_add(&a21, &a22, m, ctr), &b11, m, ctr);
let m3 = strassen(&a11, &mat_sub(&b12, &b22, m, ctr), m, ctr);
let m4 = strassen(&a22, &mat_sub(&b21, &b11, m, ctr), m, ctr);
let m5 = strassen(&mat_add(&a11, &a12, m, ctr), &b22, m, ctr);
let m6 = strassen(
&mat_sub(&a21, &a11, m, ctr),
&mat_add(&b11, &b12, m, ctr),
m,
ctr,
);
let m7 = strassen(
&mat_sub(&a12, &a22, m, ctr),
&mat_add(&b21, &b22, m, ctr),
m,
ctr,
);
let c11 = {
let t = mat_add(&m1, &m4, m, ctr);
let t = mat_sub(&t, &m5, m, ctr);
mat_add(&t, &m7, m, ctr)
};
let c12 = mat_add(&m3, &m5, m, ctr);
let c21 = mat_add(&m2, &m4, m, ctr);
let c22 = {
let t = mat_sub(&m1, &m2, m, ctr);
let t = mat_add(&t, &m3, m, ctr);
mat_add(&t, &m6, m, ctr)
};
merge_quads(&c11, &c12, &c21, &c22, n)
}
fn naive_mul(a: &[f64], b: &[f64], n: usize, ctr: &mut OpCounter) -> Vec<f64> {
let mut c = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
let mut sum = 0.0;
for k in 0..n {
sum += a[idx(i, k, n)] * b[idx(k, j, n)];
ctr.mults += 1;
if k > 0 {
ctr.adds += 1;
}
}
c[idx(i, j, n)] = sum;
}
}
c
}
fn init_matrix(n: usize) -> Vec<f64> {
(0..n * n).map(|i| (i + 1) as f64).collect()
}
fn main() {
println!(
"{:<4} | {:<14} | {:<13} | {:<11} | {:<10}",
"n", "Strassen Mults", "Strassen Adds", "Naive Mults", "Naive Adds"
);
println!("{:-<70}", "");
let sizes = [2, 4, 8, 16, 32, 64];
for n in sizes {
let a = init_matrix(n);
let b = init_matrix(n);
let mut ctr_str = OpCounter { mults: 0, adds: 0 };
strassen(&a, &b, n, &mut ctr_str);
let mut ctr_naive = OpCounter { mults: 0, adds: 0 };
naive_mul(&a, &b, n, &mut ctr_naive);
println!(
"{:<4} | {:<14} | {:<13} | {:<11} | {:<10}",
n, ctr_str.mults, ctr_str.adds, ctr_naive.mults, ctr_naive.adds
);
}
}
Viva Questions
- Why must the input array be sorted for binary search?
- Why is merge sort stable?
- What causes Quick Sort's worst case?
- Why does Strassen's method reduce multiplication count?