How can I optimize my C++ code for a sum problem on CodeChef?

  • C/C++
  • Thread starter Saitama
  • Start date
  • Tags
    C++ Sum
In summary: Actually, it will have a significant impact.That is because if $n$ is large, $m$ is likely very small in the test case provided.If that is really the case, we indeed need a smarter outer loop.
  • #1
Saitama
4,243
93
I am trying this problem on CodeChef: Just a simple sum

My task is to evaluate:
$$\sum_{i=1}^n i^i \pmod m$$

Following is the code I have written:
Code:
#include <iostream>
using namespace std;

typedef long long ll;

ll modularPower(ll base, ll exponent, ll M)
{
    ll res = 1;
    while (exponent)
    {
        if (exponent & 1)
            res = (res * base) % M;
        exponent >>= 1;
        base = (base * base) % M;
    }
    return res;
}

int main() {
	ll T, N, M, i, res;
	cin>>T;
	while(T--) {
	    cin>>N>>M;
	    res=0;
	    for(i=1; i<=N; ++i)
	    {
	            res+=modularPower(i, i, M);
	            res%=M;
	    }
	    cout<<res<<endl;
	}
	return 0;
}

But this code exceeds the time limit.

How do I optimise it?

Thanks!
 
Technology news on Phys.org
  • #2
Hey Pranav! (Wave)

Standard speed-up of algorithms is by tracking intermediate results in a table.

First thing I can think of, is to calculate the square of each number up to $m$ and store it $\bmod m$.
So instead of calculating [m](base * base) % M[/m], we can simply look it up. (Thinking)
 
  • #3
I like Serena said:
Hey Pranav! (Wave)

Standard speed-up of algorithms is by tracking intermediate results in a table.

First thing I can think of, is to calculate the square of each number up to $m$ and store it $\bmod m$.
So instead of calculating [m](base * base) % M[/m], we can simply look it up. (Thinking)

That won't really do much, a 64-bit modular multiplication is quite fast compared to a cache miss (because the precomputed table is going to be quite large and not accessed in a cache-friendly way). Besides, you still end up doing a loop from 1 to n, and since n can be up to $10^{18} \approx 2^{60}$ you're still going to be doing that many iterations, and $2^{60}$ iterations of anything is not going to happen in under 80 seconds. So the outer loop must go. My approach would be to factor m into its constituent prime factors, and use that information together with the CRT to try and simplify the sum.
 
  • #4
Bacterius said:
That won't really do much, a 64-bit modular multiplication is quite fast compared to a cache miss (because the precomputed table is going to be quite large and not accessed in a cache-friendly way). Besides, you still end up doing a loop from 1 to n, and since n can be up to $10^{18} \approx 2^{60}$ you're still going to be doing that many iterations, and $2^{60}$ iterations of anything is not going to happen in under 80 seconds. So the outer loop must go. My approach would be to factor m into its constituent prime factors, and use that information together with the CRT to try and simplify the sum.

I think I can factorise $m$ but I have never used CRT before. I looked it up online but could not understand how I would apply it here.

Can you please elaborate the approach? Thanks!
 
  • #5
Bacterius said:
That won't really do much, a 64-bit modular multiplication is quite fast compared to a cache miss (because the precomputed table is going to be quite large and not accessed in a cache-friendly way). Besides, you still end up doing a loop from 1 to n, and since n can be up to $10^{18} \approx 2^{60}$ you're still going to be doing that many iterations, and $2^{60}$ iterations of anything is not going to happen in under 80 seconds. So the outer loop must go. My approach would be to factor m into its constituent prime factors, and use that information together with the CRT to try and simplify the sum.

Actually, it will have a significant impact.
That is because if $n$ is large, $m$ is likely very small in the test case provided.
So it will hit the cache successfully every time.
It may be enough to get the problem accepted, and the effort involved is trivial. (Nod)I do agree that an $n = 10^{18}$ is way too large to complete in a reasonable amount of time.
If that is really the case, we indeed need a smarter outer loop.

I see a couple of ways that might split it up:
First we have:
$$1^1 + 2^2 + ... + m^m + (m+1)^{m+1} + ... + n^n \\
= 1^1 + 2^2 + ... + 0^m + 1^1\cdot 1^m + ... + (n \bmod m)^{n \bmod m}\cdot ((n \bmod m)^m)^{n / m}
$$
I'd have to think some more on it, but perhaps we could merge the bases to a separate summation that we can evaluate.
It appears there is a geometric series in there. (Thinking)

Or we can run an sieve similar to Erathosthenes's sieve.
If we know the prime factorization of $m$ and also the greatest common divisor $d=\gcd(m,i)$, we can use:
$$\left(\frac a d\right)^{\phi(m/d)} \equiv 1 \pmod {\frac m d} \Rightarrow a^{\phi(m/d)} \equiv d^{\phi(m/d)} \pmod m$$
That's as far as I got. (Thinking)
 
  • #6
Pranav,
I read your problem just a few days ago. I found it to be interesting and had fun finding a solution. First, a discussion of my solution. Then a formal implementation in Java, not C++. More of this choice of language later.

15dxssj.png

2u73fhi.png


For n reasonably small, you can use a C++ long long type or a Java long. However, if $n>2^{63}-1$, you need the ability to calculate with such large ints. This is mainly why I chose Java since Java comes pre packaged with a BigInteger class.
Code:
package bigsummod;

import java.math.BigInteger;

/**
 *
 * @author John on 7/6/2015
 */
public class BigSumMod {

   public static void main(String[] args) {
      SumModM test = new SumModM();
      long n = 1;
      int i;
      for (i = 1; i <= 18; i++) {
         n *= 10;
      }
      long time1 = System.currentTimeMillis(), time2;
      test.set(104729, n);
      System.out.println("The 10,000th prime is p = 104729");
      System.out.println("For m = p and n=i*10^18 for i=1 to 6:");
      for (i = 1; i <= 6; i++) {
         test.set(i * n);
         System.out.println(test.computeSum());
      }
      System.out.println("time " + (System.currentTimeMillis() - time1)+" milliseconds.");
      BigInteger N = BigInteger.TEN;
      N = N.pow(100); // N is a googol
      System.err.println("For m=p and n=10^100, a googol:");
      time1 = System.currentTimeMillis();
      test.set(N);
      System.out.println(test.computeSum());
      System.out.println("time " + (System.currentTimeMillis() - time1)+" milliseconds.");      
   }
}

package bigsummod;

/**
 * Class to compute sum(i = 1 to n)i^i mod m.  Here m<=200000 and n is "large", say 10^18
 * or even bigger than a long (2^63-1), say the "big" integer 10^100.
 */
import java.math.BigInteger;

public class SumModM {

   int m;
   long n;
   BigInteger N, N_j, R;
   boolean nIsBig;
   int[] primes; // array of prime divisors of m
   int[] primePowers; // m = product of components; each component a prime power
   int[] primePowerExponents; // array of exponents of the primes
   int primeCount; // number of prime divisors of m

   SumModM() {
      primes = new int[10]; // 10 is plenty since m<=200000
      primePowers = new int[10];
      primePowerExponents = new int[10];
   }
   
   public void set(int m,long n) {
      this.m=m;
      this.n=n;
      primeCount=0;
      factorM();
      nIsBig = false;
   }
   
   public void set(long n) {
      this.n = n;
      nIsBig = false;
   }
   
   public void set(BigInteger n) {
      N = n;
      nIsBig = true;
   }
   
/* crude factoring method of m, but fine for "small" m */   
   private void factorM() {
      int n = m, d;
      if ((n & 1) == 0) { // 2 divides n
         primePowers[0] = primes[0] = 2;
         primePowerExponents[0]=1;
         n >>= 1;
         while ((n & 1) == 0) {
            primePowers[0] *= 2;
            primePowerExponents[0]++;
            n >>= 1;
         }
         primeCount = 1;
      }
      d = 3;
      while (d <= n / d) {
         if (n % d == 0) {
            primes[primeCount] = primePowers[primeCount] = d;
            primePowerExponents[primeCount] = 1;
            n /= d;
            while (n % d == 0) {
               primePowers[primeCount] *= d;
               primePowerExponents[primeCount]++;
               n /= d;
            }
            primeCount++;
         }
         d += 2;
      }
      if (n != 1) {
         primes[primeCount] = primePowers[primeCount] = n;
         primePowerExponents[primeCount] = 1;
         primeCount++;
      }
   }
   
/* Computes x^m mod n.  Notice the method requires m*m <= max value for long , namely 2^63-1,
 * so it's perfectly safe for m an int..  This * recursive version is essentially just as efficient as a
 * non-recursive version.
 */
   private long powMod(long x, long m, long n) {
      long result;
      if (m == 1) {
         return (x);
      }
      result = powMod(x, m >> 1, n);
      result = (result * result) % n;
      if ((m & 1) == 1) {
         result = (x * result) % n;
      }
      return (result);
   }
   
   int computeSum() {
      int m1 = computeSum(0), modulus1 = primePowers[0];
      int i, m2;
      for (i = 1; i < primeCount; i++) {
         m2 = computeSum(i);
         m1 = chineseRemainder(m1, modulus1, m2, primePowers[i]);
         modulus1 *= primePowers[i];
      }
      return (m1);
   }

   int computeSum(int i) {
      int p = primes[i];
      m = primePowers[i];
      int j, sum = 0;
      long term;
      for (j = 1; j < m; j++) {
         if (j%p==0) {  // try to save a few milliseconds
            if (primePowerExponents[i]<=j) {
               continue;
            }
         }
         term = powMod(j, j, m);
         term = (term * computeSum_j(j, p, m)) % m;
         sum = (int) ((sum + term) % m);
      }
      return (sum);
   }

   int computeSum_j(int j, int p, int m) {
      if (j % p == 0) {
         return (1);
      }
      int x;
      long result, r, inverse;
      if (!nIsBig) {
         long n_j = (n - j) / m;
         if (p == 2) {
            return ((int) ((n_j + 1) % m));
         }
         x = (int) powMod(j, m / p, m);
         if (x == 1) {
            return ((int) ((n_j + 1) % m));
         }
         result = 0;
         r = (n_j + 1) % (p - 1);
         if (r != 0) {
            inverse = aInverseModB(x - 1, m);
            result = ((powMod(x, r, m) - 1) * inverse) % m;
         }
         return ((int) result);
      }
      N_j = N.subtract(BigInteger.valueOf(j));
      N_j = N_j.divide(BigInteger.valueOf(m));
      N_j = N_j.add(BigInteger.ONE);
      if (p == 2) {
         R = N_j.mod(BigInteger.valueOf(m));
         return (R.intValue());
      }
      x = (int) powMod(j, m / p, m);
      if (x == 1) {
         R = N_j.mod(BigInteger.valueOf(m));
         return (R.intValue());
      }
      result = 0;
      R = N_j.mod(BigInteger.valueOf(p - 1));
      if (!R.equals(BigInteger.ZERO)) {
         inverse = aInverseModB(x - 1, m);
         r = R.intValue();
         result = ((powMod(x, r, m) - 1) * inverse) % m;
      }
      return ((int) result);
   }
/* Upon entry m and n are positive coprime ints and a, b are non-negative ints reduced mod m and mod n
 * respectively.  The return is x with 0<=x and x<m*n  where x is congruent to a mod m and x is congruent
 * to b mod n.  Here's how it is done:
 * Clearly a+km is congruent to a for any int k;  so just need a k with a+km congruent to b mod n.  That is,
 * k should be congruent to (b-a)*(inverse of m mod n).  Then easily with 0<=k<n, x is of the correct size.
 */
   int chineseRemainder(int a, int m, int b, int n) {
      int k = (b - a) * aInverseModB(m, n);
      k %= n;
      if (k < 0) {
         k += n;
      }
      return (a + k * m);
   }

   /* Computation for small n's.  Check for the result of the above algorithm. */
   int feasible() {
      int j;
      int sum = 1;
      long term;
      for (j = 2; j <= n; j++) {
         term = powMod(j, j, m);
         sum = (int) ((sum + term) % m);
      }
      return (sum);
   }
   /*  "Standard computation of a inverse mod b where a is prime to b. */
   int aInverseModB(int a, int b) {
      int a1 = a;
      int b1 = b;
      int m0 = 1,  m1 = 0,  r, q, m2;
      while (b1 != 0) {
         r = a1 % b1;
         q = a1 / b1;
         m2 = m0 - m1 * q;
         m0 = m1;
         m1 = m2;
         a1 = b1;
         b1 = r;
      }
      if (m0 < 0) {
         m0 = b + m0;
      }
      return (m0);
   }
}
Finally, here's some timings for a couple of runs of the program:
For m=10^5 and n=i*10^18 for i=1 to 6:
62500
62500
62500
62500
62500
62500
time 16 milliseconds.
For m=10^5 and n=10^100, a googol:
62500
time 16 milliseconds.

You might want to verify for yourself that for m=10^5 and n any value with m^2 dividing n, you always get an answer of 62500. So m=100000 is an easy case. Notice the time for a googol is still very fast.

The 10,000th prime is p = 104729
For m = p and n=i*10^18 for i=1 to 6:
52444
39642
378
64098
7399
11003
time 936 milliseconds.
For m=p and n=10^100, a googol:
23255
time 312 milliseconds.
 
  • #7
Thanks a lot johng! That is exactly what I was looking for, great approach and explanation! :)
 

FAQ: How can I optimize my C++ code for a sum problem on CodeChef?

What is a sum implementation in C++?

A sum implementation in C++ refers to the process of writing code that allows the user to input two numbers and get their sum as a result. It involves using variables, operators, and control structures to add two numbers together and output the result.

What are the basic steps for implementing a sum in C++?

The basic steps for implementing a sum in C++ are as follows:

  1. Declare two variables to hold the input numbers.
  2. Ask the user to input the two numbers.
  3. Use the addition operator (+) to add the two numbers together and store the result in a third variable.
  4. Display the result to the user.

Do I need to include any libraries or headers for implementing a sum in C++?

No, you do not need to include any additional libraries or headers for implementing a sum in C++. The standard library already includes the necessary functions and operators for performing addition.

Can I implement a sum using only integers in C++?

Yes, you can implement a sum using only integers in C++. However, if the sum of the two integers exceeds the maximum value that can be stored in an integer variable, the result will be incorrect. In such cases, it is recommended to use a data type with a larger range, such as a long int or double.

Are there any common errors to watch out for when implementing a sum in C++?

Some common errors to watch out for when implementing a sum in C++ include:

  • Forgetting to declare variables before using them.
  • Using the wrong data type for the input numbers or the result variable.
  • Not handling user input errors, such as entering non-numeric values.
  • Not using the correct operator for addition (+).

Similar threads

Replies
6
Views
2K
Replies
22
Views
3K
Replies
6
Views
10K
Replies
5
Views
2K
Replies
39
Views
4K
Replies
8
Views
2K
Replies
118
Views
8K
Back
Top