diff --git a/paddle/fluid/translator/attribute_translator.cc b/paddle/fluid/translator/attribute_translator.cc new file mode 100644 index 0000000000000000000000000000000000000000..4390ba68616d5652e48b3dc06a575000a5eea799 --- /dev/null +++ b/paddle/fluid/translator/attribute_translator.cc @@ -0,0 +1,231 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/translator/attribute_translator.h" + +#include +#include + +#include "paddle/fluid/dialect/pd_attribute.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/utils/variant.h" + +namespace paddle { +namespace translator { + +class AttributeVisitor { + public: + ir::IrContext* ctx; + AttributeVisitor() { ctx = ir::IrContext::Instance(); } + ~AttributeVisitor() {} + + public: + virtual ir::Attribute operator()(int i) { + VLOG(10) << "translating int"; + return ir::Int32_tAttribute::get(ctx, i); + } + + virtual ir::Attribute operator()(float f) { + VLOG(10) << "translating float"; + return ir::FloatAttribute::get(ctx, f); + } + + virtual ir::Attribute operator()(bool b) { + VLOG(10) << "translating bool"; + return ir::BoolAttribute::get(ctx, b); + } + + virtual ir::Attribute operator()(double d) { + VLOG(10) << "translating double"; + return ir::DoubleAttribute::get(ctx, d); + } + + virtual ir::Attribute operator()(std::string str) { + VLOG(10) << "translating string"; + return ir::StrAttribute::get(ctx, str); + } + + virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { + VLOG(10) << "translating scalar"; + return paddle::dialect::ScalarAttribute::get(ctx, scalar); + } + + virtual ir::Attribute operator()(const std::vector& strs) { + VLOG(10) << "translating vector"; + std::vector attrs; + attrs.reserve(strs.size()); + for (const auto& v : strs) { + attrs.push_back(ir::StrAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + virtual ir::Attribute operator()(const std::vector& fs) { + VLOG(10) << "translating vector"; + std::vector attrs; + attrs.reserve(fs.size()); + for (const auto& v : fs) { + attrs.push_back(ir::FloatAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + virtual ir::Attribute operator()(const std::vector& is) { + VLOG(10) << "translating vector"; + std::vector attrs; + attrs.reserve(is.size()); + for (const auto& v : is) { + attrs.push_back(ir::Int32_tAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + virtual ir::Attribute operator()(const std::vector& bs) { + VLOG(10) << "translating vector"; + std::vector attrs; + attrs.reserve(bs.size()); + for (const auto& v : bs) { + attrs.push_back(ir::BoolAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + virtual ir::Attribute operator()(const std::vector& i64s) { + VLOG(10) << "translating vector"; + std::vector attrs; + attrs.reserve(i64s.size()); + for (const auto& v : i64s) { + attrs.push_back(ir::Int64_tAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + virtual ir::Attribute operator()(const std::vector& ds) { + VLOG(10) << "translating vector"; + std::vector attrs; + attrs.reserve(ds.size()); + for (const auto& v : ds) { + attrs.push_back(ir::DoubleAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + virtual ir::Attribute operator()( + const std::vector& ss) { + VLOG(10) << "translating vector"; + std::vector attrs; + attrs.reserve(ss.size()); + for (const auto& v : ss) { + attrs.push_back(paddle::dialect::ScalarAttribute::get(ctx, v)); + } + return ir::ArrayAttribute::get(ctx, attrs); + } + + virtual ir::Attribute operator()(const paddle::blank& blank) { + VLOG(10) << "translating paddle::blank"; + return ir::Attribute(nullptr); + } + + template + ir::Attribute operator()(T attr) { + VLOG(10) << "translating null type"; + return ir::Attribute(nullptr); + } +}; + +class IntArrayAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + ir::Attribute operator()(const std::vector& is) override { + VLOG(10) << "translating vector to IntArray"; + phi::IntArray data(is); + return paddle::dialect::IntArrayAttribute::get(ctx, data); + } + + ir::Attribute operator()(const std::vector& is) override { + VLOG(10) << "translating vector to IntArray"; + phi::IntArray data(is); + return paddle::dialect::IntArrayAttribute::get(ctx, data); + } +}; + +class ScalarAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + ir::Attribute operator()(int i) override { + VLOG(10) << "translating int to Scalar"; + phi::Scalar data(i); + return paddle::dialect::ScalarAttribute::get(ctx, data); + } + + ir::Attribute operator()(float f) override { + VLOG(10) << "translating float to Scalar"; + phi::Scalar data(f); + return paddle::dialect::ScalarAttribute::get(ctx, data); + } +}; + +class DataTypeAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + ir::Attribute operator()(int i) override { + VLOG(10) << "translating int to DataType: " << i; + phi::DataType data = static_cast(i); + return paddle::dialect::DataTypeAttribute::get(ctx, data); + } +}; + +class PlaceAttributeVisitor : public AttributeVisitor { + public: + using AttributeVisitor::AttributeVisitor; + + ir::Attribute operator()(const paddle::blank& blank) override { + VLOG(10) << "translating paddle::blank"; + phi::Place data(phi::AllocationType::CPU); + return paddle::dialect::PlaceAttribute::get(ctx, data); + } +}; + +AttributeTranslator::AttributeTranslator() { + general_visitor = new AttributeVisitor(); + special_visitors["paddle::dialect::IntArrayAttribute"] = + new IntArrayAttributeVisitor(); + special_visitors["paddle::dialect::ScalarAttribute"] = + new ScalarAttributeVisitor(); + special_visitors["paddle::dialect::DataTypeAttribute"] = + new DataTypeAttributeVisitor(); + special_visitors["paddle::dialect::PlaceAttribute"] = + new PlaceAttributeVisitor(); +} + +ir::Attribute AttributeTranslator::operator()( + const framework::Attribute& attr) { + return paddle::visit(*general_visitor, attr); +} + +ir::Attribute AttributeTranslator::operator()( + const std::string& target_type, const framework::Attribute& attr) { + if (special_visitors.find(target_type) == special_visitors.end()) { + VLOG(10) << "[" << target_type << "] not found"; + return paddle::visit(*general_visitor, attr); + } + return paddle::visit(*(special_visitors.at(target_type)), attr); +} + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/attribute_translator.h b/paddle/fluid/translator/attribute_translator.h new file mode 100644 index 0000000000000000000000000000000000000000..ea509c7e34673625e8a28c04a72acc1c8e895c8d --- /dev/null +++ b/paddle/fluid/translator/attribute_translator.h @@ -0,0 +1,54 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/ir_context.h" + +#pragma once + +namespace paddle { +namespace translator { + +class AttributeVisitor; + +class AttributeTranslator { + private: + AttributeTranslator(); + AttributeVisitor* general_visitor; + std::unordered_map special_visitors; + + public: + AttributeTranslator(const AttributeTranslator&) = delete; + AttributeTranslator& operator=(const AttributeTranslator&) = delete; + AttributeTranslator(AttributeTranslator&&) = delete; + AttributeTranslator& operator=(AttributeTranslator&&) = delete; + + static auto& instance() { + static AttributeTranslator attribute_translator; + return attribute_translator; + } + + ir::Attribute operator()(const framework::Attribute& attr); + ir::Attribute operator()(const std::string& target_type, + const framework::Attribute& attr); +}; + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/op_compat_gen.py b/paddle/fluid/translator/op_compat_gen.py index d6aeeeaf8ee4bf9f206b075ada81296d8c181e64..5bc9df7ee8b34b40482492baae03e664ba14d788 100644 --- a/paddle/fluid/translator/op_compat_gen.py +++ b/paddle/fluid/translator/op_compat_gen.py @@ -14,6 +14,7 @@ import argparse from pathlib import Path +from typing import Dict import yaml from jinja2 import Environment, FileSystemLoader, StrictUndefined @@ -33,7 +34,7 @@ def OpNameNormalizerInitialization( op_compat_yaml_file: str = "", output_source_file: str = "" ) -> None: def to_phi_and_fluid_op_name(op_item): - # Templat: - op : phi_name (fluid_name) + # Template: - op : phi_name (fluid_name) names = op_item.split('(') if len(names) == 1: phi_fluid_name = names[0].strip() @@ -46,21 +47,55 @@ def OpNameNormalizerInitialization( with open(op_compat_yaml_file, "r") as f: op_compat_infos = yaml.safe_load(f) op_name_mappings = {} + op_arg_name_mappings = {} for op_compat_item in op_compat_infos: - def insert_new_mappings(op_name_str): + def insert_new_mappings(op_name_str: str) -> str: normalized_name, legacy_name = to_phi_and_fluid_op_name(op_name_str) if normalized_name == legacy_name: - return + return normalized_name, legacy_name op_name_mappings[legacy_name] = normalized_name + return normalized_name, legacy_name + + def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): + if op_name is None: + return + if op_name not in op_arg_name_mappings: + op_arg_name_mappings[op_name] = {} + op_arg_name_mappings[op_name].update(arg_mapping) - insert_new_mappings(op_compat_item["op"]) + _, legacy_name = insert_new_mappings(op_compat_item["op"]) + legacy_backward_op_names = [] if "backward" in op_compat_item: - insert_new_mappings(op_compat_item["backward"]) + backward_op_name_mapping_paris = op_compat_item["backward"].split( + "," + ) + for pair in backward_op_name_mapping_paris: + _, legacy_backward_op_name = insert_new_mappings(pair) + legacy_backward_op_names.append(legacy_backward_op_name) + + if "inputs" in op_compat_item: + insert_new_arg_mappings(legacy_name, op_compat_item["inputs"]) + for backward_op in legacy_backward_op_names: + insert_new_arg_mappings(backward_op, op_compat_item["inputs"]) + + if "attrs" in op_compat_item: + insert_new_arg_mappings(legacy_name, op_compat_item["attrs"]) + for backward_op in legacy_backward_op_names: + insert_new_arg_mappings(backward_op, op_compat_item["attrs"]) + if "outputs" in op_compat_item: + insert_new_arg_mappings(legacy_name, op_compat_item["outputs"]) + for backward_op in legacy_backward_op_names: + insert_new_arg_mappings(backward_op, op_compat_item["outputs"]) + + # special op mappings + op_name_mappings["fetch_v2"] = "fetch" + op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") with open(output_source_file, 'wt') as f: op_compat_definition = op_name_normailzer_template.render( - op_name_paris=op_name_mappings + op_name_pairs=op_name_mappings, + op_arg_name_pairs=op_arg_name_mappings, ) f.write(op_compat_definition) diff --git a/paddle/fluid/translator/op_compat_info.cc.j2 b/paddle/fluid/translator/op_compat_info.cc.j2 index af42cf9b8abdc391dcdeebb27af5bb42d1cb9c4e..a44941595fbb8d1f5c4d984e2ee232ecd91f4976 100644 --- a/paddle/fluid/translator/op_compat_info.cc.j2 +++ b/paddle/fluid/translator/op_compat_info.cc.j2 @@ -5,10 +5,22 @@ namespace translator { OpNameNormalizer::OpNameNormalizer() { op_name_mappings = { - {% for legacy_name, normalized_name in op_name_paris.items() %} + {% for legacy_name, normalized_name in op_name_pairs.items() %} { "{{legacy_name}}", "{{normalized_name}}" }, {% endfor %} }; + op_arg_name_mappings = { + {% for op_name, arg_name_mappings in op_arg_name_pairs.items() %} + { + "{{op_name}}", + { + {% for normalized_name, legacy_name in arg_name_mappings.items() %} + { "{{normalized_name}}", "{{legacy_name}}" }, + {% endfor %} + }, + }, + {% endfor %} + }; } } // namespace translator diff --git a/paddle/fluid/translator/op_compat_info.h b/paddle/fluid/translator/op_compat_info.h index 86acafe7a0f1a5cb480b8881ba99cc27c41bdc11..654329dbbe357eaf22229003ee631a1e823ba060 100644 --- a/paddle/fluid/translator/op_compat_info.h +++ b/paddle/fluid/translator/op_compat_info.h @@ -12,11 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include "glog/logging.h" +#include "paddle/fluid/translator/utils.h" + #pragma once namespace paddle { @@ -26,6 +29,8 @@ class OpNameNormalizer { private: OpNameNormalizer(); // Disallow instantiation outside of the class. std::unordered_map op_name_mappings; + std::unordered_map> + op_arg_name_mappings; public: OpNameNormalizer(const OpNameNormalizer&) = delete; @@ -44,6 +49,49 @@ class OpNameNormalizer { } return op_name_mappings.at(op_type); } + + std::string GetLegacyArgName(const std::string& op_type, + const std::string& arg_name) { + bool is_grad_op = (op_type.find("grad") != std::string::npos); + bool is_grad_arg = (arg_name.find("grad") != std::string::npos); + if (is_grad_op && is_grad_arg) { + std::string target = "_grad"; + std::string data = "@GRAD"; + + size_t first_grad_pos = arg_name.find_first_of(target); + std::string legacy_name = + this->GetLegacyArgName(op_type, arg_name.substr(0, first_grad_pos)); + legacy_name += arg_name.substr(first_grad_pos); + for (size_t pos = 0; + legacy_name.npos != (pos = legacy_name.find(target, pos)); + pos += data.length()) { + legacy_name.replace(pos, target.length(), data); + } + return legacy_name; + } + if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) { + return UnderscoreToCamelCase(arg_name); + } + auto& arg_mappings = op_arg_name_mappings[op_type]; + if (arg_mappings.find(arg_name) == arg_mappings.end()) { + return UnderscoreToCamelCase(arg_name); + } + return arg_mappings.at(arg_name); + } + + std::string GetLegacyAttrName(const std::string& op_type, + const std::string& arg_name) { + if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) { + VLOG(10) << "[" << op_type << "] not found"; + return arg_name; + } + auto& arg_mappings = op_arg_name_mappings[op_type]; + if (arg_mappings.find(arg_name) == arg_mappings.end()) { + VLOG(10) << "[" << op_type << "][" << arg_name << "] not found"; + return arg_name; + } + return arg_mappings.at(arg_name); + } }; } // namespace translator diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 7d917825859b8ead0fd825415be97bab76dc28e7..5dca36db1c8b5910343f8afe3fb228af0b9f913d 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -15,19 +15,23 @@ #include "paddle/fluid/translator/op_translator.h" #include +#include #include #include #include #include #include +#include "paddle/fluid/dialect/pd_interface.h" #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/translator/attribute_translator.h" #include "paddle/fluid/translator/op_compat_info.h" #include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/type_translator.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/operation.h" #include "paddle/ir/core/value.h" #include "paddle/phi/core/enforce.h" @@ -42,11 +46,24 @@ using BlockDesc = paddle::framework::BlockDesc; using VarDesc = paddle::framework::VarDesc; using OpOutputTypeList = std::vector; using OpOutputMapping = std::unordered_map; +using OpInputInfo = paddle::dialect::OpInputInfo; +using OpInputInfoList = std::vector; +using OpAttributeInfo = paddle::dialect::OpAttributeInfo; +using OpAttributeInfoList = std::vector; +using OpOutputInfo = paddle::dialect::OpOutputInfo; +using OpOutputInfoList = std::vector; static const char kTargetDialectPrefix[] = "pd."; +static const std::unordered_set special_inplace_ops = { + "batch_norm", +}; + inline bool IsInplace(const OpDesc& op_desc) { bool inplace = false; + if (special_inplace_ops.count(op_desc.Type())) { + return inplace; + } auto input_names = op_desc.InputArgumentNames(); auto output_names = op_desc.OutputArgumentNames(); @@ -129,7 +146,7 @@ inline ir::Operation* InsertCombineOperationForTarget( std::vector src_values; std::vector types_in_vec; - for (auto arg_name : args) { + for (const auto& arg_name : args) { auto defining_info = param_map->at(arg_name); src_values.push_back(defining_info.value); types_in_vec.push_back(defining_info.value.type()); @@ -141,13 +158,25 @@ inline ir::Operation* InsertCombineOperationForTarget( return operation; } +inline ir::Operation* InsertConstantOperationForOptionalArg( + ir::IrContext* ctx, ir::Program* program) { + std::string constant_op_name(ir::ConstantOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name); + + ir::Type null_type = ir::Type(nullptr); + ir::Operation* operation = + ir::Operation::create({}, {}, {null_type}, op_info); + program->block()->push_back(operation); + return operation; +} + inline std::vector GenerateOperationInput( ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, - const OpDesc& op_desc) { - std::vector op_inputs = {}; - + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos) { // scan all inputs to see if any of them is generated as a vector // so need an additional `SliceOp` to take it out. for (const auto& n : op_desc.Inputs()) { @@ -159,7 +188,7 @@ inline std::vector GenerateOperationInput( param_map->count(arg_name), 0, platform::errors::PreconditionNotMet( - "arg %s.%s as input should be exists before prasing %d", + "arg %s.%s as input should be exists before prasing %s", name, arg_name, op_desc.Type())); @@ -171,73 +200,116 @@ inline std::vector GenerateOperationInput( } } - for (const auto& n : op_desc.Inputs()) { - auto& name = n.first; - VLOG(10) << "[input retriving]" - << "[" << op_desc.Type() << "]" << name; - auto& args = n.second; + std::vector op_inputs; + auto& op_normalizer = OpNameNormalizer::instance(); - // if src type is Tensor or a Vector with size <= 1 - if (args.size() <= 1) { - for (const auto& arg_name : args) { - auto defining_info = (*param_map)[arg_name]; - op_inputs.push_back(defining_info.value); - } + for (const auto& info : input_infos) { + std::string legacy_input_name = + op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); + + // return empty OpResult if this arg is optional and not shown in OpDesc + // TODO(lyk): HasInput doesnot consider variadic attribute + if (!op_desc.HasInput(legacy_input_name)) { + PADDLE_ENFORCE(info.optional, + platform::errors::PreconditionNotMet( + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_input_name)); + op_inputs.push_back(ir::OpResult(nullptr)); + continue; + } + + const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true); + bool is_vector = (info.type_name.find("VectorType") != std::string::npos); + + // if src type is Tensor + if (!is_vector) { + auto defining_info = (*param_map)[legacy_input_vars[0]]; + op_inputs.push_back(defining_info.value); // if src type is Vector , need an additional `CombineOp` to // assemble them. } else { - auto* combine_op = - InsertCombineOperationForTarget(ctx, param_map, program, args); + auto* combine_op = InsertCombineOperationForTarget( + ctx, param_map, program, legacy_input_vars); op_inputs.push_back(combine_op->GetResultByIndex(0)); } } + return op_inputs; } inline std::tuple GenerateOperationOutput( - ir::IrContext* ctx, const OpDesc& op_desc) { + ir::IrContext* ctx, + const OpDesc& op_desc, + const OpOutputInfoList& output_infos) { OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types = {}; auto& type_translator = TypeTranslator::instance(); + auto& op_normalizer = OpNameNormalizer::instance(); const BlockDesc* block = op_desc.Block(); - for (const auto& n : op_desc.Outputs()) { - auto& name = n.first; - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name; - auto& args = n.second; + for (const auto& info : output_infos) { size_t cur_output_idx = op_output_types.size(); + std::string legacy_output_name = + op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); + + // return empty type if this arg is optional and not shown in OpDesc + // TODO(lyk): HasOutput doesnot consider variadic attribute + if (!op_desc.HasOutput(legacy_output_name)) { + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "] optional " << info.name << " :" + << info.type_name << " " << legacy_output_name; + PADDLE_ENFORCE(info.optional, + platform::errors::PreconditionNotMet( + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_output_name)); + op_output_types.push_back(ir::Type(nullptr)); + continue; + } - // if src type is Tensor or a Vector with size <= 1 - if (args.size() <= 1) { - for (const auto& arg_name : args) { - VarDesc* var = block->FindVarRecursive(arg_name); - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name << " " << arg_name - << " " << var->GetType(); + const auto& legacy_output_vars = op_desc.Output(legacy_output_name); + bool is_vector = (info.type_name.find("VectorType") != std::string::npos); + + // if src type is Tensor + if (!is_vector) { + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << info.name << " :" + << info.type_name << " " << legacy_output_name; + if (legacy_output_vars.size() == 0) { + op_output_types.push_back(ir::Type(nullptr)); + continue; + } - ir::Type translated_var_type = - type_translator[var->GetType()](ctx, *var); + auto& var_name = legacy_output_vars[0]; + VarDesc* var = block->FindVarRecursive(var_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << info.name << " " << var_name + << " " << var->GetType(); - arg_to_idx[arg_name] = cur_output_idx; - op_output_types.push_back(translated_var_type); - } + ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); + + arg_to_idx[var_name] = cur_output_idx; + op_output_types.push_back(translated_var_type); // if src type is Vector } else { + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << info.name << " :" + << info.type_name << " " << legacy_output_name; std::vector types; - for (const auto& arg_name : args) { - VarDesc* var = block->FindVarRecursive(arg_name); + for (const auto& var_name : legacy_output_vars) { + VarDesc* var = block->FindVarRecursive(var_name); VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << name << " " << arg_name + << "[" << op_desc.Type() << "]" << info.name << " " << var_name << " " << var->GetType(); ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); types.push_back(translated_var_type); - arg_to_idx[arg_name] = cur_output_idx; + arg_to_idx[var_name] = cur_output_idx; } ir::Type vec_type = ir::VectorType::get(ctx, types); op_output_types.push_back(vec_type); @@ -246,6 +318,38 @@ inline std::tuple GenerateOperationOutput( return {op_output_types, arg_to_idx}; } +inline ir::AttributeMap TranslateOpAttribute( + std::string normalized_op_name, + const OpAttributeInfoList& op_attr_infos, + const OpDesc& op_desc) { + auto& attribute_translator = AttributeTranslator::instance(); + auto& op_normalizer = OpNameNormalizer::instance(); + ir::AttributeMap attribute_map = {}; + + for (const auto& info : op_attr_infos) { + auto legacy_attr_name = + op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name); + + paddle::framework::Attribute legacy_attr; + if (op_desc.HasAttr(legacy_attr_name)) { + legacy_attr = op_desc.GetAttr(legacy_attr_name); + } + VLOG(10) << "attribute in " << op_desc.Type() + << " name: " << legacy_attr_name << " " << legacy_attr.index(); + ir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr); + attribute_map[info.name] = new_attr; + if (!new_attr) { + VLOG(0) << "empty attribute in " << op_desc.Type() + << " name: " << info.name; + } else { + VLOG(10) << "new attribute in " << op_desc.Type() + << " name: " << info.name << " " << new_attr.storage(); + } + } + + return attribute_map; +} + inline void RecordOpResultMapping(TranslationContext* param_map, const OpDesc& op_desc, ir::Operation* operation, @@ -274,15 +378,34 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); + auto op_info = LoopkUpOpInfo(ctx, op_desc); + auto* op_info_concept = + op_info.GetInterfaceImpl(); + + OpInputInfoList input_infos; + OpAttributeInfoList attr_infos; + OpOutputInfoList output_infos; + std::tie(input_infos, attr_infos, output_infos, std::ignore) = + op_info_concept->get_op_info_(); + + auto op_inputs = GenerateOperationInput( + ctx, param_map, program, op_desc, op_info.name(), input_infos); OpOutputMapping arg_to_idx; - OpOutputTypeList op_output_types = {}; - std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); - auto op_info = LoopkUpOpInfo(ctx, op_desc); + OpOutputTypeList op_output_types; + std::tie(op_output_types, arg_to_idx) = + GenerateOperationOutput(ctx, op_desc, output_infos); + + auto attribute_map = + TranslateOpAttribute(op_info.name(), attr_infos, op_desc); + VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end."; + ir::Operation* operation = - ir::Operation::create(op_inputs, {}, op_output_types, op_info); + ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info); + VLOG(4) << "[general op][" << op_desc.Type() << "] opearation creation end."; program->block()->push_back(operation); + + VLOG(4) << "[general op][" << op_desc.Type() << "] opearation insertion end."; RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); return operation; @@ -292,14 +415,28 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - std::vector op_inputs = {}; + auto op_info = LoopkUpOpInfo(ctx, op_desc); + + auto* op_info_concept = + op_info.GetInterfaceImpl(); + OpInputInfoList input_infos; + OpAttributeInfoList attr_infos; + OpOutputInfoList output_infos; + std::tie(input_infos, attr_infos, output_infos, std::ignore) = + op_info_concept->get_op_info_(); + + std::vector op_inputs; OpOutputMapping arg_to_idx; - OpOutputTypeList op_output_types = {}; - std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); - auto op_info = LoopkUpOpInfo(ctx, op_desc); + OpOutputTypeList op_output_types; + std::tie(op_output_types, arg_to_idx) = + GenerateOperationOutput(ctx, op_desc, output_infos); + ir::AttributeMap attribute_map = { + {"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])}, + }; + ir::Operation* operation = - ir::Operation::create(op_inputs, {}, op_output_types, op_info); + ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info); program->block()->push_back(operation); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); @@ -310,12 +447,26 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); - - OpOutputTypeList op_output_types = {}; auto op_info = LoopkUpOpInfo(ctx, op_desc); + + auto* op_info_concept = + op_info.GetInterfaceImpl(); + OpInputInfoList input_infos; + OpAttributeInfoList attr_infos; + OpOutputInfoList output_infos; + std::tie(input_infos, attr_infos, output_infos, std::ignore) = + op_info_concept->get_op_info_(); + + auto op_inputs = GenerateOperationInput( + ctx, param_map, program, op_desc, op_info.name(), input_infos); + + OpOutputTypeList op_output_types; + ir::AttributeMap attribute_map = { + {"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])}, + }; + ir::Operation* operation = - ir::Operation::create(op_inputs, {}, op_output_types, op_info); + ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info); program->block()->push_back(operation); return operation; diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc index 2b98e4e11cf559dece61ee857b92f94c04ebf947..8f0ea2cfb39b522ba388c48351b264089b5cb7d5 100644 --- a/paddle/fluid/translator/program_translator.cc +++ b/paddle/fluid/translator/program_translator.cc @@ -76,7 +76,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock( std::string get_parameter_op_name(ir::GetParameterOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); std::unordered_map op_attribute_map = { - {var->Name(), ir::StrAttribute::get(ctx, var->Name())}, + {"parameter_name", ir::StrAttribute::get(ctx, var->Name())}, }; ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Operation* operation = ir::Operation::create( diff --git a/paddle/fluid/translator/program_translator.h b/paddle/fluid/translator/program_translator.h index 148ac9b0563f19d8821c7f0c03270219fb32726d..3012dfc84c6f3aa21367b31ed8b9adb0a3cc1398 100644 --- a/paddle/fluid/translator/program_translator.h +++ b/paddle/fluid/translator/program_translator.h @@ -39,9 +39,9 @@ struct VariableDefiningInfo { ir::OpResult value; bool generated_by_vector = - false; // true if target variabe is generated by Vector + false; // true if target variable is generated by Vector int idx_in_vector = - -1; // positive if target variabe is generated by Vector + -1; // positive if target variable is generated by Vector }; using TranslationContext = diff --git a/paddle/fluid/translator/utils.h b/paddle/fluid/translator/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..7065f46992c6aabcd77efdeba5f804b45a8e7eef --- /dev/null +++ b/paddle/fluid/translator/utils.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace paddle { +namespace translator { + +static std::string UnderscoreToCamelCase(std::string str) { + std::string camel_case; + bool next_upper = true; + for (char c : str) { + if (c == '_') { + next_upper = true; + } else { + if (next_upper) { + camel_case += toupper(c); + next_upper = false; + } else { + camel_case += c; + } + } + } + return camel_case; +} + +} // namespace translator +} // namespace paddle diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index 92adc02a5e0484dd9e50ade981c3757eec70ccec..b98bedd016fca8f84eb6f93753fd70d48529f9da 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -47,17 +47,17 @@ ProgramDesc load_from_file(const std::string &file_name) { } TEST(PaddleDialectTest, Translator) { - LOG(WARNING) << "TODO"; - // auto p = load_from_file("restnet50_main.prog"); - // EXPECT_EQ(p.Size(), 1u); - - // ir::IrContext *ctx = ir::IrContext::Instance(); - // ctx->GetOrRegisterDialect(); - // ctx->GetOrRegisterDialect(); - // auto program = paddle::TranslateLegacyProgramToProgram(p); - - // size_t op_size = program->block()->size(); - // // ops.size() = op size in BlockDesc + get_parameter_op + combine op - // EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20); - // VLOG(0) << *program; + auto p = load_from_file("restnet50_main.prog"); + EXPECT_EQ(p.Size(), 1u); + + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = paddle::TranslateLegacyProgramToProgram(p); + + size_t op_size = program->block()->size(); + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21); + + std::cout << *program << std::endl; }