From 343a9e9519586ac838f9e4e3b38050d1bd6f6e9e Mon Sep 17 00:00:00 2001 From: kangguangli Date: Sat, 3 Jun 2023 19:54:36 +0800 Subject: [PATCH] Revert "[IR] Support op attribute and refactor for new op definition (#54068)" This reverts commit 37930a69f8f823e31b527c0c155d76a256835644. --- .../fluid/translator/attribute_translator.cc | 231 ---------------- .../fluid/translator/attribute_translator.h | 54 ---- paddle/fluid/translator/op_compat_gen.py | 47 +--- paddle/fluid/translator/op_compat_info.cc.j2 | 14 +- paddle/fluid/translator/op_compat_info.h | 48 ---- paddle/fluid/translator/op_translator.cc | 259 ++++-------------- paddle/fluid/translator/program_translator.cc | 2 +- paddle/fluid/translator/program_translator.h | 4 +- paddle/fluid/translator/utils.h | 42 --- test/cpp/ir/core/program_translator_test.cc | 26 +- 10 files changed, 77 insertions(+), 650 deletions(-) delete mode 100644 paddle/fluid/translator/attribute_translator.cc delete mode 100644 paddle/fluid/translator/attribute_translator.h delete mode 100644 paddle/fluid/translator/utils.h diff --git a/paddle/fluid/translator/attribute_translator.cc b/paddle/fluid/translator/attribute_translator.cc deleted file mode 100644 index 4390ba68616..00000000000 --- a/paddle/fluid/translator/attribute_translator.cc +++ /dev/null @@ -1,231 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/translator/attribute_translator.h" - -#include -#include - -#include "paddle/fluid/dialect/pd_attribute.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/int_array.h" -#include "paddle/phi/common/layout.h" -#include "paddle/phi/common/place.h" -#include "paddle/phi/common/scalar.h" -#include "paddle/utils/variant.h" - -namespace paddle { -namespace translator { - -class AttributeVisitor { - public: - ir::IrContext* ctx; - AttributeVisitor() { ctx = ir::IrContext::Instance(); } - ~AttributeVisitor() {} - - public: - virtual ir::Attribute operator()(int i) { - VLOG(10) << "translating int"; - return ir::Int32_tAttribute::get(ctx, i); - } - - virtual ir::Attribute operator()(float f) { - VLOG(10) << "translating float"; - return ir::FloatAttribute::get(ctx, f); - } - - virtual ir::Attribute operator()(bool b) { - VLOG(10) << "translating bool"; - return ir::BoolAttribute::get(ctx, b); - } - - virtual ir::Attribute operator()(double d) { - VLOG(10) << "translating double"; - return ir::DoubleAttribute::get(ctx, d); - } - - virtual ir::Attribute operator()(std::string str) { - VLOG(10) << "translating string"; - return ir::StrAttribute::get(ctx, str); - } - - virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { - VLOG(10) << "translating scalar"; - return paddle::dialect::ScalarAttribute::get(ctx, scalar); - } - - virtual ir::Attribute operator()(const std::vector& strs) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(strs.size()); - for (const auto& v : strs) { - attrs.push_back(ir::StrAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const std::vector& fs) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(fs.size()); - for (const auto& v : fs) { - attrs.push_back(ir::FloatAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const std::vector& is) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(is.size()); - for (const auto& v : is) { - attrs.push_back(ir::Int32_tAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const std::vector& bs) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(bs.size()); - for (const auto& v : bs) { - attrs.push_back(ir::BoolAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const std::vector& i64s) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(i64s.size()); - for (const auto& v : i64s) { - attrs.push_back(ir::Int64_tAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const std::vector& ds) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(ds.size()); - for (const auto& v : ds) { - attrs.push_back(ir::DoubleAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()( - const std::vector& ss) { - VLOG(10) << "translating vector"; - std::vector attrs; - attrs.reserve(ss.size()); - for (const auto& v : ss) { - attrs.push_back(paddle::dialect::ScalarAttribute::get(ctx, v)); - } - return ir::ArrayAttribute::get(ctx, attrs); - } - - virtual ir::Attribute operator()(const paddle::blank& blank) { - VLOG(10) << "translating paddle::blank"; - return ir::Attribute(nullptr); - } - - template - ir::Attribute operator()(T attr) { - VLOG(10) << "translating null type"; - return ir::Attribute(nullptr); - } -}; - -class IntArrayAttributeVisitor : public AttributeVisitor { - public: - using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(const std::vector& is) override { - VLOG(10) << "translating vector to IntArray"; - phi::IntArray data(is); - return paddle::dialect::IntArrayAttribute::get(ctx, data); - } - - ir::Attribute operator()(const std::vector& is) override { - VLOG(10) << "translating vector to IntArray"; - phi::IntArray data(is); - return paddle::dialect::IntArrayAttribute::get(ctx, data); - } -}; - -class ScalarAttributeVisitor : public AttributeVisitor { - public: - using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(int i) override { - VLOG(10) << "translating int to Scalar"; - phi::Scalar data(i); - return paddle::dialect::ScalarAttribute::get(ctx, data); - } - - ir::Attribute operator()(float f) override { - VLOG(10) << "translating float to Scalar"; - phi::Scalar data(f); - return paddle::dialect::ScalarAttribute::get(ctx, data); - } -}; - -class DataTypeAttributeVisitor : public AttributeVisitor { - public: - using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(int i) override { - VLOG(10) << "translating int to DataType: " << i; - phi::DataType data = static_cast(i); - return paddle::dialect::DataTypeAttribute::get(ctx, data); - } -}; - -class PlaceAttributeVisitor : public AttributeVisitor { - public: - using AttributeVisitor::AttributeVisitor; - - ir::Attribute operator()(const paddle::blank& blank) override { - VLOG(10) << "translating paddle::blank"; - phi::Place data(phi::AllocationType::CPU); - return paddle::dialect::PlaceAttribute::get(ctx, data); - } -}; - -AttributeTranslator::AttributeTranslator() { - general_visitor = new AttributeVisitor(); - special_visitors["paddle::dialect::IntArrayAttribute"] = - new IntArrayAttributeVisitor(); - special_visitors["paddle::dialect::ScalarAttribute"] = - new ScalarAttributeVisitor(); - special_visitors["paddle::dialect::DataTypeAttribute"] = - new DataTypeAttributeVisitor(); - special_visitors["paddle::dialect::PlaceAttribute"] = - new PlaceAttributeVisitor(); -} - -ir::Attribute AttributeTranslator::operator()( - const framework::Attribute& attr) { - return paddle::visit(*general_visitor, attr); -} - -ir::Attribute AttributeTranslator::operator()( - const std::string& target_type, const framework::Attribute& attr) { - if (special_visitors.find(target_type) == special_visitors.end()) { - VLOG(10) << "[" << target_type << "] not found"; - return paddle::visit(*general_visitor, attr); - } - return paddle::visit(*(special_visitors.at(target_type)), attr); -} - -} // namespace translator -} // namespace paddle diff --git a/paddle/fluid/translator/attribute_translator.h b/paddle/fluid/translator/attribute_translator.h deleted file mode 100644 index ea509c7e346..00000000000 --- a/paddle/fluid/translator/attribute_translator.h +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/fluid/framework/attribute.h" -#include "paddle/fluid/framework/type_defs.h" -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/ir_context.h" - -#pragma once - -namespace paddle { -namespace translator { - -class AttributeVisitor; - -class AttributeTranslator { - private: - AttributeTranslator(); - AttributeVisitor* general_visitor; - std::unordered_map special_visitors; - - public: - AttributeTranslator(const AttributeTranslator&) = delete; - AttributeTranslator& operator=(const AttributeTranslator&) = delete; - AttributeTranslator(AttributeTranslator&&) = delete; - AttributeTranslator& operator=(AttributeTranslator&&) = delete; - - static auto& instance() { - static AttributeTranslator attribute_translator; - return attribute_translator; - } - - ir::Attribute operator()(const framework::Attribute& attr); - ir::Attribute operator()(const std::string& target_type, - const framework::Attribute& attr); -}; - -} // namespace translator -} // namespace paddle diff --git a/paddle/fluid/translator/op_compat_gen.py b/paddle/fluid/translator/op_compat_gen.py index 5bc9df7ee8b..d6aeeeaf8ee 100644 --- a/paddle/fluid/translator/op_compat_gen.py +++ b/paddle/fluid/translator/op_compat_gen.py @@ -14,7 +14,6 @@ import argparse from pathlib import Path -from typing import Dict import yaml from jinja2 import Environment, FileSystemLoader, StrictUndefined @@ -34,7 +33,7 @@ def OpNameNormalizerInitialization( op_compat_yaml_file: str = "", output_source_file: str = "" ) -> None: def to_phi_and_fluid_op_name(op_item): - # Template: - op : phi_name (fluid_name) + # Templat: - op : phi_name (fluid_name) names = op_item.split('(') if len(names) == 1: phi_fluid_name = names[0].strip() @@ -47,55 +46,21 @@ def OpNameNormalizerInitialization( with open(op_compat_yaml_file, "r") as f: op_compat_infos = yaml.safe_load(f) op_name_mappings = {} - op_arg_name_mappings = {} for op_compat_item in op_compat_infos: - def insert_new_mappings(op_name_str: str) -> str: + def insert_new_mappings(op_name_str): normalized_name, legacy_name = to_phi_and_fluid_op_name(op_name_str) if normalized_name == legacy_name: - return normalized_name, legacy_name - op_name_mappings[legacy_name] = normalized_name - return normalized_name, legacy_name - - def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): - if op_name is None: return - if op_name not in op_arg_name_mappings: - op_arg_name_mappings[op_name] = {} - op_arg_name_mappings[op_name].update(arg_mapping) + op_name_mappings[legacy_name] = normalized_name - _, legacy_name = insert_new_mappings(op_compat_item["op"]) - legacy_backward_op_names = [] + insert_new_mappings(op_compat_item["op"]) if "backward" in op_compat_item: - backward_op_name_mapping_paris = op_compat_item["backward"].split( - "," - ) - for pair in backward_op_name_mapping_paris: - _, legacy_backward_op_name = insert_new_mappings(pair) - legacy_backward_op_names.append(legacy_backward_op_name) - - if "inputs" in op_compat_item: - insert_new_arg_mappings(legacy_name, op_compat_item["inputs"]) - for backward_op in legacy_backward_op_names: - insert_new_arg_mappings(backward_op, op_compat_item["inputs"]) - - if "attrs" in op_compat_item: - insert_new_arg_mappings(legacy_name, op_compat_item["attrs"]) - for backward_op in legacy_backward_op_names: - insert_new_arg_mappings(backward_op, op_compat_item["attrs"]) - if "outputs" in op_compat_item: - insert_new_arg_mappings(legacy_name, op_compat_item["outputs"]) - for backward_op in legacy_backward_op_names: - insert_new_arg_mappings(backward_op, op_compat_item["outputs"]) - - # special op mappings - op_name_mappings["fetch_v2"] = "fetch" - + insert_new_mappings(op_compat_item["backward"]) op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") with open(output_source_file, 'wt') as f: op_compat_definition = op_name_normailzer_template.render( - op_name_pairs=op_name_mappings, - op_arg_name_pairs=op_arg_name_mappings, + op_name_paris=op_name_mappings ) f.write(op_compat_definition) diff --git a/paddle/fluid/translator/op_compat_info.cc.j2 b/paddle/fluid/translator/op_compat_info.cc.j2 index a44941595fb..af42cf9b8ab 100644 --- a/paddle/fluid/translator/op_compat_info.cc.j2 +++ b/paddle/fluid/translator/op_compat_info.cc.j2 @@ -5,22 +5,10 @@ namespace translator { OpNameNormalizer::OpNameNormalizer() { op_name_mappings = { - {% for legacy_name, normalized_name in op_name_pairs.items() %} + {% for legacy_name, normalized_name in op_name_paris.items() %} { "{{legacy_name}}", "{{normalized_name}}" }, {% endfor %} }; - op_arg_name_mappings = { - {% for op_name, arg_name_mappings in op_arg_name_pairs.items() %} - { - "{{op_name}}", - { - {% for normalized_name, legacy_name in arg_name_mappings.items() %} - { "{{normalized_name}}", "{{legacy_name}}" }, - {% endfor %} - }, - }, - {% endfor %} - }; } } // namespace translator diff --git a/paddle/fluid/translator/op_compat_info.h b/paddle/fluid/translator/op_compat_info.h index 654329dbbe3..86acafe7a0f 100644 --- a/paddle/fluid/translator/op_compat_info.h +++ b/paddle/fluid/translator/op_compat_info.h @@ -12,14 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include #include "glog/logging.h" -#include "paddle/fluid/translator/utils.h" - #pragma once namespace paddle { @@ -29,8 +26,6 @@ class OpNameNormalizer { private: OpNameNormalizer(); // Disallow instantiation outside of the class. std::unordered_map op_name_mappings; - std::unordered_map> - op_arg_name_mappings; public: OpNameNormalizer(const OpNameNormalizer&) = delete; @@ -49,49 +44,6 @@ class OpNameNormalizer { } return op_name_mappings.at(op_type); } - - std::string GetLegacyArgName(const std::string& op_type, - const std::string& arg_name) { - bool is_grad_op = (op_type.find("grad") != std::string::npos); - bool is_grad_arg = (arg_name.find("grad") != std::string::npos); - if (is_grad_op && is_grad_arg) { - std::string target = "_grad"; - std::string data = "@GRAD"; - - size_t first_grad_pos = arg_name.find_first_of(target); - std::string legacy_name = - this->GetLegacyArgName(op_type, arg_name.substr(0, first_grad_pos)); - legacy_name += arg_name.substr(first_grad_pos); - for (size_t pos = 0; - legacy_name.npos != (pos = legacy_name.find(target, pos)); - pos += data.length()) { - legacy_name.replace(pos, target.length(), data); - } - return legacy_name; - } - if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) { - return UnderscoreToCamelCase(arg_name); - } - auto& arg_mappings = op_arg_name_mappings[op_type]; - if (arg_mappings.find(arg_name) == arg_mappings.end()) { - return UnderscoreToCamelCase(arg_name); - } - return arg_mappings.at(arg_name); - } - - std::string GetLegacyAttrName(const std::string& op_type, - const std::string& arg_name) { - if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) { - VLOG(10) << "[" << op_type << "] not found"; - return arg_name; - } - auto& arg_mappings = op_arg_name_mappings[op_type]; - if (arg_mappings.find(arg_name) == arg_mappings.end()) { - VLOG(10) << "[" << op_type << "][" << arg_name << "] not found"; - return arg_name; - } - return arg_mappings.at(arg_name); - } }; } // namespace translator diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 5dca36db1c8..7d917825859 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -15,23 +15,19 @@ #include "paddle/fluid/translator/op_translator.h" #include -#include #include #include #include #include #include -#include "paddle/fluid/dialect/pd_interface.h" #include "paddle/fluid/framework/op_desc.h" -#include "paddle/fluid/translator/attribute_translator.h" #include "paddle/fluid/translator/op_compat_info.h" #include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/type_translator.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/operation.h" #include "paddle/ir/core/value.h" #include "paddle/phi/core/enforce.h" @@ -46,24 +42,11 @@ using BlockDesc = paddle::framework::BlockDesc; using VarDesc = paddle::framework::VarDesc; using OpOutputTypeList = std::vector; using OpOutputMapping = std::unordered_map; -using OpInputInfo = paddle::dialect::OpInputInfo; -using OpInputInfoList = std::vector; -using OpAttributeInfo = paddle::dialect::OpAttributeInfo; -using OpAttributeInfoList = std::vector; -using OpOutputInfo = paddle::dialect::OpOutputInfo; -using OpOutputInfoList = std::vector; static const char kTargetDialectPrefix[] = "pd."; -static const std::unordered_set special_inplace_ops = { - "batch_norm", -}; - inline bool IsInplace(const OpDesc& op_desc) { bool inplace = false; - if (special_inplace_ops.count(op_desc.Type())) { - return inplace; - } auto input_names = op_desc.InputArgumentNames(); auto output_names = op_desc.OutputArgumentNames(); @@ -146,7 +129,7 @@ inline ir::Operation* InsertCombineOperationForTarget( std::vector src_values; std::vector types_in_vec; - for (const auto& arg_name : args) { + for (auto arg_name : args) { auto defining_info = param_map->at(arg_name); src_values.push_back(defining_info.value); types_in_vec.push_back(defining_info.value.type()); @@ -158,25 +141,13 @@ inline ir::Operation* InsertCombineOperationForTarget( return operation; } -inline ir::Operation* InsertConstantOperationForOptionalArg( - ir::IrContext* ctx, ir::Program* program) { - std::string constant_op_name(ir::ConstantOp::name()); - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name); - - ir::Type null_type = ir::Type(nullptr); - ir::Operation* operation = - ir::Operation::create({}, {}, {null_type}, op_info); - program->block()->push_back(operation); - return operation; -} - inline std::vector GenerateOperationInput( ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, - const OpDesc& op_desc, - const std::string& normalized_op_name, - const OpInputInfoList& input_infos) { + const OpDesc& op_desc) { + std::vector op_inputs = {}; + // scan all inputs to see if any of them is generated as a vector // so need an additional `SliceOp` to take it out. for (const auto& n : op_desc.Inputs()) { @@ -188,7 +159,7 @@ inline std::vector GenerateOperationInput( param_map->count(arg_name), 0, platform::errors::PreconditionNotMet( - "arg %s.%s as input should be exists before prasing %s", + "arg %s.%s as input should be exists before prasing %d", name, arg_name, op_desc.Type())); @@ -200,116 +171,73 @@ inline std::vector GenerateOperationInput( } } - std::vector op_inputs; - auto& op_normalizer = OpNameNormalizer::instance(); - - for (const auto& info : input_infos) { - std::string legacy_input_name = - op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); - - // return empty OpResult if this arg is optional and not shown in OpDesc - // TODO(lyk): HasInput doesnot consider variadic attribute - if (!op_desc.HasInput(legacy_input_name)) { - PADDLE_ENFORCE(info.optional, - platform::errors::PreconditionNotMet( - "Op %s arg %s should be optional if it can be empty", - op_desc.Type(), - legacy_input_name)); - op_inputs.push_back(ir::OpResult(nullptr)); - continue; - } - - const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true); - bool is_vector = (info.type_name.find("VectorType") != std::string::npos); + for (const auto& n : op_desc.Inputs()) { + auto& name = n.first; + VLOG(10) << "[input retriving]" + << "[" << op_desc.Type() << "]" << name; + auto& args = n.second; - // if src type is Tensor - if (!is_vector) { - auto defining_info = (*param_map)[legacy_input_vars[0]]; - op_inputs.push_back(defining_info.value); + // if src type is Tensor or a Vector with size <= 1 + if (args.size() <= 1) { + for (const auto& arg_name : args) { + auto defining_info = (*param_map)[arg_name]; + op_inputs.push_back(defining_info.value); + } // if src type is Vector , need an additional `CombineOp` to // assemble them. } else { - auto* combine_op = InsertCombineOperationForTarget( - ctx, param_map, program, legacy_input_vars); + auto* combine_op = + InsertCombineOperationForTarget(ctx, param_map, program, args); op_inputs.push_back(combine_op->GetResultByIndex(0)); } } - return op_inputs; } inline std::tuple GenerateOperationOutput( - ir::IrContext* ctx, - const OpDesc& op_desc, - const OpOutputInfoList& output_infos) { + ir::IrContext* ctx, const OpDesc& op_desc) { OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types = {}; auto& type_translator = TypeTranslator::instance(); - auto& op_normalizer = OpNameNormalizer::instance(); const BlockDesc* block = op_desc.Block(); + for (const auto& n : op_desc.Outputs()) { + auto& name = n.first; + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name; + auto& args = n.second; - for (const auto& info : output_infos) { size_t cur_output_idx = op_output_types.size(); - std::string legacy_output_name = - op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); - - // return empty type if this arg is optional and not shown in OpDesc - // TODO(lyk): HasOutput doesnot consider variadic attribute - if (!op_desc.HasOutput(legacy_output_name)) { - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "] optional " << info.name << " :" - << info.type_name << " " << legacy_output_name; - PADDLE_ENFORCE(info.optional, - platform::errors::PreconditionNotMet( - "Op %s arg %s should be optional if it can be empty", - op_desc.Type(), - legacy_output_name)); - op_output_types.push_back(ir::Type(nullptr)); - continue; - } - - const auto& legacy_output_vars = op_desc.Output(legacy_output_name); - bool is_vector = (info.type_name.find("VectorType") != std::string::npos); - - // if src type is Tensor - if (!is_vector) { - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << info.name << " :" - << info.type_name << " " << legacy_output_name; - if (legacy_output_vars.size() == 0) { - op_output_types.push_back(ir::Type(nullptr)); - continue; - } - auto& var_name = legacy_output_vars[0]; - VarDesc* var = block->FindVarRecursive(var_name); - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << info.name << " " << var_name - << " " << var->GetType(); + // if src type is Tensor or a Vector with size <= 1 + if (args.size() <= 1) { + for (const auto& arg_name : args) { + VarDesc* var = block->FindVarRecursive(arg_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name << " " << arg_name + << " " << var->GetType(); - ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); + ir::Type translated_var_type = + type_translator[var->GetType()](ctx, *var); - arg_to_idx[var_name] = cur_output_idx; - op_output_types.push_back(translated_var_type); + arg_to_idx[arg_name] = cur_output_idx; + op_output_types.push_back(translated_var_type); + } // if src type is Vector } else { - VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << info.name << " :" - << info.type_name << " " << legacy_output_name; std::vector types; - for (const auto& var_name : legacy_output_vars) { - VarDesc* var = block->FindVarRecursive(var_name); + for (const auto& arg_name : args) { + VarDesc* var = block->FindVarRecursive(arg_name); VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << info.name << " " << var_name + << "[" << op_desc.Type() << "]" << name << " " << arg_name << " " << var->GetType(); ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); types.push_back(translated_var_type); - arg_to_idx[var_name] = cur_output_idx; + arg_to_idx[arg_name] = cur_output_idx; } ir::Type vec_type = ir::VectorType::get(ctx, types); op_output_types.push_back(vec_type); @@ -318,38 +246,6 @@ inline std::tuple GenerateOperationOutput( return {op_output_types, arg_to_idx}; } -inline ir::AttributeMap TranslateOpAttribute( - std::string normalized_op_name, - const OpAttributeInfoList& op_attr_infos, - const OpDesc& op_desc) { - auto& attribute_translator = AttributeTranslator::instance(); - auto& op_normalizer = OpNameNormalizer::instance(); - ir::AttributeMap attribute_map = {}; - - for (const auto& info : op_attr_infos) { - auto legacy_attr_name = - op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name); - - paddle::framework::Attribute legacy_attr; - if (op_desc.HasAttr(legacy_attr_name)) { - legacy_attr = op_desc.GetAttr(legacy_attr_name); - } - VLOG(10) << "attribute in " << op_desc.Type() - << " name: " << legacy_attr_name << " " << legacy_attr.index(); - ir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr); - attribute_map[info.name] = new_attr; - if (!new_attr) { - VLOG(0) << "empty attribute in " << op_desc.Type() - << " name: " << info.name; - } else { - VLOG(10) << "new attribute in " << op_desc.Type() - << " name: " << info.name << " " << new_attr.storage(); - } - } - - return attribute_map; -} - inline void RecordOpResultMapping(TranslationContext* param_map, const OpDesc& op_desc, ir::Operation* operation, @@ -378,34 +274,15 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_info = LoopkUpOpInfo(ctx, op_desc); - auto* op_info_concept = - op_info.GetInterfaceImpl(); - - OpInputInfoList input_infos; - OpAttributeInfoList attr_infos; - OpOutputInfoList output_infos; - std::tie(input_infos, attr_infos, output_infos, std::ignore) = - op_info_concept->get_op_info_(); - - auto op_inputs = GenerateOperationInput( - ctx, param_map, program, op_desc, op_info.name(), input_infos); + auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); OpOutputMapping arg_to_idx; - OpOutputTypeList op_output_types; - std::tie(op_output_types, arg_to_idx) = - GenerateOperationOutput(ctx, op_desc, output_infos); - - auto attribute_map = - TranslateOpAttribute(op_info.name(), attr_infos, op_desc); - VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end."; - + OpOutputTypeList op_output_types = {}; + std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); + auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::Operation* operation = - ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info); - VLOG(4) << "[general op][" << op_desc.Type() << "] opearation creation end."; + ir::Operation::create(op_inputs, {}, op_output_types, op_info); program->block()->push_back(operation); - - VLOG(4) << "[general op][" << op_desc.Type() << "] opearation insertion end."; RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); return operation; @@ -415,28 +292,14 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_info = LoopkUpOpInfo(ctx, op_desc); - - auto* op_info_concept = - op_info.GetInterfaceImpl(); - OpInputInfoList input_infos; - OpAttributeInfoList attr_infos; - OpOutputInfoList output_infos; - std::tie(input_infos, attr_infos, output_infos, std::ignore) = - op_info_concept->get_op_info_(); - - std::vector op_inputs; + std::vector op_inputs = {}; OpOutputMapping arg_to_idx; - OpOutputTypeList op_output_types; - std::tie(op_output_types, arg_to_idx) = - GenerateOperationOutput(ctx, op_desc, output_infos); - ir::AttributeMap attribute_map = { - {"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])}, - }; - + OpOutputTypeList op_output_types = {}; + std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); + auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::Operation* operation = - ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info); + ir::Operation::create(op_inputs, {}, op_output_types, op_info); program->block()->push_back(operation); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); @@ -447,26 +310,12 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, TranslationContext* param_map, ir::Program* program, const OpDesc& op_desc) { - auto op_info = LoopkUpOpInfo(ctx, op_desc); - - auto* op_info_concept = - op_info.GetInterfaceImpl(); - OpInputInfoList input_infos; - OpAttributeInfoList attr_infos; - OpOutputInfoList output_infos; - std::tie(input_infos, attr_infos, output_infos, std::ignore) = - op_info_concept->get_op_info_(); - - auto op_inputs = GenerateOperationInput( - ctx, param_map, program, op_desc, op_info.name(), input_infos); - - OpOutputTypeList op_output_types; - ir::AttributeMap attribute_map = { - {"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])}, - }; + auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); + OpOutputTypeList op_output_types = {}; + auto op_info = LoopkUpOpInfo(ctx, op_desc); ir::Operation* operation = - ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info); + ir::Operation::create(op_inputs, {}, op_output_types, op_info); program->block()->push_back(operation); return operation; diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc index 8f0ea2cfb39..2b98e4e11cf 100644 --- a/paddle/fluid/translator/program_translator.cc +++ b/paddle/fluid/translator/program_translator.cc @@ -76,7 +76,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock( std::string get_parameter_op_name(ir::GetParameterOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); std::unordered_map op_attribute_map = { - {"parameter_name", ir::StrAttribute::get(ctx, var->Name())}, + {var->Name(), ir::StrAttribute::get(ctx, var->Name())}, }; ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Operation* operation = ir::Operation::create( diff --git a/paddle/fluid/translator/program_translator.h b/paddle/fluid/translator/program_translator.h index 3012dfc84c6..148ac9b0563 100644 --- a/paddle/fluid/translator/program_translator.h +++ b/paddle/fluid/translator/program_translator.h @@ -39,9 +39,9 @@ struct VariableDefiningInfo { ir::OpResult value; bool generated_by_vector = - false; // true if target variable is generated by Vector + false; // true if target variabe is generated by Vector int idx_in_vector = - -1; // positive if target variable is generated by Vector + -1; // positive if target variabe is generated by Vector }; using TranslationContext = diff --git a/paddle/fluid/translator/utils.h b/paddle/fluid/translator/utils.h deleted file mode 100644 index 7065f46992c..00000000000 --- a/paddle/fluid/translator/utils.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -namespace paddle { -namespace translator { - -static std::string UnderscoreToCamelCase(std::string str) { - std::string camel_case; - bool next_upper = true; - for (char c : str) { - if (c == '_') { - next_upper = true; - } else { - if (next_upper) { - camel_case += toupper(c); - next_upper = false; - } else { - camel_case += c; - } - } - } - return camel_case; -} - -} // namespace translator -} // namespace paddle diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index b98bedd016f..92adc02a5e0 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -47,17 +47,17 @@ ProgramDesc load_from_file(const std::string &file_name) { } TEST(PaddleDialectTest, Translator) { - auto p = load_from_file("restnet50_main.prog"); - EXPECT_EQ(p.Size(), 1u); - - ir::IrContext *ctx = ir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - auto program = paddle::TranslateLegacyProgramToProgram(p); - - size_t op_size = program->block()->size(); - // ops.size() = op size in BlockDesc + get_parameter_op + combine op - EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21); - - std::cout << *program << std::endl; + LOG(WARNING) << "TODO"; + // auto p = load_from_file("restnet50_main.prog"); + // EXPECT_EQ(p.Size(), 1u); + + // ir::IrContext *ctx = ir::IrContext::Instance(); + // ctx->GetOrRegisterDialect(); + // ctx->GetOrRegisterDialect(); + // auto program = paddle::TranslateLegacyProgramToProgram(p); + + // size_t op_size = program->block()->size(); + // // ops.size() = op size in BlockDesc + get_parameter_op + combine op + // EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20); + // VLOG(0) << *program; } -- GitLab