Module Qcert.Translation.Lang.tDNNRCtoSparkDF


Require Import String.

Require Import List.
Require Import Peano_dec.
Require Import EquivDec.
Require Import Utils.
Require Import DataSystem.
Require Import NNRCRuntime.
Require Import tDNNRCSystem.
Require Import ForeignToScala.
Require Import ForeignEJson.
Require Import ForeignDataToEJson.
Require Import SparkDFRuntime.
Require Import DatatoSparkDF.

Local Open Scope string_scope.

Section tDNNRCtoSparkDF.

  Section Helpers.
    Definition eol := (String.String (Ascii.ascii_of_nat 10) EmptyString).
  End Helpers.

  Context {fruntime:foreign_runtime}.
  Context {foreign_ejson_model:Set}.
  Context {fejson:foreign_ejson foreign_ejson_model}.
  Context {foreign_ejson_runtime_op : Set}.
  Context {fdatatoejson:foreign_to_ejson foreign_ejson_model foreign_ejson_runtime_op}.
  Context {h:brand_relation_t}.
  Context {ftype:foreign_type}.
  Context {m:brand_model}.
  Context {fts: ForeignToScala.foreign_to_scala}.

  Definition quote_string (s: string) : string :=
    """" ++ s ++ """".

Get code for a Spark DataType corresping to an rtype. * * These are things like StringType, ArrayType(...), StructType(Seq(FieldType(...), ...)), ... * * This function encodes details of our data representation, e.g. records are StructFields * with two toplevel fields $blob and $known, and the $known contains a record with the * statically known fields and their types.
  Fixpoint rtype_to_spark_DataType (r: rtype₀) : string :=
    match r with
    | Bottom₀ => "NullType"
    | Top₀ => "StringType"
    | Unit₀ => "NullType"
    | Nat₀ => "LongType"
    | Float₀ => "DoubleType"
    | Bool₀ => "BooleanType"
    | String₀ => "StringType"
    | Colle => "ArrayType(" ++ rtype_to_spark_DataType e ++ ")"
    | Rec_ fields =>
      let known_fields: list string :=
          map (fun p => "StructField(""" ++ fst p ++ """, " ++ rtype_to_spark_DataType (snd p) ++ ")")
              fields in
      let known_struct := "StructType(Seq(" ++ String.concat ", " known_fields ++ "))" in
      "StructType(Seq(StructField(""$blob"", StringType), StructField(""$known"", " ++ known_struct ++ ")))"
    | Eitherl r =>
      "StructType(Seq(StructField(""$left"", " ++ rtype_to_spark_DataType l ++ "), StructField(""$right"", " ++ rtype_to_spark_DataType r ++ ")))"
    | Brand_ =>
      "StructType(Seq(StructField(""$class"", ArrayType(StringType)), StructField(""$data"", StringType)))"
    | Arrow_ _ => "ARROW TYPE?"
    | Foreignft => foreign_to_scala_spark_datatype ft
    end.

Scala-level type of an rtype. * * These are things like Long, String, Boolean, Array..., Row. * * We need to annotate some expressions with Scala-level types * (e.g. ArrayRow() for an empty Array of Records) to help * the Scala compiler because it does not infer types everywhere.
  Fixpoint rtype_to_scala_type (t: rtype₀): string :=
    match t with
    | Bottom₀ => "BOTTOM?"
    | Top₀ => "TOP?"
    | Unit₀ => "Unit"
    | Nat₀ => "Long"
    | Float₀ => "Double"
    | Bool₀ => "Boolean"
    | String₀ => "String"
    | Collr => "Array[" ++ rtype_to_scala_type r ++ "]"
    | Rec_ _ => "Record"
    | Eithertl tr => "Either"
    | Brandbs => "BrandedValue"
    | Arrowtin t => "CANNOT PUT AN ARROW INTO A DATAFRAME"
    | Foreignf => "FOREIGN?"
    end.

  Definition drtype_to_scala (t: drtype): string :=
    match t with
    | Tlocal r => rtype_to_scala_type (proj1_sig r)
    | Tdistr r => "Dataset[" ++ rtype_to_scala_type (proj1_sig r) ++ "]"
    end.

  Fixpoint scala_literal_data (d: data) (t: rtype₀) {struct t}: string :=
    match t, d with
    | Unit₀, d => "()"
    | Nat₀, dnat i => Z_to_string10 i
    | Bool₀, dbool true => "true"
    | Bool₀, dbool false => "false"
    | String₀, dstring s => quote_string (s)
    | Collr, dcoll xs =>
      let element_type := rtype_to_scala_type r in
      let elements := map (fun x => scala_literal_data x r) xs in
      "Array[" ++ element_type ++ "](" ++ String.concat ", " elements ++ ")"
    | Rec_ fts, drec fds =>
      let blob := quote_string (data_to_blob d) in
      let known_schema :=
          "StructType(Seq("
            ++ String.concat ", "
            (map (fun ft =>
                    "StructField(""" ++ fst ft ++ """, " ++ rtype_to_spark_DataType (snd ft) ++ ")")
                 fts)
            ++ "))" in
      let fields := map (fun ft => match lookup string_dec fds (fst ft) with
                                   | Some d => scala_literal_data d (snd ft)
                                   | None => "FIELD_IN_TYPE_BUT_NOT_IN_DATA"
                                   end) fts in
      let known := "srow(" ++ String.concat ", " (known_schema :: fields) ++ ")" in
      "srow(StructType(Seq(StructField(""$blob"", StringType), StructField(""$known"", " ++ known_schema ++ "))), " ++ blob ++ ", " ++ known ++ ")"
    | _, _ => "UNIMPLEMENTED_SCALA_LITERAL_DATA"
    end.

  Fixpoint code_of_column (c: column) : string :=
    match c with
    | CCol s => "column(""" ++ s ++ """)"
    | CDot fld c => code_of_column c ++ ".getField(" ++ quote_string (fld) ++ ")"
    | CEq c1 c2 => code_of_column c1 ++ ".equalTo(" ++ code_of_column c2 ++ ")"
    | CLessThan c1 c2 => code_of_column c1 ++ ".lt(" ++ code_of_column c2 ++ ")"
    | CLit (d, r) => "lit(" ++ scala_literal_data d r ++ ")"
    | CNeg c => "not(" ++ code_of_column c ++ ")"
    | CPlus c1 c2 => code_of_column c1 ++ ".plus(" ++ code_of_column c2 ++ ")"
    | CSConcat c1 c2 =>
      "concat(" ++ code_of_column c1 ++ ", " ++ code_of_column c2 ++ ")"
    | CToString c =>
      "toQcertStringUDF(" ++ code_of_column c ++ ")"
    | CUDFCast bs c =>
      "castUDF(" ++ String.concat ", " ("INHERITANCE"%string :: map quote_string bs) ++ ")(" ++ code_of_column c ++ ")"
    | CUDFUnbrand t c =>
      "unbrandUDF(" ++ rtype_to_spark_DataType t ++ ")(" ++ code_of_column c ++ ")"
    end.

  Fixpoint code_of_dataframe (e: dataframe) : string :=
    match e with
    | DSVar s => s
    | DSSelect cs d =>
      let columns :=
          map (fun nc => code_of_column (snd nc) ++ ".as(""" ++ fst nc ++ """)") cs in
      code_of_dataframe d ++ ".select(" ++ String.concat ", " columns ++ ")"
    | DSFilter c d => code_of_dataframe d ++ ".filter(" ++ code_of_column c ++ ")"
    | DSCartesian d1 d2 => code_of_dataframe d1 ++ ".join(" ++ (code_of_dataframe d2) ++ ")"
    | DSExplode s d1 => code_of_dataframe d1 ++ ".select(explode(" ++ code_of_column (CCol s) ++ ").as(""" ++ s ++ """))"
    end.

  Definition spark_of_unary_op (op: unary_op) (x: string) : string :=
    match op with
      | OpCount => x ++ ".count()"
      | OpNatSum => x ++ ".select(sum(""value"")).first().getLong(0)"
      | OpFlatten => x ++ ".flatMap(r => r)"
      | _ => "SPARK_OF_UNARY_OP don't know how to generate Spark code for this operator"
    end.

  Definition scala_of_unary_op (required_type: drtype) (op: unary_op) (x: string) : string :=
    let prefix s := s ++ "(" ++ x ++ ")" in
    let postfix s := x ++ "." ++ s in
    match op with
    | OpIdentity => prefix "identity"
    | OpNeg => "(!" ++ x ++ ")"
    | OpRec n =>
      match lift_tlocal required_type with
      | Some (exist _ (RecClosed ((_, ft)::nil)) _) =>
        "singletonRecord(" ++ quote_string (n) ++ ", " ++ rtype_to_spark_DataType ft ++ ", " ++ x ++ ")"
      | _ => "AREC_WITH_UNEXPECTED_REQUIRED_TYPE"
      end
    | OpDot n =>
      match lift_tlocal required_type with
      | Some r =>
        prefix ("dot[" ++ rtype_to_scala_type (proj1_sig r) ++ "](""" ++ n ++ """)")
      | None => "NONLOCAL EXPECTED TYPE IN DOT"
      end
    | OpRecProject fs => prefix ("recProject(" ++ String.concat ", " (map quote_string fs) ++ ")")
    | OpBag => prefix "Array"
    | OpFlatten => postfix "flatten.sorted(QcertOrdering)"
    | OpDistinct => postfix "distinct"
    | OpCount => postfix "length"
    | OpToText
    | OpToString => prefix "toQcertString"
    | OpLength => "(" ++ x ++ ").length()"
    | OpSubstring start olen =>
      "(" ++ x ++ ").substring(" ++ toString start ++
          match olen with
          | Some len => ", " ++ toString len
          | None => ""
          end ++ ")"
    | OpLike pat =>
      "ALike currently implemented. Please implement as in the java backend"
    | OpLeft => prefix "left"
    | OpRight => prefix "right"
    | OpBrand bs => "brand(" ++ String.concat ", " (x::bs) ++ ")"
    | OpCast bs =>
      "cast(INHERITANCE, " ++ x ++ ", " ++ String.concat ", " bs ++ ")"
    | OpUnbrand =>
      match lift_tlocal required_type with
      | Some (exist _ r _) =>
        let schema := rtype_to_spark_DataType r in
        let scala := rtype_to_scala_type r in
        "unbrand[" ++ scala ++ "](" ++ schema ++ ", " ++ x ++ ")"
      | None => "UNBRAND_REQUIRED_TYPE_ISSUE"
      end
    | OpNatUnary NatAbs => prefix "Math.abs"
    | OpNatSum => postfix "sum"
    | OpNatMax => prefix "anummax"
    | OpNatMin => prefix "anummin"
    | OpNatMean => prefix "arithMean"
    | OpForeignUnary o => foreign_to_scala_unary_op o x
    | OpFloatSum => postfix "sum"

    | OpFloatOfNat => prefix "FLOAT OF NAT??"
    | OpFloatUnary _ => prefix "FLOAT ARITH??"
    | OpFloatTruncate => prefix "TRUNCATE??"
    | OpFloatBagMax => prefix "MAX??"
    | OpFloatBagMin => prefix "MIN??"
    | OpFloatMean => prefix "MEAN??"
    | OpOrderBy scl => "SORT???"
    | OpRecRemove _ => "ARECREMOVE???"
    | OpSingleton => "SINGLETON???"
    | OpNatUnary NatLog2 => "LOG2???"
    | OpNatUnary NatSqrt => "SQRT???"
    end.

  Definition spark_of_binary_op (op: binary_op) (x: string) (y: string) : string :=
    match op with
    | OpBagUnion => x ++ ".union(" ++ y ++ ")"
    | _ => "SPARK_OF_BINARY_OP don't know how to generate Spark code for this operator"
    end.

  Definition scala_of_binary_op (op: binary_op) (l: string) (r: string) : string :=
    let infix s := l ++ "." ++ s ++ "(" ++ r ++ ")" in
    let infix' s := "(" ++ l ++ " " ++ s ++ " " ++ r ++ ")" in
    let prefix s := s ++ "(" ++ l ++ ", " ++ r ++ ")" in
    match op with
    | OpEqual => prefix "equal"
    | OpRecConcat => prefix "recordConcat"
    | OpRecMerge => prefix "mergeConcat"
    | OpAnd => infix "&&"
    | OpOr => infix "||"
    | OpLe => prefix "lessOrEqual"
    | OpLt => prefix "lessThan"
    | OpBagUnion => infix "++" ++ ".sorted(QcertOrdering)"
    | OpBagDiff => r ++ ".diff(" ++ l ++ ")"
    | OpBagMax => l ++ ".++(" ++ r ++ ".diff(" ++ l ++ "))"
    | OpBagMin => l ++ ".diff(" ++ l ++ ".diff(" ++ r ++ "))"
    | OpBagNth => prefix "bagNth"
    | OpContains => prefix "AContains"
    | OpStringConcat => infix "+"
    | OpStringJoin => prefix "AStringJoin"
    | OpNatBinary NatDiv => infix "/"
    | OpNatBinary NatMax => infix "max"
    | OpNatBinary NatMin => infix "min"
    | OpNatBinary NatMinus => infix "-"
    | OpNatBinary NatMult => infix "*"
    | OpNatBinary NatPlus => infix "+"
    | OpNatBinary NatRem => infix "%"
    | OpFloatBinary FloatDiv => infix "/"
    | OpFloatBinary FloatMax => infix "max"
    | OpFloatBinary FloatMin => infix "min"
    | OpFloatBinary FloatMinus => infix "-"
    | OpFloatBinary FloatMult => infix "*"
    | OpFloatBinary FloatPlus => infix "+"
    | OpFloatBinary FloatRem => infix "%"
    | OpFloatCompare FloatLt => infix "<"
    | OpFloatCompare FloatLe => infix "<="
    | OpFloatCompare FloatGt => infix ">"
    | OpFloatCompare FloatGe => infix ">="

    | OpForeignBinary op => "FOREIGNBINARYOP???"
    end.

  Definition primitive_type (t: rtype) :=
    match proj1_sig t with
    | ⊥₀ | ⊤₀ | Unit₀ | Nat₀ | Float₀ | String₀ | Bool₀ => true
    | Coll_ | Rec_ _ | Either_ _ | Arrow_ _ | Brand_ => false
    | Foreign_ => false
    end.

  Fixpoint scala_of_dnnrc_base {A: Set} (d:@dnnrc_base _ (type_annotation A) dataframe) : string :=
    let code :=
        match d with
        | DNNRCGetConstant t n => n
        | DNNRCVar t n => n
        | DNNRCConst t c =>
          match (lift_tlocal (di_required_typeof d)) with
          | Some r => scala_literal_data c (proj1_sig r)
          | None => "Don't know how to construct a distributed constant"
          end
        | DNNRCBinop t op x y =>
          match di_typeof x, di_typeof y with
          | Tlocal _, Tlocal _ => scala_of_binary_op op (scala_of_dnnrc_base x) (scala_of_dnnrc_base y)
          | Tdistr _, Tdistr _ => spark_of_binary_op op (scala_of_dnnrc_base x) (scala_of_dnnrc_base y)
          | _, _ => "DONT_SUPPORT_MIXED_LOCAL_DISTRIBUTED_BINARY_OPERATORS"
          end
        | DNNRCUnop t op x =>
          match di_typeof x with
          | Tlocal _ => scala_of_unary_op (di_required_typeof d) op (scala_of_dnnrc_base x)
          | Tdistr _ => spark_of_unary_op op (scala_of_dnnrc_base x)
          end
        | DNNRCLet t n x y =>
          "((( " ++ n ++ ": " ++ drtype_to_scala (di_typeof x) ++ ") => " ++ scala_of_dnnrc_base y ++ ")(" ++ scala_of_dnnrc_base x ++ "))"
        | DNNRCFor t n x y =>
          scala_of_dnnrc_base x ++ ".map((" ++ n ++ ") => { " ++ scala_of_dnnrc_base y ++ " })"
        | DNNRCIf t x y z =>
          "(if (" ++ scala_of_dnnrc_base x ++ ") " ++ scala_of_dnnrc_base y ++ " else " ++ scala_of_dnnrc_base z ++ ")"
        | DNNRCEither t x vy y vz z =>
          match olift tuneither (lift_tlocal (di_required_typeof x)) with
          | Some (lt, rt) =>
            "either(" ++ scala_of_dnnrc_base x ++ ", (" ++
                      vy ++ ": " ++ rtype_to_scala_type (proj1_sig lt) ++
                      ") => { " ++ scala_of_dnnrc_base y ++ " }, (" ++
                      vz ++ ": " ++ rtype_to_scala_type (proj1_sig rt) ++
                      ") => { " ++ scala_of_dnnrc_base z ++ " })"
          | None => "DNNRCEither's first argument is not of type Either"
          end
        | DNNRCGroupBy t g sl x =>
          "DNNRC_GROUPBY_CODEGEN_IS_NOT_CURRENTLY_IMPLEMENTED"
        | DNNRCCollect t x =>
          let postfix :=
              match olift tuncoll (lift_tdistr (di_typeof x)) with
              | Some rt =>
                match proj1_sig rt with
                | Nat₀ => ".map((row) => row.getLong(0))"
                | Foreign_ => ".map((row) => row.getFloat(0))"
                | _ => ""
                end
              | None => "ARGUMENT_TO_COLLECT_SHOULD_BE_A_DISTRIBUTED_COLLECTION"
              end in
          scala_of_dnnrc_base x ++ ".collect()" ++ postfix
        | DNNRCDispatch t x =>
          match olift tuncoll (lift_tlocal (di_typeof x)) with
          | Some et =>
            "dispatch(" ++ rtype_to_spark_DataType (proj1_sig et) ++ ", " ++ scala_of_dnnrc_base x ++ ")"
          | None => "Argument to dispatch is not a local collection."
          end
        | DNNRCAlg t a ((x, d)::nil) =>
          "{ val " ++ x ++ " = " ++ scala_of_dnnrc_base d ++ "; " ++
                   code_of_dataframe a
                   ++ " }"
        | DNNRCAlg _ _ _ =>
          "NON_UNARY_ALG_CODEGEN_IS_NOT_CURRENTLY_IMPLEMENTED"
        end in
    if di_typeof d == di_required_typeof d then code else
      match lift_tlocal (di_required_typeof d) with
      | Some r => "identity/*CAST*/(" ++ code ++ ")"
      | None => "CANTCASTTODISTRIBUTEDTYPE"
      end.

  Definition initBrandInheritance : string :=
    let lines :=
        map (fun p => "(""" ++ fst p ++ """ -> """ ++ snd p ++ """)")
            brand_relation_brands in
    String.concat ", " lines.

  Fixpoint emitGlobals (tenv: tdbindings) (fileArgCounter: nat) : string :=
    match tenv with
    | nil => ""
    | (name, Tlocal lt) :: rest => "LOCAL INPUT UNIMPLEMENTED"
    | (name, Tdistr elt) :: rest =>
      "val " ++ name ++ " = sparkSession.read.schema(" ++ rtype_to_spark_DataType (proj1_sig elt) ++ ").json(args(" ++ nat_to_string10 fileArgCounter ++ "))" ++ eol ++
      emitGlobals rest (fileArgCounter + 1)
    end.

Toplevel entry to Spark2/Scala codegen
  Definition dnnrc_typed_to_spark_df_top {A : Set} (tenv:tdbindings) (name: string)
             (e:dnnrc_typed) : spark_df :=
    ""
      ++ "import org.apache.spark.SparkContext" ++ eol
      ++ "import org.apache.spark.sql.functions._" ++ eol
      ++ "import org.apache.spark.sql.SparkSession" ++ eol
      ++ "import org.apache.spark.sql.types._" ++ eol
      ++ "import org.qcert.QcertRuntime" ++ eol
      ++ "import org.qcert.QcertRuntime._" ++ eol

      ++ "object " ++ name ++ " {" ++ eol
      ++ "def main(args: Array[String]): Unit = {" ++ eol
      ++ "val INHERITANCE = QcertRuntime.makeInheritance(" ++ initBrandInheritance ++ ")" ++ eol
      ++ "val sparkContext = new SparkContext()" ++ eol
      ++ "org.apache.log4j.Logger.getRootLogger().setLevel(org.apache.log4j.Level.WARN)" ++ eol
      ++ "val sparkSession = SparkSession.builder().getOrCreate()" ++ eol
      ++ emitGlobals tenv 0
      ++ "import sparkSession.implicits._" ++ eol
      ++ "QcertRuntime.beforeQuery()" ++ eol
      ++ "println(QcertRuntime.toBlob(" ++ eol
      ++ scala_of_dnnrc_base e ++ eol
      ++ "))" ++ eol
      ++ "QcertRuntime.afterQuery()" ++ eol
      ++ "sparkContext.stop()" ++ eol
      ++ "}" ++ eol
      ++ "}" ++ eol
  .

End tDNNRCtoSparkDF.