
Section Helpers.
  Definition eol := String (Ascii.ascii_of_nat 10) EmptyString.

End Helpers.

Section DNNRCtoScala.

  Context {f:foreign_runtime}.
  Context {h:brand_relation_t}.
  Context {ftype:foreign_type}.
  Context {m:brand_model}.
  Context {fdtyping:foreign_data_typing}.
  Context {fboptyping:foreign_binary_op_typing}.
  Context {fuoptyping:foreign_unary_op_typing}.
  Context {fttjs: ForeignToJavascript.foreign_to_javascript}.
  Context {fts: ForeignToScala.foreign_to_scala}.

  Definition quote_string (s: string) : string :=
    """" ++ s ++ """".

Get code for a Spark DataType corresping to an rtype. *

These are things like StringType, ArrayType(...), StructType(Seq(FieldType(...), ...)), ...


This function encodes details of our data representation, e.g. records are StructFields

with two toplevel fields known, and the

statically known fields and their types.

  Fixpoint rtype_to_spark_DataType (r: rtype₀) : string :=
    match r with
    | Bottom₀ ⇒ "NullType"
    | Top₀ ⇒ "StringType"
    | Unit₀ ⇒ "NullType"
    | Nat₀ ⇒ "LongType"
    | Bool₀ ⇒ "BooleanType"
    | String₀ ⇒ "StringType"
    | Coll₀ e ⇒ "ArrayType(" ++ rtype_to_spark_DataType e ++ ")"
    | Rec₀ _ fields
      let known_fields: list string :=
          map (fun p ⇒ "StructField(""" ++ fst p ++ """, " ++ rtype_to_spark_DataType (snd p) ++ ")")
              fields in
      let known_struct := "StructType(Seq(" ++ joinStrings ", " known_fields ++ "))" in
      "StructType(Seq(StructField(""$blob"", StringType), StructField(""$known"", " ++ known_struct ++ ")))"
    | Either₀ l r
      "StructType(Seq(StructField(""$left"", " ++ rtype_to_spark_DataType l ++ "), StructField(""$right"", " ++ rtype_to_spark_DataType r ++ ")))"
    | Brand₀ _
      "StructType(Seq(StructField(""$data"", StringType), StructField(""$type"", ArrayType(StringType))))"
    | Arrow₀ _ _ ⇒ "ARROW TYPE?"
    | Foreign₀ ftforeign_to_scala_spark_datatype ft

Scala-level type of an rtype. *

These are things like Long, String, Boolean, Array..., Row.


We need to annotate some expressions with Scala-level types

(e.g. ArrayRow() for an empty Array of Records) to help

the Scala compiler because it does not infer types everywhere.

  Fixpoint rtype_to_scala_type (t: rtype₀): string :=
    match t with
    | Bottom₀ ⇒ "BOTTOM?"
    | Top₀ ⇒ "TOP?"
    | Unit₀ ⇒ "Unit"
    | Nat₀ ⇒ "Long"
    | Bool₀ ⇒ "Boolean"
    | String₀ ⇒ "String"
    | Coll₀ r ⇒ "Array[" ++ rtype_to_scala_type r ++ "]"
    | Rec₀ _ _ ⇒ "Record"
    | Either₀ tl tr ⇒ "Either"
    | Brand₀ bs ⇒ "BrandedValue"
    | Foreign₀ f ⇒ "FOREIGN?"

  Definition drtype_to_scala (t: drtype): string :=
    match t with
    | Tlocal rrtype_to_scala_type (proj1_sig r)
    | Tdistr r ⇒ "Dataset[" ++ rtype_to_scala_type (proj1_sig r) ++ "]"

  Fixpoint scala_literal_data (d: data) (t: rtype₀) {struct t}: string :=
    match t, d with
    | Unit₀, d ⇒ "()"
    | Nat₀, dnat iZ_to_string10 i
    | Bool₀, dbool true ⇒ "true"
    | Bool₀, dbool false ⇒ "false"
    | String₀, dstring squote_string s
    | Coll₀ r, dcoll xs
      let element_type := rtype_to_scala_type r in
      let elements := map (fun xscala_literal_data x r) xs in
      "Array[" ++ element_type ++ "](" ++ joinStrings ", " elements ++ ")"
    | Rec₀ _ fts, drec fds
      let blob := quote_string (data_to_blob d) in
      let known_schema :=
            ++ joinStrings ", "
            (map (fun ft
                    "StructField(""" ++ fst ft ++ """, " ++ rtype_to_spark_DataType (snd ft) ++ ")")
            ++ "))" in
      let fields := map (fun ftmatch lookup string_dec fds (fst ft) with
                                   | Some dscala_literal_data d (snd ft)
                                   | None ⇒ "FIELD_IN_TYPE_BUT_NOT_IN_DATA"
                                   end) fts in
      let known := "srow(" ++ joinStrings ", " (known_schema :: fields) ++ ")" in
      "srow(StructType(Seq(StructField(""$blob"", StringType), StructField(""$known"", " ++ known_schema ++ "))), " ++ blob ++ ", " ++ known ++ ")"

  Fixpoint code_of_column (c: column) : string :=
    match c with
    | CCol s ⇒ "column(""" ++ s ++ """)"
    | CDot fld ccode_of_column c ++ ".getField(" ++ quote_string fld ++ ")"
    | CEq c1 c2code_of_column c1 ++ ".equalTo(" ++ code_of_column c2 ++ ")"
    | CLessThan c1 c2code_of_column c1 ++ ".lt(" ++ code_of_column c2 ++ ")"
    | CLit (d, r) ⇒ "lit(" ++ scala_literal_data d r ++ ")"
    | CNeg c ⇒ "not(" ++ code_of_column c ++ ")"
    | CPlus c1 c2code_of_column c1 ++ ".plus(" ++ code_of_column c2 ++ ")"
    | CSConcat c1 c2
      "concat(" ++ code_of_column c1 ++ ", " ++ code_of_column c2 ++ ")"
    | CToString c
      "toQcertStringUDF(" ++ code_of_column c ++ ")"
    | CUDFCast bs c
      "castUDF(" ++ joinStrings ", " ("HIERARCHY"%string :: map quote_string bs) ++ ")(" ++ code_of_column c ++ ")"
    | CUDFUnbrand t c
      "unbrandUDF(" ++ rtype_to_spark_DataType t ++ ")(" ++ code_of_column c ++ ")"

  Fixpoint code_of_dataset (e: dataset) : string :=
    match e with
    | DSVar ss
    | DSSelect cs d
      let columns :=
          map (fun nccode_of_column (snd nc) ++ ".as(""" ++ fst nc ++ """)") cs in
      code_of_dataset d ++ ".select(" ++ joinStrings ", " columns ++ ")"
    | DSFilter c dcode_of_dataset d ++ ".filter(" ++ code_of_column c ++ ")"
    | DSCartesian d1 d2code_of_dataset d1 ++ ".join(" ++ (code_of_dataset d2) ++ ")"
    | DSExplode s d1code_of_dataset d1 ++ ".select(explode(" ++ code_of_column (CCol s) ++ ").as(""" ++ s ++ """))"

  Definition spark_of_unop (op: unaryOp) (x: string) : string :=
    match op with
      | ACountx ++ ".count()"
      | ASumx ++ ".select(sum(""value"")).first().getLong(0)"
      | AFlattenx ++ ".flatMap(r => r)"
      | _ ⇒ "SPARK_OF_UNOP don't know how to generate Spark code for this operator"

  Definition scala_of_unop (required_type: drtype) (op: unaryOp) (x: string) : string :=
    let prefix s := s ++ "(" ++ x ++ ")" in
    let postfix s := x ++ "." ++ s in
    match op with
    | AArithMeanprefix "arithMean"
    | ABrand bs ⇒ "brand(" ++ joinStrings ", " (x::(map quote_string bs)) ++ ")"
    | ACast bs
      "cast(HIERARCHY, " ++ x ++ ", " ++ joinStrings ", " (map quote_string bs) ++ ")"
    | ACollprefix "Array"
    | ACountpostfix "length"
    | ADot n
      match lift_tlocal required_type with
      | Some r
        prefix ("dot[" ++ rtype_to_scala_type (proj1_sig r) ++ "](""" ++ n ++ """)")
    | AFlattenpostfix "flatten.sorted(QcertOrdering)"
    | AIdOpprefix "identity"
    | ALeftprefix "left"
    | ANeg ⇒ "(!" ++ x ++ ")"
    | ANumMaxprefix "anummax"
    | ANumMinprefix "anummin"
    | ARec n
      match lift_tlocal required_type with
      | Some (exist _ (Rec₀ Closed ((_, ft)::nil)) _) ⇒
        "singletonRecord(" ++ quote_string n ++ ", " ++ rtype_to_spark_DataType ft ++ ", " ++ x ++ ")"
    | ARecProject fsprefix ("recProject(" ++ joinStrings ", " (map quote_string fs) ++ ")")
    | ARightprefix "right"
    | ASumpostfix "sum"
    | AToStringprefix "toQcertString"
    | ASubstring start olen
      "(" ++ x ++ ").substring(" ++ toString start ++
          match olen with
          | Some len ⇒ ", " ++ toString len
          | None ⇒ ""
          end ++ ")"
    | ALike pat oescape
      "ALike currently implemented. Please implement as in the java backend"

    | AUArith ArithAbsprefix "Math.abs"
    | AUnbrand
      match lift_tlocal required_type with
      | Some (exist _ r _) ⇒
        let schema := rtype_to_spark_DataType r in
        let scala := rtype_to_scala_type r in
        "unbrand[" ++ scala ++ "](" ++ schema ++ ", " ++ x ++ ")"
    | ADistinctpostfix "distinct"
    | AOrderBy scl ⇒ "SORT???"
    | AForeignUnaryOp oforeign_to_scala_unary_op o x

    | ARecRemove _ ⇒ "ARECREMOVE???"
    | ASingleton ⇒ "SINGLETON???"
    | AUArith ArithLog2 ⇒ "LOG2???"
    | AUArith ArithSqrt ⇒ "SQRT???"

  Definition spark_of_binop (op: binOp) (x: string) (y: string) : string :=
    match op with
    | AUnionx ++ ".union(" ++ y ++ ")"
    | _ ⇒ "SPARK_OF_BINOP don't know how to generate Spark code for this operator"

  Definition scala_of_binop (op: binOp) (l: string) (r: string) : string :=
    let infix s := l ++ "." ++ s ++ "(" ++ r ++ ")" in
    let infix' s := "(" ++ l ++ " " ++ s ++ " " ++ r ++ ")" in
    let prefix s := s ++ "(" ++ l ++ ", " ++ r ++ ")" in
    match op with
    | AAndinfix "&&"
    | ABArith ArithDivideinfix "/"
    | ABArith ArithMaxinfix "max"
    | ABArith ArithMininfix "min"
    | ABArith ArithMinusinfix "-"
    | ABArith ArithMultinfix "*"
    | ABArith ArithPlusinfix "+"
    | ABArith ArithReminfix "%"
    | AConcatprefix "recordConcat"
    | AContainsprefix "AContains"
    | AEqprefix "equal"
    | ALeprefix "lessOrEqual"
    | ALtprefix "lessThan"
    | AMaxl ++ ".++(" ++ r ++ ".diff(" ++ l ++ "))"
    | AMinl ++ ".diff(" ++ l ++ ".diff(" ++ r ++ "))"
    | AMinusr ++ ".diff(" ++ l ++ ")"
    | AMergeConcatprefix "mergeConcat"
    | AOrinfix "||"
    | ASConcatinfix "+"
    | AUnioninfix "++" ++ ".sorted(QcertOrdering)"

    | AForeignBinaryOp op ⇒ "FOREIGNBINARYOP???"

  Definition primitive_type (t: rtype) :=
    match proj1_sig t with
    | | | Unit₀ | Nat₀ | String₀ | Bool₀true
    | Coll₀ _ | Rec₀ _ _ | Either₀ _ _ | Arrow₀ _ _ | Brand₀ _false
    | Foreign₀ _false

  Fixpoint scala_of_dnnrc {A: Set} (d: dnnrc (type_annotation A) dataset) : string :=
    let code :=
        match d with
        | DNNRCVar t nn
        | DNNRCConst t c
          match (lift_tlocal (di_required_typeof d)) with
          | Some rscala_literal_data c (proj1_sig r)
          | None ⇒ "Don't know how to construct a distributed constant"
        | DNNRCBinop t op x y
          match di_typeof x, di_typeof y with
          | Tlocal _, Tlocal _scala_of_binop op (scala_of_dnnrc x) (scala_of_dnnrc y)
          | Tdistr _, Tdistr _spark_of_binop op (scala_of_dnnrc x) (scala_of_dnnrc y)
        | DNNRCUnop t op x
          match di_typeof x with
          | Tlocal _scala_of_unop (di_required_typeof d) op (scala_of_dnnrc x)
          | Tdistr _spark_of_unop op (scala_of_dnnrc x)
        | DNNRCLet t n x y
          "((( " ++ n ++ ": " ++ drtype_to_scala (di_typeof x) ++ ") => " ++ scala_of_dnnrc y ++ ")(" ++ scala_of_dnnrc x ++ "))"
        | DNNRCFor t n x y
          scala_of_dnnrc x ++ ".map((" ++ n ++ ") => { " ++ scala_of_dnnrc y ++ " })"
        | DNNRCIf t x y z
          "(if (" ++ scala_of_dnnrc x ++ ") " ++ scala_of_dnnrc y ++ " else " ++ scala_of_dnnrc z ++ ")"
        | DNNRCEither t x vy y vz z
          match olift tuneither (lift_tlocal (di_required_typeof x)) with
          | Some (lt, rt)
            "either(" ++ scala_of_dnnrc x ++ ", (" ++
                      vy ++ ": " ++ rtype_to_scala_type (proj1_sig lt) ++
                      ") => { " ++ scala_of_dnnrc y ++ " }, (" ++
                      vz ++ ": " ++ rtype_to_scala_type (proj1_sig rt) ++
                      ") => { " ++ scala_of_dnnrc z ++ " })"
          | None ⇒ "DNNRCEither's first argument is not of type Either"
        | DNNRCGroupBy t g sl x
        | DNNRCCollect t x
          let postfix :=
              match olift tuncoll (lift_tdistr (di_typeof x)) with
              | Some rt
                match proj1_sig rt with
                | Nat₀ ⇒ ".map((row) => row.getLong(0))"
                | Foreign₀ _ ⇒ ".map((row) => row.getFloat(0))"
                | _ ⇒ ""
              end in
          scala_of_dnnrc x ++ ".collect()" ++ postfix
        | DNNRCDispatch t x
          match olift tuncoll (lift_tlocal (di_typeof x)) with
          | Some et
            "dispatch(" ++ rtype_to_spark_DataType (proj1_sig et) ++ ", " ++ scala_of_dnnrc x ++ ")"
          | None ⇒ "Argument to dispatch is not a local collection."
        | DNNRCAlg t a ((x, d)::nil) ⇒
          "{ val " ++ x ++ " = " ++ scala_of_dnnrc d ++ "; " ++
                   code_of_dataset a
                   ++ " }"
        | DNNRCAlg _ _ _
        end in
    if di_typeof d == di_required_typeof d then code else
      match lift_tlocal (di_required_typeof d) with
      | Some r ⇒ "identity/*CAST*/(" ++ code ++ ")"

  Definition initBrandHierarchy : string :=
    let lines :=
        map (fun p ⇒ "(""" ++ fst p ++ """ -> """ ++ snd p ++ """)")
            brand_relation_brands in
    joinStrings ", " lines.

Toplevel entry to Spark2/Scala codegen

  Definition type_name_of_var (var:string) : string :=
    var ++ "$TYPE".

  Definition scala_type_of_tbinding (bind:string × rtype) :=
    "val " ++ (type_name_of_var (fst bind)) ++ " = " ++ rtype_to_spark_DataType (proj1_sig (snd bind)) ++ eol.

  Definition scala_var_of_tbinding (bind:string × rtype) :=
    "val " ++ (fst bind) ++ " = sparkSession.read.schema(" ++ (type_name_of_var (fst bind)) ++ ").json(args(0))" ++ eol.

  Definition dnnrcToSpark2Top {A : Set} (tenv:tdbindings) (name: string)
             (e: dnnrc (type_annotation A) dataset) : string :=
    let inputType := lookup equiv_dec tenv "CONST$WORLD"%string in
      ++ "import org.apache.spark.SparkContext" ++ eol
      ++ "import org.apache.spark.sql.functions._" ++ eol
      ++ "import org.apache.spark.sql.SparkSession" ++ eol
      ++ "import org.apache.spark.sql.types._" ++ eol
      ++ "import org.qcert.QcertRuntime" ++ eol
      ++ "import org.qcert.QcertRuntime._" ++ eol

      ++ "object " ++ name ++ " {" ++ eol
      ++ "def main(args: Array[String]): Unit = {" ++ eol
      ++ (joinStrings "" (map scala_type_of_tbinding (unlocalize_tdbindings tenv)))
      ++ "val HIERARCHY = QcertRuntime.makeHierarchy(" ++ initBrandHierarchy ++ ")" ++ eol
      ++ "val sparkContext = new SparkContext()" ++ eol
      ++ "val sparkSession = SparkSession.builder().getOrCreate()" ++ eol
      ++ (joinStrings "" (map scala_var_of_tbinding (unlocalize_tdbindings tenv)))
      ++ "import sparkSession.implicits._" ++ eol
      ++ "QcertRuntime.beforeQuery()" ++ eol
      ++ "println(QcertRuntime.toBlob(" ++ eol
      ++ scala_of_dnnrc e ++ eol
      ++ "))" ++ eol
      ++ "QcertRuntime.afterQuery()" ++ eol
      ++ "sparkContext.stop()" ++ eol
      ++ "}" ++ eol
      ++ "}"

End DNNRCtoScala.