Module FastModularInverse

Computation of a modular inverse, unsing a fast method

Require Import Coqlib.


Lemma succ_pred_double : forall x,
    Pos.succ (Pos.pred_double x) = Pos.mul 2 x.
Proof.
  intros x.
  induction x.
  - reflexivity.
  - simpl. rewrite IHx. reflexivity.
  - reflexivity.
Qed.

Lemma div2_2x_1 : forall x,
    Z.div2 (2*x + 1) = x.
Proof.
  intros x.
  destruct x as [| p | p].
  - (* x = 0 *)
    rewrite <- Zmult_0_r_reverse. reflexivity.
  - (* x = Zpos p *)
    transitivity (Z.div2 (Z.pos p~1)). f_equal. reflexivity.
  - (* x = Zneg p *)
    destruct p as [p' | p' |].
    + (* p = p'~1 *)
      transitivity (Z.div2 (Z.neg p'~0~1)). f_equal. reflexivity.
    + (* p = p'~0 *)
      transitivity (Z.div2 (Z.succ (Z.neg p'~0~0))). f_equal. simpl.
      rewrite succ_pred_double. auto.
    + auto.
Qed.

Lemma a_plus_b_squared : forall a b,
    (a + b) * (a + b) = a*a + a*b*2 + b*b.
Proof.
  intros.
  rewrite Z.mul_add_distr_r.
  rewrite !Z.mul_add_distr_l.
  replace (b*a) with (a*b) by apply Z.mul_comm.
  omega.
Qed.


Fast Bézout algorithm and specification

Fixpoint fast_bezout (a: Z) (n: nat) : Z * Z :=
  match n with
  | O => (1, -Z.div2 a)
  | S n' => let (u, v) := fast_bezout a n' in
            (a*u*u + u*(2 ^ (2 ^ (Z.of_nat n')))*v*2, v*v)
  end.


Theorem fast_bezout_spec : forall (a: Z) (n: nat),
    Z.odd a = true ->
    let (u, v) := (fast_bezout a n) in
      a * u + (2 ^ (2 ^ Z.of_nat n)) * v = 1.
Proof.
  intros a n.
  intros Hodd.
  
  induction n.
  - unfold fast_bezout.
    rewrite (Zdiv2_odd_eqn a).
    rewrite Hodd.
    rewrite div2_2x_1.
    replace (2 ^ 2 ^ Z.of_nat 0) with 2 by reflexivity.
    omega.
  - remember (2 ^ 2 ^ Z.of_nat n) as b.
    assert(Hbb: 2 ^ 2 ^ Z.of_nat (S n) = b * b).
    { rewrite Heqb.
      rewrite <- Z.pow_add_r.
      - f_equal.
        rewrite Zplus_diag_eq_mult_2.
        transitivity (2 ^ Z.of_nat n * 2 ^ 1).
        + rewrite <- Z.pow_add_r; try omega.
          f_equal. simpl.
          rewrite Zpos_P_of_succ_nat. auto.
        + auto.
      - apply Z.pow_nonneg. omega.
      - apply Z.pow_nonneg. omega.
    }
    rewrite Hbb.
      
    simpl. rewrite <- Heqb.
    destruct (fast_bezout a n) as (u, v).
    transitivity ((a * u) * (a * u) + (a * u) * (b * v) * 2 + (b * v) * (b * v)).
    + rewrite Z.mul_add_distr_l.
      assert(H1: a * (a * u * u) = a * u * (a * u)).
      { rewrite !Z.mul_assoc. f_equal.
        rewrite <- !Z.mul_assoc. f_equal.
        apply Z.mul_comm.
      }
      assert(H2: a * (u * b * v * 2) = a * u * (b * v) * 2).
      { rewrite !Z.mul_assoc. reflexivity. }
      assert(H3: b * b * (v * v) = b * v * (b * v)).
      { rewrite !Z.mul_assoc. f_equal.
        rewrite <- !Z.mul_assoc. f_equal.
        apply Z.mul_comm.
      }
      omega.
    + rewrite <- a_plus_b_squared.
      rewrite IHn.
      reflexivity.
Qed.





Fast modular inverse and specification

Definition fast_mod_inv (a: Z) (n: nat) : Z :=
  fst (fast_bezout a n) mod (2 ^ (2 ^ Z.of_nat n)).

Theorem fast_mod_inv_spec : forall (a: Z) (n: nat),
    Z.odd a = true ->
    ((a * fast_mod_inv a n) mod (2 ^ (2 ^ Z.of_nat n))) = 1.
Proof.
  intros a n.
  intros H.
  unfold fast_mod_inv.
  destruct (fast_bezout a n) as (u, v) eqn:H2.
  apply fast_bezout_spec with a n in H.
  remember (2 ^ 2 ^ Z.of_nat n) as b.

  assert(Hb: 1 < b).
  { rewrite Heqb.
    apply Z.pow_gt_1. omega.
    apply Z.pow_pos_nonneg; omega.
  }
  
  rewrite H2 in H.
  simpl.
  rewrite Z.mul_mod_idemp_r; try omega.
  replace (a * u) with (1 - b * v) by omega.
  rewrite Z.mul_comm.
  rewrite <- Z.add_opp_r.
  rewrite Zopp_mult_distr_l.
  rewrite Z_mod_plus_full.
  apply Zmod_1_l; omega.
Qed.




Fast identity insertion and specification

Definition fast_identity (a b:Z) (n: nat) :=
  fun x => let c := (fast_mod_inv a n) in
           let d := - b * c in
           (a*x + b) * c + d.

Theorem fast_identity_spec : forall (a b: Z) (n: nat) (x: Z),
    Z.odd a = true ->
    (fast_identity a b n) x mod (2 ^ (2 ^ Z.of_nat n)) = x mod (2 ^ (2 ^ Z.of_nat n)).
Proof.
  intros.
  unfold fast_identity.

  assert(Hpos: 1 < 2 ^ 2 ^ Z.of_nat n).
  { apply Z.pow_gt_1. omega.
    apply Z.pow_pos_nonneg; omega.
  }
  
  rewrite Z.mul_add_distr_r.
  rewrite <- Zopp_mult_distr_l.
  rewrite <- Z.add_assoc.
  rewrite Z.add_opp_r.
  rewrite Z.sub_diag.
  rewrite Z.add_0_r.

  rewrite (Z.mul_comm a x).
  rewrite <- Z.mul_assoc.
  rewrite Z.mul_mod; try omega.
  rewrite fast_mod_inv_spec; auto.
  rewrite Z.mul_1_r.
  apply Z.mod_mod; omega.
Qed.