Module Qcert.Translation.tDNNRCtoSparkDF


Section Helpers.
  Require Import String.
  Definition eol := String (Ascii.ascii_of_nat 10) EmptyString.
End Helpers.

Require Import List.
Require Import String.
Require Import Peano_dec.
Require Import EquivDec.
Require Import Utils.
Require Import CommonSystem.
Require Import NNRCRuntime.
Require Import tDNNRCSystem.
Require Import ForeignToScala.
Require Import DatatoSparkDF.

Local Open Scope string_scope.

Section tDNNRCtoSparkDF.

  Context {f:foreign_runtime}.
  Context {h:brand_relation_t}.
  Context {ftype:foreign_type}.
  Context {m:brand_model}.
  Context {fdtyping:foreign_data_typing}.
  Context {fboptyping:foreign_binary_op_typing}.
  Context {fuoptyping:foreign_unary_op_typing}.
  Context {fttjs: ForeignToJavaScript.foreign_to_javascript}.
  Context {fts: ForeignToScala.foreign_to_scala}.
  Context {ftjson:foreign_to_JSON}.

  Definition quote_string (s: string) : string :=
    """" ++ s ++ """".

Get code for a Spark DataType corresping to an rtype. * * These are things like StringType, ArrayType(...), StructType(Seq(FieldType(...), ...)), ... * * This function encodes details of our data representation, e.g. records are StructFields * with two toplevel fields $blob and $known, and the $known contains a record with the * statically known fields and their types.
  Fixpoint rtype_to_spark_DataType (r: rtype₀) : string :=
    match r with
    | Bottom₀ => "NullType"
    | Top₀ => "StringType"
    | Unit₀ => "NullType"
    | Nat₀ => "LongType"
    | Bool₀ => "BooleanType"
    | String₀ => "StringType"
    | Colle => "ArrayType(" ++ rtype_to_spark_DataType e ++ ")"
    | Rec_ fields =>
      let known_fields: list string :=
          map (fun p => "StructField(""" ++ fst p ++ """, " ++ rtype_to_spark_DataType (snd p) ++ ")")
              fields in
      let known_struct := "StructType(Seq(" ++ joinStrings ", " known_fields ++ "))" in
      "StructType(Seq(StructField(""$blob"", StringType), StructField(""$known"", " ++ known_struct ++ ")))"
    | Eitherl r =>
      "StructType(Seq(StructField(""$left"", " ++ rtype_to_spark_DataType l ++ "), StructField(""$right"", " ++ rtype_to_spark_DataType r ++ ")))"
    | Brand_ =>
      "StructType(Seq(StructField(""$data"", StringType), StructField(""$type"", ArrayType(StringType))))"
    | Arrow_ _ => "ARROW TYPE?"
    | Foreignft => foreign_to_scala_spark_datatype ft
    end.

Scala-level type of an rtype. * * These are things like Long, String, Boolean, Array..., Row. * * We need to annotate some expressions with Scala-level types * (e.g. ArrayRow() for an empty Array of Records) to help * the Scala compiler because it does not infer types everywhere.
  Fixpoint rtype_to_scala_type (t: rtype₀): string :=
    match t with
    | Bottom₀ => "BOTTOM?"
    | Top₀ => "TOP?"
    | Unit₀ => "Unit"
    | Nat₀ => "Long"
    | Bool₀ => "Boolean"
    | String₀ => "String"
    | Collr => "Array[" ++ rtype_to_scala_type r ++ "]"
    | Rec_ _ => "Record"
    | Eithertl tr => "Either"
    | Brandbs => "BrandedValue"
    | Arrowtin t => "CANNOT PUT AN ARROW INTO A DATAFRAME"
    | Foreignf => "FOREIGN?"
    end.

  Definition drtype_to_scala (t: drtype): string :=
    match t with
    | Tlocal r => rtype_to_scala_type (proj1_sig r)
    | Tdistr r => "Dataset[" ++ rtype_to_scala_type (proj1_sig r) ++ "]"
    end.

  Fixpoint scala_literal_data (d: data) (t: rtype₀) {struct t}: string :=
    match t, d with
    | Unit₀, d => "()"
    | Nat₀, dnat i => Z_to_string10 i
    | Bool₀, dbool true => "true"
    | Bool₀, dbool false => "false"
    | String₀, dstring s => quote_string s
    | Collr, dcoll xs =>
      let element_type := rtype_to_scala_type r in
      let elements := map (fun x => scala_literal_data x r) xs in
      "Array[" ++ element_type ++ "](" ++ joinStrings ", " elements ++ ")"
    | Rec_ fts, drec fds =>
      let blob := quote_string (data_to_blob d) in
      let known_schema :=
          "StructType(Seq("
            ++ joinStrings ", "
            (map (fun ft =>
                    "StructField(""" ++ fst ft ++ """, " ++ rtype_to_spark_DataType (snd ft) ++ ")")
                 fts)
            ++ "))" in
      let fields := map (fun ft => match lookup string_dec fds (fst ft) with
                                   | Some d => scala_literal_data d (snd ft)
                                   | None => "FIELD_IN_TYPE_BUT_NOT_IN_DATA"
                                   end) fts in
      let known := "srow(" ++ joinStrings ", " (known_schema :: fields) ++ ")" in
      "srow(StructType(Seq(StructField(""$blob"", StringType), StructField(""$known"", " ++ known_schema ++ "))), " ++ blob ++ ", " ++ known ++ ")"
    | _, _ => "UNIMPLEMENTED_SCALA_LITERAL_DATA"
    end.

  Fixpoint code_of_column (c: column) : string :=
    match c with
    | CCol s => "column(""" ++ s ++ """)"
    | CDot fld c => code_of_column c ++ ".getField(" ++ quote_string fld ++ ")"
    | CEq c1 c2 => code_of_column c1 ++ ".equalTo(" ++ code_of_column c2 ++ ")"
    | CLessThan c1 c2 => code_of_column c1 ++ ".lt(" ++ code_of_column c2 ++ ")"
    | CLit (d, r) => "lit(" ++ scala_literal_data d r ++ ")"
    | CNeg c => "not(" ++ code_of_column c ++ ")"
    | CPlus c1 c2 => code_of_column c1 ++ ".plus(" ++ code_of_column c2 ++ ")"
    | CSConcat c1 c2 =>
      "concat(" ++ code_of_column c1 ++ ", " ++ code_of_column c2 ++ ")"
    | CToString c =>
      "toQcertStringUDF(" ++ code_of_column c ++ ")"
    | CUDFCast bs c =>
      "castUDF(" ++ joinStrings ", " ("HIERARCHY"%string :: map quote_string bs) ++ ")(" ++ code_of_column c ++ ")"
    | CUDFUnbrand t c =>
      "unbrandUDF(" ++ rtype_to_spark_DataType t ++ ")(" ++ code_of_column c ++ ")"
    end.

  Fixpoint code_of_dataframe (e: dataframe) : string :=
    match e with
    | DSVar s => s
    | DSSelect cs d =>
      let columns :=
          map (fun nc => code_of_column (snd nc) ++ ".as(""" ++ fst nc ++ """)") cs in
      code_of_dataframe d ++ ".select(" ++ joinStrings ", " columns ++ ")"
    | DSFilter c d => code_of_dataframe d ++ ".filter(" ++ code_of_column c ++ ")"
    | DSCartesian d1 d2 => code_of_dataframe d1 ++ ".join(" ++ (code_of_dataframe d2) ++ ")"
    | DSExplode s d1 => code_of_dataframe d1 ++ ".select(explode(" ++ code_of_column (CCol s) ++ ").as(""" ++ s ++ """))"
    end.

  Definition spark_of_unary_op (op: unary_op) (x: string) : string :=
    match op with
      | OpCount => x ++ ".count()"
      | OpSum => x ++ ".select(sum(""value"")).first().getLong(0)"
      | OpFlatten => x ++ ".flatMap(r => r)"
      | _ => "SPARK_OF_UNARY_OP don't know how to generate Spark code for this operator"
    end.

  Definition scala_of_unary_op (required_type: drtype) (op: unary_op) (x: string) : string :=
    let prefix s := s ++ "(" ++ x ++ ")" in
    let postfix s := x ++ "." ++ s in
    match op with
    | OpIdentity => prefix "identity"
    | OpNeg => "(!" ++ x ++ ")"
    | OpRec n =>
      match lift_tlocal required_type with
      | Some (exist _ (RecClosed ((_, ft)::nil)) _) =>
        "singletonRecord(" ++ quote_string n ++ ", " ++ rtype_to_spark_DataType ft ++ ", " ++ x ++ ")"
      | _ => "AREC_WITH_UNEXPECTED_REQUIRED_TYPE"
      end
    | OpDot n =>
      match lift_tlocal required_type with
      | Some r =>
        prefix ("dot[" ++ rtype_to_scala_type (proj1_sig r) ++ "](""" ++ n ++ """)")
      | None => "NONLOCAL EXPECTED TYPE IN DOT"
      end
    | OpRecProject fs => prefix ("recProject(" ++ joinStrings ", " (map quote_string fs) ++ ")")
    | OpBag => prefix "Array"
    | OpFlatten => postfix "flatten.sorted(QcertOrdering)"
    | OpDistinct => postfix "distinct"
    | OpCount => postfix "length"
    | OpSum => postfix "sum"
    | OpNumMax => prefix "anummax"
    | OpNumMin => prefix "anummin"
    | OpNumMean => prefix "arithMean"
    | OpToString => prefix "toQcertString"
    | OpSubstring start olen =>
      "(" ++ x ++ ").substring(" ++ toString start ++
          match olen with
          | Some len => ", " ++ toString len
          | None => ""
          end ++ ")"
    | OpLike pat oescape =>
      "ALike currently implemented. Please implement as in the java backend"
    | OpLeft => prefix "left"
    | OpRight => prefix "right"
    | OpBrand bs => "brand(" ++ joinStrings ", " (x::(map quote_string bs)) ++ ")"
    | OpCast bs =>
      "cast(HIERARCHY, " ++ x ++ ", " ++ joinStrings ", " (map quote_string bs) ++ ")"
    | OpUnbrand =>
      match lift_tlocal required_type with
      | Some (exist _ r _) =>
        let schema := rtype_to_spark_DataType r in
        let scala := rtype_to_scala_type r in
        "unbrand[" ++ scala ++ "](" ++ schema ++ ", " ++ x ++ ")"
      | None => "UNBRAND_REQUIRED_TYPE_ISSUE"
      end
    | OpArithUnary ArithAbs => prefix "Math.abs"
    | OpForeignUnary o => foreign_to_scala_unary_op o x

    | OpOrderBy scl => "SORT???"
    | OpRecRemove _ => "ARECREMOVE???"
    | OpSingleton => "SINGLETON???"
    | OpArithUnary ArithLog2 => "LOG2???"
    | OpArithUnary ArithSqrt => "SQRT???"
    end.

  Definition spark_of_binary_op (op: binary_op) (x: string) (y: string) : string :=
    match op with
    | OpBagUnion => x ++ ".union(" ++ y ++ ")"
    | _ => "SPARK_OF_BINARY_OP don't know how to generate Spark code for this operator"
    end.

  Definition scala_of_binary_op (op: binary_op) (l: string) (r: string) : string :=
    let infix s := l ++ "." ++ s ++ "(" ++ r ++ ")" in
    let infix' s := "(" ++ l ++ " " ++ s ++ " " ++ r ++ ")" in
    let prefix s := s ++ "(" ++ l ++ ", " ++ r ++ ")" in
    match op with
    | OpEqual => prefix "equal"
    | OpRecConcat => prefix "recordConcat"
    | OpRecMerge => prefix "mergeConcat"
    | OpAnd => infix "&&"
    | OpOr => infix "||"
    | OpLe => prefix "lessOrEqual"
    | OpLt => prefix "lessThan"
    | OpBagUnion => infix "++" ++ ".sorted(QcertOrdering)"
    | OpBagDiff => r ++ ".diff(" ++ l ++ ")"
    | OpBagMax => l ++ ".++(" ++ r ++ ".diff(" ++ l ++ "))"
    | OpBagMin => l ++ ".diff(" ++ l ++ ".diff(" ++ r ++ "))"
    | OpContains => prefix "AContains"
    | OpStringConcat => infix "+"
    | OpArithBinary ArithDivide => infix "/"
    | OpArithBinary ArithMax => infix "max"
    | OpArithBinary ArithMin => infix "min"
    | OpArithBinary ArithMinus => infix "-"
    | OpArithBinary ArithMult => infix "*"
    | OpArithBinary ArithPlus => infix "+"
    | OpArithBinary ArithRem => infix "%"

    | OpForeignBinary op => "FOREIGNBINARYOP???"
    end.

  Definition primitive_type (t: rtype) :=
    match proj1_sig t with
    | ⊥₀ | ⊤₀ | Unit₀ | Nat₀ | String₀ | Bool₀ => true
    | Coll_ | Rec_ _ | Either_ _ | Arrow_ _ | Brand_ => false
    | Foreign_ => false
    end.

  Fixpoint scala_of_dnnrc_base {A: Set} (d:@dnnrc_base _ (type_annotation A) dataframe) : string :=
    let code :=
        match d with
        | DNNRCGetConstant t n => n
        | DNNRCVar t n => n
        | DNNRCConst t c =>
          match (lift_tlocal (di_required_typeof d)) with
          | Some r => scala_literal_data c (proj1_sig r)
          | None => "Don't know how to construct a distributed constant"
          end
        | DNNRCBinop t op x y =>
          match di_typeof x, di_typeof y with
          | Tlocal _, Tlocal _ => scala_of_binary_op op (scala_of_dnnrc_base x) (scala_of_dnnrc_base y)
          | Tdistr _, Tdistr _ => spark_of_binary_op op (scala_of_dnnrc_base x) (scala_of_dnnrc_base y)
          | _, _ => "DONT_SUPPORT_MIXED_LOCAL_DISTRIBUTED_BINARY_OPERATORS"
          end
        | DNNRCUnop t op x =>
          match di_typeof x with
          | Tlocal _ => scala_of_unary_op (di_required_typeof d) op (scala_of_dnnrc_base x)
          | Tdistr _ => spark_of_unary_op op (scala_of_dnnrc_base x)
          end
        | DNNRCLet t n x y =>
          "((( " ++ n ++ ": " ++ drtype_to_scala (di_typeof x) ++ ") => " ++ scala_of_dnnrc_base y ++ ")(" ++ scala_of_dnnrc_base x ++ "))"
        | DNNRCFor t n x y =>
          scala_of_dnnrc_base x ++ ".map((" ++ n ++ ") => { " ++ scala_of_dnnrc_base y ++ " })"
        | DNNRCIf t x y z =>
          "(if (" ++ scala_of_dnnrc_base x ++ ") " ++ scala_of_dnnrc_base y ++ " else " ++ scala_of_dnnrc_base z ++ ")"
        | DNNRCEither t x vy y vz z =>
          match olift tuneither (lift_tlocal (di_required_typeof x)) with
          | Some (lt, rt) =>
            "either(" ++ scala_of_dnnrc_base x ++ ", (" ++
                      vy ++ ": " ++ rtype_to_scala_type (proj1_sig lt) ++
                      ") => { " ++ scala_of_dnnrc_base y ++ " }, (" ++
                      vz ++ ": " ++ rtype_to_scala_type (proj1_sig rt) ++
                      ") => { " ++ scala_of_dnnrc_base z ++ " })"
          | None => "DNNRCEither's first argument is not of type Either"
          end
        | DNNRCGroupBy t g sl x =>
          "DNNRC_GROUPBY_CODEGEN_IS_NOT_CURRENTLY_IMPLEMENTED"
        | DNNRCCollect t x =>
          let postfix :=
              match olift tuncoll (lift_tdistr (di_typeof x)) with
              | Some rt =>
                match proj1_sig rt with
                | Nat₀ => ".map((row) => row.getLong(0))"
                | Foreign_ => ".map((row) => row.getFloat(0))"
                | _ => ""
                end
              | None => "ARGUMENT_TO_COLLECT_SHOULD_BE_A_DISTRIBUTED_COLLECTION"
              end in
          scala_of_dnnrc_base x ++ ".collect()" ++ postfix
        | DNNRCDispatch t x =>
          match olift tuncoll (lift_tlocal (di_typeof x)) with
          | Some et =>
            "dispatch(" ++ rtype_to_spark_DataType (proj1_sig et) ++ ", " ++ scala_of_dnnrc_base x ++ ")"
          | None => "Argument to dispatch is not a local collection."
          end
        | DNNRCAlg t a ((x, d)::nil) =>
          "{ val " ++ x ++ " = " ++ scala_of_dnnrc_base d ++ "; " ++
                   code_of_dataframe a
                   ++ " }"
        | DNNRCAlg _ _ _ =>
          "NON_UNARY_ALG_CODEGEN_IS_NOT_CURRENTLY_IMPLEMENTED"
        end in
    if di_typeof d == di_required_typeof d then code else
      match lift_tlocal (di_required_typeof d) with
      | Some r => "identity/*CAST*/(" ++ code ++ ")"
      | None => "CANTCASTTODISTRIBUTEDTYPE"
      end.

  Definition initBrandHierarchy : string :=
    let lines :=
        map (fun p => "(""" ++ fst p ++ """ -> """ ++ snd p ++ """)")
            brand_relation_brands in
    joinStrings ", " lines.

  Fixpoint emitGlobals (tenv: tdbindings) (fileArgCounter: nat) :=
    match tenv with
    | nil => ""
    | (name, Tlocal lt) :: rest => "LOCAL INPUT UNIMPLEMENTED"
    | (name, Tdistr elt) :: rest =>
      "val " ++ name ++ " = sparkSession.read.schema(" ++ rtype_to_spark_DataType (proj1_sig elt) ++ ").json(args(" ++ nat_to_string10 fileArgCounter ++ "))" ++ eol ++
      emitGlobals rest (fileArgCounter + 1)
    end.

Toplevel entry to Spark2/Scala codegen
  Definition dnnrc_typed_to_spark_df_top {A : Set} (tenv:tdbindings) (name: string)
             (e:dnnrc_typed) : string :=
    ""
      ++ "import org.apache.spark.SparkContext" ++ eol
      ++ "import org.apache.spark.sql.functions._" ++ eol
      ++ "import org.apache.spark.sql.SparkSession" ++ eol
      ++ "import org.apache.spark.sql.types._" ++ eol
      ++ "import org.qcert.QcertRuntime" ++ eol
      ++ "import org.qcert.QcertRuntime._" ++ eol

      ++ "object " ++ name ++ " {" ++ eol
      ++ "def main(args: Array[String]): Unit = {" ++ eol
      ++ "val HIERARCHY = QcertRuntime.makeHierarchy(" ++ initBrandHierarchy ++ ")" ++ eol
      ++ "val sparkContext = new SparkContext()" ++ eol
      ++ "org.apache.log4j.Logger.getRootLogger().setLevel(org.apache.log4j.Level.WARN)" ++ eol
      ++ "val sparkSession = SparkSession.builder().getOrCreate()" ++ eol
      ++ emitGlobals tenv 0
      ++ "import sparkSession.implicits._" ++ eol
      ++ "QcertRuntime.beforeQuery()" ++ eol
      ++ "println(QcertRuntime.toBlob(" ++ eol
      ++ scala_of_dnnrc_base e ++ eol
      ++ "))" ++ eol
      ++ "QcertRuntime.afterQuery()" ++ eol
      ++ "sparkContext.stop()" ++ eol
      ++ "}" ++ eol
      ++ "}" ++ eol
  .

End tDNNRCtoSparkDF.