-- |
-- Module      : Crypto.Number.Basic
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good

{-# LANGUAGE BangPatterns #-}
module Crypto.Number.Basic
    ( sqrti
    , gcde
    , areEven
    , log2
    , numBits
    , numBytes
    , asPowerOf2AndOdd
    ) where

import Data.Bits

import Crypto.Number.Compat

-- | @sqrti@ returns two integers @(l,b)@ so that @l <= sqrt i <= b@.
-- The implementation is quite naive, use an approximation for the first number
-- and use a dichotomy algorithm to compute the bound relatively efficiently.
sqrti :: Integer -> (Integer, Integer)
sqrti :: Integer -> (Integer, Integer)
sqrti i :: Integer
i
    | Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 0     = [Char] -> (Integer, Integer)
forall a. HasCallStack => [Char] -> a
error "cannot compute negative square root"
    | Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0    = (0,0)
    | Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 1    = (1,1)
    | Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 2    = (1,2)
    | Bool
otherwise = Integer -> (Integer, Integer)
loop Integer
x0
        where
            nbdigits :: Int
nbdigits = [Char] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Char] -> Int) -> [Char] -> Int
forall a b. (a -> b) -> a -> b
$ Integer -> [Char]
forall a. Show a => a -> [Char]
show Integer
i
            x0n :: Int
x0n = (if Int -> Bool
forall a. Integral a => a -> Bool
even Int
nbdigits then Int
nbdigits Int -> Int -> Int
forall a. Num a => a -> a -> a
- 2 else Int
nbdigits Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` 2
            x0 :: Integer
x0  = if Int -> Bool
forall a. Integral a => a -> Bool
even Int
nbdigits then 2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* 10 Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
x0n else 6 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* 10 Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
x0n
            loop :: Integer -> (Integer, Integer)
loop x :: Integer
x = case Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Integer -> Integer
forall a. Num a => a -> a
sq Integer
x) Integer
i of
                LT -> Integer -> (Integer, Integer)
iterUp Integer
x
                EQ -> (Integer
x, Integer
x)
                GT -> Integer -> (Integer, Integer)
iterDown Integer
x
            iterUp :: Integer -> (Integer, Integer)
iterUp lb :: Integer
lb = if Integer -> Integer
forall a. Num a => a -> a
sq Integer
ub Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
i then Integer -> Integer -> (Integer, Integer)
iter Integer
lb Integer
ub else Integer -> (Integer, Integer)
iterUp Integer
ub
                where ub :: Integer
ub = Integer
lb Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* 2
            iterDown :: Integer -> (Integer, Integer)
iterDown ub :: Integer
ub = if Integer -> Integer
forall a. Num a => a -> a
sq Integer
lb Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
i then Integer -> (Integer, Integer)
iterDown Integer
lb else Integer -> Integer -> (Integer, Integer)
iter Integer
lb Integer
ub
                where lb :: Integer
lb = Integer
ub Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` 2
            iter :: Integer -> Integer -> (Integer, Integer)
iter lb :: Integer
lb ub :: Integer
ub
                | Integer
lb Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
ub   = (Integer
lb, Integer
ub)
                | Integer
lbInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
ub = (Integer
lb, Integer
ub)
                | Bool
otherwise  =
                    let d :: Integer
d = (Integer
ub Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
lb) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` 2 in
                    if Integer -> Integer
forall a. Num a => a -> a
sq (Integer
lb Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
d) Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
i
                        then Integer -> Integer -> (Integer, Integer)
iter Integer
lb (Integer
ubInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
d)
                        else Integer -> Integer -> (Integer, Integer)
iter (Integer
lbInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
d) Integer
ub
            sq :: a -> a
sq a :: a
a = a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
a

-- | Get the extended GCD of two integer using integer divMod
--
-- gcde 'a' 'b' find (x,y,gcd(a,b)) where ax + by = d
--
gcde :: Integer -> Integer -> (Integer, Integer, Integer)
gcde :: Integer -> Integer -> (Integer, Integer, Integer)
gcde a :: Integer
a b :: Integer
b = GmpSupported (Integer, Integer, Integer)
-> (Integer, Integer, Integer) -> (Integer, Integer, Integer)
forall a. GmpSupported a -> a -> a
onGmpUnsupported (Integer -> Integer -> GmpSupported (Integer, Integer, Integer)
gmpGcde Integer
a Integer
b) ((Integer, Integer, Integer) -> (Integer, Integer, Integer))
-> (Integer, Integer, Integer) -> (Integer, Integer, Integer)
forall a b. (a -> b) -> a -> b
$
    if Integer
d Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 0 then (-Integer
x,-Integer
y,-Integer
d) else (Integer
x,Integer
y,Integer
d)
  where
    (d :: Integer
d, x :: Integer
x, y :: Integer
y)                     = (Integer, Integer, Integer)
-> (Integer, Integer, Integer) -> (Integer, Integer, Integer)
forall a. Integral a => (a, a, a) -> (a, a, a) -> (a, a, a)
f (Integer
a,1,0) (Integer
b,0,1)
    f :: (a, a, a) -> (a, a, a) -> (a, a, a)
f t :: (a, a, a)
t              (0, _, _)    = (a, a, a)
t
    f (a' :: a
a', sa :: a
sa, ta :: a
ta) t :: (a, a, a)
t@(b' :: a
b', sb :: a
sb, tb :: a
tb) =
        let (q :: a
q, r :: a
r) = a
a' a -> a -> (a, a)
forall a. Integral a => a -> a -> (a, a)
`divMod` a
b' in
        (a, a, a) -> (a, a, a) -> (a, a, a)
f (a, a, a)
t (a
r, a
sa a -> a -> a
forall a. Num a => a -> a -> a
- (a
q a -> a -> a
forall a. Num a => a -> a -> a
* a
sb), a
ta a -> a -> a
forall a. Num a => a -> a -> a
- (a
q a -> a -> a
forall a. Num a => a -> a -> a
* a
tb))

-- | Check if a list of integer are all even
areEven :: [Integer] -> Bool
areEven :: [Integer] -> Bool
areEven = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> ([Integer] -> [Bool]) -> [Integer] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Bool) -> [Integer] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map Integer -> Bool
forall a. Integral a => a -> Bool
even

-- | Compute the binary logarithm of a integer
log2 :: Integer -> Int
log2 :: Integer -> Int
log2 n :: Integer
n = GmpSupported Int -> Int -> Int
forall a. GmpSupported a -> a -> a
onGmpUnsupported (Integer -> GmpSupported Int
gmpLog2 Integer
n) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Int
forall p t. (Integral p, Integral t) => t -> t -> p
imLog 2 Integer
n
  where
    -- http://www.haskell.org/pipermail/haskell-cafe/2008-February/039465.html
    imLog :: t -> t -> p
imLog b :: t
b x :: t
x = if t
x t -> t -> Bool
forall a. Ord a => a -> a -> Bool
< t
b then 0 else (t
x t -> t -> t
forall a. Integral a => a -> a -> a
`div` t
bt -> p -> t
forall a b. (Num a, Integral b) => a -> b -> a
^p
l) t -> p -> p
forall t. Num t => t -> t -> t
`doDiv` p
l
      where
        l :: p
l = 2 p -> p -> p
forall a. Num a => a -> a -> a
* t -> t -> p
imLog (t
b t -> t -> t
forall a. Num a => a -> a -> a
* t
b) t
x
        doDiv :: t -> t -> t
doDiv x' :: t
x' l' :: t
l' = if t
x' t -> t -> Bool
forall a. Ord a => a -> a -> Bool
< t
b then t
l' else (t
x' t -> t -> t
forall a. Integral a => a -> a -> a
`div` t
b) t -> t -> t
`doDiv` (t
l' t -> t -> t
forall a. Num a => a -> a -> a
+ 1)
{-# INLINE log2 #-}

-- | Compute the number of bits for an integer
numBits :: Integer -> Int
numBits :: Integer -> Int
numBits n :: Integer
n = Integer -> GmpSupported Int
gmpSizeInBits Integer
n GmpSupported Int -> Int -> Int
forall a. GmpSupported a -> a -> a
`onGmpUnsupported` (if Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then 1 else Int -> Integer -> Int
forall t t. (Num t, Integral t) => t -> t -> t
computeBits 0 Integer
n)
  where computeBits :: t -> t -> t
computeBits !t
acc i :: t
i
            | t
q t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== 0 =
                if t
r t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= 0x80 then t
acct -> t -> t
forall a. Num a => a -> a -> a
+8
                else if t
r t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= 0x40 then t
acct -> t -> t
forall a. Num a => a -> a -> a
+7
                else if t
r t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= 0x20 then t
acct -> t -> t
forall a. Num a => a -> a -> a
+6
                else if t
r t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= 0x10 then t
acct -> t -> t
forall a. Num a => a -> a -> a
+5
                else if t
r t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= 0x08 then t
acct -> t -> t
forall a. Num a => a -> a -> a
+4
                else if t
r t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= 0x04 then t
acct -> t -> t
forall a. Num a => a -> a -> a
+3
                else if t
r t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= 0x02 then t
acct -> t -> t
forall a. Num a => a -> a -> a
+2
                else if t
r t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= 0x01 then t
acct -> t -> t
forall a. Num a => a -> a -> a
+1
                else t
acc -- should be catch by previous loop
            | Bool
otherwise = t -> t -> t
computeBits (t
acct -> t -> t
forall a. Num a => a -> a -> a
+8) t
q
          where (q :: t
q,r :: t
r) = t
i t -> t -> (t, t)
forall a. Integral a => a -> a -> (a, a)
`divMod` 256

-- | Compute the number of bytes for an integer
numBytes :: Integer -> Int
numBytes :: Integer -> Int
numBytes n :: Integer
n = Integer -> GmpSupported Int
gmpSizeInBytes Integer
n GmpSupported Int -> Int -> Int
forall a. GmpSupported a -> a -> a
`onGmpUnsupported` ((Integer -> Int
numBits Integer
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 7) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` 8)

-- | Express an integer as an odd number and a power of 2
asPowerOf2AndOdd :: Integer -> (Int, Integer)
asPowerOf2AndOdd :: Integer -> (Int, Integer)
asPowerOf2AndOdd a :: Integer
a
    | Integer
a Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0       = (0, 0)
    | Integer -> Bool
forall a. Integral a => a -> Bool
odd Integer
a        = (0, Integer
a)
    | Integer
a Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 0        = let (e :: Int
e, a1 :: Integer
a1) = Integer -> (Int, Integer)
asPowerOf2AndOdd (Integer -> (Int, Integer)) -> Integer -> (Int, Integer)
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall a. Num a => a -> a
abs Integer
a in (Int
e, -Integer
a1)
    | Integer -> Bool
forall a. (Num a, Bits a) => a -> Bool
isPowerOf2 Integer
a = (Integer -> Int
log2 Integer
a, 1)
    | Bool
otherwise    = Integer -> Int -> (Int, Integer)
forall b a. (Integral b, Num a) => b -> a -> (a, b)
loop Integer
a 0
        where      
          isPowerOf2 :: a -> Bool
isPowerOf2 n :: a
n = (a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= 0) Bool -> Bool -> Bool
&& ((a
n a -> a -> a
forall a. Bits a => a -> a -> a
.&. (a
n a -> a -> a
forall a. Num a => a -> a -> a
- 1)) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== 0)
          loop :: b -> a -> (a, b)
loop n :: b
n pw :: a
pw = if b
n b -> b -> b
forall a. Integral a => a -> a -> a
`mod` 2 b -> b -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then b -> a -> (a, b)
loop (b
n b -> b -> b
forall a. Integral a => a -> a -> a
`div` 2) (a
pw a -> a -> a
forall a. Num a => a -> a -> a
+ 1)
                      else (a
pw, b
n)