The Sum Of Two Squares

January 5, 2010

Brute force search over all pairs (x, y) is impractical for large n. We follow a solution given by Edsger Dijkstra in his 1976 book The Discipline of Programming. Dijkstra finds all pairs in a single pass as x sweeps downward from the integer square root of n and y sweeps upward from zero. Consider the function B(x, y) that returns all suitable pairs between x and y, guided by the following three recursive rules plus a recursive base:

  • If x² + y² < n, then B(x, y) = B(x, y+1), because there is no possible solution (u, v) with xu, since that would imply u² + v² < n.
  • If x² + y² = n, then the pair (x, y) is a solution, B(x, y) = (x, y) &union; B(x-1, y+1), and the sweep continues.
  • If x² + y² > n, then B(x, y) = B(x-1, y), because there is no possible solution with any x.
  • Finally, if x < y, B(x, y) is the null set, and recursion stops.

This can all be coded in a single recursive function:

(define (squares n)
  (let loop ((x (isqrt n)) (y 0) (zs '()))
    (cond ((< x y) zs)
          ((< (+ (* x x) (* y y)) n) (loop x (+ y 1) zs))
          ((< n (+ (* x x) (* y y))) (loop (- x 1) y zs))
          (else (loop (- x 1) (+ y 1) (cons (list x y) zs))))))

The code exposes the simplicity of the recursion. X is decremented when x² + y² is greater than n. Y is incremented when x² + y² is less than n. Solutions are accumulated in zs and returned when x and y cross. Here are some examples:

> (squares 50)
((5 5) (7 1))
> (squares 48612265)
((5008 4851) (5139 4712) (5179 4668) (5243 4596) (5432 4371)
 (5613 4136) (5656 4077) (5691 4028) (5832 3821) (5907 3704)
 (6048 3469) (6124 3333) (6213 3164) (6259 3072) (6384 2803)
 (6404 2757) (6413 2736) (6556 2373) (6576 2317) (6637 2136)
 (6651 2092) (6756 1723) (6772 1659) (6789 1588) (6853 1284)
 (6899 1008) (6917 876) (6944 627) (6948 581) (6952 531)
 (6971 132) (6972 59))
> (squares 999)
()

The integer square root function isqrt comes from the Standard Prelude. You can run the program at http://programmingpraxis.codepad.org/kdQgtQ34.

Pages: 1 2

19 Responses to “The Sum Of Two Squares”

  1. kernelbob said

    I am not going to code a solution.. (It is five hours past my bedtime.)

    Instead, I’ll observe that x^2 + y^2 = n is the equation for a circle, and Bresenham’s circle algorithm is a pretty quick way to iterate through the possible solutions. It’s not entirely straightforward because n = r^2 and r will generally not be an integer. I think Bresenham’s assumes r is an integer.

  2. Jaime said

    I’ve implemented a somehow brute force approach, but it’s surprisingly fast…

    import math
    
    def two_squares(number):
        '''Return all the tuples of two numbers with the
        following properties:
        x >= y >= 0
        x²+y²=number
        '''
        
        results = []
        
        #The Maximum number will be the square root of the number
        for y in xrange(int(math.sqrt(number))):
            #Check possible match
            x2 = number - y ** 2
            x = math.sqrt(x2)
            if y > x:
                break 
            if x == int(x):
                #Exact root
                results.append((int(x), y))
                
        return results
    
    
    if __name__ == '__main__':
        assert two_squares(50) == [(7, 1), (5, 5)]
        print two_squares(48612265)
        #print two_squares(4861226500)
        assert two_squares(999) == []
        assert two_squares(100) == [(10, 0), (8, 6)]
    
    
  3. […] Praxis – The Sum Of Two Squares By Remco Niemeijer In today’s Programming Praxis exercise we have to find all the ways a given number can be written as the sum […]

  4. Remco Niemeijer said

    My Haskell solution (see http://bonsaicode.wordpress.com/2010/01/05/programming-praxis-the-sum-of-two-squares/ for a version with comments):

    squareSum :: Integer -> [(Integer, Integer)]
    squareSum n = b (ceiling . sqrt $ fromIntegral n) 0 where
        b x y = if x < y then [] else case compare (x*x + y*y) n of
                    LT -> b x (y + 1)
                    EQ -> (x, y) : b (x - 1) (y + 1)
                    GT -> b (x - 1) y
    
  5. Lautaro Pecile said
    import itertools
    from math import sqrt
    
    def sum_of_squares(n):
        lim = int(sqrt(n)) + 1
        l = itertools.product(xrange(lim), xrange(lim))
        return ((a, b) for a, b in l if (a >= b) and (a**2 + b**2 == n))
    
    if __name__ == "__main__":
        for x in sum_of_squares(50):
            print x
    
  6. Here’s my simple niave Haskell solution using two list comprehensions :-

    squares :: Integer -&gt; [(Integer, Integer)]
    squares n = [(x,y) | x &lt;- c, y &lt;- c, ((x * x) + (y * y) == n), x &lt;= y]   
      where c     = [i | i &lt;- [1..(isqrt n)]]
            isqrt = floor . sqrt . fromIntegral
    
  7. Sorry it’s the first time I’ve posted here so I don’t know how to get the nice formatted source listing.

  8. Jamie Hope said

    Here’s one which constructs a list of pairs (x, n-x*x) for x <- [0..sqrt(n/2)] and then filters out the elements in which n-x*x is not an integer.

    #!r6rs
    
    (import (rnrs)
            (only (srfi :1) list-tabulate))
    
    (define (squares n)
      (let ((sqrt-n/2 (sqrt (/ n 2))))
        (filter (lambda (pair) (exact? (cdr pair)))
                (list-tabulate (+ 1 (exact (floor sqrt-n/2)))
                               (lambda (i) (cons i (sqrt (- n (* i i)))))))))
    
  9. Jamie Hope said

    Hey, why didn’t sourcecode lang=”css” work?

  10. David Humphreys said

    Andrew: see http://en.support.wordpress.com/code/posting-source-code/
    Jamie: I think it should be language, not lang.

    A solution in Clojure. It works fine (but slowly) for large values. I’m sure there are more elegant ways of doing it.

    (ns Sum-of-Squares)
    (defmacro add-squares [x y] `(+ (* ~x ~x) (* ~y ~y)))
    (defn test-single-square [x y n result]
     (if (< x y)
      result
      (let [sum (add-squares x y)]
       (cond
        (< sum n)	#(test-single-square x (inc y) n result)
        (> sum n)	#(test-single-square (dec x) y n result)
        :else	#(test-single-square (dec x) (inc y) n (conj result [x y]))))))
    (defn #^{:test (fn []
     (let [check-values (fn [test-value]
      (assert (every?
       #(= test-value
        (add-squares (first %) (last %)))
        (find-squares test-value))))]
     (map check-values [1790119876545 10000798002 48612265 999 50])))}
     find-squares [n]
      (let [x (int (java.lang.Math/sqrt n)) y 0]
     (trampoline (test-single-square x y n []))))
    (test #'find-squares)
    
  11. programmingpraxis said

    Andrew, Jaime: I fixed your comments so the source code is properly formatted.

  12. Kevin said

    Here’s mine in C++…it’s somewhat bruteforce. If there are no solutions it wont give any output. Please, gimme some feedback! I’m a student that’s trying to practice/learn and get a better feel of C++. I wanna make it my goto language :D

    #include
    #include
    using namespace std;
    int main(void)
    {
    int n, i=0;
    cout <> n;
    double temp_frac, a, b;
    do
    {
    a = pow(i,2);
    b = sqrt(n-a);
    if( a > n )
    break;
    else if( i > b )
    break;
    else if( modf(b, &temp_frac) == 0 )
    cout << i << " , " << b << endl;
    ++i;
    }
    while(1);
    return 0;
    }

  13. Kevin said

    OH CRAP I’m so sorry for being a noob. I JUST read the “Posting Source Code” page. So sorry!!

    #include<iostream>
    #include<cmath>
    using namespace std;
    int main(void)
    {
    	int n, i=0;
    	cout << "Integer? ";
    	cin >> n;
    	double temp_frac, a, b;
    	do
    	{
    		a = pow(i,2);
    		b = sqrt(n-a);
    		if( a > n )
    			break;
    		else if( i > b )
    			break;
    		else if( modf(b, &temp_frac) == 0 )
    			cout << i << " , " << b << endl;
    		++i;
    	}
    	while(1);
    	return 0;
    }
    
  14. Jamie Hope said

    Oh, I see. It should be “language” and not “lang”. But the HOWTO page says “lang”:

    HOWTO: Posting Source Code

    Please fix, as I’ll probably forget this by the time I go to post source code again. Thanks!

  15. Jebb said

    I’ve tried to pay extra attention to the boundaries of both while loops, to avoid any unnecessary iteration. I’m not convinced it’s optimal, though: I don’t particularly like the “if (y > x) break;”.

    #include <stdio.h>
    #include <math.h>
    
    int main()
    {
        unsigned int n, x, y;
        printf("Enter the integer n:\n");
        scanf("%d", &n);
        x = sqrt(n);
        while (x-- > sqrt(n / 2)) {
            y = sqrt(n - x * x); 
            while (x * x + y * y <= n) {
                if (y > x)
                    break;
                if (x * x + y * y == n)
                    printf("%d %d\n", x, y); 
                ++y;
            }   
        }   
        return 0;
    }
    
  16. Mike said

    Here is a solution that only uses adds and multiply by 2 (could use a shift).

    def sumofsquares(n):
        '''Generate (x,y) pairs, where x*x + y*y == n and x >= y >= 0.
    
        Find max x, such that x*x <= n.  Then loop, comparing y*y + x*x to n.
        If the sum is too small, increment y; if too big, decrement x.
        '''
        x, y, sos = 0, 0, 0
        while sos < n:
            sos += 2*x + 1
            x += 1
    
    
        while x >= y:
            if sos < n:
                sos += 2*y + 1
                y += 1
    
            elif sos > n:
                sos -= 2*x - 1
                x -= 1
    
            else:
                yield x,y
                sos += 2*( y -x + 1 )
                y += 1
                x -= 1
                
    
  17. Mike said

    Because x >= y, decrementing x always causes a bigger change in the sum than incrementing y. So the loop in the my routine above can be shortened to:

    def sumofsquares(n):
    ”’Generate (x,y) pairs, where x*x + y*y == n and x >= y >= 0.

    Find max x, such that x*x <= n. Then loop, comparing y*y + x*x to n.
    If the sum is too small, increment y; if too big, decrement x.

    "sos" is the sum of the squares.
    '''

    x, y, sos = 0, 0, 0
    while sos = y:
    while sos < n:
    sos += 2*y + 1
    y += 1

    if sos == n:
    yield x,y

    sos -= 2*x – 1
    x -= 1

  18. Hey guys. Was looking at this code for an example. Ended up moving on but not before porting one of the functions over to PHP. Hope someone can use eventually.

    <?php
    function squares($n) {
    	$pairs = array();
    	
    	$x = sqrt($n);
    	
    	while ( $x-- > sqrt($n / 2) ) {
    		$y = sqrt($n - $x * $x);
    		
    		while ($x * $x + $y * $y <= $n) {
    			if ($y > $x)
    				break;
    			if ($x * $x + $y * $y == $n)
    				array_push($pairs, array($x, $y));
    			++$y;
    		}
    	}
    	
    	echo sprintf("%d -- Pairs: %d\r\n", $n, count($pairs));
    }
    
    squares(50);
    squares(48612265);
    ?>
    
  19. Drake said

    At the end of the chapter Dijkstra includes this: “Note. Obvious improvements, such as testing whether r mod 4 =3, and exploiting
    the recurrence relation (x+1)^2 = x^2 + (2x +1) are left as exercises.”

Leave a comment