Qcert.NNRCMR.Optim.NNRCMRRewrite
Section NRewMR.
Context {fruntime:foreign_runtime}.
Context {fredop:foreign_reduce_op}.
Context (h:list(string×string)).
Definition is_id_function (f: var × nnrc) :=
let (x, n) := f in
match n with
| NNRCVar y ⇒ equiv_decb x y
| NNRCUnop AIdOp (NNRCVar y) ⇒ equiv_decb x y
| _ ⇒ false
end.
Definition is_coll_function (f: var × nnrc) :=
let (x, n) := f in
match n with
| NNRCUnop AColl (NNRCVar y) ⇒ equiv_decb x y
| _ ⇒ false
end.
Definition is_constant_function (f: var × nnrc) :=
let (x, n) := f in
match n with
| NNRCConst _ ⇒ true
| _ ⇒ false
end.
Definition is_flatten_function (f: var × nnrc) :=
let (x, n) := f in
match n with
| NNRCUnop AFlatten (NNRCVar y) ⇒ equiv_decb x y
| NNRCLet a (NNRCUnop AFlatten (NNRCVar y)) (NNRCVar b) ⇒ equiv_decb x y && equiv_decb a b
| _ ⇒ false
end.
Lemma is_flatten_function_correct (x:var) (n:nnrc) (env:bindings) :
is_flatten_function (x,n) = true →
∀ d,
lookup equiv_dec env x = Some d →
(nnrc_core_eval h env n) = lift_oncoll (fun l ⇒ (lift dcoll (rflatten l))) d.
Definition is_uncoll_function_arg (f: var × nnrc) :=
let (x, n) := f in
match n with
| NNRCLet a
(NNRCEither (NNRCUnop ASingleton (NNRCVar y))
b (NNRCVar b')
c (NNRCConst dunit))
n' ⇒
equiv_decb x y && equiv_decb b b' && equiv_decb a x
| _ ⇒ false
end.
Definition is_id_scalar_map map :=
match map with
| MapDist f ⇒ false
| MapDistFlatten f ⇒ false
| MapScalar f ⇒ is_coll_function f
end.
Lemma id_scalar_map_correct mr_map d :
is_id_scalar_map mr_map = true →
mr_map_eval h mr_map (Dlocal d) = Some (d::nil).
Definition is_id_dist_map map :=
match map with
| MapDist f ⇒ is_id_function f
| MapDistFlatten f ⇒ is_coll_function f
| MapScalar _ ⇒ false
end.
Lemma id_dist_map_correct mr_map coll :
is_id_dist_map mr_map = true →
mr_map_eval h mr_map (Ddistr coll) = Some coll.
Definition is_dispatch_map map :=
match map with
| MapScalar f ⇒ is_id_function f
| _ ⇒ false
end.
Lemma dispatch_map_correct mr_map coll:
is_dispatch_map mr_map = true →
mr_map_eval h mr_map (Dlocal (dcoll coll)) = Some coll.
Definition is_scalar_map map :=
match map with
| MapScalar _ ⇒ true
| _ ⇒ false
end.
Definition is_flatten_dist_map map :=
match map with
| MapDistFlatten f ⇒ is_id_function f
| _ ⇒ false
end.
Lemma flatten_dist_map_correct mr_map coll:
is_flatten_dist_map mr_map = true →
mr_map_eval h mr_map (Ddistr coll) = rflatten coll.
Definition is_flatten_collect red :=
match red with
| RedId ⇒ false
| RedCollect reduce ⇒ is_flatten_function reduce
| RedOp op ⇒ false
| RedSingleton ⇒ false
end.
Definition is_id_reduce red :=
match red with
| RedId ⇒ true
| RedCollect reduce ⇒ false
| RedOp op ⇒ false
| RedSingleton ⇒ false
end.
Lemma id_reduce_correct red coll:
is_id_reduce red = true →
mr_reduce_eval h red coll = Some (Ddistr coll).
Definition is_id_collect red :=
match red with
| RedId ⇒ false
| RedCollect reduce ⇒ is_id_function reduce
| RedOp op ⇒ false
| RedSingleton ⇒ false
end.
Lemma id_collect_correct red coll:
is_id_collect red = true →
mr_reduce_eval h red coll = Some (Dlocal (dcoll coll)).
Definition is_singleton_reduce red :=
match red with
| RedId ⇒ false
| RedCollect _ ⇒ false
| RedOp _ ⇒ false
| RedSingleton ⇒ true
end.
Lemma singleton_reduce_correct red d:
is_singleton_reduce red = true →
mr_reduce_eval h red (d::nil) = Some (Dlocal d).
Definition is_uncoll_collect red :=
match red with
| RedId ⇒ false
| RedCollect reduce ⇒ is_uncoll_function_arg reduce
| RedOp op ⇒ false
| RedSingleton ⇒ false
end.
Definition suppress_uncoll_in_collect_reduce red :=
match red with
| RedId ⇒ None
| RedCollect f ⇒
if is_uncoll_function_arg f then
let (x, n) := f in
match n with
| NNRCLet a _ n' ⇒ Some (RedCollect (a, n'))
| _ ⇒ None
end
else
None
| RedOp op ⇒ None
| RedSingleton ⇒ None
end.
Lemma suppress_uncoll_in_collect_reduce_correct reduce coll:
∀ reduce',
reduce_well_formed reduce →
is_uncoll_collect reduce = true →
suppress_uncoll_in_collect_reduce reduce = Some reduce' →
mr_reduce_eval h reduce (dcoll coll :: nil) =
mr_reduce_eval h reduce' coll.
Definition is_id_dist_mr mr :=
match mr.(mr_reduce) with
| RedId ⇒ is_id_dist_map mr.(mr_map)
| RedCollect reduce ⇒ false
| RedOp op ⇒ false
| RedSingleton ⇒ false
end.
Lemma is_id_dist_mr_correct (mr:mr) :
is_id_dist_mr mr = true →
∀ loc_d,
map_well_localized mr.(mr_map) loc_d →
match loc_d with
| Ddistr coll ⇒
mr_eval h mr loc_d = Some (Ddistr coll)
| Dlocal d ⇒
mr_eval h mr loc_d = Some (Ddistr (d::nil))
end.
Definition is_id_scalar_mr mr :=
match mr.(mr_reduce) with
| RedId ⇒ false
| RedCollect reduce ⇒ false
| RedOp op ⇒ false
| RedSingleton ⇒ is_id_scalar_map mr.(mr_map)
end.
Lemma is_id_scalar_mr_correct (mr:mr) :
is_id_scalar_mr mr = true →
∀ loc_d,
map_well_localized mr.(mr_map) loc_d →
match loc_d with
| Ddistr coll ⇒ False
| Dlocal d ⇒
mr_eval h mr loc_d = Some (Dlocal d)
end.
Definition is_kindofflatten_mr mr :=
is_id_dist_map mr.(mr_map) && is_flatten_collect mr.(mr_reduce).
Lemma is_kindofflatten_mr_correct (mr:mr) (loc_d: ddata) :
is_kindofflatten_mr mr = true →
map_well_localized mr.(mr_map) loc_d →
match loc_d with
| Ddistr coll ⇒
mr_eval h mr loc_d = lift (fun l ⇒ Dlocal (dcoll l)) (rflatten coll)
| Dlocal d ⇒
mr_eval h mr loc_d = lift (fun l ⇒ Dlocal (dcoll l)) (rflatten (d::nil))
end.
Definition is_collect_mr mr :=
is_id_dist_map mr.(mr_map) && is_id_collect mr.(mr_reduce).
Lemma mr_collect_collects (mr:mr) (loc_d:ddata) :
is_collect_mr mr = true →
map_well_localized mr.(mr_map) loc_d →
match loc_d with
| Ddistr coll ⇒
mr_eval h mr loc_d = Some (Dlocal (dcoll coll))
| Dlocal d ⇒
mr_eval h mr loc_d = Some (Dlocal (dcoll (d::nil)))
end.
Definition is_dispatch_mr mr :=
match mr.(mr_reduce) with
| RedId ⇒
is_dispatch_map mr.(mr_map)
| RedCollect reduce ⇒ false
| RedOp _ ⇒ false
| RedSingleton ⇒ false
end.
Lemma mr_dispatch_correct (mr:mr) (coll:list data) :
is_dispatch_mr mr = true →
mr_eval h mr (Dlocal (dcoll coll)) = Some (Ddistr coll).
Definition map_collect_flatten_to_map_flatten_collect mr :=
match mr.(mr_map) with
| MapDist f ⇒
if is_flatten_collect mr.(mr_reduce) then
let mr' :=
mkMR
mr.(mr_input)
mr.(mr_output)
(MapDistFlatten f)
(RedCollect id_function)
in
Some (mr'::nil)
else
None
| _ ⇒ None
end.
Lemma map_collect_flatten_to_map_flatten_collect_correct mr:
∀ mr_chain,
map_collect_flatten_to_map_flatten_collect mr = Some mr_chain →
∀ env,
mr_chain_eval h env (mr::nil) = mr_chain_eval h env mr_chain.
Definition merge_correct (mf:mr → mr → option mr) (m1 m2: mr) :=
∀ (m3:mr),
m1.(mr_output) ≠ m1.(mr_input) →
m2.(mr_output) ≠ m2.(mr_input) →
m2.(mr_output) ≠ m1.(mr_input) →
mf m1 m2 = Some m3 →
∀ (loc_d: ddata),
map_well_localized m1.(mr_map) loc_d →
get_mr_chain_result
(mr_chain_eval h ((m1.(mr_input),loc_d)::nil) (m1::m2::nil)) =
get_mr_chain_result
(mr_chain_eval h ((m1.(mr_input),loc_d)::nil) (m3::nil)).
Definition merge_correct_weak (mf:mr → mr → option mr) :=
∀ (m1 m2 m3:mr),
m1.(mr_output) ≠ m1.(mr_input) →
m2.(mr_output) ≠ m2.(mr_input) →
m2.(mr_output) ≠ m1.(mr_input) →
mf m1 m2 = Some m3 →
∀ (loc_d: ddata),
∀ (result: data),
map_well_localized m1.(mr_map) loc_d →
get_mr_chain_result
(mr_chain_eval h ((m1.(mr_input),loc_d)::nil) (m1::m2::nil)) = Some result →
get_mr_chain_result
(mr_chain_eval h ((m1.(mr_input),loc_d)::nil) (m3::nil)) = Some result.
Definition merge_collect_dispatch mr1 mr2 :=
if (equiv_decb mr1.(mr_output) mr2.(mr_input))
&& is_id_collect mr1.(mr_reduce)
&& is_dispatch_map mr2.(mr_map) then
let mr :=
mkMR
mr1.(mr_input)
mr2.(mr_output)
mr1.(mr_map)
mr2.(mr_reduce)
in
Some mr
else
None.
Lemma merge_collect_dispatch_correct mr1 mr2:
merge_correct merge_collect_dispatch mr1 mr2.
Definition merge_mr_id_dist_l mr1 mr2 :=
if equiv_decb mr1.(mr_output) mr2.(mr_input) && is_id_dist_mr mr1 then
let mr :=
mkMR
mr1.(mr_input)
mr2.(mr_output)
mr2.(mr_map)
mr2.(mr_reduce)
in
Some mr
else
None.
Lemma merge_mr_id_dist_l_correct mr1 mr2:
merge_correct merge_mr_id_dist_l mr1 mr2.
Definition merge_mr_id_scalar_l mr1 mr2 :=
if equiv_decb mr1.(mr_output) mr2.(mr_input) && is_id_scalar_mr mr1 then
let mr :=
mkMR
mr1.(mr_input)
mr2.(mr_output)
mr2.(mr_map)
mr2.(mr_reduce)
in
Some mr
else
None.
Lemma merge_mr_id_scalar_l_correct mr1 mr2:
merge_correct merge_mr_id_scalar_l mr1 mr2.
Definition merge_id_reduce_id_dist_map mr1 mr2 :=
if equiv_decb mr1.(mr_output) mr2.(mr_input)
&& is_id_reduce mr1.(mr_reduce)
&& is_id_dist_map mr2.(mr_map) then
let mr :=
mkMR
mr1.(mr_input)
mr2.(mr_output)
mr1.(mr_map)
mr2.(mr_reduce)
in
Some mr
else
None.
Lemma merge_id_reduce_id_dist_map_correct mr1 mr2:
merge_correct merge_id_reduce_id_dist_map mr1 mr2.
Definition merge_singleton_reduce_id_scalar_map mr1 mr2 :=
if equiv_decb mr1.(mr_output) mr2.(mr_input)
&& is_singleton_reduce mr1.(mr_reduce)
&& is_id_scalar_map mr2.(mr_map) then
let mr :=
mkMR
mr1.(mr_input)
mr2.(mr_output)
mr1.(mr_map)
mr2.(mr_reduce)
in
Some mr
else
None.
Lemma merge_singleton_reduce_id_scalar_map_correct mr1 mr2:
(∀ loc_d, ∃ d, mr_map_eval h mr1.(mr_map) loc_d = Some (d::nil)) →
merge_correct merge_singleton_reduce_id_scalar_map mr1 mr2.
Definition merge_reduce_id_flatten_map (mr1 mr2:mr) :=
match mr1.(mr_map) with
| MapDist map1 ⇒
if equiv_decb mr1.(mr_output) mr2.(mr_input)
&& is_id_reduce mr1.(mr_reduce)
&& is_flatten_dist_map mr2.(mr_map) then
let mr :=
mkMR
mr1.(mr_input)
mr2.(mr_output)
(MapDistFlatten map1)
mr2.(mr_reduce)
in
Some mr
else
None
| _ ⇒ None
end.
Lemma merge_reduce_id_flatten_map_correct mr1 mr2:
merge_correct merge_reduce_id_flatten_map mr1 mr2.
Definition merge_scalar_singleton_scalar (mr1 mr2: mr) :=
if equiv_decb mr1.(mr_output) mr2.(mr_input)
&& is_singleton_reduce mr1.(mr_reduce) then
match mr1.(mr_map), mr2.(mr_map) with
| MapScalar (x1, NNRCUnop AColl n1), MapScalar (x2, n2) ⇒
let map :=
MapScalar (x1, NNRCLet x2 n1 n2)
in
let mr :=
mkMR
mr1.(mr_input)
mr2.(mr_output)
map
mr2.(mr_reduce)
in
Some mr
| _, _ ⇒ None
end
else
None.
Lemma merge_scalar_singleton_scalar_correct mr1 mr2:
mr_well_formed mr2 →
merge_correct merge_scalar_singleton_scalar mr1 mr2.
Definition merge_mr_last mr (last: ((list var × nnrc) × list (var × dlocalization)) ) :=
let '((params, n), args) := last in
match (params, args) with
| (x::nil, (output, Vscalar)::nil) ⇒
if equiv_decb output mr.(mr_output) && is_singleton_reduce mr.(mr_reduce) then
match mr.(mr_map) with
| MapScalar (x1, NNRCUnop AColl n1) ⇒
Some ((mr.(mr_input)::nil, NNRCLet x
(NNRCLet x1 (NNRCVar mr.(mr_input)) n1)
n),
(mr.(mr_input), Vscalar)::nil)
| _ ⇒ None
end
else
None
| (_, _) ⇒ None
end.
Definition merge_last mrl :=
let '(chain, output, last) :=
List.fold_right
(fun mr chain_output_last ⇒
match chain_output_last with
| (nil, None, last) ⇒
match merge_mr_last mr last with
| Some last' ⇒ (nil, None, last')
| None ⇒ (mr::nil, Some (mr_output mr), last)
end
| (acc, output, last) ⇒ (mr::acc, output, last)
end)
(nil, None, mrl.(mr_last)) mrl.(mr_chain)
in
match output with
| None ⇒ None
| Some output ⇒ Some (mkMRChain mrl.(mr_inputs_loc) chain last)
end.
Definition mr_chain_apply_rewrite (rew: mr → option (list mr)) l :=
List.flat_map
(fun mr ⇒
match rew mr with
| None ⇒ mr::nil
| Some mr_chain ⇒ mr_chain
end)
l.
Definition apply_rewrite (rew: mr → option (list mr)) mrl :=
mkMRChain
mrl.(mr_inputs_loc)
(mr_chain_apply_rewrite rew mrl.(mr_chain))
mrl.(mr_last).
Definition mr_chain_apply_merge (merge: mr → mr → option mr) l :=
let output_vars : list var := List.fold_left (fun acc mr ⇒ mr.(mr_output) :: acc) l nil in
List.fold_right
(fun mr1 acc ⇒
if leb (mult output_vars mr1.(mr_output)) 1 then
let l_optimized :=
List.fold_right
(fun mr2 acc ⇒
match merge mr1 mr2 with
| None ⇒ mr2 :: acc
| Some mr12 ⇒ mr12 :: acc
end)
nil acc
in
mr1 :: l_optimized
else
mr1 :: acc)
nil l.
Definition apply_merge (merge: mr → mr → option mr) mrl :=
mkMRChain
mrl.(mr_inputs_loc)
(mr_chain_apply_merge merge mrl.(mr_chain))
mrl.(mr_last).
Fixpoint mr_chain_cleanup l (to_keep: list var) :=
let (to_keep', res) :=
List.fold_right
(fun r (acc: list var × list mr) ⇒
let (to_keep, res) := acc in
if in_dec equiv_dec r.(mr_output) to_keep then
(r.(mr_input)::to_keep, r::res)
else
(to_keep, res))
(to_keep, nil) l
in
res.
Definition mr_cleanup mrl to_keep :=
mkMRChain
mrl.(mr_inputs_loc)
(mr_chain_cleanup mrl.(mr_chain) to_keep)
mrl.(mr_last).
Definition mr_optimize_step (l: nnrcmr): nnrcmr :=
let to_keep := List.map fst (snd l.(mr_last)) in
let l := apply_rewrite map_collect_flatten_to_map_flatten_collect l in
let l := apply_merge merge_id_reduce_id_dist_map l in
let l := mr_cleanup l to_keep in
let l := apply_merge merge_singleton_reduce_id_scalar_map l in
let l := mr_cleanup l to_keep in
let l := apply_merge merge_reduce_id_flatten_map l in
let l := mr_cleanup l to_keep in
let l := apply_merge merge_collect_dispatch l in
let l := mr_cleanup l to_keep in
let l := apply_merge merge_mr_id_dist_l l in
let l := mr_cleanup l to_keep in
let l := apply_merge merge_mr_id_scalar_l l in
let l := mr_cleanup l to_keep in
let l := apply_merge merge_scalar_singleton_scalar l in
let l := mr_cleanup l to_keep in
let l :=
match merge_last l with
| None ⇒ l
| Some l ⇒ l
end
in
l.
Fixpoint mr_optimize_loop n (l: nnrcmr) :=
match n with
| 0 ⇒ l
| S n ⇒
let l := mr_optimize_step l in
mr_optimize_loop n l
end.
Definition mr_optimize (l: nnrcmr) :=
mr_optimize_loop 10 l.
Definition fresh_mr_var (prefix: string) (vars: list var) :=
let x := fresh_var prefix vars in
(x, x::vars).
Definition get_mr_chain_vars mr_chain :=
List.fold_left
(fun acc mr ⇒ mr.(mr_input) :: mr.(mr_output) :: acc)
mr_chain nil.
Definition get_nnrcmr_vars mrl :=
get_mr_chain_vars mrl.(mr_chain).
End NRewMR.