未验证 提交 37930a69 编写于 作者: K kangguangli 提交者: GitHub

[IR] Support op attribute and refactor for new op definition (#54068)

* add vector type support for program translator

* polish

* support basic attribute type

* resolve conflicts

* add verify for combine/slice and unittests

* polish

* support more type in attribute translator

* modify by reviews

* fix merge mistakes

* refine code

* refine code

* add interface

* fix: op name normalization

* fix typo

* refactor input translator

* fix merge conflicts

* fix op normalizer bug

* refactor attribute translator

* fix bug

* refactor output translator

* fix typo

* fix

* fix approval error

* fix coverage

* fix op_compat parser

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

---------
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 4f848aa9
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/translator/attribute_translator.h"
#include <string>
#include <vector>
#include "paddle/fluid/dialect/pd_attribute.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/variant.h"
namespace paddle {
namespace translator {
class AttributeVisitor {
public:
ir::IrContext* ctx;
AttributeVisitor() { ctx = ir::IrContext::Instance(); }
~AttributeVisitor() {}
public:
virtual ir::Attribute operator()(int i) {
VLOG(10) << "translating int";
return ir::Int32_tAttribute::get(ctx, i);
}
virtual ir::Attribute operator()(float f) {
VLOG(10) << "translating float";
return ir::FloatAttribute::get(ctx, f);
}
virtual ir::Attribute operator()(bool b) {
VLOG(10) << "translating bool";
return ir::BoolAttribute::get(ctx, b);
}
virtual ir::Attribute operator()(double d) {
VLOG(10) << "translating double";
return ir::DoubleAttribute::get(ctx, d);
}
virtual ir::Attribute operator()(std::string str) {
VLOG(10) << "translating string";
return ir::StrAttribute::get(ctx, str);
}
virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) {
VLOG(10) << "translating scalar";
return paddle::dialect::ScalarAttribute::get(ctx, scalar);
}
virtual ir::Attribute operator()(const std::vector<std::string>& strs) {
VLOG(10) << "translating vector<string>";
std::vector<ir::Attribute> attrs;
attrs.reserve(strs.size());
for (const auto& v : strs) {
attrs.push_back(ir::StrAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const std::vector<float>& fs) {
VLOG(10) << "translating vector<float>";
std::vector<ir::Attribute> attrs;
attrs.reserve(fs.size());
for (const auto& v : fs) {
attrs.push_back(ir::FloatAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const std::vector<int>& is) {
VLOG(10) << "translating vector<int>";
std::vector<ir::Attribute> attrs;
attrs.reserve(is.size());
for (const auto& v : is) {
attrs.push_back(ir::Int32_tAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const std::vector<bool>& bs) {
VLOG(10) << "translating vector<bool>";
std::vector<ir::Attribute> attrs;
attrs.reserve(bs.size());
for (const auto& v : bs) {
attrs.push_back(ir::BoolAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const std::vector<int64_t>& i64s) {
VLOG(10) << "translating vector<int64>";
std::vector<ir::Attribute> attrs;
attrs.reserve(i64s.size());
for (const auto& v : i64s) {
attrs.push_back(ir::Int64_tAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const std::vector<double>& ds) {
VLOG(10) << "translating vector<double>";
std::vector<ir::Attribute> attrs;
attrs.reserve(ds.size());
for (const auto& v : ds) {
attrs.push_back(ir::DoubleAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(
const std::vector<paddle::experimental::Scalar>& ss) {
VLOG(10) << "translating vector<scalar>";
std::vector<ir::Attribute> attrs;
attrs.reserve(ss.size());
for (const auto& v : ss) {
attrs.push_back(paddle::dialect::ScalarAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const paddle::blank& blank) {
VLOG(10) << "translating paddle::blank";
return ir::Attribute(nullptr);
}
template <typename T>
ir::Attribute operator()(T attr) {
VLOG(10) << "translating null type";
return ir::Attribute(nullptr);
}
};
class IntArrayAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(const std::vector<int>& is) override {
VLOG(10) << "translating vector<int> to IntArray";
phi::IntArray data(is);
return paddle::dialect::IntArrayAttribute::get(ctx, data);
}
ir::Attribute operator()(const std::vector<int64_t>& is) override {
VLOG(10) << "translating vector<int> to IntArray";
phi::IntArray data(is);
return paddle::dialect::IntArrayAttribute::get(ctx, data);
}
};
class ScalarAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(int i) override {
VLOG(10) << "translating int to Scalar";
phi::Scalar data(i);
return paddle::dialect::ScalarAttribute::get(ctx, data);
}
ir::Attribute operator()(float f) override {
VLOG(10) << "translating float to Scalar";
phi::Scalar data(f);
return paddle::dialect::ScalarAttribute::get(ctx, data);
}
};
class DataTypeAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(int i) override {
VLOG(10) << "translating int to DataType: " << i;
phi::DataType data = static_cast<phi::DataType>(i);
return paddle::dialect::DataTypeAttribute::get(ctx, data);
}
};
class PlaceAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(const paddle::blank& blank) override {
VLOG(10) << "translating paddle::blank";
phi::Place data(phi::AllocationType::CPU);
return paddle::dialect::PlaceAttribute::get(ctx, data);
}
};
AttributeTranslator::AttributeTranslator() {
general_visitor = new AttributeVisitor();
special_visitors["paddle::dialect::IntArrayAttribute"] =
new IntArrayAttributeVisitor();
special_visitors["paddle::dialect::ScalarAttribute"] =
new ScalarAttributeVisitor();
special_visitors["paddle::dialect::DataTypeAttribute"] =
new DataTypeAttributeVisitor();
special_visitors["paddle::dialect::PlaceAttribute"] =
new PlaceAttributeVisitor();
}
ir::Attribute AttributeTranslator::operator()(
const framework::Attribute& attr) {
return paddle::visit(*general_visitor, attr);
}
ir::Attribute AttributeTranslator::operator()(
const std::string& target_type, const framework::Attribute& attr) {
if (special_visitors.find(target_type) == special_visitors.end()) {
VLOG(10) << "[" << target_type << "] not found";
return paddle::visit(*general_visitor, attr);
}
return paddle::visit(*(special_visitors.at(target_type)), attr);
}
} // namespace translator
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/ir_context.h"
#pragma once
namespace paddle {
namespace translator {
class AttributeVisitor;
class AttributeTranslator {
private:
AttributeTranslator();
AttributeVisitor* general_visitor;
std::unordered_map<std::string, AttributeVisitor*> special_visitors;
public:
AttributeTranslator(const AttributeTranslator&) = delete;
AttributeTranslator& operator=(const AttributeTranslator&) = delete;
AttributeTranslator(AttributeTranslator&&) = delete;
AttributeTranslator& operator=(AttributeTranslator&&) = delete;
static auto& instance() {
static AttributeTranslator attribute_translator;
return attribute_translator;
}
ir::Attribute operator()(const framework::Attribute& attr);
ir::Attribute operator()(const std::string& target_type,
const framework::Attribute& attr);
};
} // namespace translator
} // namespace paddle
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import Dict
import yaml import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined from jinja2 import Environment, FileSystemLoader, StrictUndefined
...@@ -33,7 +34,7 @@ def OpNameNormalizerInitialization( ...@@ -33,7 +34,7 @@ def OpNameNormalizerInitialization(
op_compat_yaml_file: str = "", output_source_file: str = "" op_compat_yaml_file: str = "", output_source_file: str = ""
) -> None: ) -> None:
def to_phi_and_fluid_op_name(op_item): def to_phi_and_fluid_op_name(op_item):
# Templat: - op : phi_name (fluid_name) # Template: - op : phi_name (fluid_name)
names = op_item.split('(') names = op_item.split('(')
if len(names) == 1: if len(names) == 1:
phi_fluid_name = names[0].strip() phi_fluid_name = names[0].strip()
...@@ -46,21 +47,55 @@ def OpNameNormalizerInitialization( ...@@ -46,21 +47,55 @@ def OpNameNormalizerInitialization(
with open(op_compat_yaml_file, "r") as f: with open(op_compat_yaml_file, "r") as f:
op_compat_infos = yaml.safe_load(f) op_compat_infos = yaml.safe_load(f)
op_name_mappings = {} op_name_mappings = {}
op_arg_name_mappings = {}
for op_compat_item in op_compat_infos: for op_compat_item in op_compat_infos:
def insert_new_mappings(op_name_str): def insert_new_mappings(op_name_str: str) -> str:
normalized_name, legacy_name = to_phi_and_fluid_op_name(op_name_str) normalized_name, legacy_name = to_phi_and_fluid_op_name(op_name_str)
if normalized_name == legacy_name: if normalized_name == legacy_name:
return return normalized_name, legacy_name
op_name_mappings[legacy_name] = normalized_name op_name_mappings[legacy_name] = normalized_name
return normalized_name, legacy_name
def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]):
if op_name is None:
return
if op_name not in op_arg_name_mappings:
op_arg_name_mappings[op_name] = {}
op_arg_name_mappings[op_name].update(arg_mapping)
insert_new_mappings(op_compat_item["op"]) _, legacy_name = insert_new_mappings(op_compat_item["op"])
legacy_backward_op_names = []
if "backward" in op_compat_item: if "backward" in op_compat_item:
insert_new_mappings(op_compat_item["backward"]) backward_op_name_mapping_paris = op_compat_item["backward"].split(
","
)
for pair in backward_op_name_mapping_paris:
_, legacy_backward_op_name = insert_new_mappings(pair)
legacy_backward_op_names.append(legacy_backward_op_name)
if "inputs" in op_compat_item:
insert_new_arg_mappings(legacy_name, op_compat_item["inputs"])
for backward_op in legacy_backward_op_names:
insert_new_arg_mappings(backward_op, op_compat_item["inputs"])
if "attrs" in op_compat_item:
insert_new_arg_mappings(legacy_name, op_compat_item["attrs"])
for backward_op in legacy_backward_op_names:
insert_new_arg_mappings(backward_op, op_compat_item["attrs"])
if "outputs" in op_compat_item:
insert_new_arg_mappings(legacy_name, op_compat_item["outputs"])
for backward_op in legacy_backward_op_names:
insert_new_arg_mappings(backward_op, op_compat_item["outputs"])
# special op mappings
op_name_mappings["fetch_v2"] = "fetch"
op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") op_name_normailzer_template = env.get_template("op_compat_info.cc.j2")
with open(output_source_file, 'wt') as f: with open(output_source_file, 'wt') as f:
op_compat_definition = op_name_normailzer_template.render( op_compat_definition = op_name_normailzer_template.render(
op_name_paris=op_name_mappings op_name_pairs=op_name_mappings,
op_arg_name_pairs=op_arg_name_mappings,
) )
f.write(op_compat_definition) f.write(op_compat_definition)
......
...@@ -5,10 +5,22 @@ namespace translator { ...@@ -5,10 +5,22 @@ namespace translator {
OpNameNormalizer::OpNameNormalizer() { OpNameNormalizer::OpNameNormalizer() {
op_name_mappings = { op_name_mappings = {
{% for legacy_name, normalized_name in op_name_paris.items() %} {% for legacy_name, normalized_name in op_name_pairs.items() %}
{ "{{legacy_name}}", "{{normalized_name}}" }, { "{{legacy_name}}", "{{normalized_name}}" },
{% endfor %} {% endfor %}
}; };
op_arg_name_mappings = {
{% for op_name, arg_name_mappings in op_arg_name_pairs.items() %}
{
"{{op_name}}",
{
{% for normalized_name, legacy_name in arg_name_mappings.items() %}
{ "{{normalized_name}}", "{{legacy_name}}" },
{% endfor %}
},
},
{% endfor %}
};
} }
} // namespace translator } // namespace translator
......
...@@ -12,11 +12,14 @@ ...@@ -12,11 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <functional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/translator/utils.h"
#pragma once #pragma once
namespace paddle { namespace paddle {
...@@ -26,6 +29,8 @@ class OpNameNormalizer { ...@@ -26,6 +29,8 @@ class OpNameNormalizer {
private: private:
OpNameNormalizer(); // Disallow instantiation outside of the class. OpNameNormalizer(); // Disallow instantiation outside of the class.
std::unordered_map<std::string, std::string> op_name_mappings; std::unordered_map<std::string, std::string> op_name_mappings;
std::unordered_map<std::string, std::unordered_map<std::string, std::string>>
op_arg_name_mappings;
public: public:
OpNameNormalizer(const OpNameNormalizer&) = delete; OpNameNormalizer(const OpNameNormalizer&) = delete;
...@@ -44,6 +49,49 @@ class OpNameNormalizer { ...@@ -44,6 +49,49 @@ class OpNameNormalizer {
} }
return op_name_mappings.at(op_type); return op_name_mappings.at(op_type);
} }
std::string GetLegacyArgName(const std::string& op_type,
const std::string& arg_name) {
bool is_grad_op = (op_type.find("grad") != std::string::npos);
bool is_grad_arg = (arg_name.find("grad") != std::string::npos);
if (is_grad_op && is_grad_arg) {
std::string target = "_grad";
std::string data = "@GRAD";
size_t first_grad_pos = arg_name.find_first_of(target);
std::string legacy_name =
this->GetLegacyArgName(op_type, arg_name.substr(0, first_grad_pos));
legacy_name += arg_name.substr(first_grad_pos);
for (size_t pos = 0;
legacy_name.npos != (pos = legacy_name.find(target, pos));
pos += data.length()) {
legacy_name.replace(pos, target.length(), data);
}
return legacy_name;
}
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
return UnderscoreToCamelCase(arg_name);
}
auto& arg_mappings = op_arg_name_mappings[op_type];
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
return UnderscoreToCamelCase(arg_name);
}
return arg_mappings.at(arg_name);
}
std::string GetLegacyAttrName(const std::string& op_type,
const std::string& arg_name) {
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
VLOG(10) << "[" << op_type << "] not found";
return arg_name;
}
auto& arg_mappings = op_arg_name_mappings[op_type];
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
VLOG(10) << "[" << op_type << "][" << arg_name << "] not found";
return arg_name;
}
return arg_mappings.at(arg_name);
}
}; };
} // namespace translator } // namespace translator
......
...@@ -15,19 +15,23 @@ ...@@ -15,19 +15,23 @@
#include "paddle/fluid/translator/op_translator.h" #include "paddle/fluid/translator/op_translator.h"
#include <algorithm> #include <algorithm>
#include <cctype>
#include <numeric> #include <numeric>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/dialect/pd_interface.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/translator/attribute_translator.h"
#include "paddle/fluid/translator/op_compat_info.h" #include "paddle/fluid/translator/op_compat_info.h"
#include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/program_translator.h"
#include "paddle/fluid/translator/type_translator.h" #include "paddle/fluid/translator/type_translator.h"
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
...@@ -42,11 +46,24 @@ using BlockDesc = paddle::framework::BlockDesc; ...@@ -42,11 +46,24 @@ using BlockDesc = paddle::framework::BlockDesc;
using VarDesc = paddle::framework::VarDesc; using VarDesc = paddle::framework::VarDesc;
using OpOutputTypeList = std::vector<ir::Type>; using OpOutputTypeList = std::vector<ir::Type>;
using OpOutputMapping = std::unordered_map<std::string, ResultIdx>; using OpOutputMapping = std::unordered_map<std::string, ResultIdx>;
using OpInputInfo = paddle::dialect::OpInputInfo;
using OpInputInfoList = std::vector<paddle::dialect::OpInputInfo>;
using OpAttributeInfo = paddle::dialect::OpAttributeInfo;
using OpAttributeInfoList = std::vector<paddle::dialect::OpAttributeInfo>;
using OpOutputInfo = paddle::dialect::OpOutputInfo;
using OpOutputInfoList = std::vector<paddle::dialect::OpOutputInfo>;
static const char kTargetDialectPrefix[] = "pd."; static const char kTargetDialectPrefix[] = "pd.";
static const std::unordered_set<std::string> special_inplace_ops = {
"batch_norm",
};
inline bool IsInplace(const OpDesc& op_desc) { inline bool IsInplace(const OpDesc& op_desc) {
bool inplace = false; bool inplace = false;
if (special_inplace_ops.count(op_desc.Type())) {
return inplace;
}
auto input_names = op_desc.InputArgumentNames(); auto input_names = op_desc.InputArgumentNames();
auto output_names = op_desc.OutputArgumentNames(); auto output_names = op_desc.OutputArgumentNames();
...@@ -129,7 +146,7 @@ inline ir::Operation* InsertCombineOperationForTarget( ...@@ -129,7 +146,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
std::vector<ir::OpResult> src_values; std::vector<ir::OpResult> src_values;
std::vector<ir::Type> types_in_vec; std::vector<ir::Type> types_in_vec;
for (auto arg_name : args) { for (const auto& arg_name : args) {
auto defining_info = param_map->at(arg_name); auto defining_info = param_map->at(arg_name);
src_values.push_back(defining_info.value); src_values.push_back(defining_info.value);
types_in_vec.push_back(defining_info.value.type()); types_in_vec.push_back(defining_info.value.type());
...@@ -141,13 +158,25 @@ inline ir::Operation* InsertCombineOperationForTarget( ...@@ -141,13 +158,25 @@ inline ir::Operation* InsertCombineOperationForTarget(
return operation; return operation;
} }
inline ir::Operation* InsertConstantOperationForOptionalArg(
ir::IrContext* ctx, ir::Program* program) {
std::string constant_op_name(ir::ConstantOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name);
ir::Type null_type = ir::Type(nullptr);
ir::Operation* operation =
ir::Operation::create({}, {}, {null_type}, op_info);
program->block()->push_back(operation);
return operation;
}
inline std::vector<ir::OpResult> GenerateOperationInput( inline std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx, ir::IrContext* ctx,
TranslationContext* param_map, TranslationContext* param_map,
ir::Program* program, ir::Program* program,
const OpDesc& op_desc) { const OpDesc& op_desc,
std::vector<ir::OpResult> op_inputs = {}; const std::string& normalized_op_name,
const OpInputInfoList& input_infos) {
// scan all inputs to see if any of them is generated as a vector<Tensor> // scan all inputs to see if any of them is generated as a vector<Tensor>
// so need an additional `SliceOp` to take it out. // so need an additional `SliceOp` to take it out.
for (const auto& n : op_desc.Inputs()) { for (const auto& n : op_desc.Inputs()) {
...@@ -159,7 +188,7 @@ inline std::vector<ir::OpResult> GenerateOperationInput( ...@@ -159,7 +188,7 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
param_map->count(arg_name), param_map->count(arg_name),
0, 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"arg %s.%s as input should be exists before prasing %d", "arg %s.%s as input should be exists before prasing %s",
name, name,
arg_name, arg_name,
op_desc.Type())); op_desc.Type()));
...@@ -171,73 +200,116 @@ inline std::vector<ir::OpResult> GenerateOperationInput( ...@@ -171,73 +200,116 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
} }
} }
for (const auto& n : op_desc.Inputs()) { std::vector<ir::OpResult> op_inputs;
auto& name = n.first; auto& op_normalizer = OpNameNormalizer::instance();
VLOG(10) << "[input retriving]"
<< "[" << op_desc.Type() << "]" << name;
auto& args = n.second;
// if src type is Tensor or a Vector<Tensor> with size <= 1 for (const auto& info : input_infos) {
if (args.size() <= 1) { std::string legacy_input_name =
for (const auto& arg_name : args) { op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
auto defining_info = (*param_map)[arg_name];
op_inputs.push_back(defining_info.value); // return empty OpResult if this arg is optional and not shown in OpDesc
} // TODO(lyk): HasInput doesnot consider variadic attribute
if (!op_desc.HasInput(legacy_input_name)) {
PADDLE_ENFORCE(info.optional,
platform::errors::PreconditionNotMet(
"Op %s arg %s should be optional if it can be empty",
op_desc.Type(),
legacy_input_name));
op_inputs.push_back(ir::OpResult(nullptr));
continue;
}
const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true);
bool is_vector = (info.type_name.find("VectorType") != std::string::npos);
// if src type is Tensor
if (!is_vector) {
auto defining_info = (*param_map)[legacy_input_vars[0]];
op_inputs.push_back(defining_info.value);
// if src type is Vector<Tesnor> , need an additional `CombineOp` to // if src type is Vector<Tesnor> , need an additional `CombineOp` to
// assemble them. // assemble them.
} else { } else {
auto* combine_op = auto* combine_op = InsertCombineOperationForTarget(
InsertCombineOperationForTarget(ctx, param_map, program, args); ctx, param_map, program, legacy_input_vars);
op_inputs.push_back(combine_op->GetResultByIndex(0)); op_inputs.push_back(combine_op->GetResultByIndex(0));
} }
} }
return op_inputs; return op_inputs;
} }
inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput( inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
ir::IrContext* ctx, const OpDesc& op_desc) { ir::IrContext* ctx,
const OpDesc& op_desc,
const OpOutputInfoList& output_infos) {
OpOutputMapping arg_to_idx; OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types = {}; OpOutputTypeList op_output_types = {};
auto& type_translator = TypeTranslator::instance(); auto& type_translator = TypeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
const BlockDesc* block = op_desc.Block(); const BlockDesc* block = op_desc.Block();
for (const auto& n : op_desc.Outputs()) {
auto& name = n.first;
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << name;
auto& args = n.second;
for (const auto& info : output_infos) {
size_t cur_output_idx = op_output_types.size(); size_t cur_output_idx = op_output_types.size();
std::string legacy_output_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
// return empty type if this arg is optional and not shown in OpDesc
// TODO(lyk): HasOutput doesnot consider variadic attribute
if (!op_desc.HasOutput(legacy_output_name)) {
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "] optional " << info.name << " :"
<< info.type_name << " " << legacy_output_name;
PADDLE_ENFORCE(info.optional,
platform::errors::PreconditionNotMet(
"Op %s arg %s should be optional if it can be empty",
op_desc.Type(),
legacy_output_name));
op_output_types.push_back(ir::Type(nullptr));
continue;
}
// if src type is Tensor or a Vector<Tensor> with size <= 1 const auto& legacy_output_vars = op_desc.Output(legacy_output_name);
if (args.size() <= 1) { bool is_vector = (info.type_name.find("VectorType") != std::string::npos);
for (const auto& arg_name : args) {
VarDesc* var = block->FindVarRecursive(arg_name); // if src type is Tensor
VLOG(10) << "[output translating]" if (!is_vector) {
<< "[" << op_desc.Type() << "]" << name << " " << arg_name VLOG(10) << "[output translating]"
<< " " << var->GetType(); << "[" << op_desc.Type() << "]" << info.name << " :"
<< info.type_name << " " << legacy_output_name;
if (legacy_output_vars.size() == 0) {
op_output_types.push_back(ir::Type(nullptr));
continue;
}
ir::Type translated_var_type = auto& var_name = legacy_output_vars[0];
type_translator[var->GetType()](ctx, *var); VarDesc* var = block->FindVarRecursive(var_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name << " " << var_name
<< " " << var->GetType();
arg_to_idx[arg_name] = cur_output_idx; ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
op_output_types.push_back(translated_var_type);
} arg_to_idx[var_name] = cur_output_idx;
op_output_types.push_back(translated_var_type);
// if src type is Vector<Tesnor> // if src type is Vector<Tesnor>
} else { } else {
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name << " :"
<< info.type_name << " " << legacy_output_name;
std::vector<ir::Type> types; std::vector<ir::Type> types;
for (const auto& arg_name : args) { for (const auto& var_name : legacy_output_vars) {
VarDesc* var = block->FindVarRecursive(arg_name); VarDesc* var = block->FindVarRecursive(var_name);
VLOG(10) << "[output translating]" VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << name << " " << arg_name << "[" << op_desc.Type() << "]" << info.name << " " << var_name
<< " " << var->GetType(); << " " << var->GetType();
ir::Type translated_var_type = ir::Type translated_var_type =
type_translator[var->GetType()](ctx, *var); type_translator[var->GetType()](ctx, *var);
types.push_back(translated_var_type); types.push_back(translated_var_type);
arg_to_idx[arg_name] = cur_output_idx; arg_to_idx[var_name] = cur_output_idx;
} }
ir::Type vec_type = ir::VectorType::get(ctx, types); ir::Type vec_type = ir::VectorType::get(ctx, types);
op_output_types.push_back(vec_type); op_output_types.push_back(vec_type);
...@@ -246,6 +318,38 @@ inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput( ...@@ -246,6 +318,38 @@ inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
return {op_output_types, arg_to_idx}; return {op_output_types, arg_to_idx};
} }
inline ir::AttributeMap TranslateOpAttribute(
std::string normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc) {
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
ir::AttributeMap attribute_map = {};
for (const auto& info : op_attr_infos) {
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
paddle::framework::Attribute legacy_attr;
if (op_desc.HasAttr(legacy_attr_name)) {
legacy_attr = op_desc.GetAttr(legacy_attr_name);
}
VLOG(10) << "attribute in " << op_desc.Type()
<< " name: " << legacy_attr_name << " " << legacy_attr.index();
ir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr);
attribute_map[info.name] = new_attr;
if (!new_attr) {
VLOG(0) << "empty attribute in " << op_desc.Type()
<< " name: " << info.name;
} else {
VLOG(10) << "new attribute in " << op_desc.Type()
<< " name: " << info.name << " " << new_attr.storage();
}
}
return attribute_map;
}
inline void RecordOpResultMapping(TranslationContext* param_map, inline void RecordOpResultMapping(TranslationContext* param_map,
const OpDesc& op_desc, const OpDesc& op_desc,
ir::Operation* operation, ir::Operation* operation,
...@@ -274,15 +378,34 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, ...@@ -274,15 +378,34 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
TranslationContext* param_map, TranslationContext* param_map,
ir::Program* program, ir::Program* program,
const OpDesc& op_desc) { const OpDesc& op_desc) {
auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>();
OpInputInfoList input_infos;
OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) =
op_info_concept->get_op_info_();
auto op_inputs = GenerateOperationInput(
ctx, param_map, program, op_desc, op_info.name(), input_infos);
OpOutputMapping arg_to_idx; OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types = {}; OpOutputTypeList op_output_types;
std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); std::tie(op_output_types, arg_to_idx) =
auto op_info = LoopkUpOpInfo(ctx, op_desc); GenerateOperationOutput(ctx, op_desc, output_infos);
auto attribute_map =
TranslateOpAttribute(op_info.name(), attr_infos, op_desc);
VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end.";
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info); ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info);
VLOG(4) << "[general op][" << op_desc.Type() << "] opearation creation end.";
program->block()->push_back(operation); program->block()->push_back(operation);
VLOG(4) << "[general op][" << op_desc.Type() << "] opearation insertion end.";
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
return operation; return operation;
...@@ -292,14 +415,28 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, ...@@ -292,14 +415,28 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
TranslationContext* param_map, TranslationContext* param_map,
ir::Program* program, ir::Program* program,
const OpDesc& op_desc) { const OpDesc& op_desc) {
std::vector<ir::OpResult> op_inputs = {}; auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>();
OpInputInfoList input_infos;
OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) =
op_info_concept->get_op_info_();
std::vector<ir::OpResult> op_inputs;
OpOutputMapping arg_to_idx; OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types = {}; OpOutputTypeList op_output_types;
std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); std::tie(op_output_types, arg_to_idx) =
auto op_info = LoopkUpOpInfo(ctx, op_desc); GenerateOperationOutput(ctx, op_desc, output_infos);
ir::AttributeMap attribute_map = {
{"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])},
};
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info); ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info);
program->block()->push_back(operation); program->block()->push_back(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
...@@ -310,12 +447,26 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, ...@@ -310,12 +447,26 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
TranslationContext* param_map, TranslationContext* param_map,
ir::Program* program, ir::Program* program,
const OpDesc& op_desc) { const OpDesc& op_desc) {
auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc);
OpOutputTypeList op_output_types = {};
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>();
OpInputInfoList input_infos;
OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) =
op_info_concept->get_op_info_();
auto op_inputs = GenerateOperationInput(
ctx, param_map, program, op_desc, op_info.name(), input_infos);
OpOutputTypeList op_output_types;
ir::AttributeMap attribute_map = {
{"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])},
};
ir::Operation* operation = ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info); ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info);
program->block()->push_back(operation); program->block()->push_back(operation);
return operation; return operation;
......
...@@ -76,7 +76,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock( ...@@ -76,7 +76,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
std::string get_parameter_op_name(ir::GetParameterOp::name()); std::string get_parameter_op_name(ir::GetParameterOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = { std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{var->Name(), ir::StrAttribute::get(ctx, var->Name())}, {"parameter_name", ir::StrAttribute::get(ctx, var->Name())},
}; };
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::create( ir::Operation* operation = ir::Operation::create(
......
...@@ -39,9 +39,9 @@ struct VariableDefiningInfo { ...@@ -39,9 +39,9 @@ struct VariableDefiningInfo {
ir::OpResult value; ir::OpResult value;
bool generated_by_vector = bool generated_by_vector =
false; // true if target variabe is generated by Vector<Tensor> false; // true if target variable is generated by Vector<Tensor>
int idx_in_vector = int idx_in_vector =
-1; // positive if target variabe is generated by Vector<Tensor> -1; // positive if target variable is generated by Vector<Tensor>
}; };
using TranslationContext = using TranslationContext =
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <string_view>
namespace paddle {
namespace translator {
static std::string UnderscoreToCamelCase(std::string str) {
std::string camel_case;
bool next_upper = true;
for (char c : str) {
if (c == '_') {
next_upper = true;
} else {
if (next_upper) {
camel_case += toupper(c);
next_upper = false;
} else {
camel_case += c;
}
}
}
return camel_case;
}
} // namespace translator
} // namespace paddle
...@@ -47,17 +47,17 @@ ProgramDesc load_from_file(const std::string &file_name) { ...@@ -47,17 +47,17 @@ ProgramDesc load_from_file(const std::string &file_name) {
} }
TEST(PaddleDialectTest, Translator) { TEST(PaddleDialectTest, Translator) {
LOG(WARNING) << "TODO"; auto p = load_from_file("restnet50_main.prog");
// auto p = load_from_file("restnet50_main.prog"); EXPECT_EQ(p.Size(), 1u);
// EXPECT_EQ(p.Size(), 1u);
ir::IrContext *ctx = ir::IrContext::Instance();
// ir::IrContext *ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect<PaddleDialect>();
// ctx->GetOrRegisterDialect<PaddleDialect>(); ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
// ctx->GetOrRegisterDialect<ir::BuiltinDialect>(); auto program = paddle::TranslateLegacyProgramToProgram(p);
// auto program = paddle::TranslateLegacyProgramToProgram(p);
size_t op_size = program->block()->size();
// size_t op_size = program->block()->size(); // ops.size() = op size in BlockDesc + get_parameter_op + combine op
// // ops.size() = op size in BlockDesc + get_parameter_op + combine op EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21);
// EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20);
// VLOG(0) << *program; std::cout << *program << std::endl;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册