Qcert.NNRCMR.Optim.NNRCMRRewrite


Section NRewMR.



  Context {fruntime:foreign_runtime}.
  Context {fredop:foreign_reduce_op}.

  Context (h:list(string×string)).



  Definition is_id_function (f: var × nnrc) :=
    let (x, n) := f in
    match n with
    | NNRCVar yequiv_decb x y
    | NNRCUnop AIdOp (NNRCVar y) ⇒ equiv_decb x y
    | _false
    end.


  Definition is_coll_function (f: var × nnrc) :=
    let (x, n) := f in
    match n with
    | NNRCUnop AColl (NNRCVar y) ⇒ equiv_decb x y
    | _false
    end.


  Definition is_constant_function (f: var × nnrc) :=
    let (x, n) := f in
    match n with
    | NNRCConst _true
    | _false
    end.


  Definition is_flatten_function (f: var × nnrc) :=
    let (x, n) := f in
    match n with
    | NNRCUnop AFlatten (NNRCVar y) ⇒ equiv_decb x y
    | NNRCLet a (NNRCUnop AFlatten (NNRCVar y)) (NNRCVar b) ⇒ equiv_decb x y && equiv_decb a b
    | _false
    end.

  Lemma is_flatten_function_correct (x:var) (n:nnrc) (env:bindings) :
    is_flatten_function (x,n) = true
     d,
      lookup equiv_dec env x = Some d
      (nnrc_core_eval h env n) = lift_oncoll (fun l ⇒ (lift dcoll (rflatten l))) d.


  Definition is_uncoll_function_arg (f: var × nnrc) :=
    let (x, n) := f in
    match n with
    | NNRCLet a
             (NNRCEither (NNRCUnop ASingleton (NNRCVar y))
                        b (NNRCVar b')
                        c (NNRCConst dunit))
             n'
      equiv_decb x y && equiv_decb b b' && equiv_decb a x
    | _false
    end.



  Definition is_id_scalar_map map :=
    match map with
    | MapDist ffalse
    | MapDistFlatten ffalse
    | MapScalar fis_coll_function f
    end.

  Lemma id_scalar_map_correct mr_map d :
    is_id_scalar_map mr_map = true
    mr_map_eval h mr_map (Dlocal d) = Some (d::nil).

  Definition is_id_dist_map map :=
    match map with
    | MapDist fis_id_function f
    | MapDistFlatten fis_coll_function f
    | MapScalar _false
    end.

  Lemma id_dist_map_correct mr_map coll :
    is_id_dist_map mr_map = true
    mr_map_eval h mr_map (Ddistr coll) = Some coll.


  Definition is_dispatch_map map :=
    match map with
    | MapScalar fis_id_function f
    | _false
    end.

  Lemma dispatch_map_correct mr_map coll:
    is_dispatch_map mr_map = true
    mr_map_eval h mr_map (Dlocal (dcoll coll)) = Some coll.


  Definition is_scalar_map map :=
    match map with
    | MapScalar _true
    | _false
    end.


  Definition is_flatten_dist_map map :=
    match map with
    | MapDistFlatten fis_id_function f
    | _false
    end.

  Lemma flatten_dist_map_correct mr_map coll:
    is_flatten_dist_map mr_map = true
    mr_map_eval h mr_map (Ddistr coll) = rflatten coll.




  Definition is_flatten_collect red :=
    match red with
    | RedIdfalse
    | RedCollect reduceis_flatten_function reduce
    | RedOp opfalse
    | RedSingletonfalse
    end.


  Definition is_id_reduce red :=
    match red with
    | RedIdtrue
    | RedCollect reducefalse
    | RedOp opfalse
    | RedSingletonfalse
    end.

  Lemma id_reduce_correct red coll:
    is_id_reduce red = true
    mr_reduce_eval h red coll = Some (Ddistr coll).


  Definition is_id_collect red :=
    match red with
    | RedIdfalse
    | RedCollect reduceis_id_function reduce
    | RedOp opfalse
    | RedSingletonfalse
    end.

  Lemma id_collect_correct red coll:
    is_id_collect red = true
    mr_reduce_eval h red coll = Some (Dlocal (dcoll coll)).

  Definition is_singleton_reduce red :=
    match red with
    | RedIdfalse
    | RedCollect _false
    | RedOp _false
    | RedSingletontrue
    end.

  Lemma singleton_reduce_correct red d:
    is_singleton_reduce red = true
    mr_reduce_eval h red (d::nil) = Some (Dlocal d).

  Definition is_uncoll_collect red :=
    match red with
    | RedIdfalse
    | RedCollect reduceis_uncoll_function_arg reduce
    | RedOp opfalse
    | RedSingletonfalse
    end.

  Definition suppress_uncoll_in_collect_reduce red :=
    match red with
    | RedIdNone
    | RedCollect f
      if is_uncoll_function_arg f then
        let (x, n) := f in
        match n with
        | NNRCLet a _ n'Some (RedCollect (a, n'))
        | _None
        end
      else
        None
    | RedOp opNone
    | RedSingletonNone
    end.

  Lemma suppress_uncoll_in_collect_reduce_correct reduce coll:
     reduce',
      reduce_well_formed reduce
      is_uncoll_collect reduce = true
      suppress_uncoll_in_collect_reduce reduce = Some reduce'
      mr_reduce_eval h reduce (dcoll coll :: nil) =
      mr_reduce_eval h reduce' coll.



  Definition is_id_dist_mr mr :=
    match mr.(mr_reduce) with
    | RedIdis_id_dist_map mr.(mr_map)
    | RedCollect reducefalse
    | RedOp opfalse
    | RedSingletonfalse
    end.

  Lemma is_id_dist_mr_correct (mr:mr) :
    is_id_dist_mr mr = true
     loc_d,
      map_well_localized mr.(mr_map) loc_d
      match loc_d with
      | Ddistr coll
        mr_eval h mr loc_d = Some (Ddistr coll)
      | Dlocal d
        mr_eval h mr loc_d = Some (Ddistr (d::nil))
      end.

  Definition is_id_scalar_mr mr :=
    match mr.(mr_reduce) with
    | RedIdfalse
    | RedCollect reducefalse
    | RedOp opfalse
    | RedSingletonis_id_scalar_map mr.(mr_map)
    end.

  Lemma is_id_scalar_mr_correct (mr:mr) :
    is_id_scalar_mr mr = true
     loc_d,
      map_well_localized mr.(mr_map) loc_d
      match loc_d with
      | Ddistr collFalse
      | Dlocal d
        mr_eval h mr loc_d = Some (Dlocal d)
      end.


  Definition is_kindofflatten_mr mr :=
      is_id_dist_map mr.(mr_map) && is_flatten_collect mr.(mr_reduce).

  Lemma is_kindofflatten_mr_correct (mr:mr) (loc_d: ddata) :
    is_kindofflatten_mr mr = true
    map_well_localized mr.(mr_map) loc_d
    match loc_d with
    | Ddistr coll
      mr_eval h mr loc_d = lift (fun lDlocal (dcoll l)) (rflatten coll)
    | Dlocal d
      mr_eval h mr loc_d = lift (fun lDlocal (dcoll l)) (rflatten (d::nil))
    end.


  Definition is_collect_mr mr :=
    is_id_dist_map mr.(mr_map) && is_id_collect mr.(mr_reduce).

  Lemma mr_collect_collects (mr:mr) (loc_d:ddata) :
    is_collect_mr mr = true
    map_well_localized mr.(mr_map) loc_d
    match loc_d with
    | Ddistr coll
      mr_eval h mr loc_d = Some (Dlocal (dcoll coll))
    | Dlocal d
      mr_eval h mr loc_d = Some (Dlocal (dcoll (d::nil)))
    end.


  Definition is_dispatch_mr mr :=
    match mr.(mr_reduce) with
    | RedId
      is_dispatch_map mr.(mr_map)
    | RedCollect reducefalse
    | RedOp _false
    | RedSingletonfalse
    end.

  Lemma mr_dispatch_correct (mr:mr) (coll:list data) :
    is_dispatch_mr mr = true
    mr_eval h mr (Dlocal (dcoll coll)) = Some (Ddistr coll).


  Definition map_collect_flatten_to_map_flatten_collect mr :=
    match mr.(mr_map) with
    | MapDist f
      if is_flatten_collect mr.(mr_reduce) then
        let mr' :=
            mkMR
              mr.(mr_input)
              mr.(mr_output)
              (MapDistFlatten f)
              (RedCollect id_function)
        in
        Some (mr'::nil)
      else
        None
    | _None
    end.

  Lemma map_collect_flatten_to_map_flatten_collect_correct mr:
     mr_chain,
      map_collect_flatten_to_map_flatten_collect mr = Some mr_chain
       env,
      mr_chain_eval h env (mr::nil) = mr_chain_eval h env mr_chain.



  Definition merge_correct (mf:mr mr option mr) (m1 m2: mr) :=
     (m3:mr),
      m1.(mr_output) m1.(mr_input)
      m2.(mr_output) m2.(mr_input)
      m2.(mr_output) m1.(mr_input)
      mf m1 m2 = Some m3
       (loc_d: ddata),
        map_well_localized m1.(mr_map) loc_d
        get_mr_chain_result
          (mr_chain_eval h ((m1.(mr_input),loc_d)::nil) (m1::m2::nil)) =
        get_mr_chain_result
          (mr_chain_eval h ((m1.(mr_input),loc_d)::nil) (m3::nil)).

  Definition merge_correct_weak (mf:mr mr option mr) :=
     (m1 m2 m3:mr),
      m1.(mr_output) m1.(mr_input)
      m2.(mr_output) m2.(mr_input)
      m2.(mr_output) m1.(mr_input)
      mf m1 m2 = Some m3
       (loc_d: ddata),
         (result: data),
          map_well_localized m1.(mr_map) loc_d
          get_mr_chain_result
            (mr_chain_eval h ((m1.(mr_input),loc_d)::nil) (m1::m2::nil)) = Some result
          get_mr_chain_result
            (mr_chain_eval h ((m1.(mr_input),loc_d)::nil) (m3::nil)) = Some result.


  Definition merge_collect_dispatch mr1 mr2 :=
    if (equiv_decb mr1.(mr_output) mr2.(mr_input))
         && is_id_collect mr1.(mr_reduce)
         && is_dispatch_map mr2.(mr_map) then
      let mr :=
          mkMR
            mr1.(mr_input)
            mr2.(mr_output)
            mr1.(mr_map)
            mr2.(mr_reduce)
      in
      Some mr
    else
      None.

  Lemma merge_collect_dispatch_correct mr1 mr2:
    merge_correct merge_collect_dispatch mr1 mr2.


  Definition merge_mr_id_dist_l mr1 mr2 :=
    if equiv_decb mr1.(mr_output) mr2.(mr_input) && is_id_dist_mr mr1 then
      let mr :=
          mkMR
            mr1.(mr_input)
            mr2.(mr_output)
            mr2.(mr_map)
            mr2.(mr_reduce)
      in
      Some mr
    else
      None.

  Lemma merge_mr_id_dist_l_correct mr1 mr2:
    merge_correct merge_mr_id_dist_l mr1 mr2.

  Definition merge_mr_id_scalar_l mr1 mr2 :=
    if equiv_decb mr1.(mr_output) mr2.(mr_input) && is_id_scalar_mr mr1 then
      let mr :=
          mkMR
            mr1.(mr_input)
            mr2.(mr_output)
            mr2.(mr_map)
            mr2.(mr_reduce)
      in
      Some mr
    else
      None.

  Lemma merge_mr_id_scalar_l_correct mr1 mr2:
    merge_correct merge_mr_id_scalar_l mr1 mr2.






  Definition merge_id_reduce_id_dist_map mr1 mr2 :=
    if equiv_decb mr1.(mr_output) mr2.(mr_input)
       && is_id_reduce mr1.(mr_reduce)
       && is_id_dist_map mr2.(mr_map) then
      let mr :=
          mkMR
            mr1.(mr_input)
            mr2.(mr_output)
            mr1.(mr_map)
            mr2.(mr_reduce)
      in
      Some mr
    else
      None.

  Lemma merge_id_reduce_id_dist_map_correct mr1 mr2:
    merge_correct merge_id_reduce_id_dist_map mr1 mr2.


  Definition merge_singleton_reduce_id_scalar_map mr1 mr2 :=
    if equiv_decb mr1.(mr_output) mr2.(mr_input)
       && is_singleton_reduce mr1.(mr_reduce)
       && is_id_scalar_map mr2.(mr_map) then
      let mr :=
          mkMR
            mr1.(mr_input)
            mr2.(mr_output)
            mr1.(mr_map)
            mr2.(mr_reduce)
      in
      Some mr
    else
      None.

  Lemma merge_singleton_reduce_id_scalar_map_correct mr1 mr2:
    ( loc_d, d, mr_map_eval h mr1.(mr_map) loc_d = Some (d::nil))
    merge_correct merge_singleton_reduce_id_scalar_map mr1 mr2.


  Definition merge_reduce_id_flatten_map (mr1 mr2:mr) :=
    match mr1.(mr_map) with
    | MapDist map1
      if equiv_decb mr1.(mr_output) mr2.(mr_input)
          && is_id_reduce mr1.(mr_reduce)
          && is_flatten_dist_map mr2.(mr_map) then
       let mr :=
           mkMR
             mr1.(mr_input)
             mr2.(mr_output)
             (MapDistFlatten map1)
             mr2.(mr_reduce)
       in
       Some mr
      else
        None
    | _None
    end.

  Lemma merge_reduce_id_flatten_map_correct mr1 mr2:
    merge_correct merge_reduce_id_flatten_map mr1 mr2.


  Definition merge_scalar_singleton_scalar (mr1 mr2: mr) :=
    if equiv_decb mr1.(mr_output) mr2.(mr_input)
       && is_singleton_reduce mr1.(mr_reduce) then
      match mr1.(mr_map), mr2.(mr_map) with
      | MapScalar (x1, NNRCUnop AColl n1), MapScalar (x2, n2)
        let map :=
            MapScalar (x1, NNRCLet x2 n1 n2)
        in
        let mr :=
            mkMR
              mr1.(mr_input)
              mr2.(mr_output)
              map
              mr2.(mr_reduce)
        in
        Some mr
      | _, _None
      end
    else
      None.

  Lemma merge_scalar_singleton_scalar_correct mr1 mr2:
    mr_well_formed mr2
    merge_correct merge_scalar_singleton_scalar mr1 mr2.










  Definition merge_mr_last mr (last: ((list var × nnrc) × list (var × dlocalization)) ) :=
    let '((params, n), args) := last in
    match (params, args) with
    | (x::nil, (output, Vscalar)::nil)
      if equiv_decb output mr.(mr_output) && is_singleton_reduce mr.(mr_reduce) then
        match mr.(mr_map) with
        | MapScalar (x1, NNRCUnop AColl n1)
          Some ((mr.(mr_input)::nil, NNRCLet x
                                            (NNRCLet x1 (NNRCVar mr.(mr_input)) n1)
                                            n),
                (mr.(mr_input), Vscalar)::nil)
        | _None
        end
      else
        None
    | (_, _)None
    end.

  Definition merge_last mrl :=
    let '(chain, output, last) :=
        List.fold_right
          (fun mr chain_output_last
             match chain_output_last with
             | (nil, None, last)
               match merge_mr_last mr last with
               | Some last'(nil, None, last')
               | None(mr::nil, Some (mr_output mr), last)
               end
             | (acc, output, last)(mr::acc, output, last)
             end)
          (nil, None, mrl.(mr_last)) mrl.(mr_chain)
    in
    match output with
    | NoneNone
    | Some outputSome (mkMRChain mrl.(mr_inputs_loc) chain last)
    end.


  Definition mr_chain_apply_rewrite (rew: mr option (list mr)) l :=
    List.flat_map
      (fun mr
         match rew mr with
         | Nonemr::nil
         | Some mr_chainmr_chain
         end)
      l.

  Definition apply_rewrite (rew: mr option (list mr)) mrl :=
    mkMRChain
      mrl.(mr_inputs_loc)
      (mr_chain_apply_rewrite rew mrl.(mr_chain))
      mrl.(mr_last).

  Definition mr_chain_apply_merge (merge: mr mr option mr) l :=
    let output_vars : list var := List.fold_left (fun acc mrmr.(mr_output) :: acc) l nil in
    List.fold_right
      (fun mr1 acc
         if leb (mult output_vars mr1.(mr_output)) 1 then
           let l_optimized :=
               List.fold_right
                 (fun mr2 acc
                    match merge mr1 mr2 with
                    | Nonemr2 :: acc
                    | Some mr12mr12 :: acc
                    end)
                 nil acc
           in
           mr1 :: l_optimized
         else
           mr1 :: acc)
      nil l.

  Definition apply_merge (merge: mr mr option mr) mrl :=
    mkMRChain
      mrl.(mr_inputs_loc)
      (mr_chain_apply_merge merge mrl.(mr_chain))
      mrl.(mr_last).

  Fixpoint mr_chain_cleanup l (to_keep: list var) :=
    let (to_keep', res) :=
        List.fold_right
          (fun r (acc: list var × list mr) ⇒
             let (to_keep, res) := acc in
             if in_dec equiv_dec r.(mr_output) to_keep then
               (r.(mr_input)::to_keep, r::res)
             else
               (to_keep, res))
          (to_keep, nil) l
    in
    res.

  Definition mr_cleanup mrl to_keep :=
    mkMRChain
      mrl.(mr_inputs_loc)
      (mr_chain_cleanup mrl.(mr_chain) to_keep)
      mrl.(mr_last).


  Definition mr_optimize_step (l: nnrcmr): nnrcmr :=
    let to_keep := List.map fst (snd l.(mr_last)) in
    let l := apply_rewrite map_collect_flatten_to_map_flatten_collect l in
    let l := apply_merge merge_id_reduce_id_dist_map l in
    let l := mr_cleanup l to_keep in
    let l := apply_merge merge_singleton_reduce_id_scalar_map l in
    let l := mr_cleanup l to_keep in
    let l := apply_merge merge_reduce_id_flatten_map l in
    let l := mr_cleanup l to_keep in
    
    
    
    
    let l := apply_merge merge_collect_dispatch l in
    let l := mr_cleanup l to_keep in
    let l := apply_merge merge_mr_id_dist_l l in
    let l := mr_cleanup l to_keep in
    let l := apply_merge merge_mr_id_scalar_l l in
    let l := mr_cleanup l to_keep in
    
    
    let l := apply_merge merge_scalar_singleton_scalar l in
    let l := mr_cleanup l to_keep in
    let l :=
        match merge_last l with
        | Nonel
        | Some ll
        end
    in
    l.

  Fixpoint mr_optimize_loop n (l: nnrcmr) :=
    match n with
    | 0 ⇒ l
    | S n
      let l := mr_optimize_step l in
      mr_optimize_loop n l
    end.

  Definition mr_optimize (l: nnrcmr) :=
    mr_optimize_loop 10 l.

  Definition fresh_mr_var (prefix: string) (vars: list var) :=
    let x := fresh_var prefix vars in
    (x, x::vars).

  Definition get_mr_chain_vars mr_chain :=
    List.fold_left
      (fun acc mrmr.(mr_input) :: mr.(mr_output) :: acc)
      mr_chain nil.

  Definition get_nnrcmr_vars mrl :=
    get_mr_chain_vars mrl.(mr_chain).

End NRewMR.