Qcert.Backend.DNNRCtoScala


Section Helpers.
  Definition eol := String (Ascii.ascii_of_nat 10) EmptyString.

End Helpers.




Section DNNRCtoScala.

  Context {f:foreign_runtime}.
  Context {h:brand_relation_t}.
  Context {ftype:foreign_type}.
  Context {m:brand_model}.
  Context {fdtyping:foreign_data_typing}.
  Context {fboptyping:foreign_binary_op_typing}.
  Context {fuoptyping:foreign_unary_op_typing}.
  Context {fttjs: ForeignToJavascript.foreign_to_javascript}.
  Context {fts: ForeignToScala.foreign_to_scala}.

  Definition quote_string (s: string) : string :=
    """" ++ s ++ """".

Get code for a Spark DataType corresping to an rtype. *

These are things like StringType, ArrayType(...), StructType(Seq(FieldType(...), ...)), ...

*

This function encodes details of our data representation, e.g. records are StructFields

with two toplevel fields known, and the

statically known fields and their types.

  Fixpoint rtype_to_spark_DataType (r: rtype₀) : string :=
    match r with
    | Bottom₀ ⇒ "NullType"
    | Top₀ ⇒ "StringType"
    | Unit₀ ⇒ "NullType"
    | Nat₀ ⇒ "LongType"
    | Bool₀ ⇒ "BooleanType"
    | String₀ ⇒ "StringType"
    | Coll₀ e ⇒ "ArrayType(" ++ rtype_to_spark_DataType e ++ ")"
    | Rec₀ _ fields
      let known_fields: list string :=
          map (fun p ⇒ "StructField(""" ++ fst p ++ """, " ++ rtype_to_spark_DataType (snd p) ++ ")")
              fields in
      let known_struct := "StructType(Seq(" ++ joinStrings ", " known_fields ++ "))" in
      "StructType(Seq(StructField(""$blob"", StringType), StructField(""$known"", " ++ known_struct ++ ")))"
    | Either₀ l r
      "StructType(Seq(StructField(""$left"", " ++ rtype_to_spark_DataType l ++ "), StructField(""$right"", " ++ rtype_to_spark_DataType r ++ ")))"
    | Brand₀ _
      "StructType(Seq(StructField(""$data"", StringType), StructField(""$type"", ArrayType(StringType))))"
    
    | Arrow₀ _ _ ⇒ "ARROW TYPE?"
    | Foreign₀ ftforeign_to_scala_spark_datatype ft
    end.

Scala-level type of an rtype. *

These are things like Long, String, Boolean, Array..., Row.

*

We need to annotate some expressions with Scala-level types

(e.g. ArrayRow() for an empty Array of Records) to help

the Scala compiler because it does not infer types everywhere.

  Fixpoint rtype_to_scala_type (t: rtype₀): string :=
    match t with
    | Bottom₀ ⇒ "BOTTOM?"
    | Top₀ ⇒ "TOP?"
    | Unit₀ ⇒ "Unit"
    | Nat₀ ⇒ "Long"
    | Bool₀ ⇒ "Boolean"
    | String₀ ⇒ "String"
    | Coll₀ r ⇒ "Array[" ++ rtype_to_scala_type r ++ "]"
    | Rec₀ _ _ ⇒ "Record"
    | Either₀ tl tr ⇒ "Either"
    | Brand₀ bs ⇒ "BrandedValue"
    | Arrow₀ tin t ⇒ "CANNOT PUT AN ARROW INTO A DATASET"
    | Foreign₀ f ⇒ "FOREIGN?"
    end.

  Definition drtype_to_scala (t: drtype): string :=
    match t with
    | Tlocal rrtype_to_scala_type (proj1_sig r)
    | Tdistr r ⇒ "Dataset[" ++ rtype_to_scala_type (proj1_sig r) ++ "]"
    end.

  Fixpoint scala_literal_data (d: data) (t: rtype₀) {struct t}: string :=
    match t, d with
    | Unit₀, d ⇒ "()"
    | Nat₀, dnat iZ_to_string10 i
    | Bool₀, dbool true ⇒ "true"
    | Bool₀, dbool false ⇒ "false"
    | String₀, dstring squote_string s
    | Coll₀ r, dcoll xs
      let element_type := rtype_to_scala_type r in
      let elements := map (fun xscala_literal_data x r) xs in
      "Array[" ++ element_type ++ "](" ++ joinStrings ", " elements ++ ")"
    | Rec₀ _ fts, drec fds
      let blob := quote_string (data_to_blob d) in
      let known_schema :=
          "StructType(Seq("
            ++ joinStrings ", "
            (map (fun ft
                    "StructField(""" ++ fst ft ++ """, " ++ rtype_to_spark_DataType (snd ft) ++ ")")
                 fts)
            ++ "))" in
      let fields := map (fun ftmatch lookup string_dec fds (fst ft) with
                                   | Some dscala_literal_data d (snd ft)
                                   | None ⇒ "FIELD_IN_TYPE_BUT_NOT_IN_DATA"
                                   end) fts in
      let known := "srow(" ++ joinStrings ", " (known_schema :: fields) ++ ")" in
      "srow(StructType(Seq(StructField(""$blob"", StringType), StructField(""$known"", " ++ known_schema ++ "))), " ++ blob ++ ", " ++ known ++ ")"
    
    | _, _ ⇒ "UNIMPLEMENTED_SCALA_LITERAL_DATA"
    end.

  Fixpoint code_of_column (c: column) : string :=
    match c with
    | CCol s ⇒ "column(""" ++ s ++ """)"
    | CDot fld ccode_of_column c ++ ".getField(" ++ quote_string fld ++ ")"
    | CEq c1 c2code_of_column c1 ++ ".equalTo(" ++ code_of_column c2 ++ ")"
    | CLessThan c1 c2code_of_column c1 ++ ".lt(" ++ code_of_column c2 ++ ")"
    | CLit (d, r) ⇒ "lit(" ++ scala_literal_data d r ++ ")"
    | CNeg c ⇒ "not(" ++ code_of_column c ++ ")"
    | CPlus c1 c2code_of_column c1 ++ ".plus(" ++ code_of_column c2 ++ ")"
    | CSConcat c1 c2
      "concat(" ++ code_of_column c1 ++ ", " ++ code_of_column c2 ++ ")"
    | CToString c
      "toQcertStringUDF(" ++ code_of_column c ++ ")"
    | CUDFCast bs c
      "castUDF(" ++ joinStrings ", " ("HIERARCHY"%string :: map quote_string bs) ++ ")(" ++ code_of_column c ++ ")"
    | CUDFUnbrand t c
      "unbrandUDF(" ++ rtype_to_spark_DataType t ++ ")(" ++ code_of_column c ++ ")"
    end.

  Fixpoint code_of_dataset (e: dataset) : string :=
    match e with
    | DSVar ss
    | DSSelect cs d
      let columns :=
          map (fun nccode_of_column (snd nc) ++ ".as(""" ++ fst nc ++ """)") cs in
      code_of_dataset d ++ ".select(" ++ joinStrings ", " columns ++ ")"
    | DSFilter c dcode_of_dataset d ++ ".filter(" ++ code_of_column c ++ ")"
    | DSCartesian d1 d2code_of_dataset d1 ++ ".join(" ++ (code_of_dataset d2) ++ ")"
    | DSExplode s d1code_of_dataset d1 ++ ".select(explode(" ++ code_of_column (CCol s) ++ ").as(""" ++ s ++ """))"
    end.

  Definition spark_of_unop (op: unaryOp) (x: string) : string :=
    match op with
      
      | ACountx ++ ".count()"
      | ASumx ++ ".select(sum(""value"")).first().getLong(0)"
      
      
      | AFlattenx ++ ".flatMap(r => r)"
      | _ ⇒ "SPARK_OF_UNOP don't know how to generate Spark code for this operator"
    end.

  Definition scala_of_unop (required_type: drtype) (op: unaryOp) (x: string) : string :=
    let prefix s := s ++ "(" ++ x ++ ")" in
    let postfix s := x ++ "." ++ s in
    match op with
    | AArithMeanprefix "arithMean"
    | ABrand bs ⇒ "brand(" ++ joinStrings ", " (x::(map quote_string bs)) ++ ")"
    | ACast bs
      "cast(HIERARCHY, " ++ x ++ ", " ++ joinStrings ", " (map quote_string bs) ++ ")"
    | ACollprefix "Array"
    | ACountpostfix "length"
    | ADot n
      match lift_tlocal required_type with
      | Some r
        prefix ("dot[" ++ rtype_to_scala_type (proj1_sig r) ++ "](""" ++ n ++ """)")
      | None ⇒ "NONLOCAL EXPECTED TYPE IN DOT"
      end
    | AFlattenpostfix "flatten.sorted(QcertOrdering)"
    | AIdOpprefix "identity"
    | ALeftprefix "left"
    | ANeg ⇒ "(!" ++ x ++ ")"
    | ANumMaxprefix "anummax"
    | ANumMinprefix "anummin"
    | ARec n
      match lift_tlocal required_type with
      | Some (exist _ (Rec₀ Closed ((_, ft)::nil)) _) ⇒
        "singletonRecord(" ++ quote_string n ++ ", " ++ rtype_to_spark_DataType ft ++ ", " ++ x ++ ")"
      | _ ⇒ "AREC_WITH_UNEXPECTED_REQUIRED_TYPE"
      end
    | ARecProject fsprefix ("recProject(" ++ joinStrings ", " (map quote_string fs) ++ ")")
    | ARightprefix "right"
    | ASumpostfix "sum"
    | AToStringprefix "toQcertString"
    | ASubstring start olen
      "(" ++ x ++ ").substring(" ++ toString start ++
          match olen with
          | Some len ⇒ ", " ++ toString len
          | None ⇒ ""
          end ++ ")"
    | ALike pat oescape
      "ALike currently implemented. Please implement as in the java backend"

    | AUArith ArithAbsprefix "Math.abs"
    | AUnbrand
      match lift_tlocal required_type with
      | Some (exist _ r _) ⇒
        let schema := rtype_to_spark_DataType r in
        let scala := rtype_to_scala_type r in
        "unbrand[" ++ scala ++ "](" ++ schema ++ ", " ++ x ++ ")"
      | None ⇒ "UNBRAND_REQUIRED_TYPE_ISSUE"
      end
    | ADistinctpostfix "distinct"
    | AOrderBy scl ⇒ "SORT???"
    | AForeignUnaryOp oforeign_to_scala_unary_op o x

    
    | ARecRemove _ ⇒ "ARECREMOVE???"
    | ASingleton ⇒ "SINGLETON???"
    | AUArith ArithLog2 ⇒ "LOG2???"
    | AUArith ArithSqrt ⇒ "SQRT???"
    end.

  Definition spark_of_binop (op: binOp) (x: string) (y: string) : string :=
    match op with
    | AUnionx ++ ".union(" ++ y ++ ")"
    | _ ⇒ "SPARK_OF_BINOP don't know how to generate Spark code for this operator"
    end.

  Definition scala_of_binop (op: binOp) (l: string) (r: string) : string :=
    
    let infix s := l ++ "." ++ s ++ "(" ++ r ++ ")" in
    let infix' s := "(" ++ l ++ " " ++ s ++ " " ++ r ++ ")" in
    let prefix s := s ++ "(" ++ l ++ ", " ++ r ++ ")" in
    match op with
    | AAndinfix "&&"
    | ABArith ArithDivideinfix "/"
    | ABArith ArithMaxinfix "max"
    | ABArith ArithMininfix "min"
    | ABArith ArithMinusinfix "-"
    | ABArith ArithMultinfix "*"
    | ABArith ArithPlusinfix "+"
    | ABArith ArithReminfix "%"
    | AConcatprefix "recordConcat"
    | AContainsprefix "AContains"
    
    | AEqprefix "equal"
    | ALeprefix "lessOrEqual"
    | ALtprefix "lessThan"
    
    | AMaxl ++ ".++(" ++ r ++ ".diff(" ++ l ++ "))"
    | AMinl ++ ".diff(" ++ l ++ ".diff(" ++ r ++ "))"
    | AMinusr ++ ".diff(" ++ l ++ ")"
    | AMergeConcatprefix "mergeConcat"
    | AOrinfix "||"
    | ASConcatinfix "+"
    | AUnioninfix "++" ++ ".sorted(QcertOrdering)"

    
    | AForeignBinaryOp op ⇒ "FOREIGNBINARYOP???"
    end.

  Definition primitive_type (t: rtype) :=
    match proj1_sig t with
    | | | Unit₀ | Nat₀ | String₀ | Bool₀true
    | Coll₀ _ | Rec₀ _ _ | Either₀ _ _ | Arrow₀ _ _ | Brand₀ _false
    
    | Foreign₀ _false
    end.

  Fixpoint scala_of_dnnrc {A: Set} (d: dnnrc (type_annotation A) dataset) : string :=
    let code :=
        match d with
        | DNNRCVar t nn
        | DNNRCConst t c
          match (lift_tlocal (di_required_typeof d)) with
          | Some rscala_literal_data c (proj1_sig r)
          | None ⇒ "Don't know how to construct a distributed constant"
          end
        | DNNRCBinop t op x y
          match di_typeof x, di_typeof y with
          | Tlocal _, Tlocal _scala_of_binop op (scala_of_dnnrc x) (scala_of_dnnrc y)
          | Tdistr _, Tdistr _spark_of_binop op (scala_of_dnnrc x) (scala_of_dnnrc y)
          | _, _ ⇒ "DONT_SUPPORT_MIXED_LOCAL_DISTRIBUTED_BINARY_OPERATORS"
          end
        | DNNRCUnop t op x
          match di_typeof x with
          | Tlocal _scala_of_unop (di_required_typeof d) op (scala_of_dnnrc x)
          | Tdistr _spark_of_unop op (scala_of_dnnrc x)
          end
        | DNNRCLet t n x y
          "((( " ++ n ++ ": " ++ drtype_to_scala (di_typeof x) ++ ") => " ++ scala_of_dnnrc y ++ ")(" ++ scala_of_dnnrc x ++ "))"
        | DNNRCFor t n x y
          
          scala_of_dnnrc x ++ ".map((" ++ n ++ ") => { " ++ scala_of_dnnrc y ++ " })"
        | DNNRCIf t x y z
          "(if (" ++ scala_of_dnnrc x ++ ") " ++ scala_of_dnnrc y ++ " else " ++ scala_of_dnnrc z ++ ")"
        | DNNRCEither t x vy y vz z
          match olift tuneither (lift_tlocal (di_required_typeof x)) with
          | Some (lt, rt)
            "either(" ++ scala_of_dnnrc x ++ ", (" ++
                      vy ++ ": " ++ rtype_to_scala_type (proj1_sig lt) ++
                      ") => { " ++ scala_of_dnnrc y ++ " }, (" ++
                      vz ++ ": " ++ rtype_to_scala_type (proj1_sig rt) ++
                      ") => { " ++ scala_of_dnnrc z ++ " })"
          | None ⇒ "DNNRCEither's first argument is not of type Either"
          end
        
        | DNNRCGroupBy t g sl x
          "DNNRC_GROUPBY_CODEGEN_IS_NOT_CURRENTLY_IMPLEMENTED"
        | DNNRCCollect t x
          
          let postfix :=
              match olift tuncoll (lift_tdistr (di_typeof x)) with
              | Some rt
                match proj1_sig rt with
                | Nat₀ ⇒ ".map((row) => row.getLong(0))"
                | Foreign₀ _ ⇒ ".map((row) => row.getFloat(0))"
                | _ ⇒ ""
                         
                end
              | None ⇒ "ARGUMENT_TO_COLLECT_SHOULD_BE_A_DISTRIBUTED_COLLECTION"
              end in
          scala_of_dnnrc x ++ ".collect()" ++ postfix
        
        | DNNRCDispatch t x
          match olift tuncoll (lift_tlocal (di_typeof x)) with
          | Some et
            "dispatch(" ++ rtype_to_spark_DataType (proj1_sig et) ++ ", " ++ scala_of_dnnrc x ++ ")"
          | None ⇒ "Argument to dispatch is not a local collection."
          end
        | DNNRCAlg t a ((x, d)::nil) ⇒
          
          "{ val " ++ x ++ " = " ++ scala_of_dnnrc d ++ "; " ++
                   code_of_dataset a
                   ++ " }"
        | DNNRCAlg _ _ _
          "NON_UNARY_ALG_CODEGEN_IS_NOT_CURRENTLY_IMPLEMENTED"
        end in
    if di_typeof d == di_required_typeof d then code else
      match lift_tlocal (di_required_typeof d) with
      
      | Some r ⇒ "identity/*CAST*/(" ++ code ++ ")"
      | None ⇒ "CANTCASTTODISTRIBUTEDTYPE"
      end.

  Definition initBrandHierarchy : string :=
    let lines :=
        map (fun p ⇒ "(""" ++ fst p ++ """ -> """ ++ snd p ++ """)")
            brand_relation_brands in
    joinStrings ", " lines.

Toplevel entry to Spark2/Scala codegen


  Definition type_name_of_var (var:string) : string :=
    var ++ "$TYPE".

  Definition scala_type_of_tbinding (bind:string × rtype) :=
    "val " ++ (type_name_of_var (fst bind)) ++ " = " ++ rtype_to_spark_DataType (proj1_sig (snd bind)) ++ eol.

  Definition scala_var_of_tbinding (bind:string × rtype) :=
    "val " ++ (fst bind) ++ " = sparkSession.read.schema(" ++ (type_name_of_var (fst bind)) ++ ").json(args(0))" ++ eol.

  Definition dnnrcToSpark2Top {A : Set} (tenv:tdbindings) (name: string)
             (e: dnnrc (type_annotation A) dataset) : string :=
    
    let inputType := lookup equiv_dec tenv "CONST$WORLD"%string in
    ""
      ++ "import org.apache.spark.SparkContext" ++ eol
      ++ "import org.apache.spark.sql.functions._" ++ eol
      ++ "import org.apache.spark.sql.SparkSession" ++ eol
      ++ "import org.apache.spark.sql.types._" ++ eol
      ++ "import org.qcert.QcertRuntime" ++ eol
      ++ "import org.qcert.QcertRuntime._" ++ eol

      ++ "object " ++ name ++ " {" ++ eol
      ++ "def main(args: Array[String]): Unit = {" ++ eol
      
      
      ++ (joinStrings "" (map scala_type_of_tbinding (unlocalize_tdbindings tenv)))
      ++ "val HIERARCHY = QcertRuntime.makeHierarchy(" ++ initBrandHierarchy ++ ")" ++ eol
      ++ "val sparkContext = new SparkContext()" ++ eol
      ++ "val sparkSession = SparkSession.builder().getOrCreate()" ++ eol
      
      
      ++ (joinStrings "" (map scala_var_of_tbinding (unlocalize_tdbindings tenv)))
      ++ "import sparkSession.implicits._" ++ eol
      ++ "QcertRuntime.beforeQuery()" ++ eol
      ++ "println(QcertRuntime.toBlob(" ++ eol
      ++ scala_of_dnnrc e ++ eol
      ++ "))" ++ eol
      ++ "QcertRuntime.afterQuery()" ++ eol
      ++ "sparkContext.stop()" ++ eol
      ++ "}" ++ eol
      ++ "}"
  .

End DNNRCtoScala.