Require Import Basics.
Require Import List.
Require Import String.
Require Import Peano_dec.
Require Import EquivDec.
Require Import Utils.
Require Import DataSystem.
Require Import NNRCRuntime.
Require Import DNNRCSystem.
Require Import tDNNRC.
Require Import tDNNRCInfer.
Require Import OptimizerStep.
Require Import OptimizerLogger.
Section tDNNRCOptimizer.
  
Context {
fruntime:
foreign_runtime}.
  
  
Fixpoint dnnrc_base_map_plug {
A: 
Set} {
P: 
Set}
           (
f: 
P -> 
P)
           (
e: @
dnnrc_base _ A P) : @
dnnrc_base _ A P
    := 
match e with
       | 
DNNRCGetConstant a e0 => 
DNNRCGetConstant a e0
       | 
DNNRCVar a e0 => 
DNNRCVar a e0
       | 
DNNRCConst a e0 => 
DNNRCConst a e0
       | 
DNNRCBinop a b e1 e2 =>
         
DNNRCBinop a b (
dnnrc_base_map_plug f e1) (
dnnrc_base_map_plug f e2)
       | 
DNNRCUnop a u e0 =>
         
DNNRCUnop a u (
dnnrc_base_map_plug f e0)
       | 
DNNRCLet a x e1 e2 =>
         
DNNRCLet a x (
dnnrc_base_map_plug f e1) (
dnnrc_base_map_plug f e2)
       | 
DNNRCFor a x e1 e2 =>
         
DNNRCFor a x (
dnnrc_base_map_plug f e1) (
dnnrc_base_map_plug f e2)
       | 
DNNRCIf a e1 e2 e3 =>
         
DNNRCIf a
                (
dnnrc_base_map_plug f e1)
                (
dnnrc_base_map_plug f e2)
                (
dnnrc_base_map_plug f e3)
       | 
DNNRCEither a e0 x1 e1 x2 e2 =>
         
DNNRCEither a (
dnnrc_base_map_plug f e0) 
x1 (
dnnrc_base_map_plug f e1) 
x2 (
dnnrc_base_map_plug f e2)
       | 
DNNRCGroupBy a g sl e0 =>
         
DNNRCGroupBy a g sl (
dnnrc_base_map_plug f e0)
       | 
DNNRCCollect a e0 =>
         
DNNRCCollect a (
dnnrc_base_map_plug f e0)
       | 
DNNRCDispatch a e0 =>
         
DNNRCDispatch a (
dnnrc_base_map_plug f e0)
       | 
DNNRCAlg a p sdl =>
         
DNNRCAlg a (
f p) (
map (
fun sd => (
fst sd, (
dnnrc_base_map_plug f (
snd sd)))) 
sdl)
    
end.
  
Lemma dnnrc_base_map_plug_correct {
A: 
Set} {
P: 
Set}  
        {
plug:
AlgPlug P}
        {
f: 
P -> 
P}
        (
pf:
forall (
a:
A) 
e env, 
dnnrc_base_eq (
DNNRCAlg a e env) (
DNNRCAlg a (
f e) 
env))
        (
e: @
dnnrc_base _ A P) :
    
dnnrc_base_eq e (
dnnrc_base_map_plug f e).
Proof.
  Fixpoint dnnrc_base_map_deep {
A: 
Set} {
P: 
Set}
           (
f: @
dnnrc_base _ A P -> @
dnnrc_base _ A P)
           (
e: @
dnnrc_base _ A P) : @
dnnrc_base _ A P
    := 
match e with
       | 
DNNRCGetConstant a e0 =>
         
f (
DNNRCGetConstant a e0)
       | 
DNNRCVar a e0 =>
         
f (
DNNRCVar a e0)
       | 
DNNRCConst a e0 =>
         
f (
DNNRCConst a e0)
       | 
DNNRCBinop a b e1 e2 =>
         
f (
DNNRCBinop a b (
dnnrc_base_map_deep f e1) (
dnnrc_base_map_deep f e2))
       | 
DNNRCUnop a u e0 =>
         
f (
DNNRCUnop a u (
dnnrc_base_map_deep f e0))
       | 
DNNRCLet a x e1 e2 =>
         
f (
DNNRCLet a x (
dnnrc_base_map_deep f e1) (
dnnrc_base_map_deep f e2))
       | 
DNNRCFor a x e1 e2 =>
         
f (
DNNRCFor a x (
dnnrc_base_map_deep f e1) (
dnnrc_base_map_deep f e2))
       | 
DNNRCIf a e1 e2 e3 =>
         
f (
DNNRCIf a
                (
dnnrc_base_map_deep f e1)
                (
dnnrc_base_map_deep f e2)
                (
dnnrc_base_map_deep f e3))
       | 
DNNRCEither a e0 x1 e1 x2 e2 =>
         
f (
DNNRCEither a (
dnnrc_base_map_deep f e0) 
x1 (
dnnrc_base_map_deep f e1) 
x2 (
dnnrc_base_map_deep f e2))
       | 
DNNRCGroupBy a g sl e0 =>
         
f (
DNNRCGroupBy a g sl (
dnnrc_base_map_deep f e0))
       | 
DNNRCCollect a e0 =>
         
f (
DNNRCCollect a (
dnnrc_base_map_deep f e0))
       | 
DNNRCDispatch a e0 =>
         
f (
DNNRCDispatch a (
dnnrc_base_map_deep f e0))
       | 
DNNRCAlg a p sdl =>
         
f (
DNNRCAlg a p (
map (
fun sd => (
fst sd, (
dnnrc_base_map_deep f (
snd sd)))) 
sdl))
    
end.
    
Lemma dnnrc_base_map_deep_correctness {
A: 
Set} {
P: 
Set} 
          {
plug:
AlgPlug P}
          {
f: @
dnnrc_base _ A P -> @
dnnrc_base _ A P}
          (
pf:
forall e, 
dnnrc_base_eq e (
f e))
          (
e: @
dnnrc_base _ A P) :
      
dnnrc_base_eq e (
dnnrc_base_map_deep f e).
Proof.
    Context {
ftype:
foreign_type}.
    
Context {
h:
brand_relation_t}.
    
Context {
m:
brand_model}.
Discover the traditional casting the world pattern:
   * Iterate over a collection (the world), cast the element and perform some action on success, return the empty collection otherwise, and flatten at the end.
   * We can translate this into a filter with a user defined cast function.
   * We do not inline unbranding, as we would have to make sure that we don't use the branded value anywhere.
   
  Definition rec_cast_to_filter {
A: 
Set}
             (
e: @
dnnrc_base _ (
type_annotation A) 
dataframe) :
    @
dnnrc_base _ (
type_annotation A) 
dataframe
    := 
match e with
    | 
DNNRCUnop t1 OpFlatten
               (
DNNRCFor t2 x
                        (
DNNRCCollect t3 xs)
                        (
DNNRCEither _ (
DNNRCUnop t4 (
OpCast brands) (
DNNRCVar _ x'))
                                    
leftVar
                                    leftE
                                    _
                                    (
DNNRCConst _ (
dcoll nil)))) =>
      
if (
x == 
x')
      
then
        match olift tuneither (
lift_tlocal (
ta_inferred t4)) 
with
        | 
Some (
castSuccessType, 
_) =>
          
let algTypeA := 
ta_mk (
ta_base t4) (
Tdistr castSuccessType) 
in
          let collectTypeA := 
ta_mk (
ta_base t3) (
Tlocal (
Coll castSuccessType)) 
in
          let ALG := (
DNNRCAlg algTypeA
                            (
DSFilter (
CUDFCast brands (
CCol "$
class"))
                                      (
DSVar "
map_cast"))
                            (("
map_cast"%
string, 
xs)::
nil)) 
in
          (
DNNRCUnop t1 OpFlatten
                         (
DNNRCFor t2 leftVar (
DNNRCCollect collectTypeA ALG)
                                  
leftE))
        | 
_ => 
e
        end
      else e
    | 
_ => 
e
    end.
  
Definition rec_cast_to_filter_step {
A:
Set}
    := 
mkOptimizerStep
         "
rec cast filter" 
         "???" 
         "
rec_cast_to_filter" 
         (@
rec_cast_to_filter A) .
  
  
Fixpoint rewrite_unbrand_or_fail
           {
A: 
Set} {
P: 
Set}
           (
s: 
string)
           (
e: @
dnnrc_base _ A P) : 
option (@
dnnrc_base _ A P)
    := 
match e with
    | 
DNNRCUnop t1 OpUnbrand (
DNNRCGetConstant t2 v) =>
      
if (
s == 
v)
      
then Some (
DNNRCGetConstant t1 s)
      
else None
    | 
DNNRCUnop t1 OpUnbrand (
DNNRCVar t2 v) =>
      
if (
s == 
v)
      
then Some (
DNNRCVar t1 s)
      
else None
    | 
DNNRCVar t1 v =>
      
if (
s == 
v)
      
then None
      else Some (
DNNRCVar t1 v)
    | 
DNNRCGetConstant t1 v =>
      
if (
s == 
v)
      
then None
      else Some (
DNNRCGetConstant t1 v)
    | 
DNNRCConst _ _ => 
Some e
    | 
DNNRCBinop a b x y =>
      
lift2 (
DNNRCBinop a b) (
rewrite_unbrand_or_fail s x) (
rewrite_unbrand_or_fail s y)
    | 
DNNRCUnop a b x =>
      
lift (
DNNRCUnop a b) (
rewrite_unbrand_or_fail s x)
    | 
DNNRCLet a b x y =>
      
lift2 (
DNNRCLet a b) (
rewrite_unbrand_or_fail s x) (
rewrite_unbrand_or_fail s y)
    | 
DNNRCFor a b x y =>
      
lift2 (
DNNRCFor a b) (
rewrite_unbrand_or_fail s x) (
rewrite_unbrand_or_fail s y)
    | 
DNNRCIf a x y z =>
      
match rewrite_unbrand_or_fail s x, 
rewrite_unbrand_or_fail s y, 
rewrite_unbrand_or_fail s z with
      | 
Some x', 
Some y', 
Some z' => 
Some (
DNNRCIf a x' 
y' 
z')
      | 
_, 
_, 
_ => 
None
      end
    | 
DNNRCEither a x b y c z =>
      
match rewrite_unbrand_or_fail s x, 
rewrite_unbrand_or_fail s y, 
rewrite_unbrand_or_fail s z with
      | 
Some x', 
Some y', 
Some z' => 
Some (
DNNRCEither a x' 
b y' 
c z')
      | 
_, 
_, 
_ => 
None
      end
    | 
DNNRCGroupBy a g sl x =>
      
lift (
DNNRCGroupBy a g sl) (
rewrite_unbrand_or_fail s x)
    | 
DNNRCCollect a x =>
      
lift (
DNNRCCollect a) (
rewrite_unbrand_or_fail s x)
    | 
DNNRCDispatch a x =>
      
lift (
DNNRCDispatch a) (
rewrite_unbrand_or_fail s x)
    | 
DNNRCAlg _ _ _ => 
None
    end.
  
Definition rec_lift_unbrand
             {
A : 
Set}
             (
e: @
dnnrc_base _ (
type_annotation A) 
dataframe):
    (@
dnnrc_base _ (
type_annotation _) 
dataframe) :=
    
match e with
    | 
DNNRCFor t1 x (
DNNRCCollect t2 xs as c) 
body =>
      
match lift_tlocal (
di_required_typeof c) 
with
      | 
Some (
exist _ (
Coll₀ (
Brand₀ 
bs)) 
_) =>
        
let t := 
proj1_sig (
brands_type bs) 
in
        match rewrite_unbrand_or_fail x body with
        | 
Some e' =>
          
let ALG :=
              
DNNRCAlg (
dnnrc_base_annotation_get xs)
                       (
DSSelect (("$
blob"%
string, 
CCol "
unbranded.$
blob")
                                    :: ("$
known"%
string, 
CCol "
unbranded.$
known")::
nil)
                                 (
DSSelect (("
unbranded"%
string, 
CUDFUnbrand t (
CCol "$
data"))::
nil)
                                           (
DSVar "
lift_unbrand")))
                       (("
lift_unbrand"%
string, 
xs)::
nil)
          
in
          DNNRCFor t1 x (
DNNRCCollect t2 ALG) 
e'
        | 
None => 
e
        end
      | 
_ => 
e
      end
    | 
_ => 
e
    end.
    
Definition rec_lift_unbrand_step {
A:
Set}
    := 
mkOptimizerStep
         "
rec lift unbrand" 
         "???" 
         " 
rec_lift_unbrand" 
         (@
rec_lift_unbrand A) .
  
Fixpoint spark_equality_matches_qcert_equality_for_type (
r: 
rtype₀) :=
    
match r with
    | 
Nat₀
    | 
Float₀ 
    | 
Bool₀
    | 
String₀ => 
true
    | 
Rec₀ 
Closed fs =>
      
forallb (
compose spark_equality_matches_qcert_equality_for_type snd) 
fs
    | 
Either₀ 
l r  =>
      
andb (
spark_equality_matches_qcert_equality_for_type l)
           (
spark_equality_matches_qcert_equality_for_type r)
    | 
Bottom₀
    | 
Top₀
    | 
Unit₀ 
    | 
Coll₀ 
_ 
    | 
Rec₀ 
Open _
    | 
Arrow₀ 
_ _
    | 
Brand₀ 
_
    | 
Foreign₀ 
_ => 
false
    end.
  
Fixpoint condition_to_column {
A: 
Set}
           (
e: @
dnnrc_base _ (
type_annotation A) 
dataframe)
           (
binding: (
string * 
column)) :=
    
match e with
    | 
DNNRCUnop _ (
OpDot fld) (
DNNRCGetConstant _ n) =>
      
let (
var, 
_) := 
binding in
      if (
n == 
var)
      
then Some (
CCol ("$
known."%
string ++ 
fld))
      
else None
    | 
DNNRCUnop _ (
OpDot fld) (
DNNRCVar _ n) =>
      
let (
var, 
_) := 
binding in
      if (
n == 
var)
      
then Some (
CCol ("$
known."%
string ++ 
fld))
      
else None
    | 
DNNRCConst _ d =>
      
lift (
fun t => 
CLit (
d, (
proj1_sig t))) (
lift_tlocal (
di_required_typeof e))
    | 
DNNRCBinop _ OpEqual l r =>
      
let types_are_okay :=
          
lift2 (
fun lt rt => 
andb (
equiv_decb lt rt)
                                   (
spark_equality_matches_qcert_equality_for_type (
proj1_sig lt)))
                (
lift_tlocal (
di_typeof l)) (
lift_tlocal (
di_typeof r)) 
in
      match types_are_okay, 
condition_to_column l binding, 
condition_to_column r binding with
      | 
Some true, 
Some l', 
Some r' =>
        
Some (
CEq l' 
r')
      | 
_, 
_, 
_ => 
None
      end
    | 
DNNRCBinop _ OpStringConcat l r =>
      
lift2 CSConcat
            (
condition_to_column l binding)
            (
condition_to_column r binding)
    | 
DNNRCBinop _ OpLt l r =>
      
lift2 CLessThan
            (
condition_to_column l binding)
            (
condition_to_column r binding)
    | 
DNNRCUnop _ OpToString x =>
      
lift CToString
           (
condition_to_column x binding)
    | 
_ => 
None
    end.
  
Definition rec_if_else_empty_to_filter {
A: 
Set}
             (
e: @
dnnrc_base _ (
type_annotation A) 
dataframe):
    (@
dnnrc_base _ (
type_annotation A) 
dataframe) :=
    
match e with
    | 
DNNRCUnop t1 OpFlatten
               (
DNNRCFor t2 x (
DNNRCCollect t3 xs)
                        (
DNNRCIf _ condition
                                thenE
                                (
DNNRCConst _ (
dcoll nil)))) =>
      
match condition_to_column condition (
x, 
CCol "
abc") 
with
      | 
Some c' =>
        
let ALG :=
            
DNNRCAlg (
dnnrc_base_annotation_get xs)
                    (
DSFilter c' (
DSVar "
if_else_empty_to_filter"))
                    (("
if_else_empty_to_filter"%
string, 
xs)::
nil)
        
in
        DNNRCUnop t1 OpFlatten
                       (
DNNRCFor t2 x (
DNNRCCollect t3 ALG)
                                
thenE)
      | 
None => 
e
      end
    | 
_ => 
e
    end.
  
Definition rec_if_else_empty_to_filter_step {
A:
Set}
    := 
mkOptimizerStep
         "
rec/
if/
empty" 
         "" 
         "
rec_if_else_empty_to_filter" 
         (@
rec_if_else_empty_to_filter A) .
  
Definition rec_remove_map_singletoncoll_flatten {
A: 
Set}
             (
e: @
dnnrc_base _ (
type_annotation A) 
dataframe):
    @
dnnrc_base _ (
type_annotation A) 
dataframe :=
    
match e with
    | 
DNNRCUnop t1 OpFlatten
               (
DNNRCFor t2 x xs
                        (
DNNRCUnop t3 OpBag e)) =>
      
DNNRCFor t1 x xs e
    | 
_ => 
e
    end.
  
Definition rec_remove_map_singletoncoll_flatten_step {
A:
Set}
    := 
mkOptimizerStep
         "
flatten/
for/
coll" 
         "
Simplifes flatten of a for comprehension that creates singleton bags" 
         "
rec_remove_map_singletoncoll_flatten" 
         (@
rec_remove_map_singletoncoll_flatten A) .
  
Definition rec_for_to_select {
A: 
Set}
             (
e: @
dnnrc_base _ (
type_annotation A) 
dataframe):
    @
dnnrc_base _ (
type_annotation A) 
dataframe :=
    
match e with
    | 
DNNRCFor t1 x (
DNNRCCollect t2 xs) 
body =>
      
match lift_tlocal (
di_typeof body) 
with
      | 
Some String =>
        
match condition_to_column body (
x, 
CCol "
abc") 
with
        | 
Some body' =>
          
let ALG_type := 
Tdistr String in
          let ALG :=
              
DNNRCAlg (
ta_mk (
ta_base t1) 
ALG_type)
                      (
DSSelect (("
value"%
string, 
body')::
nil) (
DSVar "
for_to_select"))
                      (("
for_to_select"%
string, 
xs)::
nil)
          
in
          DNNRCCollect t1 ALG
        | 
None => 
e
        end
      | 
_ => 
e
      end
    | 
_ => 
e
    end.
  
Definition rec_for_to_select_step {
A:
Set}
    := 
mkOptimizerStep
         "
rec for to select" 
         "???" 
         "
rec_for_to_select" 
         (@
rec_for_to_select A) .
  
Import ListNotations.
  
Definition dnnrc_optim_list {
A} :
    
list (
OptimizerStep (@
dnnrc_base _ (
type_annotation A) 
dataframe))
    := [
        
rec_cast_to_filter_step
        ; 
rec_lift_unbrand_step
        ; 
rec_if_else_empty_to_filter_step
        ; 
rec_remove_map_singletoncoll_flatten_step
        ; 
rec_for_to_select_step
      ].
  
Lemma dnnrc_optim_list_distinct {
A:
Set}:
    
optim_list_distinct (@
dnnrc_optim_list A).
Proof.
  
  Definition dnnrc_optim_top {
A}
             {
logger:
optimizer_logger string (@
dnnrc_base _ (
type_annotation A) 
dataframe)}
             (
optims:
list string)
             (
iterationsBetweenCostCheck:
nat)
    : @
dnnrc_base _ (
type_annotation A) 
dataframe -> @
dnnrc_base _ (
type_annotation A) 
dataframe :=
    
run_phase dnnrc_base_map_deep (
dnnrc_base_size ) 
dnnrc_optim_list
              "[
dnnrc] " 
optims iterationsBetweenCostCheck.
  
Definition dnnrc_default_optim_list : 
list string
    := [
        
optim_step_name (@
rec_for_to_select_step unit)
          ; 
optim_step_name (@
rec_remove_map_singletoncoll_flatten_step unit)
          ; 
optim_step_name (@
rec_if_else_empty_to_filter_step unit)
          ; 
optim_step_name (@
rec_lift_unbrand_step unit)
          ; 
optim_step_name (@
rec_cast_to_filter_step unit)
        ].
  
Remark dnnrc_default_optim_list_all_valid  {
A:
Set}
    : 
valid_optims (@
dnnrc_optim_list A) 
dnnrc_default_optim_list = (
dnnrc_default_optim_list,
nil).
Proof.
    vm_compute; trivial.
  Qed.
  Definition dnnrc_optim_top_default {
A:
Set}
             {
logger:
optimizer_logger string (@
dnnrc_base _ (
type_annotation A) 
dataframe)}
    := 
dnnrc_optim_top dnnrc_default_optim_list 6.
End tDNNRCOptimizer.