/*
 * erf.cpp: Error function and its relatives.
 */

#include <stdio.h>
#include <stdlib.h>

#include "spigot.h"
#include "funcs.h"
#include "cr.h"

class ErfBaseRational : public Source {
    /*
     * This class computes the simplest erf-like function: just the
     * integral of exp(-x^2), without any scaling or factors of
     * sqrt(pi) anywhere. We put those on later.
     *
     * We compute this for rationals by translating its obvious
     * power series into a spigot description. This converges like
     * an absolute dog for large |x|, but I don't know of any more
     * efficient way to compute erf at those values. (There isn't
     * any sensible range reduction method, for instance.)
     *
     * Let x = n/d. Then we have
     *
     *                  n      n^3        n^5        n^7
     *   erfbase(n/d) = - - -------- + -------- - -------- + ...
     *                  d   1! 3 d^3   2! 5 d^5   3! 7 d^7
     *
     *              n   1     n^2   1     n^2   1     n^2
     *            = - ( - - ----- ( - - ----- ( - - ----- ( ... ) ) )
     *              d   1   1.d^2   3   2.d^2   5   3.d^2
     *
     * so our matrices go
     *
     *   ( n 0 ) ( -1.n^2   1.d^2 ) ( -3.n^2   2.d^2 ) ( -5.n^2   3.d^2 ) ...
     *   ( 0 d ) (    0   1.1.d^2 ) (    0   2.3.d^2 ) (    0   3.5.d^2 )
     */

    bigint n, d, fn, fd, n2, d2, k, kodd;
    int crState;

  public:
    /*
     * For convenience of Phi, we also compute e^-fx^2 for some
     * factor f. (Multiplying a nice rational 1/2 into each matrix
     * here is much faster than scaling the input by _root_ 2.)
     */
    ErfBaseRational(const bigint &an, const bigint &ad,
                    const bigint &afn, const bigint &afd)
        : n(an), d(ad), fn(afn), fd(afd)
    {
        crState = -1;
    }

    virtual ErfBaseRational *clone()
    {
        return new ErfBaseRational(n, d, fn, fd);
    }

    bool gen_interval(bigint *low, bigint *high)
    {
        /* I totally made these numbers up, but they seem to work. Ahem. */
        *low = 0;
        *high = 2;
        return true;
    }

    bool gen_matrix(bigint *matrix)
    {
        crBegin;

        /*
         * The initial anomalous matrix.
         */
        matrix[1] = matrix[2] = 0;
        matrix[0] = n;
        matrix[3] = d;
        crReturn(false);

        /*
         * Then the regular series.
         */
        k = 1;
        kodd = 1;
        n2 = fn*n*n;
        d2 = fd*d*d;
        while (1) {
            matrix[0] = -kodd*n2;
            matrix[1] = k*d2;
            matrix[2] = 0;
            matrix[3] = k*kodd*d2;
            crReturn(false);
            ++k;
            kodd += 2;
        }

        crEnd;
    }
};

struct ErfBaseConstructor : MonotoneConstructor {
    MonotoneConstructor *clone() { return new ErfBaseConstructor(); };
    Spigot *construct(const bigint &n, const bigint &d) {
        return new ErfBaseRational(n, d, 1, 1);
    }
};
struct PhiBaseConstructor : MonotoneConstructor {
    MonotoneConstructor *clone() { return new PhiBaseConstructor(); };
    Spigot *construct(const bigint &n, const bigint &d) {
        return new ErfBaseRational(n, d, 1, 2);
    }
};

Spigot *spigot_erf(Spigot *a)
{
    bigint n, d;
    if (a->is_rational(&n, &d) && n == 0)
        return spigot_integer(0);
    return spigot_div(spigot_monotone(new ErfBaseConstructor, a),
                      spigot_sqrt(spigot_rational_mul(spigot_pi(),1,4)));
}

Spigot *spigot_erfc(Spigot *a)
{
    bigint n, d;
    if (a->is_rational(&n, &d) && n == 0)
        return spigot_integer(1);
    return spigot_sub(spigot_integer(1),
                      spigot_div(spigot_monotone(new ErfBaseConstructor, a),
                                 spigot_sqrt(spigot_rational_mul(spigot_pi(),
                                                                 1,4))));
}

Spigot *spigot_Phi(Spigot *a)
{
    bigint n, d;
    if (a->is_rational(&n, &d) && n == 0)
        return spigot_rational(1, 2);
    return spigot_add(spigot_rational(1,2),
                      spigot_div(spigot_monotone(new PhiBaseConstructor, a),
                                 spigot_sqrt(spigot_rational_mul(spigot_pi(),
                                                                 2, 1))));
}

Spigot *spigot_erfinv(Spigot *t)
{
    /*
     * To compute inverse erf using spigot_monotone_invert, we must
     * start by finding an interval of numbers which reliably bracket
     * the right answer. That is, we need a number with erf(x) < a,
     * and one - ideally not much bigger - with erf(x) > a.
     *
     * Lemma 1: for all x,t, exp(-t^2) <= exp(x^2 - 2xt), with
     * equality only at x=t.
     *
     * Proof: exp is strictly increasing, so this is true iff it's
     * still true with the exps stripped off, i.e. we need to show
     * -t^2 <= x^2 - 2xt, which rearranges to 0 <= (x-t)^2, which is
     * clearly non-negative everywhere, and zero iff x-t = 0. []
     *
     * Lemma 2: for x > 0, erfc(x) < 1/sqrt(pi) exp(-x^2)/x.
     *
     * Proof: erfc(x) = 2/sqrt(pi) int_x^inf exp(-t^2) dt
     *               <= 2/sqrt(pi) int_x^inf exp(x^2-2xt) dt  (by Lemma 1)
     *                = 2/sqrt(pi) exp(x^2) int_x^inf exp(-2xt) dt
     *                = 2/sqrt(pi) exp(x^2) [0 - exp(-2x^2)/-2x]
     *                = 1/sqrt(pi) exp(-x^2)/x.
     *
     * And the <= is easily seen to be an < by further observing that
     * equality in the Lemma 1 inequality only holds at one single
     * point, so we cannot still have equality once we integrate. []
     *
     * Lemma 3: for all x>0, erfc(x) < exp(-x^2).
     *
     * Proof: if x >= 1/sqrt(pi), then 1/(x sqrt(pi)) <= 1, so by Lemma 2,
     *     erfc(x) < 1/(x sqrt(pi)) exp(-x^2) <= exp(-x^2).
     *
     * And on the remaining interval [0, 1/sqrt(pi)] it's easy to see
     * that the inequality holds, just by plotting the two graphs and
     * looking at them, or else by observing that exp(-x^2) has a
     * negative second derivative throughout that interval (its 2nd
     * derivative doesn't go positive again until x reaches 1/sqrt(2))
     * while erfc(x) has a positive one, hence they curve away from
     * each other and can't cross over. []
     *
     * Theorem: for 0 < t < 1, erf(sqrt(-log(1-t))) > t.
     *
     * Proof: let x = sqrt(-log(1-t)). Then we have
     *
     *    erf(sqrt(-log(1-t))) = 1 - erfc(x)
     *                         > 1 - exp(-x^2)  (by Lemma 3)
     *                         = 1 - exp(log(1-t))
     *                         = 1 - (1-t) = t. []
     *
     * Corollary: if we're looking for inverse-erf of some t > 0, then
     * sqrt(-log(1-t)) is an upper bound on the answer. (For t < 0,
     * just flip all the signs, of course.)
     *
     * To find a lower bound, we just do one iteration of
     * Newton-Raphson. Since erf has a negative 2nd derivative (for
     * x>0), this should always give us something on the far side of
     * the root.
     */

    /*
     * Start by checking the input number's sign. To avoid an
     * exactness hazard at erfinv(0), we do this only approximately,
     * leaving a small interval around 0 where we aren't sure of the
     * sign. In that interval it's safe to choose very simple upper
     * and lower bounds anyway.
     */
    int sign;
    {
        StaticGenerator test(t->clone());
        bigint approx = test.get_approximate_approximant(64);
        if (approx < -2)
            sign = -1;
        else if (approx > +2)
            sign = 1;
        else
            sign = 0;
    }

    bigint nlo, nhi, d;

    if (sign == 0) {
        /*
         * The input number could be in the range [-1/16, +1/16].
         * (get_approximate_approximant returned at least -2 and at
         * most +2, and guarantees to be within 2 of the real answer,
         * so the real answer is in the range [-4/64, +4/64].)
         *
         * erfinv(1/16) is just under 1/18, so +-1/18 will do as our
         * limits.
         */
        nlo = -1;
        nhi = +1;
        d = 18;
    } else {
        bigint dlo, dhi;

        if (sign < 0)
            t = spigot_neg(t);

        // Find an upper bound.
        Spigot *upperbound;
        {
            BracketingGenerator boundgen
                (spigot_sqrt
                 (spigot_neg(spigot_log(spigot_sub(spigot_integer(1),
                                                   t->clone())))));
            while (1) {
                bigint n1, n2;
                boundgen.get_bracket(&n1, &n2, &dhi);

                // dprint("upper bound: trying (%b, %b) / %b", &n1, &n2, &dhi);

                int s = parallel_sign_test
                    (spigot_sub(spigot_erf(spigot_rational(n1, dhi)),
                                t->clone()),
                     spigot_sub(spigot_erf(spigot_rational(n2, dhi)),
                                t->clone()));

                // dprint("s = %d", s);

                if (s == +1) {
                    nhi = n1;
                    break;
                } else if (s == +2) {
                    nhi = n2;
                    break;
                }
            }

            // Having found a nice rational upper bound, replace the
            // theoretical one with that, for speed and simplicity.
            upperbound = spigot_rational(nhi, dhi);
        }

        // dprint("upper bound (%b/%b)", &nhi, &dhi);

        // Now do a Newton-Raphson iteration to find a lower bound.
        {
            BracketingGenerator boundgen
                (spigot_sub(upperbound->clone(),
                            spigot_div(spigot_sub(spigot_erf(upperbound->clone()),
                                                  t->clone()),
                                       spigot_div(spigot_exp(spigot_neg(spigot_mul(upperbound->clone(),
                                                                                   upperbound->clone()))),
                                                  spigot_rational_mul(spigot_sqrt(spigot_pi()), 1, 2)))));
            while (1) {
                bigint n1, n2;
                boundgen.get_bracket(&n1, &n2, &dlo);

                // dprint("lower bound: trying (%b, %b) / %b", &n1, &n2, &dlo);

                int s = parallel_sign_test
                    (spigot_sub(spigot_erf(spigot_rational(n1, dlo)),
                                t->clone()),
                     spigot_sub(spigot_erf(spigot_rational(n2, dlo)),
                                t->clone()));

                // dprint("s = %d", s);

                if (s == -1) {
                    nlo = n1;
                    break;
                } else if (s == -2) {
                    nlo = n2;
                    break;
                }
            }
        }
        delete upperbound;

        // dprint("bounds (%b/%b, %b/%b)", &nlo, &dlo, &nhi, &dhi);

        nlo *= dhi;
        nhi *= dlo;
        d = dhi * dlo;
    }

    // Now go and do the full root-finding step with those bounds.
    // First we scale t by sqrt(pi)/2, to compensate for the fact that
    // erfbase_constructor doesn't do that for us.
    t = spigot_mul(t, spigot_sqrt(spigot_rational_mul(spigot_pi(),1,4)));
    Spigot *ret = spigot_monotone_invert(new ErfBaseConstructor, true,
                                         nlo, nhi, d, t);
    if (sign < 0)
        ret = spigot_neg(ret);
    return ret;
}

Spigot *spigot_erfcinv(Spigot *a)
{
    return spigot_erfinv(spigot_sub(spigot_integer(1), a));
}

Spigot *spigot_Phiinv(Spigot *a)
{
    return spigot_mul(spigot_neg(spigot_sqrt(spigot_integer(2))),
                      spigot_erfcinv(spigot_rational_mul(a, 2, 1)));
}
