Module NNRCMROptimizer



Section NNRCMROptimizer.
  Require Import String.
  Require Import List ListSet.
  Require Import Arith.
  Require Import Equivalence.
  Require Import Morphisms.
  Require Import Setoid.
  Require Import EquivDec.
  Require Import Program.

  Require Import Utils BasicSystem.
  Require Import cNNRCSystem.
  Require Import NNRCOptim.
  Require Import OptimizerLogger.
  Require Import OptimizerStep.

  Require Import NNRCMR NNRCMRRewrite.
  Require Import ForeignReduceOps.
  Definition trew_nnrcmr
             {fruntime:foreign_runtime} {fredop:foreign_reduce_op} {logger:optimizer_logger string nnrc}
             (l: nnrcmr) :=
    let inputs_loc := l.(mr_inputs_loc) in
    let chain :=
        List.map
          (fun mr =>
             let map :=
                 match mr.(mr_map) with
                 | MapDist (x, n) => MapDist (x, run_nnrc_optims_default n)
                 | MapDistFlatten (x, n) => MapDistFlatten (x, run_nnrc_optims_default n)
                 | MapScalar (x, n) => MapScalar (x, run_nnrc_optims_default n)
                 end
             in
             let reduce :=
                 match mr.(mr_reduce) with
                 | RedId => RedId
                 | RedCollect (x, n) => RedCollect (x, run_nnrc_optims_default n)
                 | RedOp op => RedOp op
                 | RedSingleton => RedSingleton
                 end
             in
             mkMR mr.(mr_input) mr.(mr_output) map reduce)
          l.(mr_chain)
    in
    let last :=
        let '((params, n), args) := l.(mr_last) in
        ((params, run_nnrc_optims_default n), args)
    in
    mkMRChain
      inputs_loc
      chain
      last.

  Definition run_nnrcmr_optims
             {fruntime:foreign_runtime} {fredop:foreign_reduce_op} {logger:optimizer_logger string nnrc}
             q :=
    let q := trew_nnrcmr (mr_optimize q) in
    trew_nnrcmr q.
  
End NNRCMROptimizer.