Kth Smallest Element in a Row-wise and Column-wise Sorted Matrix

Kth Smallest Element in a Row-wise and Column-wise Sorted Matrix

Problem Link: LeetCode

We are given a matrix in which the rows are sorted, and columns are sorted. WE have to find the kth smallest element in the matrix. There are many approaches. Applying Binary Search is one of them. Its time complexity is bit better than the others. I have tried to explain the binary search approach in this article.

Observing Test Case 1

Suppose we have to find k=8th smallest element from the matrix given.

We have the search range from the (0,0)th element to (n-1,n-1)th element. Because they are sorted in row-wise and also in column-wise.

Now we can calculate the mid from the range.

Now it's time to think logically. We have a mid-value which is 8. We can count the numbers of element less or equal to the mid value present in the matrix.

Since, 2 < k(8) If we search in values less than or equal to 8(mid) we won't be able to find 8th one. For 8rd one we have to look into the values greater than 8.

So we move the Start pointer to the mid+1 . Our search space is reduced. The new mid is now (9+15)/2 = 12

let's count how many numbers in the matrix are less than or equal to mid(12) which is 6.

still its less than k(8) so we have to search the right side of mid(12)

Now the count is equal to 8 means there are 8 elements less or equal to the current mid (14).

14 is the answer? No, not really. 14 is not present in the given matrix. We have to find an element that is present in the matrix. The 8th element is 14 or less than 14 we can be sure about that. Can we do end = mid -1 ?? No, because 14 can also be the answer still we don't know. So, the search space should be reduced to mid only. next search space should be 13 to 14.

what is the count for elements less or equal to 13? its still 8.

Since count(8) = k(8) we again reduce the search space to from end to mid only.

end = mid

We an build a table for the better understanding.

StartEndMidValues less or equal to Mid (count)
11582
915126
1315148
1314138
1313--

Observing Test Case 2

I know we haven't covered the case for moving our start. For that let's see another example.

Here we have to find the k=3rd element in the matrix. Values less or equal to 30 are { 10, 15, 20, 24, 25, 29, 30}. Number of values = 7 which is greater than k(3)

so we have to reduce the search space and bring the end to the mid. why not mid-1? we'll see that later.

Let's follow the table for the rest.

StartEndMidValues less or equal to mid (Count)
1050307
1030203 ( its the answer but we can't stop we have to continue till start becomes equal to end.)
1020152
1620182
1920192
2020--

Let's write the code for up to this.

import java.util.Scanner;

public class Kth_smallest_element {
    // we'll improve this method getNumbersLessOrEqual() later. 
    public static int getNumbersLessOrEqual(int[][] matrix, int mid) {
        int n = matrix[0].length;
        int count = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (matrix[i][j] <= mid) {
                    count++;
                }
            }
        }

        return count;
    }

    public static int kthSmallest(int[][] matrix, int k) {
        int n = matrix[0].length;
        int start = matrix[0][0];
        int end = matrix[n - 1][n - 1];
        while (start < end) {
            int mid = start + (end - start) / 2;
            int count = getNumbersLessOrEqual(matrix, mid);
            if (count < k) {
                start = mid + 1;
            } else {
                end = mid;
            }
        }

        return start;
    }

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int t = in.nextInt();
        while (t-- > 0) {
            int n = in.nextInt();
            int k = in.nextInt();
            int[][] mat = new int[n][n];
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    mat[i][j] = in.nextInt();
                }
            }
            System.out.println(kthSmallest(mat, k));
        }

    }
}

Now, see we would do end = mid -1 for having less count than k in that case what would happen?

Let's see another test case for this.

we can directly write down the table for this for k = 13

startendmidvalues less or equal to mid
1291512
16292219
1621 (mid-1)1815
1617 (mid-1)1612
1717--

But 17 is not in the matrix. If we take end = mid-1 then we are not being in the search space we need to be. So, it should be end = mid.

Improving the counting method

We were linearly searching through every row and column. Since the rows are sorted respectively, we can apply binary search in each of them. If we can find the index of the largest element greater than the target we found the count of the numbers less or equal to a certain value(mid).

import java.util.Scanner;

public class Kth_smallest_element {
    public static int upperBound(int[] arr, int t) {
        int start = 0;
        int end = arr.length - 1;

        while (start <= end) {
            int mid = start + (end - start) / 2;
            if (t >= arr[mid]) {
                start = mid + 1;
            } else {
                end = mid - 1;
            }
        }
        return start;
    }

    public static int getNumbersLessOrEqual(int[][] matrix, int mid) {
        int n = matrix[0].length;
        int count = 0;
        for (int i = 0; i < n; i++) {
            count += (upperBound(matrix[i], mid));
        }

        return count;
    }

    public static int kthSmallest(int[][] matrix, int k) {
        int n = matrix[0].length;
        int start = matrix[0][0];
        int end = matrix[n - 1][n - 1];
        while (start < end) {
            int mid = start + (end - start) / 2;
            int count = getNumbersLessOrEqual(matrix, mid);
            if (count < k) {
                start = mid + 1;
            } else {
                end = mid;
            }
        }

        return start;
    }

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int t = in.nextInt();
        while (t-- > 0) {
            int n = in.nextInt();
            int k = in.nextInt();
            int[][] mat = new int[n][n];
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    mat[i][j] = in.nextInt();
                }
            }
            System.out.println(kthSmallest(mat, k));
        }

    }
}