9606. Modular division

 

Three positive integers a, b and n are given. Find the value of a / b mod n. You must fund such x that b * x = a mod n.

 

Input. Three positive integers a, b, n (n 2 * 109, 1 ≤ a, b < n). It is kown that n is prime.

 

Output. Print the value of a / b mod n.

 

Sample input 1

Sample output 1

3 4 7

6

 

 

Sample input 2

Sample output 2

4 8 13

7

 

 

SOLUTION

exponentiation

 

Algorithm analysis

Since number n is prime, then by Fermat's theorem bn-1 mod n = 1 for every 1 ≤ b < nThis equality can be rewritten in the form (b * bn-2) mod n = 1, whence the inverse of b equals to y = bn-2 mod n.

Hence a / b mod n = a * b-1 mod n = a * y mod n.

 

The inverse can be found using the extended Euclidean algorithm. Let the the modulo equation should be solved: ax = 1 (mod n). Consider the equation

ax + ny = 1

and find its partial solution (x0, y0) using the extended Euclidean algorithm. Taking the equation ax0 + ny0 = 1 modulo n, we get ax0 = 1 (mod n). If x0 is negative, add n to it. So x0 = a-1 (mod n) is the inverse for a.

 

Example

Consider the second sample. Compute 4 / 8 mod 13. To do this, solve the equation 8 * x = 4 mod 13, wherefrom x = (4 * 8-1) mod 13.

Number 13 is prime, Fermat's theorem implies that 812 mod 13 = 1 or (8 * 811) mod 13 = 1. Therefore 8-1 mod 13 = 811 mod 13 = 5.

Compute the answer: x = (4 * 8-1) mod 13 = (4 * 5) mod 13 = 20 mod 13 = 7.

 

Algorithm realization

Function powmod finds the value of xn mod m.

 

long long powmod(long long x, long long n, long long m)

{

  if (n == 0) return 1;

  if (n % 2 == 0) return powmod((x * x) % m, n / 2, m);

  return (x * powmod(x, n - 1, m)) % m;

}

 

The main part of the program. Read the input data.

 

scanf("%lld %lld %lld", &a, &b, &n);

 

Calculate the values y = bn-2 mod n, x = a * y mod n.

 

y = powmod(b, n - 2, n);

x = (a * y) % n;

 

Print the answer.

 

printf("%lld\n", x);

 

Algorithm realization – extended Euclidean algorithm

 

#include <stdio.h>

 

long long a, b, n, d, x, y, inv, res;

 

void gcdext(long long a, long long b,

            long long &d, long long &x, long long &y)

{

  if (b == 0)

  {

    d = a; x = 1; y = 0;

    return;

  }

 

  gcdext(b, a % b, d, x, y);

 

  long long s = y;

  y = x - (a / b) * y;

  x = s;

}

 

int main(void)

{

  scanf("%lld %lld %lld", &a, &b, &n);

  // b * inv + n * y = 1

  gcdext(b, n, d, inv, y);

  if (inv < 0) inv += n;

  res = (a * inv) % n;

  printf("%lld\n", res);

  return 0;

}