9022. Count the triplets

 

Given three array ab and c of n integers each. Find the number of triplets (ai, bj, ck) such that ai < bj < ck.

 

Input. First line contains the size of arrays n (n ≤ 105). Second line contains elements of array a. Next line contains elements of array b. Last line contains elements of array c.

 

Output. Print the number of triplets (ai, bj, ck) such that ai < bj < ck.

 

Explanation. In the first test case we have the triplets (a1b1c1), (a1b2c1) and (a1b2c2).

 

Sample input 1

Sample output 1

2

1 5

4 2

6 3

3

 

 

Sample input 2

Sample output 2

3

1 1 1

2 2 2

3 3 3

27

 

 

SOLUTION

binary search

 

Algorithm analysis

Sort the arrays. For each value of bj, using a binary search, find the amount of numbers x from array a that are less than bj, as well as the amount of numbers y from array c that are greater than bj. Then, for a fixed value of bj, there are x * y desired triples (ai, bj, ck).

 

Example

Consider the following sorted arrays. Let us calculate the number of required triples, in which b5 = 10. We have: ai < b5 for i ≤ 5, ck > b5 for k ≥ 7. That is, the inequality ai < b5 < ck holds for 1 i ≤ 5 and 7 k ≤ 8. The number of triplets (ai, b5, ck) is 5 * 2 = 10.

 

 

Algorithm realization

Declare the arrays.

 

#define MAX 100000

int a[MAX], b[MAX], c[MAX];

 

Read the input arrays.

 

scanf("%d", &n);

for (i = 0; i < n; i++) scanf("%d", &a[i]);

for (i = 0; i < n; i++) scanf("%d", &b[i]);

for (i = 0; i < n; i++) scanf("%d", &c[i]);

 

Sort the arrays.

 

sort(a, a + n);

sort(b, b + n);

sort(c, c + n);

 

Count the number of required triples in the res variable. Iterate over the values of bj.

 

res = 0;

for (j = 0; j < n; j++)

{

 

The amount of numbers from array a less than bj is x.

 

  x = lower_bound(a, a + n, b[j]) - a;

 

The amount of numbers from array c greater than bj is y.

 

  y = n - (upper_bound(c, c + n, b[j]) - c);

 

For bj there are x * y required triples.

 

  res += x * y;

}

 

Print the answer.

 

printf("%lld\n", res);

 

Java realization

 

import java.util.*;

 

public class Main

{

  static int lower_bound(int m[], int start, int end, int x)

  {

    while (start < end)

    {

      int mid = (start + end) / 2;

      if (x <= m[mid])

         end = mid;

      else

        start = mid + 1;

    }

    return start;

  }

 

  static int upper_bound(int m[], int start, int end, int x)

  {

    while (start < end)

    {

      int mid = (start + end) / 2;

      if (x >= m[mid])

        start = mid + 1;

      else

        end = mid;

    }

    return start;

  }

 

  public static void main(String[] args)

  {

    Scanner con = new Scanner(System.in);   

    int i, n = con.nextInt();

    int a[] = new int[n];

    for(i = 0; i < n; i++) a[i] = con.nextInt();

 

    int b[] = new int[n];

    for(i = 0; i < n; i++) b[i] = con.nextInt();

 

    int c[] = new int[n];

    for(i = 0; i < n; i++) c[i] = con.nextInt();

 

    Arrays.sort(a); Arrays.sort(b); Arrays.sort(c);

   

    long res = 0;

    for (i = 0; i < n; i++)

    {

      int x = lower_bound(a, 0, n, b[i]);

      int y = n - (upper_bound(c, 0, n, b[i]));

      res += 1L * x * y;

    }

   

    System.out.println(res);

    con.close();

  }

}