Qcert.DNNRC.Optim.DNNRCOptimizer




Section DNNRCDatasetRewrites.

  Context {fruntime:foreign_runtime}.

  Fixpoint dnnrc_map_plug {A: Set} {P: Set}
           (f: P P)
           (e: dnnrc A P) : dnnrc A P
    := match e with
       | DNNRCVar a e0DNNRCVar a e0
       | DNNRCConst a e0DNNRCConst a e0
       | DNNRCBinop a b e1 e2
         DNNRCBinop a b (dnnrc_map_plug f e1) (dnnrc_map_plug f e2)
       | DNNRCUnop a u e0
         DNNRCUnop a u (dnnrc_map_plug f e0)
       | DNNRCLet a x e1 e2
         DNNRCLet a x (dnnrc_map_plug f e1) (dnnrc_map_plug f e2)
       | DNNRCFor a x e1 e2
         DNNRCFor a x (dnnrc_map_plug f e1) (dnnrc_map_plug f e2)
       | DNNRCIf a e1 e2 e3
         DNNRCIf a
                (dnnrc_map_plug f e1)
                (dnnrc_map_plug f e2)
                (dnnrc_map_plug f e3)
       | DNNRCEither a e0 x1 e1 x2 e2
         DNNRCEither a (dnnrc_map_plug f e0) x1 (dnnrc_map_plug f e1) x2 (dnnrc_map_plug f e2)
       | DNNRCGroupBy a g sl e0
         DNNRCGroupBy a g sl (dnnrc_map_plug f e0)
       | DNNRCCollect a e0
         DNNRCCollect a (dnnrc_map_plug f e0)
       | DNNRCDispatch a e0
         DNNRCDispatch a (dnnrc_map_plug f e0)
       | DNNRCAlg a p sdl
         DNNRCAlg a (f p) (map (fun sd(fst sd, (dnnrc_map_plug f (snd sd)))) sdl)
    end.

  Lemma dnnrc_map_plug_correct {A: Set} {P: Set}
        {plug:AlgPlug P}
        {f: P P}
        (pf: (a:A) e env, dnnrc_eq (DNNRCAlg a e env) (DNNRCAlg a (f e) env))
        (e: dnnrc A P) :
    dnnrc_eq e (dnnrc_map_plug f e).

  Fixpoint dnnrc_map_deep {A: Set} {P: Set}
           (f: dnnrc A P dnnrc A P)
           (e: dnnrc A P) : dnnrc A P
    := match e with
       | DNNRCVar a e0
         f (DNNRCVar a e0)
       | DNNRCConst a e0
         f (DNNRCConst a e0)
       | DNNRCBinop a b e1 e2
         f (DNNRCBinop a b (dnnrc_map_deep f e1) (dnnrc_map_deep f e2))
       | DNNRCUnop a u e0
         f (DNNRCUnop a u (dnnrc_map_deep f e0))
       | DNNRCLet a x e1 e2
         f (DNNRCLet a x (dnnrc_map_deep f e1) (dnnrc_map_deep f e2))
       | DNNRCFor a x e1 e2
         f (DNNRCFor a x (dnnrc_map_deep f e1) (dnnrc_map_deep f e2))
       | DNNRCIf a e1 e2 e3
         f (DNNRCIf a
                (dnnrc_map_deep f e1)
                (dnnrc_map_deep f e2)
                (dnnrc_map_deep f e3))
       | DNNRCEither a e0 x1 e1 x2 e2
         f (DNNRCEither a (dnnrc_map_deep f e0) x1 (dnnrc_map_deep f e1) x2 (dnnrc_map_deep f e2))
       | DNNRCGroupBy a g sl e0
         f (DNNRCGroupBy a g sl (dnnrc_map_deep f e0))
       | DNNRCCollect a e0
         f (DNNRCCollect a (dnnrc_map_deep f e0))
       | DNNRCDispatch a e0
         f (DNNRCDispatch a (dnnrc_map_deep f e0))
       | DNNRCAlg a p sdl
         f (DNNRCAlg a p (map (fun sd(fst sd, (dnnrc_map_deep f (snd sd)))) sdl))
    end.

    Lemma dnnrc_map_deep_correctness {A: Set} {P: Set}
          {plug:AlgPlug P}
          {f: dnnrc A P dnnrc A P}
          (pf: e, dnnrc_eq e (f e))
          (e: dnnrc A P) :
      dnnrc_eq e (dnnrc_map_deep f e).

    Context {ftype:foreign_type}.
    Context {h:brand_relation_t}.
    Context {m:brand_model}.

Discover the traditional casting the world pattern:

Iterate over a collection (the world), cast the element and perform some action on success, return the empty collection otherwise, and flatten at the end.

We can translate this into a filter with a user defined cast function.

We do not inline unbranding, as we would have to make sure that we don't use the branded value anywhere.

  Definition rec_cast_to_filter {A: Set}
             (e: dnnrc (type_annotation A) dataset) :
    dnnrc (type_annotation A) dataset
    := match e with
    | DNNRCUnop t1 AFlatten
               (DNNRCFor t2 x
                        (DNNRCCollect t3 xs)
                        (DNNRCEither _ (DNNRCUnop t4 (ACast brands) (DNNRCVar _ x'))
                                    leftVar
                                    leftE
                                    _
                                    (DNNRC.DNNRCConst _ (dcoll nil)))) ⇒
      if (x == x')
      then
        match olift tuneither (lift_tlocal (ta_inferred t4)) with
        | Some (castSuccessType, _)
          let algTypeA := ta_mk (ta_base t4) (Tdistr castSuccessType) in
          let collectTypeA := ta_mk (ta_base t3) (Tlocal (Coll castSuccessType)) in
          
          let ALG := (DNNRCAlg algTypeA
                            (DSFilter (CUDFCast brands (CCol "$type"))
                                      (DSVar "map_cast"))
                            (("map_cast"%string, xs)::nil)) in
          (DNNRCUnop t1 AFlatten
                         (DNNRCFor t2 leftVar (DNNRCCollect collectTypeA ALG)
                                  leftE))
        | _e
        end
      else e
    | _e
    end.

  Definition rec_cast_to_filter_step {A:Set}
    := mkOptimizerStep
         "rec cast filter"
         "???"
         "rec_cast_to_filter"
         (@rec_cast_to_filter A) .

  Fixpoint rewrite_unbrand_or_fail
           {A: Set} {P: Set}
           (s: string)
           (e: dnnrc A P) : option (dnnrc A P)
    := match e with
    | DNNRCUnop t1 AUnbrand (DNNRCVar t2 v) ⇒
      if (s == v)
      then Some (DNNRCVar t1 s)
      else None
    | DNNRCVar _ v
      if (s == v)
      then None
      else Some e
    | DNNRCConst _ _Some e
    | DNNRCBinop a b x y
      lift2 (DNNRCBinop a b) (rewrite_unbrand_or_fail s x) (rewrite_unbrand_or_fail s y)
    | DNNRCUnop a b x
      lift (DNNRCUnop a b) (rewrite_unbrand_or_fail s x)
    | DNNRCLet a b x y
      lift2 (DNNRCLet a b) (rewrite_unbrand_or_fail s x) (rewrite_unbrand_or_fail s y)
    | DNNRCFor a b x y
      lift2 (DNNRCFor a b) (rewrite_unbrand_or_fail s x) (rewrite_unbrand_or_fail s y)
    | DNNRCIf a x y z
      match rewrite_unbrand_or_fail s x, rewrite_unbrand_or_fail s y, rewrite_unbrand_or_fail s z with
      | Some x', Some y', Some z'Some (DNNRCIf a x' y' z')
      | _, _, _None
      end
    | DNNRCEither a x b y c z
      match rewrite_unbrand_or_fail s x, rewrite_unbrand_or_fail s y, rewrite_unbrand_or_fail s z with
      | Some x', Some y', Some z'Some (DNNRCEither a x' b y' c z')
      | _, _, _None
      end
    | DNNRCGroupBy a g sl x
      lift (DNNRCGroupBy a g sl) (rewrite_unbrand_or_fail s x)
    | DNNRCCollect a x
      lift (DNNRCCollect a) (rewrite_unbrand_or_fail s x)
    | DNNRCDispatch a x
      lift (DNNRCDispatch a) (rewrite_unbrand_or_fail s x)
    
    | DNNRCAlg _ _ _None
    end.

  Definition rec_lift_unbrand
             {A : Set}
             (e: dnnrc (type_annotation A) dataset):
    (dnnrc (type_annotation _) dataset) :=
    match e with
    | DNNRCFor t1 x (DNNRCCollect t2 xs as c) body
      match lift_tlocal (di_required_typeof c) with
      | Some (exist _ (Coll₀ (Brand₀ bs)) _) ⇒
        let t := proj1_sig (brands_type bs) in
        match rewrite_unbrand_or_fail x body with
        | Some e'
          let ALG :=
              
              DNNRCAlg (dnnrc_annotation_get xs)
                       (DSSelect (("$blob"%string, CCol "unbranded.$blob")
                                    :: ("$known"%string, CCol "unbranded.$known")::nil)
                                 (DSSelect (("unbranded"%string, CUDFUnbrand t (CCol "$data"))::nil)
                                           (DSVar "lift_unbrand")))
                       (("lift_unbrand"%string, xs)::nil)
          in
          DNNRCFor t1 x (DNNRCCollect t2 ALG) e'
        | Nonee
        end
      | _e
      end
    | _e
    end.

    Definition rec_lift_unbrand_step {A:Set}
    := mkOptimizerStep
         "rec lift unbrand"
         "???"
         " rec_lift_unbrand"
         (@rec_lift_unbrand A) .

  Fixpoint spark_equality_matches_qcert_equality_for_type (r: rtype₀) :=
    match r with
    | Nat₀
    | Bool₀
    | String₀true
    | Rec₀ Closed fs
      forallb (compose spark_equality_matches_qcert_equality_for_type snd) fs
    | Either₀ l r
      andb (spark_equality_matches_qcert_equality_for_type l)
           (spark_equality_matches_qcert_equality_for_type r)
    | Bottom₀
    | Top₀
    | Unit₀
    | Coll₀ _
    | Rec₀ Open _
    | Arrow₀ _ _
    | Brand₀ _
    | Foreign₀ _false
    end.

  Fixpoint condition_to_column {A: Set}
           (e: dnnrc (type_annotation A) dataset)
           (binding: (string × column)) :=
    match e with
    
    | DNNRCUnop _ (ADot fld) (DNNRCVar _ n) ⇒
      let (var, _) := binding in
      if (n == var)
      then Some (CCol ("$known."%string ++ fld))
      else None
    
    | DNNRCConst _ d
      lift (fun tCLit (d, (proj1_sig t))) (lift_tlocal (di_required_typeof e))
    | DNNRCBinop _ AEq l r
      let types_are_okay :=
          lift2 (fun lt rtandb (equiv_decb lt rt)
                                   (spark_equality_matches_qcert_equality_for_type (proj1_sig lt)))
                (lift_tlocal (di_typeof l)) (lift_tlocal (di_typeof r)) in
      match types_are_okay, condition_to_column l binding, condition_to_column r binding with
      | Some true, Some l', Some r'
        Some (CEq l' r')
      | _, _, _None
      end
    | DNNRCBinop _ ASConcat l r
      lift2 CSConcat
            (condition_to_column l binding)
            (condition_to_column r binding)
    | DNNRCBinop _ ALt l r
      lift2 CLessThan
            (condition_to_column l binding)
            (condition_to_column r binding)
    
    | DNNRCUnop _ AToString x
      lift CToString
           (condition_to_column x binding)

    | _None
    end.

  Definition rec_if_else_empty_to_filter {A: Set}
             (e: dnnrc (type_annotation A) dataset):
    (dnnrc (type_annotation A) dataset) :=
    match e with
    | DNNRCUnop t1 AFlatten
               (DNNRCFor t2 x (DNNRCCollect t3 xs)
                        (DNNRCIf _ condition
                                thenE
                                (DNNRC.DNNRCConst _ (dcoll nil)))) ⇒
      match condition_to_column condition (x, CCol "abc") with
      | Some c'
        let ALG :=
            DNNRCAlg (dnnrc_annotation_get xs)
                    (DSFilter c' (DSVar "if_else_empty_to_filter"))
                    (("if_else_empty_to_filter"%string, xs)::nil)
        in
        DNNRCUnop t1 AFlatten
                       (DNNRCFor t2 x (DNNRCCollect t3 ALG)
                                thenE)
      | Nonee
      end
    | _e
    end.

  Definition rec_if_else_empty_to_filter_step {A:Set}
    := mkOptimizerStep
         "rec/if/empty"
         ""
         "rec_if_else_empty_to_filter"
         (@rec_if_else_empty_to_filter A) .

  Definition rec_remove_map_singletoncoll_flatten {A: Set}
             (e: dnnrc (type_annotation A) dataset):
    dnnrc (type_annotation A) dataset :=
    match e with
    | DNNRCUnop t1 AFlatten
               (DNNRCFor t2 x xs
                        (DNNRCUnop t3 AColl e)) ⇒
      DNNRCFor t1 x xs e
    | _e
    end.

  Definition rec_remove_map_singletoncoll_flatten_step {A:Set}
    := mkOptimizerStep
         "flatten/for/coll"
         "Simplifes flatten of a for comprehension that creates singleton bags"
         "rec_remove_map_singletoncoll_flatten"
         (@rec_remove_map_singletoncoll_flatten A) .

  Definition rec_for_to_select {A: Set}
             (e: dnnrc (type_annotation A) dataset):
    dnnrc (type_annotation A) dataset :=
    match e with
    | DNNRCFor t1 x (DNNRCCollect t2 xs) body
      match lift_tlocal (di_typeof body) with
      
      | Some String
        
        match condition_to_column body (x, CCol "abc") with
        | Some body'
          
          let ALG_type := Tdistr String in
          let ALG :=
              DNNRCAlg (ta_mk (ta_base t1) ALG_type)
                      (DSSelect (("value"%string, body')::nil) (DSVar "for_to_select"))
                      (("for_to_select"%string, xs)::nil)
          in
          DNNRCCollect t1 ALG
        | Nonee
        end
      | _e
      end
    | _e
    end.

  Definition rec_for_to_select_step {A:Set}
    := mkOptimizerStep
         "rec for to select"
         "???"
         "rec_for_to_select"
         (@rec_for_to_select A) .


  Definition dnnrc_optim_list {A} :
    list (OptimizerStep (dnnrc (type_annotation A) dataset))
    := [
        rec_cast_to_filter_step
        ; rec_lift_unbrand_step
        ; rec_if_else_empty_to_filter_step
        ; rec_remove_map_singletoncoll_flatten_step
        ; rec_for_to_select_step
      ].

  Lemma dnnrc_optim_list_distinct {A:Set}:
    optim_list_distinct (@dnnrc_optim_list A).

  Definition run_dnnrc_optims {A}
             {logger:optimizer_logger string (dnnrc (type_annotation A) dataset)}
             (phaseName:string)
             (optims:list string)
             (iterationsBetweenCostCheck:nat)
    : dnnrc (type_annotation A) dataset dnnrc (type_annotation A) dataset :=
    run_phase dnnrc_map_deep (dnnrc_size dataset_size) dnnrc_optim_list
              ("[dnnrc] " ++ phaseName) optims iterationsBetweenCostCheck.

  Definition dnnrc_default_optim_list : list string
    := [
        optim_step_name (@rec_for_to_select_step unit)
          ; optim_step_name (@rec_remove_map_singletoncoll_flatten_step unit)
          ; optim_step_name (@rec_if_else_empty_to_filter_step unit)
          ; optim_step_name (@rec_lift_unbrand_step unit)
          ; optim_step_name (@rec_cast_to_filter_step unit)
        ].

  Remark dnnrc_default_optim_list_all_valid {A:Set}
    : valid_optims (@dnnrc_optim_list A) dnnrc_default_optim_list = (dnnrc_default_optim_list,nil).

  Definition dnnrcToDatasetRewrite {A:Set}
             {logger:optimizer_logger string (dnnrc (type_annotation A) dataset)}
    := run_dnnrc_optims "" dnnrc_default_optim_list 6.

End DNNRCDatasetRewrites.