// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/ir_adaptor/translator/op_translator.h"

#include <algorithm>
#include <cctype>
#include <numeric>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h"
#include "paddle/fluid/ir_adaptor/translator/op_compat_info.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value.h"

// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h"

namespace paddle {
namespace translator {

namespace {

using IdxInOp = size_t;
using IdxInVector = size_t;
using ResultIdx = std::tuple<IdxInOp, IdxInVector>;
using OpDesc = paddle::framework::OpDesc;
using BlockDesc = paddle::framework::BlockDesc;
using VarDesc = paddle::framework::VarDesc;
using OpOutputTypeList = std::vector<ir::Type>;
using OpOutputMapping = std::unordered_map<std::string, ResultIdx>;
using OpInputInfo = dialect::OpInputInfo;
using OpInputInfoList = std::vector<dialect::OpInputInfo>;
using OpAttributeInfo = dialect::OpAttributeInfo;
using OpAttributeInfoList = std::vector<dialect::OpAttributeInfo>;
using OpOutputInfo = dialect::OpOutputInfo;
using OpOutputInfoList = std::vector<dialect::OpOutputInfo>;
using InputHandlerFn = std::function<ir::OpResult(ir::IrContext*,
                                                  TranslationContext*,
                                                  const OpDesc&,
                                                  const std::string&,
                                                  const OpInputInfo&,
                                                  ir::Program*)>;
using AttributeHandlerFn = std::function<ir::Attribute(
    ir::IrContext*, const OpDesc&, const OpAttributeInfo&)>;
constexpr char kTargetDialectPrefix[] = "pd.";  // NOLINT
constexpr char kEmptyVarName[] = "@EMPTY@";     // NOLINT

static const std::unordered_set<std::string> SpecialNonInplaceOps = {};

static const std::unordered_set<std::string> SpecialInplaceOps = {
    "adagrad",
    "adam",
    "adamax",
    "adamw",
};

inline bool IsInplace(const OpDesc& op_desc) {
  if (SpecialNonInplaceOps.count(op_desc.Type())) {
    return false;
  }
  if (SpecialInplaceOps.count(op_desc.Type())) {
    return true;
  }
  bool inplace = false;
  auto input_names = op_desc.InputArgumentNames();
  auto output_names = op_desc.OutputArgumentNames();
  if (input_names.empty() || output_names.empty()) {
    return inplace;
  }

  std::vector<std::string> name_intersection;
  std::sort(input_names.begin(), input_names.end());
  std::sort(output_names.begin(), output_names.end());
  std::set_intersection(input_names.begin(),
                        input_names.end(),
                        output_names.begin(),
                        output_names.end(),
                        std::back_inserter(name_intersection));

  if (!name_intersection.empty()) {
    std::string redundant_variables = std::accumulate(
        std::next(name_intersection.begin()),
        name_intersection.end(),
        name_intersection[0],
        [](std::string a, std::string b) { return a + "," + b; });
    VLOG(4) << "Following variables occur both in inputs and outputs: "
            << redundant_variables;
    return true;
  }

  return inplace;
}

inline std::string OpNameCompatibleMapping(std::string op_name) {
  auto& op_normalizer = OpNameNormalizer::instance();
  return op_normalizer[op_name];
}

inline ir::Operation* InsertCombineOperationForTarget(
    ir::IrContext* ctx,
    TranslationContext* param_map,
    ir::Program* program,
    const std::vector<std::string>& args) {
  std::string combine_op_name(ir::CombineOp::name());
  ir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name);

  std::vector<ir::OpResult> src_values;
  std::vector<ir::Type> types_in_vec;
  for (const auto& arg_name : args) {
    auto defining_info = param_map->at(arg_name);
    src_values.push_back(defining_info.value);
    types_in_vec.push_back(defining_info.value.type());
  }
  ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec);
  ir::Operation* operation =
      ir::Operation::Create(src_values, {}, {target_vec_type}, op_info);
  program->block()->push_back(operation);
  return operation;
}

inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx,
                                                           ir::Program* program,
                                                           ir::Attribute attr) {
  float data = 0.0f;
  phi::DataType dtype = phi::DataType::UNDEFINED;
  if (attr.isa<ir::FloatAttribute>()) {
    data = attr.dyn_cast<ir::FloatAttribute>().data();
    dtype = phi::DataType::FLOAT32;
  } else if (attr.isa<ir::DoubleAttribute>()) {
    data = static_cast<float>(attr.dyn_cast<ir::DoubleAttribute>().data());
    dtype = phi::DataType::FLOAT64;
  } else if (attr.isa<ir::Int32Attribute>()) {
    data = static_cast<float>(attr.dyn_cast<ir::Int32Attribute>().data());
    dtype = phi::DataType::INT32;
  } else if (attr.isa<ir::Int64Attribute>()) {
    data = static_cast<float>(attr.dyn_cast<ir::Int64Attribute>().data());
    dtype = phi::DataType::INT64;
  } else if (attr.isa<ir::BoolAttribute>()) {
    data = static_cast<float>(attr.dyn_cast<ir::BoolAttribute>().data());
    dtype = phi::DataType::BOOL;
  } else if (attr.isa<dialect::ScalarAttribute>()) {
    // TODO(phlrain) : need update here, downcast from double to float
    data = static_cast<float>(
        attr.dyn_cast<dialect::ScalarAttribute>().data().to<double>());
    dtype = phi::DataType::FLOAT64;
  }
  ir::Builder builder(ctx, program->block());
  dialect::FullOp full_op = builder.Build<dialect::FullOp>(
      std::vector<int64_t>{1}, data, dtype, phi::CPUPlace());

  return full_op.operation();
}

inline ir::Operation* InsertFullArrayOperationForAttributeInput(
    ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) {
  IR_ENFORCE(attr.isa<dialect::IntArrayAttribute>(),
             "Encounter non IntArray type when trying to insert IntArray "
             "mutable attribute");

  phi::IntArray int_array = attr.dyn_cast<dialect::IntArrayAttribute>().data();

  ir::Builder builder(ctx, program->block());
  dialect::FullIntArrayOp full_int_array_op =
      builder.Build<dialect::FullIntArrayOp>(
          int_array.GetData(), phi::DataType::INT64, phi::CPUPlace());
  return full_int_array_op.operation();
}

inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
                                        ir::Program* program,
                                        const OpDesc& op_desc,
                                        const OpInputInfo& input_info) {
  auto& attribute_translator = AttributeTranslator::instance();
  auto& op_normalizer = OpNameNormalizer::instance();

  auto legacy_attr_name =
      op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name);

  if (!op_desc.HasAttr(legacy_attr_name)) {
    IR_THROW("Op %s arg %s should not be zero size",
             op_desc.Type(),
             legacy_attr_name);
  }
  paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name);
  VLOG(10) << "[" << op_desc.Type() << "][attribute]"
           << " name: " << legacy_attr_name << " " << legacy_attr.index();
  ir::Attribute new_attr =
      attribute_translator(input_info.type_name, legacy_attr);

  ir::Operation* defining_op = nullptr;
  bool is_int_array = (input_info.type_name.find("IntArrayAttribute") !=
                       input_info.type_name.npos);
  if (is_int_array) {
    defining_op =
        InsertFullArrayOperationForAttributeInput(ctx, program, new_attr);
  } else {
    defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr);
  }

  return defining_op->result(0);
}

}  // namespace

/// @brief This class is used to translate a OpDesc, it's a functor class and
/// should have no non-static data member, since we expected it's stateless.
struct OpTranscriber {
 public:
  virtual ~OpTranscriber() = default;

 public:
  virtual ir::Operation* operator()(ir::IrContext* ctx,
                                    TranslationContext* param_map,
                                    const OpDesc& op_desc,
                                    ir::Program* program);

 public:
  virtual ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc);
  virtual std::vector<ir::OpResult> GenerateOperationInput(
      ir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      ir::Program* program);
  virtual std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
      ir::IrContext* ctx,
      const OpDesc& op_desc,
      const OpOutputInfoList& output_infos);
  virtual void HandleNonexistentAttribute(ir::IrContext*,
                                          ir::AttributeMap* attribute_map,
                                          const OpAttributeInfo& info) {
    auto& attribute_translator = AttributeTranslator::instance();
    (*attribute_map)[info.name] =
        attribute_translator(info.type_name, paddle::framework::Attribute());
  }
  virtual ir::AttributeMap TranslateOpAttribute(
      ir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc);

  virtual void RecordOpResultMapping(ir::IrContext* ctx,
                                     TranslationContext* param_map,
                                     const OpDesc& op_desc,
                                     ir::Operation* operation,
                                     const OpOutputMapping& arg_to_idx);

 public:
  virtual InputHandlerFn GetSpecialInputHandlers(
      const std::string& input_name) {
    return nullptr;
  }
  virtual AttributeHandlerFn GetSpecialAttributeHandlers(
      const std::string& input_name) {
    return nullptr;
  }
  virtual void InsertSliceOperationForInput(ir::IrContext* ctx,
                                            TranslationContext* param_map,
                                            const OpDesc& op_desc,
                                            const OpInputInfoList& input_infos,
                                            ir::Program* program);
};

ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx,
                                        const OpDesc& op_desc) {
  std::string target_op_name =
      kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type());
  if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') {
    target_op_name += "_";
  }
  VLOG(6) << "[op name normalizing]: " << op_desc.Type() << " to "
          << target_op_name;
  auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
  if (!op_info) {
    IR_THROW("Op %d should have corresponding OpInfo %d",
             op_desc.Type(),
             target_op_name);
  }

  return op_info;
}

void OpTranscriber::InsertSliceOperationForInput(
    ir::IrContext* ctx,
    TranslationContext* param_map,
    const OpDesc& op_desc,
    const OpInputInfoList& input_infos,
    ir::Program* program) {
  auto& op_normalizer = OpNameNormalizer::instance();
  std::set<std::string> yaml_input_set;
  for (const auto& info : input_infos) {
    std::string legacy_input_name =
        op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
    yaml_input_set.insert(legacy_input_name);
  }

  // scan all inputs to see if any of them is generated as a vector<Tensor>
  // so need an additional `SliceOp` to take it out.
  for (const auto& n : op_desc.Inputs()) {
    auto& args = n.second;

    for (const auto& arg_name : args) {
      bool check =
          param_map->count(arg_name) != 0 && !yaml_input_set.count(arg_name);
      if (!check) {
        continue;
      }
      auto defining_info = param_map->at(arg_name);
      if (defining_info.generated_by_vector) {
        InsertSliceOperationForTarget(
            ctx, param_map, program, defining_info, arg_name);
        VLOG(8) << "[op:" << op_desc.Type()
                << "] insert slice for var: " << arg_name;
      }
    }
  }
}

std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
    ir::IrContext* ctx,
    TranslationContext* param_map,
    const OpDesc& op_desc,
    const std::string& normalized_op_name,
    const OpInputInfoList& input_infos,
    ir::Program* program) {
  VLOG(10) << "[op:" << op_desc.Type() << "][input] entrance";

  auto& op_normalizer = OpNameNormalizer::instance();
  const auto* mutable_attributes =
      op_normalizer.GetMutableAttributes(op_desc.Type());

  VLOG(10) << "[op:" << op_desc.Type() << "][input] start";

  std::vector<ir::OpResult> op_inputs;

  for (const auto& info : input_infos) {
    if (auto special_handler = this->GetSpecialInputHandlers(info.name)) {
      ir::OpResult ret = special_handler(
          ctx, param_map, op_desc, normalized_op_name, info, program);
      op_inputs.push_back(ret);
      continue;
    }

    std::string legacy_input_name =
        op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);

    VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
             << legacy_input_name;

    std::vector<std::string> legacy_input_vars;
    // return empty OpResult if this arg is optional and not shown in OpDesc
    if (op_desc.HasInput(legacy_input_name, true)) {
      legacy_input_vars = op_desc.Input(legacy_input_name, true);
    }

    if (legacy_input_vars.empty()) {
      if (info.optional) {
        op_inputs.emplace_back(nullptr);
        continue;
      }
    }
    VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
             << legacy_input_name << " " << legacy_input_vars.size() << "["
             << legacy_input_vars << "]";

    if (legacy_input_vars.empty() && mutable_attributes != nullptr &&
        mutable_attributes->count(info.name) != 0) {
      const auto& candidate_var_names =
          op_normalizer.GetMutableAttributeInfos(op_desc.Type(), info.name);
      bool found_candidate_var = false;
      for (const auto& var_name : candidate_var_names) {
        VLOG(10) << "[handle mutable attribute][" << info.name << "]["
                 << var_name << "]";
        if (op_desc.HasInput(var_name)) {
          legacy_input_vars = op_desc.Input(var_name, true);
          if (legacy_input_vars.empty()) continue;
          found_candidate_var = true;
          break;
        }
      }

      if (!found_candidate_var) {
        auto attribute_input = GetAttributeAsInput(ctx, program, op_desc, info);
        op_inputs.push_back(attribute_input);
        continue;
      }
    }

    bool is_vector = (info.type_name.find("VectorType") != std::string::npos);
    is_vector |=
        (info.type_name.find("IntArrayAttribute") != std::string::npos);
    VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
             << is_vector << " " << info.type_name;
    // Specially process TensorArray, this because we cannot distinguish it with
    // Vector<DenseTensor> by other conditions but we cannot support it like
    // Vector<DenseTensor>
    if (legacy_input_vars.size() == 1) {
      VarDesc* var = op_desc.Block()->FindVarRecursive(legacy_input_vars[0]);
      if (var->GetType() ==
          paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) {
        is_vector = false;
      }
    }

    // if src type is Tensor
    if (!is_vector) {
      IR_ENFORCE(legacy_input_vars.size() == 1u,
                 "Input %s not found when parsing op %s",
                 info.name,
                 op_desc.Type());
      auto defining_info = (*param_map)[legacy_input_vars[0]];
      op_inputs.push_back(defining_info.value);

      // if src type is Vector<Tesnor> , need an additional `CombineOp` to
      // assemble them.
    } else {
      auto* combine_op = InsertCombineOperationForTarget(
          ctx, param_map, program, legacy_input_vars);
      op_inputs.push_back(combine_op->result(0));
    }
  }

  return op_inputs;
}

std::tuple<OpOutputTypeList, OpOutputMapping>
OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx,
                                       const OpDesc& op_desc,
                                       const OpOutputInfoList& output_infos) {
  OpOutputMapping arg_to_idx;
  OpOutputTypeList op_output_types = {};

  auto& type_translator = TypeTranslator::instance();
  auto& op_normalizer = OpNameNormalizer::instance();

  const BlockDesc* block = op_desc.Block();

  for (const auto& info : output_infos) {
    size_t cur_output_idx = op_output_types.size();
    std::string legacy_output_name =
        op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);

    VLOG(10) << "[op:" << op_desc.Type() << "][output]" << info.name << " "
             << legacy_output_name;

    // return empty type if this arg is optional and not shown in OpDesc
    if (!op_desc.HasOutput(legacy_output_name)) {
      VLOG(10) << "[output translating]"
               << "[" << op_desc.Type() << "] optional " << info.name << " :"
               << info.type_name << " " << legacy_output_name;
      IR_ENFORCE(info.optional,
                 "Op %s arg %s should be optional if it can be empty",
                 op_desc.Type(),
                 legacy_output_name);
      op_output_types.emplace_back(nullptr);
      continue;
    }

    const auto& legacy_output_vars = op_desc.Output(legacy_output_name);
    bool is_vector = (info.type_name.find("VectorType") != std::string::npos);

    VLOG(10) << "[op:" << op_desc.Type() << "][output]" << info.name << " "
             << legacy_output_name << " " << legacy_output_vars.size() << " "
             << is_vector;

    // Specially process TensorArray, this because we cannot distinguish it with
    // Vector<DenseTensor> by other conditions but we cannot support it like
    // Vector<DenseTensor>
    if (legacy_output_vars.size() == 1) {
      VarDesc* var = block->FindVarRecursive(legacy_output_vars[0]);
      IR_ENFORCE(var != nullptr,
                 "[op:%s] Output %s should not be null",
                 op_desc.Type(),
                 legacy_output_vars[0]);
      if (var->GetType() ==
          paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) {
        ir::Type translated_var_type =
            type_translator[var->GetType()](ctx, *var);
        op_output_types.push_back(translated_var_type);
        arg_to_idx[var->Name()] = {cur_output_idx, 0};
        continue;
      }
    }

    // if src type is Tensor
    if (!is_vector) {
      VLOG(10) << "[output translating]"
               << "[" << op_desc.Type() << "]" << info.name << " :"
               << info.type_name << " " << legacy_output_name << " "
               << legacy_output_vars.size();
      if (legacy_output_vars.empty()) {
        op_output_types.emplace_back(nullptr);
        continue;
      }

      auto& var_name = legacy_output_vars[0];
      VarDesc* var = block->FindVarRecursive(var_name);
      IR_ENFORCE(var != nullptr,
                 "[op:%s] Output %s should not be null",
                 op_desc.Type(),
                 var_name);
      VLOG(10) << "[output translating]"
               << "[" << op_desc.Type() << "]" << info.name
               << " var: " << var_name << " type: " << var->GetType();

      ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);

      arg_to_idx[var_name] = {cur_output_idx, 0};
      op_output_types.push_back(translated_var_type);

      // if src type is Vector<Tesnor>
    } else {
      VLOG(10) << "[output translating]"
               << "[" << op_desc.Type() << "]" << info.name << " :"
               << info.type_name << " var: " << legacy_output_name;
      std::vector<ir::Type> types;
      for (IdxInVector idx_in_vec = 0; idx_in_vec < legacy_output_vars.size();
           idx_in_vec++) {
        const auto& var_name = legacy_output_vars[idx_in_vec];
        if (var_name == kEmptyVarName) {
          types.emplace_back(nullptr);
          arg_to_idx[var_name] = {cur_output_idx, idx_in_vec};
          continue;
        }
        VarDesc* var = block->FindVarRecursive(var_name);
        VLOG(10) << "[output translating]"
                 << "[" << op_desc.Type() << "]" << info.name
                 << " var: " << var_name << " type: " << var->GetType();
        ir::Type translated_var_type =
            type_translator[var->GetType()](ctx, *var);
        types.push_back(translated_var_type);
        arg_to_idx[var_name] = {cur_output_idx, idx_in_vec};
      }
      ir::Type vec_type = ir::VectorType::get(ctx, types);
      op_output_types.push_back(vec_type);
    }
  }
  return {op_output_types, arg_to_idx};
}

ir::AttributeMap OpTranscriber::TranslateOpAttribute(
    ir::IrContext* ctx,
    const std::string& normalized_op_name,
    const OpAttributeInfoList& op_attr_infos,
    const OpDesc& op_desc) {
  auto& attribute_translator = AttributeTranslator::instance();
  auto& op_normalizer = OpNameNormalizer::instance();
  ir::AttributeMap attribute_map = {};

  for (const auto& info : op_attr_infos) {
    if (auto handler = this->GetSpecialAttributeHandlers(info.name)) {
      auto new_attr = handler(ctx, op_desc, info);
      attribute_map[info.name] = new_attr;
      continue;
    }

    auto legacy_attr_name =
        op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
    VLOG(10) << "[op: " << op_desc.Type()
             << "][attr] from: " << legacy_attr_name << " to: " << info.name;
    if (op_desc.HasAttr(legacy_attr_name)) {
      paddle::framework::Attribute legacy_attr =
          op_desc.GetAttr(legacy_attr_name);
      VLOG(10) << "attribute in " << op_desc.Type()
               << " name: " << legacy_attr_name << " " << legacy_attr.index();
      ir::Attribute new_attr =
          attribute_translator(info.type_name, legacy_attr);
      attribute_map[info.name] = new_attr;
      if (!new_attr) {
        VLOG(0) << "empty attribute in " << op_desc.Type()
                << " name: " << info.name;
      }
    } else {
      VLOG(10) << "attribute in " << op_desc.Type()
               << " name: " << legacy_attr_name << " doesn't exist";
      this->HandleNonexistentAttribute(ctx, &attribute_map, info);
    }
  }

  return attribute_map;
}

void OpTranscriber::RecordOpResultMapping(ir::IrContext* ctx,
                                          TranslationContext* param_map,
                                          const OpDesc& op_desc,
                                          ir::Operation* operation,
                                          const OpOutputMapping& arg_to_idx) {
  for (const auto& [arg_name, idx] : arg_to_idx) {
    const auto& [idx_in_op, idx_in_vec] = idx;
    VLOG(10) << "[output recording]"
             << "[" << op_desc.Type() << "]" << arg_name << " " << idx_in_op
             << " " << idx_in_vec;
    ir::OpResult value = operation->result(idx_in_op);
    bool generated_by_vector = value.type().isa<ir::VectorType>();

    (*param_map)[arg_name] = VariableDefiningInfo(
        value, generated_by_vector, generated_by_vector ? idx_in_vec : -1);
  }
}

ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx,
                                         TranslationContext* param_map,
                                         const OpDesc& op_desc,
                                         ir::Program* program) {
  auto op_info = this->LoopkUpOpInfo(ctx, op_desc);
  auto* op_info_concept =
      op_info.GetInterfaceImpl<dialect::OpYamlInfoInterface>();

  OpInputInfoList input_infos;
  OpAttributeInfoList attr_infos;
  OpOutputInfoList output_infos;
  std::tie(input_infos, attr_infos, output_infos, std::ignore) =
      op_info_concept->get_op_info_();

  this->InsertSliceOperationForInput(
      ctx, param_map, op_desc, input_infos, program);

  auto op_inputs = this->GenerateOperationInput(
      ctx, param_map, op_desc, op_info.name(), input_infos, program);

  OpOutputMapping arg_to_idx;
  OpOutputTypeList op_output_types;
  std::tie(op_output_types, arg_to_idx) =
      this->GenerateOperationOutput(ctx, op_desc, output_infos);

  auto attribute_map =
      this->TranslateOpAttribute(ctx, op_info.name(), attr_infos, op_desc);
  VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end.";

  ir::Operation* operation =
      ir::Operation::Create(op_inputs, attribute_map, op_output_types, op_info);
  VLOG(4) << "[general op][" << op_desc.Type() << "] opearation creation end.";
  program->block()->push_back(operation);

  VLOG(4) << "[general op][" << op_desc.Type() << "] opearation insertion end.";
  this->RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx);

  return operation;
}

struct CastOpTranscriber : public OpTranscriber {
  ir::AttributeMap TranslateOpAttribute(
      ir::IrContext*,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    auto& attribute_translator = AttributeTranslator::instance();
    ir::AttributeMap attribute_map = {};
    const OpAttributeInfo info = op_attr_infos[0];

    std::string legacy_attr_name("out_dtype");

    paddle::framework::Attribute legacy_attr;
    if (op_desc.HasAttr(legacy_attr_name)) {
      legacy_attr = op_desc.GetAttr(legacy_attr_name);
    }
    VLOG(10) << "attribute in " << op_desc.Type()
             << " name: " << legacy_attr_name << " " << legacy_attr.index();
    ir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr);
    attribute_map[info.name] = new_attr;

    return attribute_map;
  }
};

struct EmbeddingOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(ir::IrContext* ctx,
                                  ir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "padding_idx") {
      (*attribute_map)[info.name] = ir::Int64Attribute::get(ctx, -1);
    } else if (info.name == "sparse") {
      (*attribute_map)[info.name] = ir::BoolAttribute::get(ctx, false);
    }
  }
};

struct IncrementOpTranscriber : public OpTranscriber {
  ir::AttributeMap TranslateOpAttribute(
      ir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    auto& attribute_translator = AttributeTranslator::instance();
    ir::AttributeMap attribute_map = {};

    paddle::framework::Attribute legacy_attr;
    if (op_desc.HasAttr("step")) {
      legacy_attr = op_desc.GetAttr("step");
      VLOG(10) << "attribute in " << op_desc.Type() << " step: "
               << " " << legacy_attr.index();
      ir::Attribute new_attr = attribute_translator(legacy_attr);
      attribute_map["value"] = new_attr;
    } else {
      attribute_map["value"] = ir::FloatAttribute::get(ctx, 1.0f);
    }

    return attribute_map;
  }
};

// The `assign_value` in static_ops.yaml is different from the one in
// `legacy_ops.yaml`. For this op we simulate the logic in
// python/paddle/tensor/creation.py::assign(x, output)
struct AssignValueOpTranscriber : public OpTranscriber {
  ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
    std::string target_op_name = "pd.assign_value";
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      IR_THROW(
          "Op assign_value should have corresponding OpInfo pd.assign_value");
    }

    return op_info;
  }

  ir::Operation* operator()(ir::IrContext* ctx,
                            TranslationContext* param_map,
                            const OpDesc& op_desc,
                            ir::Program* program) override {
    VLOG(10) << "[op assign_value] start transcribing";
    auto op_info = this->LoopkUpOpInfo(ctx, op_desc);
    auto* op_info_concept =
        op_info.GetInterfaceImpl<dialect::OpYamlInfoInterface>();
    OpInputInfoList input_infos;
    OpAttributeInfoList attr_infos;
    OpOutputInfoList output_infos;
    std::tie(input_infos, attr_infos, output_infos, std::ignore) =
        op_info_concept->get_op_info_();
    std::unordered_map<std::string, OpAttributeInfo> attr_info_maps;
    for (auto info : attr_infos) {
      attr_info_maps.insert({info.name, info});
    }

    auto& attribute_translator = AttributeTranslator::instance();
    ir::AttributeMap attribute_map;

    paddle::framework::Attribute legacy_attr;
    if (op_desc.HasAttr("shape")) {
      legacy_attr = op_desc.GetAttr("shape");
    } else {
      IR_THROW("Op assign_value should have attribute `shape` but not find");
    }
    ir::Attribute attr_shape =
        attribute_translator(attr_info_maps.at("shape").type_name, legacy_attr);
    attribute_map["shape"] = attr_shape;

    if (op_desc.HasAttr("dtype")) {
      legacy_attr = op_desc.GetAttr("dtype");
    } else {
      IR_THROW("Op assign_value should have attribute `dtype` but not find");
    }
    ir::Attribute attr_dtype =
        attribute_translator(attr_info_maps.at("dtype").type_name, legacy_attr);
    attribute_map["dtype"] = attr_dtype;

    ir::Attribute attr_place =
        dialect::PlaceAttribute::get(ctx, phi::CPUPlace());
    attribute_map["place"] = attr_place;

    int dtype = paddle::get<int>(op_desc.GetAttr("dtype"));

    if (dtype == /*BOOL*/ 0) {
      legacy_attr = op_desc.GetAttr("bool_values");
    } else if (dtype == /*INT32*/ 2) {
      legacy_attr = op_desc.GetAttr("int32_values");
    } else if (dtype == /*FP32*/ 5) {
      legacy_attr = op_desc.GetAttr("fp32_values");
    } else if (dtype == /*INT64*/ 3) {
      legacy_attr = op_desc.GetAttr("int64_values");
    } else {
      IR_THROW(
          "Op assign_value should have attribute `**_values` but not find");
    }

    ir::Attribute attr_values = attribute_translator(
        attr_info_maps.at("values").type_name, legacy_attr);
    attribute_map["values"] = attr_values;

    VLOG(10) << "[op assign_value] attribute translation done";

    std::vector<ir::OpResult> op_inputs = {};

    OpOutputMapping arg_to_idx;
    OpOutputTypeList op_output_types;
    std::tie(op_output_types, arg_to_idx) =
        this->GenerateOperationOutput(ctx, op_desc, output_infos);

    ir::Operation* operation = ir::Operation::Create(
        op_inputs, attribute_map, op_output_types, op_info);
    program->block()->push_back(operation);
    RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx);

    VLOG(10) << "[op assign_value] translation finished";

    return operation;
  }
};

// This input `dropout_state_in` does not exist in static version definition
// So we generate an input by `full` with same type of output `DropoutState` of
// OpDesc And we still should be aware that `DropoutState` is an optional output
// in static graph.
ir::OpResult TranslateDropOutStateIn(ir::IrContext* ctx,
                                     TranslationContext* param_map,
                                     const OpDesc& op_desc,
                                     const std::string& normalized_op_name,
                                     const OpInputInfo& input_info,
                                     ir::Program* program) {
  const std::string legacy_output_name = "DropoutState";
  std::vector<std::string> legacy_output_vars;
  if (op_desc.HasOutput(legacy_output_name)) {
    legacy_output_vars = op_desc.Output(legacy_output_name);
  }

  if (legacy_output_vars.empty()) {
    VLOG(3) << "[input translating] not find output variable: DropoutState";
    return ir::OpResult(nullptr);
  }

  // `DropoutState` is a tensor
  VarDesc* dropout_state =
      op_desc.Block()->FindVarRecursive(legacy_output_vars[0]);
  if (dropout_state == nullptr) {
    IR_THROW("Unexpected: Rnn Op should have a non-empty DropoutState");
  }
  auto& type_translator = TypeTranslator::instance();
  ir::Type translated_var_type =
      type_translator[dropout_state->GetType()](ctx, *dropout_state);
  IR_ENFORCE(
      translated_var_type.isa<dialect::DenseTensorType>(),
      "Unexpected: Rnn Op's output DropoutState should be a DenseTensor");
  auto tensor_type = translated_var_type.dyn_cast<dialect::DenseTensorType>();

  ir::Builder builder(ctx, program->block());
  dialect::FullOp full_op = builder.Build<dialect::FullOp>(
      phi::vectorize(tensor_type.dims()),
      0.0f,
      dialect::TransToPhiDataType(tensor_type.dtype()),
      phi::CPUPlace());

  return full_op->result(0);
}

// `rnn` has an aditional input in dynamic graph
struct RnnOpTranscriber : public OpTranscriber {
  InputHandlerFn GetSpecialInputHandlers(
      const std::string& input_name) override {
    if (input_name != "dropout_state_in") {
      return nullptr;
    }
    return TranslateDropOutStateIn;
  };
};

struct EmbeddingGradOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(ir::IrContext* ctx,
                                  ir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "padding_idx") {
      (*attribute_map)[info.name] = ir::Int64Attribute::get(ctx, -1);
    } else if (info.name == "sparse") {
      (*attribute_map)[info.name] = ir::BoolAttribute::get(ctx, false);
    }
  }

  ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
    std::string target_op_name =
        kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type());

    bool is_sparse = paddle::get<bool>(op_desc.GetAttr("is_sparse"));

    if (is_sparse) {
      target_op_name = "pd.embedding_grad_sparse";
    } else {
      target_op_name = "pd.embedding_grad_dense";
    }
    VLOG(6) << "[op name normalizing: " << op_desc.Type() << " to "
            << target_op_name;
    auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      IR_THROW("Op %d should have corresponding OpInfo %d",
               op_desc.Type(),
               target_op_name);
    }

    return op_info;
  }
};

struct FeedOpTranscriber : public OpTranscriber {
  ir::AttributeMap TranslateOpAttribute(
      ir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    ir::AttributeMap attribute_map = {
        {"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])},
        {"col",
         ir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists<int>("col"))},
    };

    return attribute_map;
  }

  std::vector<ir::OpResult> GenerateOperationInput(
      ir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      ir::Program* program) override {
    return {};
  }
};

struct DataOpTranscriber : public FeedOpTranscriber {
  ir::AttributeMap TranslateOpAttribute(
      ir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    int allocate_type = paddle::get<int>(op_desc.GetAttr("place"));
    ir::AttributeMap attribute_map = {
        {"name",
         ir::StrAttribute::get(ctx,
                               op_desc.GetAttrIfExists<std::string>("name"))},
        {"index", ir::Int64Attribute::get(ctx, 0)},
        {"dtype",
         paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32)},
        {"place",
         paddle::dialect::PlaceAttribute::get(
             ctx, phi::Place(static_cast<phi::AllocationType>(allocate_type)))},
    };

    return attribute_map;
  }
};

struct SplitOpTranscriber : public OpTranscriber {
  std::vector<ir::OpResult> GenerateOperationInput(
      ir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      ir::Program* program) override {
    // input of split is [Tensor x, IntArray sections, Scalar(int) axis)]

    VLOG(10) << "[op:split][input] start";

    std::vector<ir::OpResult> op_inputs;
    // process first input
    auto x_input_vars = op_desc.Input("X");
    IR_ENFORCE(x_input_vars.size() == 1, "x input of split MUST be a tensor");
    auto x_defining_info = (*param_map)[x_input_vars[0]];
    op_inputs.push_back(x_defining_info.value);

    // process sections
    int num = paddle::get<int>(op_desc.GetAttr("num"));
    if (num <= 0) {
      if (op_desc.HasInput("SectionsTensorList") &&
          !op_desc.Input("SectionsTensorList").empty()) {
        // get SectionsTensorList from input

        auto sec_tensor_list = op_desc.Input("SectionsTensorList");
        auto* combine_op = InsertCombineOperationForTarget(
            ctx, param_map, program, sec_tensor_list);
        op_inputs.push_back(combine_op->result(0));
      } else {
        auto& attribute_translator = AttributeTranslator::instance();
        ir::Attribute new_attr = attribute_translator(
            "paddle::dialect::IntArrayAttribute", op_desc.GetAttr("sections"));
        auto sec_defin_op =
            InsertFullArrayOperationForAttributeInput(ctx, program, new_attr);
        op_inputs.push_back(sec_defin_op->result(0));
      }
    }

    // process axis
    if (op_desc.HasInput("AxisTensor") &&
        !op_desc.Input("AxisTensor").empty()) {
      // get axis from input
      auto axis_var_list = op_desc.Input("AxisTensor");
      IR_ENFORCE(axis_var_list.size() == 1,
                 "axis tensor input of split MUST be a tensor");
      auto axis_defining_info = (*param_map)[axis_var_list[0]];
      op_inputs.push_back(axis_defining_info.value);
    } else {
      auto& attribute_translator = AttributeTranslator::instance();
      ir::Attribute new_attr =
          attribute_translator("ir::Int32Attribute", op_desc.GetAttr("axis"));

      auto sec_defin_op =
          InsertFullOperationForAttributeInput(ctx, program, new_attr);
      op_inputs.push_back(sec_defin_op->result(0));
    }

    return op_inputs;
  }

  ir::AttributeMap TranslateOpAttribute(
      ir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    int num = paddle::get<int>(op_desc.GetAttr("num"));
    if (num > 0) {
      ir::AttributeMap attribute_map = {
          {"num",
           ir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists<int>("num"))},
      };

      return attribute_map;
    }

    return {};
  }

  ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
    int num = paddle::get<int>(op_desc.GetAttr("num"));
    std::string target_op_name;
    if (num > 0) {
      target_op_name = "pd.split_with_num";

    } else {
      target_op_name = "pd.split";
    }

    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      IR_THROW("Op assign_value should have corresponding OpInfo pd.split");
    }

    return op_info;
  }
};

struct FetchOpTranscriber : public OpTranscriber {
  ir::Operation* operator()(ir::IrContext* ctx,
                            TranslationContext* param_map,
                            const OpDesc& op_desc,
                            ir::Program* program) override {
    auto op_info = this->LoopkUpOpInfo(ctx, op_desc);

    auto* op_info_concept =
        op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
    OpInputInfoList input_infos;
    OpAttributeInfoList attr_infos;
    OpOutputInfoList output_infos;
    std::tie(input_infos, attr_infos, output_infos, std::ignore) =
        op_info_concept->get_op_info_();

    this->InsertSliceOperationForInput(
        ctx, param_map, op_desc, input_infos, program);

    auto op_inputs = this->GenerateOperationInput(
        ctx, param_map, op_desc, op_info.name(), input_infos, program);

    OpOutputTypeList op_output_types;
    ir::AttributeMap attribute_map = {
        {"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])},
        {"col",
         ir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists<int>("col"))},
    };

    op_output_types.push_back(op_inputs[0].type());
    ir::Operation* operation = ir::Operation::Create(
        op_inputs, attribute_map, op_output_types, op_info);
    program->block()->push_back(operation);

    return operation;
  }
};

struct ShadowOutputOpTranscriber : public OpTranscriber {
  ir::Operation* operator()(ir::IrContext* ctx,
                            TranslationContext* param_map,
                            const OpDesc& op_desc,
                            ir::Program* program) override {
    auto op_info = ctx->GetRegisteredOpInfo(ir::SetParameterOp::name());

    std::vector<ir::OpResult> op_inputs;
    auto legacy_input_vars = op_desc.Input("x", true);

    auto defining_info = (*param_map)[legacy_input_vars[0]];
    if (defining_info.generated_by_vector) {
      InsertSliceOperationForTarget(
          ctx, param_map, program, defining_info, legacy_input_vars[0]);
      defining_info = param_map->at(legacy_input_vars[0]).value;
    }

    op_inputs.push_back(defining_info.value);

    ir::AttributeMap attribute_map = {
        {"parameter_name",
         ir::StrAttribute::get(ctx,
                               op_desc.GetAttrIfExists<std::string>("name"))},
    };

    ir::Operation* operation =
        ir::Operation::Create(op_inputs, attribute_map, {}, op_info);
    program->block()->push_back(operation);

    return operation;
  }
};

// NOTE, add_n op in legacy ops don't have a kernel, so we use a new op for now
struct AddNOpTranscriber : public OpTranscriber {
  ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
    std::string target_op_name =
        kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type());
    if (IsInplace(op_desc)) {
      target_op_name += "_";
    } else {
      target_op_name += "_with_kernel";
    }
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      IR_THROW(
          "Op assign_value should have corresponding OpInfo pd.assign_value_");
    }

    return op_info;
  }
};

struct TrilAndTriuOpTranscriber : public OpTranscriber {
  ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
    bool lower = PADDLE_GET_CONST(bool, op_desc.GetAttr("lower"));
    std::string target_op_name = "";
    if (lower) {
      target_op_name = "pd.tril";
    } else {
      target_op_name = "pd.triu";
    }
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      IR_THROW(
          "Op tril_triu should have corresponding OpInfo pd.tril or pd.triu.");
    }

    return op_info;
  }
};

ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx,
                                          TranslationContext* param_map,
                                          const OpDesc& op_desc,
                                          const std::string& normalized_op_name,
                                          const OpInputInfo& input_info,
                                          ir::Program* program) {
  const std::string legacy_attr_name = "depth";
  const std::string legacy_tensor_name = "depth_tensor";
  std::vector<std::string> legacy_vars;
  if (op_desc.HasInput(legacy_tensor_name) &&
      !op_desc.Input(legacy_tensor_name).empty()) {
    legacy_vars = op_desc.Input(legacy_tensor_name);
    IR_ENFORCE(legacy_vars.size() == 1,
               "depth_tensor input of one hot MUST be a tensor");
    auto var_name = legacy_vars[0];
    IR_ENFORCE(legacy_vars.size() == 1,
               "depth_tensor input of one hot MUST be a tensor");
    auto defining_info = param_map->find(legacy_vars[0]);
    IR_ENFORCE(defining_info != param_map->end(),
               "%s should be existed in one_hot_v2 as input depth_tensor.",
               legacy_vars[0]);
    return defining_info->second.value;
  }

  auto& attribute_translator = AttributeTranslator::instance();
  if (!op_desc.HasAttr(legacy_attr_name)) {
    IR_THROW("Op %s arg %s should not be zero size",
             op_desc.Type(),
             legacy_attr_name);
  }
  paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name);
  VLOG(10) << "[" << op_desc.Type() << "][attribute]"
           << " name: " << legacy_attr_name << " " << legacy_attr.index();
  ir::Attribute new_attr = attribute_translator(legacy_attr);

  ir::Operation* defining_op =
      InsertFullOperationForAttributeInput(ctx, program, new_attr);
  return defining_op->result(0);
}

struct OneHotTranscriber : public OpTranscriber {
  InputHandlerFn GetSpecialInputHandlers(
      const std::string& input_name) override {
    if (input_name != "num_classes") {
      return nullptr;
    }
    return TranslateNumClassesForOneHot;
  };
};

ir::Attribute TranslateReduceAll(ir::IrContext* ctx,
                                 const OpDesc& op_desc,
                                 const OpAttributeInfo& attr_info) {
  bool reduce_all = false;
  if (op_desc.HasAttr("reduce_all")) {
    reduce_all = paddle::get<bool>(op_desc.GetAttr("reduce_all"));
  }

  if (reduce_all) {
    return ir::ArrayAttribute::get(ctx, std::vector<ir::Attribute>{});
  }

  auto& attribute_translator = AttributeTranslator::instance();
  auto& op_normalizer = OpNameNormalizer::instance();
  auto legacy_attr_name =
      op_normalizer.GetLegacyAttrName(op_desc.Type(), attr_info.name);
  paddle::framework::Attribute dims = op_desc.GetAttr(legacy_attr_name);
  return attribute_translator(attr_info.type_name, dims);
}

struct ReduceOpTranscriber : public OpTranscriber {
  AttributeHandlerFn GetSpecialAttributeHandlers(
      const std::string& input_name) override {
    if (input_name != "axis") {
      return nullptr;
    }
    return TranslateReduceAll;
  }
};

struct ElementwiseTranscriber : public OpTranscriber {
  std::vector<ir::OpResult> GenerateOperationInput(
      ir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      ir::Program* program) override {
    int axis = paddle::get<int>(op_desc.GetAttr("axis"));

    if (axis == -1) {
      return OpTranscriber::GenerateOperationInput(
          ctx, param_map, op_desc, normalized_op_name, input_infos, program);
    }

    auto x_names = op_desc.Input("X", true);
    IR_ENFORCE(x_names.size() == 1,
               "Expected op[%s]'s input X has only 1 variable, but got %d",
               op_desc.Type(),
               x_names.size());
    auto x_name = x_names[0];
    IR_ENFORCE(param_map->count(x_name) > 0,
               "Expected op[%s]'s input %s has been parsed",
               op_desc.Type(),
               x_name);
    auto x_defining_info = param_map->at(x_name);
    if (x_defining_info.generated_by_vector) {
      InsertSliceOperationForTarget(
          ctx, param_map, program, x_defining_info, x_name);
      x_defining_info = param_map->at(x_name);
    }
    ir::OpResult x_value = x_defining_info.value;
    IR_ENFORCE(x_value,
               "Expected op[%s]'s input %s is not null",
               op_desc.Type(),
               x_name);
    ir::Type x_type = x_value.type();
    IR_ENFORCE(x_type.isa<dialect::DenseTensorType>(),
               "Expected op[%s]'s input %s is DenseTensor but got %s",
               op_desc.Type(),
               x_name,
               x_type);
    dialect::DenseTensorType x_tensor_type =
        x_type.dyn_cast<dialect::DenseTensorType>();
    std::vector<int64_t> x_shape = phi::vectorize(x_tensor_type.dims());

    auto y_names = op_desc.Input("Y", true);
    IR_ENFORCE(y_names.size() == 1,
               "Expected op[%s]'s input Y has only 1 variable, but got %d",
               op_desc.Type(),
               y_names.size());
    auto y_name = y_names[0];
    IR_ENFORCE(param_map->count(y_name) > 0,
               "Expected op[%s]'s input %s has been parsed",
               op_desc.Type(),
               y_name);
    auto y_defining_info = param_map->at(y_name);
    if (y_defining_info.generated_by_vector) {
      InsertSliceOperationForTarget(
          ctx, param_map, program, y_defining_info, y_name);
      y_defining_info = param_map->at(y_name);
    }
    ir::OpResult y_value = y_defining_info.value;
    IR_ENFORCE(y_value,
               "Expected op[%s]'s input %s is not null",
               op_desc.Type(),
               y_name);
    ir::Type y_type = y_value.type();
    IR_ENFORCE(y_type.isa<dialect::DenseTensorType>(),
               "Expected op[%s]'s input %s is DenseTensor but got %s",
               op_desc.Type(),
               y_name,
               y_type);
    dialect::DenseTensorType y_tensor_type =
        y_type.dyn_cast<dialect::DenseTensorType>();
    std::vector<int64_t> y_shape = phi::vectorize(y_tensor_type.dims());

    if (axis < 0) {
      axis += x_shape.size();
    }

    int append_size = x_shape.size() - axis - 1 - y_shape.size();
    if (append_size < 0) {  // which means x.rank <= y.rank, mostly
                            // x.rank=y.rank
      return {x_value, y_value};
    }
    IR_ENFORCE(append_size >= 0,
               "Expected op[%s] have append size >= 0 with axis=%d but got %d",
               op_desc.Type(),
               axis,
               append_size);

    ir::Builder builder(ctx, program->block());
    ir::OpResult y_new;
    if (std::find(y_shape.begin(), y_shape.end(), -1) == y_shape.end()) {
      std::vector<int64_t> y_new_shape(y_shape);
      for (int i = 0; i <= append_size; i++) {
        y_new_shape.push_back(1);
      }
      dialect::ReshapeOp reshape_op =
          builder.Build<dialect::ReshapeOp>(y_value, y_new_shape);
      y_new = reshape_op.out();
      VLOG(6) << "[" << op_desc.Type() << "] y_shape change from "
              << y_tensor_type.dims() << " to " << phi::make_ddim(y_new_shape);
    } else {
      auto shape_op = builder.Build<dialect::ShapeOp>(y_value);
      auto append_shape_op = builder.Build<dialect::FullIntArrayOp>(
          std::vector<int64_t>(append_size, 1),
          phi::DataType::INT64,
          phi::CPUPlace());
      auto y_true_shape_op = builder.Build<ir::CombineOp>(
          std::vector<ir::OpResult>{shape_op.out(), append_shape_op.out()});
      auto concat_op =
          builder.Build<dialect::ConcatOp>(y_true_shape_op.out(), 0);
      auto y_new_shape = concat_op.out();
      auto reshape_op = builder.Build<dialect::ReshapeOp>(y_value, y_new_shape);
      y_new = reshape_op.out();
    }
    return {x_value, y_new};
  }
};

struct GradAddOpTranscriber : public ElementwiseTranscriber {
  ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
    const std::string& target_op_name = "pd.add";
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      IR_THROW(
          "Op assign_value should have corresponding OpInfo pd.assign_value_");
    }

    return op_info;
  }
};

struct ElementwiseGradTranscriber : public OpTranscriber {
  void RecordOpResultMapping(ir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             ir::Operation* operation,
                             const OpOutputMapping& arg_to_idx) override {
    OpTranscriber::RecordOpResultMapping(
        ctx, param_map, op_desc, operation, arg_to_idx);

    int axis = paddle::get<int>(op_desc.GetAttr("axis"));
    if (axis == -1) {
      return;
    }

    const auto& y_grad_output = op_desc.Output("Y@GRAD");
    if (y_grad_output.size() < 1) {
      return;
    }
    IR_ENFORCE(
        y_grad_output.size() == 1,
        "Expected op[%s]'s output Y@GRAD has only 1 variable, but got %d",
        op_desc.Type(),
        y_grad_output.size());
    const auto& y_grad_var_name = y_grad_output[0];

    auto idx_iter = arg_to_idx.find(y_grad_var_name);
    if (idx_iter == arg_to_idx.end()) {
      IR_THROW("op[%s] should have got its y_grad", op_desc.Type());
    }
    auto [idx_in_op, idx_in_vec] = idx_iter->second;
    VLOG(10) << "[output recording]"
             << "[" << op_desc.Type() << "]" << y_grad_var_name << " "
             << idx_in_op << " " << idx_in_vec;

    auto y_names = op_desc.Input("Y", true);
    auto y_name = y_names[0];
    IR_ENFORCE(param_map->count(y_name) > 0,
               "Expected op[%s]'s input %s has been parsed",
               op_desc.Type(),
               y_name);
    auto y_defining_info = param_map->at(y_name);
    ir::OpResult y_value = y_defining_info.value;
    IR_ENFORCE(y_value,
               "Expected op[%s]'s input %s is not null",
               op_desc.Type(),
               y_name);
    ir::Type y_type = y_value.type();
    IR_ENFORCE(y_type.isa<dialect::DenseTensorType>(),
               "Expected op[%s]'s input %s is DenseTensor but got %s",
               op_desc.Type(),
               y_name,
               y_type);
    dialect::DenseTensorType y_tensor_type =
        y_type.dyn_cast<dialect::DenseTensorType>();
    std::vector<int64_t> y_shape = phi::vectorize(y_tensor_type.dims());

    ir::OpResult value = operation->result(idx_in_op);
    ir::Builder builder(ctx, operation->GetParent());
    auto reshape_op = builder.Build<dialect::ReshapeOp>(value, y_shape);
    (*param_map)[y_grad_var_name] =
        VariableDefiningInfo(reshape_op.out(), false, -1);
  }
};

OpTranslator::OpTranslator() {
  general_handler = OpTranscriber();
  special_handlers["add_n"] = AddNOpTranscriber();
  special_handlers["assign_value"] = AssignValueOpTranscriber();
  special_handlers["cast"] = CastOpTranscriber();
  special_handlers["feed"] = FeedOpTranscriber();
  special_handlers["data"] = DataOpTranscriber();
  special_handlers["fetch_v2"] = FetchOpTranscriber();
  special_handlers["grad_add"] = GradAddOpTranscriber();
  special_handlers["increment"] = IncrementOpTranscriber();
  special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber();
  special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber();
  special_handlers["one_hot_v2"] = OneHotTranscriber();
  special_handlers["reduce_all"] = ReduceOpTranscriber();
  special_handlers["reduce_any"] = ReduceOpTranscriber();
  special_handlers["rnn"] = RnnOpTranscriber();
  special_handlers["shadow_output"] = ShadowOutputOpTranscriber();
  special_handlers["split"] = SplitOpTranscriber();
  special_handlers["sum"] = AddNOpTranscriber();
  special_handlers["tril_triu"] = TrilAndTriuOpTranscriber();

  // special handler for elementwise ops with axis != -1
  // note(lyk): maybe we should do this by a pass, which seems more reasonable
  special_handlers["elementwise_add"] = ElementwiseTranscriber();
  special_handlers["elementwise_sub"] = ElementwiseTranscriber();
  special_handlers["elementwise_mul"] = ElementwiseTranscriber();
  special_handlers["elementwise_div"] = ElementwiseTranscriber();
  special_handlers["elementwise_max"] = ElementwiseTranscriber();
  special_handlers["elementwise_min"] = ElementwiseTranscriber();
  special_handlers["elementwise_mod"] = ElementwiseTranscriber();
  special_handlers["elementwise_floordiv"] = ElementwiseTranscriber();
  special_handlers["elementwise_add_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_sub_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_mul_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_div_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_max_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_min_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_mod_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_floordiv_grad"] = ElementwiseGradTranscriber();
}

}  // namespace translator
}  // namespace paddle
