op_translator.cc 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/translator/op_translator.h"

#include <algorithm>
#include <numeric>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/op_desc.h"
25
#include "paddle/fluid/translator/op_compat_info.h"
26 27
#include "paddle/fluid/translator/program_translator.h"
#include "paddle/fluid/translator/type_translator.h"
28 29 30 31
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/value.h"
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
#include "paddle/phi/core/enforce.h"

namespace paddle {
namespace translator {

namespace {

using ResultIdx = size_t;
using OpDesc = paddle::framework::OpDesc;
using BlockDesc = paddle::framework::BlockDesc;
using VarDesc = paddle::framework::VarDesc;
using OpOutputTypeList = std::vector<ir::Type>;
using OpOutputMapping = std::unordered_map<std::string, ResultIdx>;

static const char kTargetDialectPrefix[] = "pd.";

inline bool IsInplace(const OpDesc& op_desc) {
  bool inplace = false;
  auto input_names = op_desc.InputArgumentNames();
  auto output_names = op_desc.OutputArgumentNames();

  std::vector<std::string> name_intersection;
  std::set_intersection(input_names.begin(),
                        input_names.end(),
                        output_names.begin(),
                        output_names.end(),
                        std::back_inserter(name_intersection));

  if (name_intersection.size() > 0) {
    std::string redundant_variables = std::accumulate(
        std::next(name_intersection.begin()),
        name_intersection.end(),
        name_intersection[0],
        [](std::string a, std::string b) { return a + "," + b; });
    VLOG(4) << "Following variables occur both in inputs and outputs: "
            << redundant_variables;
    return true;
  }

  return inplace;
}

74 75 76 77 78
inline std::string OpNamecompatibleMapping(std::string op_name) {
  auto& op_normalizer = OpNameNormalizer::instance();
  return op_normalizer[op_name];
}

79
inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) {
80 81
  std::string target_op_name =
      kTargetDialectPrefix + OpNamecompatibleMapping(op_desc.Type());
82 83 84
  if (IsInplace(op_desc)) {
    target_op_name += "_";
  }
85 86
  VLOG(6) << "[op name normalizing: " << op_desc.Type() << " to "
          << target_op_name;
87 88 89 90 91 92 93 94 95 96 97
  auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
  if (!op_info) {
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "Op %d should have corresponding OpInfo %d",
        op_desc.Type(),
        target_op_name));
  }

  return op_info;
}

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
inline ir::Operation* InsertSliceOperationForTarget(
    ir::IrContext* ctx,
    TranslationContext* param_map,
    ir::Program* program,
    const VariableDefiningInfo& defining_info,
    const std::string& arg_name) {
  std::string slice_op_name(ir::SliceOp::name());
  ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name);
  std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
      {"index", ir::Int32_tAttribute::get(ctx, defining_info.idx_in_vector)},
  };
  ir::VectorType src_vec_type =
      defining_info.value.type().dyn_cast<ir::VectorType>();
  ir::Operation* operation =
      ir::Operation::create({defining_info.value},
                            op_attribute_map,
114
                            {src_vec_type[defining_info.idx_in_vector]},
115
                            op_info);
116
  program->block()->push_back(operation);
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
  ir::OpResult target_op_result = operation->GetResultByIndex(0);
  (*param_map)[arg_name] = VariableDefiningInfo(target_op_result);
  return operation;
}

inline ir::Operation* InsertCombineOperationForTarget(
    ir::IrContext* ctx,
    TranslationContext* param_map,
    ir::Program* program,
    const std::vector<std::string>& args) {
  std::string combine_op_name(ir::CombineOp::name());
  ir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name);

  std::vector<ir::OpResult> src_values;
  std::vector<ir::Type> types_in_vec;
132
  for (auto arg_name : args) {
133 134 135 136 137 138
    auto defining_info = param_map->at(arg_name);
    src_values.push_back(defining_info.value);
    types_in_vec.push_back(defining_info.value.type());
  }
  ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec);
  ir::Operation* operation =
139
      ir::Operation::create(src_values, {}, {target_vec_type}, op_info);
140
  program->block()->push_back(operation);
141 142 143
  return operation;
}

144
inline std::vector<ir::OpResult> GenerateOperationInput(
145 146 147
    ir::IrContext* ctx,
    TranslationContext* param_map,
    ir::Program* program,
148 149 150
    const OpDesc& op_desc) {
  std::vector<ir::OpResult> op_inputs = {};

151 152
  // scan all inputs to see if any of them is generated as a vector<Tensor>
  // so need an additional `SliceOp` to take it out.
153 154 155
  for (const auto& n : op_desc.Inputs()) {
    auto& name = n.first;
    auto& args = n.second;
156

157 158 159 160 161
    for (const auto& arg_name : args) {
      PADDLE_ENFORCE_NE(
          param_map->count(arg_name),
          0,
          platform::errors::PreconditionNotMet(
162
              "arg %s.%s as input should be exists before prasing %d",
163
              name,
164 165
              arg_name,
              op_desc.Type()));
166 167 168 169 170 171 172 173
      auto defining_info = (*param_map)[arg_name];
      if (defining_info.generated_by_vector) {
        InsertSliceOperationForTarget(
            ctx, param_map, program, defining_info, arg_name);
      }
    }
  }

174 175 176 177 178
  for (const auto& n : op_desc.Inputs()) {
    auto& name = n.first;
    VLOG(10) << "[input retriving]"
             << "[" << op_desc.Type() << "]" << name;
    auto& args = n.second;
179

180 181 182 183 184 185
    // if src type is Tensor or a Vector<Tensor> with size <= 1
    if (args.size() <= 1) {
      for (const auto& arg_name : args) {
        auto defining_info = (*param_map)[arg_name];
        op_inputs.push_back(defining_info.value);
      }
186 187 188 189

      // if src type is Vector<Tesnor> , need an additional `CombineOp` to
      // assemble them.
    } else {
190 191
      auto* combine_op =
          InsertCombineOperationForTarget(ctx, param_map, program, args);
192
      op_inputs.push_back(combine_op->GetResultByIndex(0));
193 194 195 196 197 198
    }
  }
  return op_inputs;
}

inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
199
    ir::IrContext* ctx, const OpDesc& op_desc) {
200 201 202 203 204 205
  OpOutputMapping arg_to_idx;
  OpOutputTypeList op_output_types = {};

  auto& type_translator = TypeTranslator::instance();

  const BlockDesc* block = op_desc.Block();
206 207 208 209 210
  for (const auto& n : op_desc.Outputs()) {
    auto& name = n.first;
    VLOG(10) << "[output translating]"
             << "[" << op_desc.Type() << "]" << name;
    auto& args = n.second;
211

212 213
    size_t cur_output_idx = op_output_types.size();

214 215 216 217 218 219 220
    // if src type is Tensor or a Vector<Tensor> with size <= 1
    if (args.size() <= 1) {
      for (const auto& arg_name : args) {
        VarDesc* var = block->FindVarRecursive(arg_name);
        VLOG(10) << "[output translating]"
                 << "[" << op_desc.Type() << "]" << name << " " << arg_name
                 << " " << var->GetType();
221

222 223
        ir::Type translated_var_type =
            type_translator[var->GetType()](ctx, *var);
224

225 226 227
        arg_to_idx[arg_name] = cur_output_idx;
        op_output_types.push_back(translated_var_type);
      }
228 229 230 231

      // if src type is Vector<Tesnor>
    } else {
      std::vector<ir::Type> types;
232 233
      for (const auto& arg_name : args) {
        VarDesc* var = block->FindVarRecursive(arg_name);
234
        VLOG(10) << "[output translating]"
235
                 << "[" << op_desc.Type() << "]" << name << " " << arg_name
236 237 238 239
                 << " " << var->GetType();
        ir::Type translated_var_type =
            type_translator[var->GetType()](ctx, *var);
        types.push_back(translated_var_type);
240
        arg_to_idx[arg_name] = cur_output_idx;
241 242 243
      }
      ir::Type vec_type = ir::VectorType::get(ctx, types);
      op_output_types.push_back(vec_type);
244 245 246 247 248 249 250 251 252 253 254 255 256 257
    }
  }
  return {op_output_types, arg_to_idx};
}

inline void RecordOpResultMapping(TranslationContext* param_map,
                                  const OpDesc& op_desc,
                                  ir::Operation* operation,
                                  const OpOutputMapping& arg_to_idx) {
  for (const auto& n : op_desc.Outputs()) {
    auto& name = n.first;
    VLOG(10) << "[output recording]"
             << "[" << op_desc.Type() << "]" << name;
    auto& args = n.second;
258
    size_t idx_in_vector = 0;
259 260 261 262 263
    for (const auto& arg_name : args) {
      auto idx = arg_to_idx.at(arg_name);
      VLOG(10) << "[output recording]"
               << "[" << op_desc.Type() << "]" << arg_name << " " << idx;

264 265 266 267 268
      ir::OpResult value = operation->GetResultByIndex(idx);
      bool generated_by_vector = value.type().isa<ir::VectorType>();
      (*param_map)[arg_name] = VariableDefiningInfo(
          value, generated_by_vector, generated_by_vector ? idx_in_vector : -1);
      idx_in_vector++;
269 270 271 272 273 274 275 276
    }
  }
}

ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
                                TranslationContext* param_map,
                                ir::Program* program,
                                const OpDesc& op_desc) {
277
  auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc);
278 279

  OpOutputMapping arg_to_idx;
280 281 282
  OpOutputTypeList op_output_types = {};
  std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc);
  auto op_info = LoopkUpOpInfo(ctx, op_desc);
283
  ir::Operation* operation =
284
      ir::Operation::create(op_inputs, {}, op_output_types, op_info);
285
  program->block()->push_back(operation);
286 287 288 289 290 291 292 293 294
  RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);

  return operation;
}

ir::Operation* FeedOpHandler(ir::IrContext* ctx,
                             TranslationContext* param_map,
                             ir::Program* program,
                             const OpDesc& op_desc) {
295
  std::vector<ir::OpResult> op_inputs = {};
296 297

  OpOutputMapping arg_to_idx;
298 299 300
  OpOutputTypeList op_output_types = {};
  std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc);
  auto op_info = LoopkUpOpInfo(ctx, op_desc);
301
  ir::Operation* operation =
302
      ir::Operation::create(op_inputs, {}, op_output_types, op_info);
303
  program->block()->push_back(operation);
304 305 306 307 308 309 310 311 312
  RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);

  return operation;
}

ir::Operation* FetchOpHandler(ir::IrContext* ctx,
                              TranslationContext* param_map,
                              ir::Program* program,
                              const OpDesc& op_desc) {
313
  auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc);
314

315 316
  OpOutputTypeList op_output_types = {};
  auto op_info = LoopkUpOpInfo(ctx, op_desc);
317
  ir::Operation* operation =
318
      ir::Operation::create(op_inputs, {}, op_output_types, op_info);
319
  program->block()->push_back(operation);
320 321 322 323 324 325 326 327 328 329 330 331

  return operation;
}
}  // namespace

OpTranslator::OpTranslator() : general_handler(GeneralOpHandler) {
  special_handlers["feed"] = FeedOpHandler;
  special_handlers["fetch_v2"] = FetchOpHandler;
}

}  // namespace translator
}  // namespace paddle