未验证 提交 37930a69 编写于 作者: K kangguangli 提交者: GitHub

[IR] Support op attribute and refactor for new op definition (#54068)

* add vector type support for program translator

* polish

* support basic attribute type

* resolve conflicts

* add verify for combine/slice and unittests

* polish

* support more type in attribute translator

* modify by reviews

* fix merge mistakes

* refine code

* refine code

* add interface

* fix: op name normalization

* fix typo

* refactor input translator

* fix merge conflicts

* fix op normalizer bug

* refactor attribute translator

* fix bug

* refactor output translator

* fix typo

* fix

* fix approval error

* fix coverage

* fix op_compat parser

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

---------
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 4f848aa9
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/translator/attribute_translator.h"
#include <string>
#include <vector>
#include "paddle/fluid/dialect/pd_attribute.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/variant.h"
namespace paddle {
namespace translator {
class AttributeVisitor {
public:
ir::IrContext* ctx;
AttributeVisitor() { ctx = ir::IrContext::Instance(); }
~AttributeVisitor() {}
public:
virtual ir::Attribute operator()(int i) {
VLOG(10) << "translating int";
return ir::Int32_tAttribute::get(ctx, i);
}
virtual ir::Attribute operator()(float f) {
VLOG(10) << "translating float";
return ir::FloatAttribute::get(ctx, f);
}
virtual ir::Attribute operator()(bool b) {
VLOG(10) << "translating bool";
return ir::BoolAttribute::get(ctx, b);
}
virtual ir::Attribute operator()(double d) {
VLOG(10) << "translating double";
return ir::DoubleAttribute::get(ctx, d);
}
virtual ir::Attribute operator()(std::string str) {
VLOG(10) << "translating string";
return ir::StrAttribute::get(ctx, str);
}
virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) {
VLOG(10) << "translating scalar";
return paddle::dialect::ScalarAttribute::get(ctx, scalar);
}
virtual ir::Attribute operator()(const std::vector<std::string>& strs) {
VLOG(10) << "translating vector<string>";
std::vector<ir::Attribute> attrs;
attrs.reserve(strs.size());
for (const auto& v : strs) {
attrs.push_back(ir::StrAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const std::vector<float>& fs) {
VLOG(10) << "translating vector<float>";
std::vector<ir::Attribute> attrs;
attrs.reserve(fs.size());
for (const auto& v : fs) {
attrs.push_back(ir::FloatAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const std::vector<int>& is) {
VLOG(10) << "translating vector<int>";
std::vector<ir::Attribute> attrs;
attrs.reserve(is.size());
for (const auto& v : is) {
attrs.push_back(ir::Int32_tAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const std::vector<bool>& bs) {
VLOG(10) << "translating vector<bool>";
std::vector<ir::Attribute> attrs;
attrs.reserve(bs.size());
for (const auto& v : bs) {
attrs.push_back(ir::BoolAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const std::vector<int64_t>& i64s) {
VLOG(10) << "translating vector<int64>";
std::vector<ir::Attribute> attrs;
attrs.reserve(i64s.size());
for (const auto& v : i64s) {
attrs.push_back(ir::Int64_tAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const std::vector<double>& ds) {
VLOG(10) << "translating vector<double>";
std::vector<ir::Attribute> attrs;
attrs.reserve(ds.size());
for (const auto& v : ds) {
attrs.push_back(ir::DoubleAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(
const std::vector<paddle::experimental::Scalar>& ss) {
VLOG(10) << "translating vector<scalar>";
std::vector<ir::Attribute> attrs;
attrs.reserve(ss.size());
for (const auto& v : ss) {
attrs.push_back(paddle::dialect::ScalarAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
virtual ir::Attribute operator()(const paddle::blank& blank) {
VLOG(10) << "translating paddle::blank";
return ir::Attribute(nullptr);
}
template <typename T>
ir::Attribute operator()(T attr) {
VLOG(10) << "translating null type";
return ir::Attribute(nullptr);
}
};
class IntArrayAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(const std::vector<int>& is) override {
VLOG(10) << "translating vector<int> to IntArray";
phi::IntArray data(is);
return paddle::dialect::IntArrayAttribute::get(ctx, data);
}
ir::Attribute operator()(const std::vector<int64_t>& is) override {
VLOG(10) << "translating vector<int> to IntArray";
phi::IntArray data(is);
return paddle::dialect::IntArrayAttribute::get(ctx, data);
}
};
class ScalarAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(int i) override {
VLOG(10) << "translating int to Scalar";
phi::Scalar data(i);
return paddle::dialect::ScalarAttribute::get(ctx, data);
}
ir::Attribute operator()(float f) override {
VLOG(10) << "translating float to Scalar";
phi::Scalar data(f);
return paddle::dialect::ScalarAttribute::get(ctx, data);
}
};
class DataTypeAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(int i) override {
VLOG(10) << "translating int to DataType: " << i;
phi::DataType data = static_cast<phi::DataType>(i);
return paddle::dialect::DataTypeAttribute::get(ctx, data);
}
};
class PlaceAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(const paddle::blank& blank) override {
VLOG(10) << "translating paddle::blank";
phi::Place data(phi::AllocationType::CPU);
return paddle::dialect::PlaceAttribute::get(ctx, data);
}
};
AttributeTranslator::AttributeTranslator() {
general_visitor = new AttributeVisitor();
special_visitors["paddle::dialect::IntArrayAttribute"] =
new IntArrayAttributeVisitor();
special_visitors["paddle::dialect::ScalarAttribute"] =
new ScalarAttributeVisitor();
special_visitors["paddle::dialect::DataTypeAttribute"] =
new DataTypeAttributeVisitor();
special_visitors["paddle::dialect::PlaceAttribute"] =
new PlaceAttributeVisitor();
}
ir::Attribute AttributeTranslator::operator()(
const framework::Attribute& attr) {
return paddle::visit(*general_visitor, attr);
}
ir::Attribute AttributeTranslator::operator()(
const std::string& target_type, const framework::Attribute& attr) {
if (special_visitors.find(target_type) == special_visitors.end()) {
VLOG(10) << "[" << target_type << "] not found";
return paddle::visit(*general_visitor, attr);
}
return paddle::visit(*(special_visitors.at(target_type)), attr);
}
} // namespace translator
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/ir_context.h"
#pragma once
namespace paddle {
namespace translator {
class AttributeVisitor;
class AttributeTranslator {
private:
AttributeTranslator();
AttributeVisitor* general_visitor;
std::unordered_map<std::string, AttributeVisitor*> special_visitors;
public:
AttributeTranslator(const AttributeTranslator&) = delete;
AttributeTranslator& operator=(const AttributeTranslator&) = delete;
AttributeTranslator(AttributeTranslator&&) = delete;
AttributeTranslator& operator=(AttributeTranslator&&) = delete;
static auto& instance() {
static AttributeTranslator attribute_translator;
return attribute_translator;
}
ir::Attribute operator()(const framework::Attribute& attr);
ir::Attribute operator()(const std::string& target_type,
const framework::Attribute& attr);
};
} // namespace translator
} // namespace paddle
......@@ -14,6 +14,7 @@
import argparse
from pathlib import Path
from typing import Dict
import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined
......@@ -33,7 +34,7 @@ def OpNameNormalizerInitialization(
op_compat_yaml_file: str = "", output_source_file: str = ""
) -> None:
def to_phi_and_fluid_op_name(op_item):
# Templat: - op : phi_name (fluid_name)
# Template: - op : phi_name (fluid_name)
names = op_item.split('(')
if len(names) == 1:
phi_fluid_name = names[0].strip()
......@@ -46,21 +47,55 @@ def OpNameNormalizerInitialization(
with open(op_compat_yaml_file, "r") as f:
op_compat_infos = yaml.safe_load(f)
op_name_mappings = {}
op_arg_name_mappings = {}
for op_compat_item in op_compat_infos:
def insert_new_mappings(op_name_str):
def insert_new_mappings(op_name_str: str) -> str:
normalized_name, legacy_name = to_phi_and_fluid_op_name(op_name_str)
if normalized_name == legacy_name:
return
return normalized_name, legacy_name
op_name_mappings[legacy_name] = normalized_name
return normalized_name, legacy_name
def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]):
if op_name is None:
return
if op_name not in op_arg_name_mappings:
op_arg_name_mappings[op_name] = {}
op_arg_name_mappings[op_name].update(arg_mapping)
insert_new_mappings(op_compat_item["op"])
_, legacy_name = insert_new_mappings(op_compat_item["op"])
legacy_backward_op_names = []
if "backward" in op_compat_item:
insert_new_mappings(op_compat_item["backward"])
backward_op_name_mapping_paris = op_compat_item["backward"].split(
","
)
for pair in backward_op_name_mapping_paris:
_, legacy_backward_op_name = insert_new_mappings(pair)
legacy_backward_op_names.append(legacy_backward_op_name)
if "inputs" in op_compat_item:
insert_new_arg_mappings(legacy_name, op_compat_item["inputs"])
for backward_op in legacy_backward_op_names:
insert_new_arg_mappings(backward_op, op_compat_item["inputs"])
if "attrs" in op_compat_item:
insert_new_arg_mappings(legacy_name, op_compat_item["attrs"])
for backward_op in legacy_backward_op_names:
insert_new_arg_mappings(backward_op, op_compat_item["attrs"])
if "outputs" in op_compat_item:
insert_new_arg_mappings(legacy_name, op_compat_item["outputs"])
for backward_op in legacy_backward_op_names:
insert_new_arg_mappings(backward_op, op_compat_item["outputs"])
# special op mappings
op_name_mappings["fetch_v2"] = "fetch"
op_name_normailzer_template = env.get_template("op_compat_info.cc.j2")
with open(output_source_file, 'wt') as f:
op_compat_definition = op_name_normailzer_template.render(
op_name_paris=op_name_mappings
op_name_pairs=op_name_mappings,
op_arg_name_pairs=op_arg_name_mappings,
)
f.write(op_compat_definition)
......
......@@ -5,10 +5,22 @@ namespace translator {
OpNameNormalizer::OpNameNormalizer() {
op_name_mappings = {
{% for legacy_name, normalized_name in op_name_paris.items() %}
{% for legacy_name, normalized_name in op_name_pairs.items() %}
{ "{{legacy_name}}", "{{normalized_name}}" },
{% endfor %}
};
op_arg_name_mappings = {
{% for op_name, arg_name_mappings in op_arg_name_pairs.items() %}
{
"{{op_name}}",
{
{% for normalized_name, legacy_name in arg_name_mappings.items() %}
{ "{{normalized_name}}", "{{legacy_name}}" },
{% endfor %}
},
},
{% endfor %}
};
}
} // namespace translator
......
......@@ -12,11 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/fluid/translator/utils.h"
#pragma once
namespace paddle {
......@@ -26,6 +29,8 @@ class OpNameNormalizer {
private:
OpNameNormalizer(); // Disallow instantiation outside of the class.
std::unordered_map<std::string, std::string> op_name_mappings;
std::unordered_map<std::string, std::unordered_map<std::string, std::string>>
op_arg_name_mappings;
public:
OpNameNormalizer(const OpNameNormalizer&) = delete;
......@@ -44,6 +49,49 @@ class OpNameNormalizer {
}
return op_name_mappings.at(op_type);
}
std::string GetLegacyArgName(const std::string& op_type,
const std::string& arg_name) {
bool is_grad_op = (op_type.find("grad") != std::string::npos);
bool is_grad_arg = (arg_name.find("grad") != std::string::npos);
if (is_grad_op && is_grad_arg) {
std::string target = "_grad";
std::string data = "@GRAD";
size_t first_grad_pos = arg_name.find_first_of(target);
std::string legacy_name =
this->GetLegacyArgName(op_type, arg_name.substr(0, first_grad_pos));
legacy_name += arg_name.substr(first_grad_pos);
for (size_t pos = 0;
legacy_name.npos != (pos = legacy_name.find(target, pos));
pos += data.length()) {
legacy_name.replace(pos, target.length(), data);
}
return legacy_name;
}
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
return UnderscoreToCamelCase(arg_name);
}
auto& arg_mappings = op_arg_name_mappings[op_type];
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
return UnderscoreToCamelCase(arg_name);
}
return arg_mappings.at(arg_name);
}
std::string GetLegacyAttrName(const std::string& op_type,
const std::string& arg_name) {
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
VLOG(10) << "[" << op_type << "] not found";
return arg_name;
}
auto& arg_mappings = op_arg_name_mappings[op_type];
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
VLOG(10) << "[" << op_type << "][" << arg_name << "] not found";
return arg_name;
}
return arg_mappings.at(arg_name);
}
};
} // namespace translator
......
......@@ -15,19 +15,23 @@
#include "paddle/fluid/translator/op_translator.h"
#include <algorithm>
#include <cctype>
#include <numeric>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/dialect/pd_interface.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/translator/attribute_translator.h"
#include "paddle/fluid/translator/op_compat_info.h"
#include "paddle/fluid/translator/program_translator.h"
#include "paddle/fluid/translator/type_translator.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/enforce.h"
......@@ -42,11 +46,24 @@ using BlockDesc = paddle::framework::BlockDesc;
using VarDesc = paddle::framework::VarDesc;
using OpOutputTypeList = std::vector<ir::Type>;
using OpOutputMapping = std::unordered_map<std::string, ResultIdx>;
using OpInputInfo = paddle::dialect::OpInputInfo;
using OpInputInfoList = std::vector<paddle::dialect::OpInputInfo>;
using OpAttributeInfo = paddle::dialect::OpAttributeInfo;
using OpAttributeInfoList = std::vector<paddle::dialect::OpAttributeInfo>;
using OpOutputInfo = paddle::dialect::OpOutputInfo;
using OpOutputInfoList = std::vector<paddle::dialect::OpOutputInfo>;
static const char kTargetDialectPrefix[] = "pd.";
static const std::unordered_set<std::string> special_inplace_ops = {
"batch_norm",
};
inline bool IsInplace(const OpDesc& op_desc) {
bool inplace = false;
if (special_inplace_ops.count(op_desc.Type())) {
return inplace;
}
auto input_names = op_desc.InputArgumentNames();
auto output_names = op_desc.OutputArgumentNames();
......@@ -129,7 +146,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
std::vector<ir::OpResult> src_values;
std::vector<ir::Type> types_in_vec;
for (auto arg_name : args) {
for (const auto& arg_name : args) {
auto defining_info = param_map->at(arg_name);
src_values.push_back(defining_info.value);
types_in_vec.push_back(defining_info.value.type());
......@@ -141,13 +158,25 @@ inline ir::Operation* InsertCombineOperationForTarget(
return operation;
}
inline ir::Operation* InsertConstantOperationForOptionalArg(
ir::IrContext* ctx, ir::Program* program) {
std::string constant_op_name(ir::ConstantOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name);
ir::Type null_type = ir::Type(nullptr);
ir::Operation* operation =
ir::Operation::create({}, {}, {null_type}, op_info);
program->block()->push_back(operation);
return operation;
}
inline std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx,
TranslationContext* param_map,
ir::Program* program,
const OpDesc& op_desc) {
std::vector<ir::OpResult> op_inputs = {};
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos) {
// scan all inputs to see if any of them is generated as a vector<Tensor>
// so need an additional `SliceOp` to take it out.
for (const auto& n : op_desc.Inputs()) {
......@@ -159,7 +188,7 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
param_map->count(arg_name),
0,
platform::errors::PreconditionNotMet(
"arg %s.%s as input should be exists before prasing %d",
"arg %s.%s as input should be exists before prasing %s",
name,
arg_name,
op_desc.Type()));
......@@ -171,73 +200,116 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
}
}
for (const auto& n : op_desc.Inputs()) {
auto& name = n.first;
VLOG(10) << "[input retriving]"
<< "[" << op_desc.Type() << "]" << name;
auto& args = n.second;
std::vector<ir::OpResult> op_inputs;
auto& op_normalizer = OpNameNormalizer::instance();
// if src type is Tensor or a Vector<Tensor> with size <= 1
if (args.size() <= 1) {
for (const auto& arg_name : args) {
auto defining_info = (*param_map)[arg_name];
op_inputs.push_back(defining_info.value);
}
for (const auto& info : input_infos) {
std::string legacy_input_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
// return empty OpResult if this arg is optional and not shown in OpDesc
// TODO(lyk): HasInput doesnot consider variadic attribute
if (!op_desc.HasInput(legacy_input_name)) {
PADDLE_ENFORCE(info.optional,
platform::errors::PreconditionNotMet(
"Op %s arg %s should be optional if it can be empty",
op_desc.Type(),
legacy_input_name));
op_inputs.push_back(ir::OpResult(nullptr));
continue;
}
const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true);
bool is_vector = (info.type_name.find("VectorType") != std::string::npos);
// if src type is Tensor
if (!is_vector) {
auto defining_info = (*param_map)[legacy_input_vars[0]];
op_inputs.push_back(defining_info.value);
// if src type is Vector<Tesnor> , need an additional `CombineOp` to
// assemble them.
} else {
auto* combine_op =
InsertCombineOperationForTarget(ctx, param_map, program, args);
auto* combine_op = InsertCombineOperationForTarget(
ctx, param_map, program, legacy_input_vars);
op_inputs.push_back(combine_op->GetResultByIndex(0));
}
}
return op_inputs;
}
inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
ir::IrContext* ctx, const OpDesc& op_desc) {
ir::IrContext* ctx,
const OpDesc& op_desc,
const OpOutputInfoList& output_infos) {
OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types = {};
auto& type_translator = TypeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
const BlockDesc* block = op_desc.Block();
for (const auto& n : op_desc.Outputs()) {
auto& name = n.first;
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << name;
auto& args = n.second;
for (const auto& info : output_infos) {
size_t cur_output_idx = op_output_types.size();
std::string legacy_output_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
// return empty type if this arg is optional and not shown in OpDesc
// TODO(lyk): HasOutput doesnot consider variadic attribute
if (!op_desc.HasOutput(legacy_output_name)) {
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "] optional " << info.name << " :"
<< info.type_name << " " << legacy_output_name;
PADDLE_ENFORCE(info.optional,
platform::errors::PreconditionNotMet(
"Op %s arg %s should be optional if it can be empty",
op_desc.Type(),
legacy_output_name));
op_output_types.push_back(ir::Type(nullptr));
continue;
}
// if src type is Tensor or a Vector<Tensor> with size <= 1
if (args.size() <= 1) {
for (const auto& arg_name : args) {
VarDesc* var = block->FindVarRecursive(arg_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << name << " " << arg_name
<< " " << var->GetType();
const auto& legacy_output_vars = op_desc.Output(legacy_output_name);
bool is_vector = (info.type_name.find("VectorType") != std::string::npos);
// if src type is Tensor
if (!is_vector) {
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name << " :"
<< info.type_name << " " << legacy_output_name;
if (legacy_output_vars.size() == 0) {
op_output_types.push_back(ir::Type(nullptr));
continue;
}
ir::Type translated_var_type =
type_translator[var->GetType()](ctx, *var);
auto& var_name = legacy_output_vars[0];
VarDesc* var = block->FindVarRecursive(var_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name << " " << var_name
<< " " << var->GetType();
arg_to_idx[arg_name] = cur_output_idx;
op_output_types.push_back(translated_var_type);
}
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
arg_to_idx[var_name] = cur_output_idx;
op_output_types.push_back(translated_var_type);
// if src type is Vector<Tesnor>
} else {
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name << " :"
<< info.type_name << " " << legacy_output_name;
std::vector<ir::Type> types;
for (const auto& arg_name : args) {
VarDesc* var = block->FindVarRecursive(arg_name);
for (const auto& var_name : legacy_output_vars) {
VarDesc* var = block->FindVarRecursive(var_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << name << " " << arg_name
<< "[" << op_desc.Type() << "]" << info.name << " " << var_name
<< " " << var->GetType();
ir::Type translated_var_type =
type_translator[var->GetType()](ctx, *var);
types.push_back(translated_var_type);
arg_to_idx[arg_name] = cur_output_idx;
arg_to_idx[var_name] = cur_output_idx;
}
ir::Type vec_type = ir::VectorType::get(ctx, types);
op_output_types.push_back(vec_type);
......@@ -246,6 +318,38 @@ inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
return {op_output_types, arg_to_idx};
}
inline ir::AttributeMap TranslateOpAttribute(
std::string normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc) {
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
ir::AttributeMap attribute_map = {};
for (const auto& info : op_attr_infos) {
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
paddle::framework::Attribute legacy_attr;
if (op_desc.HasAttr(legacy_attr_name)) {
legacy_attr = op_desc.GetAttr(legacy_attr_name);
}
VLOG(10) << "attribute in " << op_desc.Type()
<< " name: " << legacy_attr_name << " " << legacy_attr.index();
ir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr);
attribute_map[info.name] = new_attr;
if (!new_attr) {
VLOG(0) << "empty attribute in " << op_desc.Type()
<< " name: " << info.name;
} else {
VLOG(10) << "new attribute in " << op_desc.Type()
<< " name: " << info.name << " " << new_attr.storage();
}
}
return attribute_map;
}
inline void RecordOpResultMapping(TranslationContext* param_map,
const OpDesc& op_desc,
ir::Operation* operation,
......@@ -274,15 +378,34 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
TranslationContext* param_map,
ir::Program* program,
const OpDesc& op_desc) {
auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc);
auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>();
OpInputInfoList input_infos;
OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) =
op_info_concept->get_op_info_();
auto op_inputs = GenerateOperationInput(
ctx, param_map, program, op_desc, op_info.name(), input_infos);
OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types = {};
std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc);
auto op_info = LoopkUpOpInfo(ctx, op_desc);
OpOutputTypeList op_output_types;
std::tie(op_output_types, arg_to_idx) =
GenerateOperationOutput(ctx, op_desc, output_infos);
auto attribute_map =
TranslateOpAttribute(op_info.name(), attr_infos, op_desc);
VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end.";
ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info);
ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info);
VLOG(4) << "[general op][" << op_desc.Type() << "] opearation creation end.";
program->block()->push_back(operation);
VLOG(4) << "[general op][" << op_desc.Type() << "] opearation insertion end.";
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
return operation;
......@@ -292,14 +415,28 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
TranslationContext* param_map,
ir::Program* program,
const OpDesc& op_desc) {
std::vector<ir::OpResult> op_inputs = {};
auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>();
OpInputInfoList input_infos;
OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) =
op_info_concept->get_op_info_();
std::vector<ir::OpResult> op_inputs;
OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types = {};
std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc);
auto op_info = LoopkUpOpInfo(ctx, op_desc);
OpOutputTypeList op_output_types;
std::tie(op_output_types, arg_to_idx) =
GenerateOperationOutput(ctx, op_desc, output_infos);
ir::AttributeMap attribute_map = {
{"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])},
};
ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info);
ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info);
program->block()->push_back(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
......@@ -310,12 +447,26 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
TranslationContext* param_map,
ir::Program* program,
const OpDesc& op_desc) {
auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc);
OpOutputTypeList op_output_types = {};
auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>();
OpInputInfoList input_infos;
OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos;
std::tie(input_infos, attr_infos, output_infos, std::ignore) =
op_info_concept->get_op_info_();
auto op_inputs = GenerateOperationInput(
ctx, param_map, program, op_desc, op_info.name(), input_infos);
OpOutputTypeList op_output_types;
ir::AttributeMap attribute_map = {
{"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])},
};
ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info);
ir::Operation::create(op_inputs, attribute_map, op_output_types, op_info);
program->block()->push_back(operation);
return operation;
......
......@@ -76,7 +76,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
std::string get_parameter_op_name(ir::GetParameterOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{var->Name(), ir::StrAttribute::get(ctx, var->Name())},
{"parameter_name", ir::StrAttribute::get(ctx, var->Name())},
};
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::create(
......
......@@ -39,9 +39,9 @@ struct VariableDefiningInfo {
ir::OpResult value;
bool generated_by_vector =
false; // true if target variabe is generated by Vector<Tensor>
false; // true if target variable is generated by Vector<Tensor>
int idx_in_vector =
-1; // positive if target variabe is generated by Vector<Tensor>
-1; // positive if target variable is generated by Vector<Tensor>
};
using TranslationContext =
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <string_view>
namespace paddle {
namespace translator {
static std::string UnderscoreToCamelCase(std::string str) {
std::string camel_case;
bool next_upper = true;
for (char c : str) {
if (c == '_') {
next_upper = true;
} else {
if (next_upper) {
camel_case += toupper(c);
next_upper = false;
} else {
camel_case += c;
}
}
}
return camel_case;
}
} // namespace translator
} // namespace paddle
......@@ -47,17 +47,17 @@ ProgramDesc load_from_file(const std::string &file_name) {
}
TEST(PaddleDialectTest, Translator) {
LOG(WARNING) << "TODO";
// auto p = load_from_file("restnet50_main.prog");
// EXPECT_EQ(p.Size(), 1u);
// ir::IrContext *ctx = ir::IrContext::Instance();
// ctx->GetOrRegisterDialect<PaddleDialect>();
// ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
// auto program = paddle::TranslateLegacyProgramToProgram(p);
// size_t op_size = program->block()->size();
// // ops.size() = op size in BlockDesc + get_parameter_op + combine op
// EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20);
// VLOG(0) << *program;
auto p = load_from_file("restnet50_main.prog");
EXPECT_EQ(p.Size(), 1u);
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p);
size_t op_size = program->block()->size();
// ops.size() = op size in BlockDesc + get_parameter_op + combine op
EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21);
std::cout << *program << std::endl;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册