未验证 提交 f18d538b 编写于 作者: H hong 提交者: GitHub

Refactor op info parser (#54859)

* add kernel dialect

* change DenseTensorTypeStorage to DenseTensorType

* add test case`

* add first pd_op to kernel dialect

* lower pd op to kernel dialect

* update

* update

* remove useless code

* add attrite print test

* fix bug

* update

* update

* update

* update

* polish code

* fix bug

* polish  code  and add python test

* add test

* fix test error

* add env flag

* fix bug

* revert test env

* change cc_test_old to cc_test

* fix build_static bug

* fix type test error

* udpate cmake

* disable test in windows

* update

* update

* fix bug

* split file

* fix conflict

* polish code and fix conflict

* support place transformer

* finish bug

* add gpu flags

* fix with cuda macro

* add fetch kernel

* support fetch var in new ir

* fix bug

* polish code

* change array equal to np.testing

* support feed in new ir

* update

* fix bug

* try to hack combine op

* add scope guard

* revert atan2 op

* add scope guard

* update

* polish code

* update

* refactor build kernel context

* fix unitest bug

* polish code

* use original order

* remove useless code

* polish code

* fix bug
上级 b94b3ac0
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/fluid/framework/new_executor/interpreter/static_build.h" #include "paddle/fluid/framework/new_executor/interpreter/static_build.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/op_yaml_info_parser.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" #include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h"
#include "paddle/fluid/memory/stats.h" #include "paddle/fluid/memory/stats.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
...@@ -951,31 +952,27 @@ void BuildOpFuncList( ...@@ -951,31 +952,27 @@ void BuildOpFuncList(
auto attr_map = (*it)->attributes(); auto attr_map = (*it)->attributes();
auto op_name = attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().data(); auto op_name = attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().data();
op_func_node.phi_op_name_ = op_name;
if (op_name == "builtin.combine" || op_name == "pd.feed") { if (op_name == "builtin.combine" || op_name == "pd.feed") {
VLOG(6) << "skip process " << op_name; VLOG(6) << "skip process " << op_name;
continue; continue;
} }
op_func_node.phi_op_name_ = op_name;
::ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); ::ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
auto impl = auto impl =
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>(); op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
auto yaml_info = impl->get_op_info_();
auto attr_info = std::get<1>(yaml_info);
op_func_node.infer_meta_interface_ = op_func_node.infer_meta_interface_ =
op_info.GetInterfaceImpl<paddle::dialect::InferMetaInterface>(); op_info.GetInterfaceImpl<paddle::dialect::InferMetaInterface>();
VLOG(6) << "op name" << op_func_node.phi_op_name_; VLOG(6) << "op name" << op_func_node.phi_op_name_;
dialect::OpYamlInfoParser op_yaml_info_parser(impl->get_op_info_());
::ir::BuildInferMetaContext((*it), ::ir::BuildInferMetaContext((*it),
value_2_name_map, value_2_name_map,
scope, scope,
yaml_info, op_yaml_info_parser,
&(op_func_node.infer_meta_context_)); &(op_func_node.infer_meta_context_));
auto kernel_name = auto kernel_name =
...@@ -996,7 +993,7 @@ void BuildOpFuncList( ...@@ -996,7 +993,7 @@ void BuildOpFuncList(
::ir::BuildPhiKernelContext((*it), ::ir::BuildPhiKernelContext((*it),
value_2_name_map, value_2_name_map,
scope, scope,
yaml_info, op_yaml_info_parser,
&(op_func_node.kernel_context_), &(op_func_node.kernel_context_),
&(op_func_node.input_index), &(op_func_node.input_index),
&(op_func_node.output_index)); &(op_func_node.output_index));
......
...@@ -116,7 +116,9 @@ void NewIRInterpreter::RunImpl() { ...@@ -116,7 +116,9 @@ void NewIRInterpreter::RunImpl() {
// && // &&
// (sync_op_num_ == 0)) { // (sync_op_num_ == 0)) {
VLOG(4) << "Tracing Instruction List"; VLOG(4) << "Tracing Instruction List";
TraceInstructionList(vec_instruction_); TraceInstructionList(vec_instruction_);
// } else { // } else {
// VLOG(4) << "Non-tracing"; // VLOG(4) << "Non-tracing";
// // For the program that only run once, it is no need to // // For the program that only run once, it is no need to
...@@ -938,15 +940,6 @@ void NewIRInterpreter::RunOperator(const Instruction& instr_node) { ...@@ -938,15 +940,6 @@ void NewIRInterpreter::RunOperator(const Instruction& instr_node) {
} }
void NewIRInterpreter::RunInstruction(const Instruction& instr_node) { void NewIRInterpreter::RunInstruction(const Instruction& instr_node) {
VLOG(5) << __func__ << " OP id:" << instr_node.Id()
<< " name:" << instr_node.OpBase()->Type() << " type:"
<< (instr_node.KernelType() == OpFuncType::kCpuSync
? "kCpuSync"
: (instr_node.KernelType() == OpFuncType::kGpuSync
? "kGpuSync"
: "kGpuAsync"))
<< " runs on " << platform::GetCurrentThreadName();
OperatorBase* op = nullptr; OperatorBase* op = nullptr;
if (instr_node.OpBaseValid()) { if (instr_node.OpBaseValid()) {
op = instr_node.OpBase(); op = instr_node.OpBase();
...@@ -1377,8 +1370,9 @@ void NewIRInterpreter::TraceInstructionList( ...@@ -1377,8 +1370,9 @@ void NewIRInterpreter::TraceInstructionList(
} }
} }
for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) { // TODO(phlrain) use orignal order for now, use better dependecy
auto instr_id = trace_execute_order_[idx]; for (size_t instr_id = 0; instr_id < vec_instruction_.size(); ++instr_id) {
/// auto instr_id = trace_execute_order_[idx];
auto& instr_node = vec_instruction_.at(instr_id); auto& instr_node = vec_instruction_.at(instr_id);
RunInstruction(instr_node); RunInstruction(instr_node);
......
...@@ -22,7 +22,7 @@ OpYamlInfoParser::OpYamlInfoParser(const OpInfoTuple& op_info_tuple) ...@@ -22,7 +22,7 @@ OpYamlInfoParser::OpYamlInfoParser(const OpInfoTuple& op_info_tuple)
parse(); parse();
} }
bool OpYamlInfoParser::IsTensorArrtibute(size_t index) const { bool OpYamlInfoParser::IsTensorAttribute(size_t index) const {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
index, index,
InputInfo().size(), InputInfo().size(),
...@@ -48,6 +48,21 @@ const std::string& OpYamlInfoParser::AttrTypeName( ...@@ -48,6 +48,21 @@ const std::string& OpYamlInfoParser::AttrTypeName(
return it->second.type_name; return it->second.type_name;
} }
const std::string& OpYamlInfoParser::TensorAttrTypeName(
const std::string& name) const {
auto it = map_input_info_.find(name);
PADDLE_ENFORCE_NE(it,
map_input_info_.end(),
phi::errors::NotFound("Not found [%s] in input map", name));
PADDLE_ENFORCE_EQ(
it->second.is_mutable_attribute,
true,
phi::errors::PreconditionNotMet("[%s] MUST be a tensor attribute", name));
return it->second.type_name;
}
const std::vector<std::string>& OpYamlInfoParser::InferMetaTensorParams() const std::vector<std::string>& OpYamlInfoParser::InferMetaTensorParams()
const { const {
return vec_infer_meta_tensor_params_; return vec_infer_meta_tensor_params_;
...@@ -62,6 +77,14 @@ const std::vector<std::string>& OpYamlInfoParser::KernelFnAttrParams() const { ...@@ -62,6 +77,14 @@ const std::vector<std::string>& OpYamlInfoParser::KernelFnAttrParams() const {
return vec_kernel_fn_attr_params_; return vec_kernel_fn_attr_params_;
} }
const OpRunTimeInfo& OpYamlInfoParser::OpRuntimeInfo() const {
return std::get<3>(op_info_tuple_);
}
const std::map<std::string, int>& OpYamlInfoParser::Name2Id() const {
return map_name2id_;
}
void OpYamlInfoParser::parse() { void OpYamlInfoParser::parse() {
auto input_info = std::get<0>(op_info_tuple_); auto input_info = std::get<0>(op_info_tuple_);
...@@ -91,7 +114,8 @@ void OpYamlInfoParser::parse() { ...@@ -91,7 +114,8 @@ void OpYamlInfoParser::parse() {
auto runtime_info = std::get<3>(op_info_tuple_); auto runtime_info = std::get<3>(op_info_tuple_);
for (auto& name : runtime_info.infer_meta_param) { for (auto& name : runtime_info.infer_meta_param) {
if (map_name2id_.count(name)) { if (map_name2id_.count(name) &&
!map_input_info_[name].is_mutable_attribute) {
vec_infer_meta_tensor_params_.push_back(name); vec_infer_meta_tensor_params_.push_back(name);
} else { } else {
vec_infer_meta_attr_params_.push_back(name); vec_infer_meta_attr_params_.push_back(name);
...@@ -99,7 +123,8 @@ void OpYamlInfoParser::parse() { ...@@ -99,7 +123,8 @@ void OpYamlInfoParser::parse() {
} }
for (auto& name : runtime_info.kernel_param) { for (auto& name : runtime_info.kernel_param) {
if (map_name2id_.count(name)) { if (map_name2id_.count(name) &&
!map_input_info_[name].is_mutable_attribute) {
vec_kernel_fn_tensor_params_.push_back(name); vec_kernel_fn_tensor_params_.push_back(name);
} else { } else {
vec_kernel_fn_attr_params_.push_back(name); vec_kernel_fn_attr_params_.push_back(name);
......
...@@ -25,15 +25,18 @@ class OpYamlInfoParser { ...@@ -25,15 +25,18 @@ class OpYamlInfoParser {
explicit OpYamlInfoParser(const OpInfoTuple& op_info_tuple); explicit OpYamlInfoParser(const OpInfoTuple& op_info_tuple);
bool IsTensorArrtibute(size_t index) const; bool IsTensorAttribute(size_t index) const;
size_t InputTensorNumber() const; size_t InputTensorNumber() const;
const std::string& AttrTypeName(const std::string& name) const; const std::string& AttrTypeName(const std::string& name) const;
const std::string& TensorAttrTypeName(const std::string& name) const;
const std::vector<std::string>& InferMetaTensorParams() const; const std::vector<std::string>& InferMetaTensorParams() const;
const std::vector<std::string>& InferMetaAttrParams() const; const std::vector<std::string>& InferMetaAttrParams() const;
const std::vector<std::string>& KernelFnTensorParams() const; const std::vector<std::string>& KernelFnTensorParams() const;
const std::vector<std::string>& KernelFnAttrParams() const; const std::vector<std::string>& KernelFnAttrParams() const;
const OpRunTimeInfo& OpRuntimeInfo() const;
const std::map<std::string, int>& Name2Id() const;
private: private:
void parse(); void parse();
...@@ -41,7 +44,7 @@ class OpYamlInfoParser { ...@@ -41,7 +44,7 @@ class OpYamlInfoParser {
return std::get<0>(op_info_tuple_); return std::get<0>(op_info_tuple_);
} }
const OpInfoTuple& op_info_tuple_; OpInfoTuple op_info_tuple_;
std::map<std::string, int> map_name2id_; std::map<std::string, int> map_name2id_;
......
...@@ -4,4 +4,4 @@ file(GLOB PD_PASS_SRCS "*.cc") ...@@ -4,4 +4,4 @@ file(GLOB PD_PASS_SRCS "*.cc")
cc_library( cc_library(
pd_op_to_kernel_pass pd_op_to_kernel_pass
SRCS ${PD_PASS_SRCS} SRCS ${PD_PASS_SRCS}
DEPS ir phi_utils) DEPS ir phi_utils pd_interface)
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/op_yaml_info_parser.h"
#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
...@@ -38,7 +39,8 @@ const int init_on_gpu_threashold = 1000; ...@@ -38,7 +39,8 @@ const int init_on_gpu_threashold = 1000;
phi::KernelKey GetKernelKey( phi::KernelKey GetKernelKey(
ir::Operation* op, ir::Operation* op,
const phi::Place& place, const phi::Place& place,
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair) { const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair,
const dialect::OpYamlInfoParser* op_info_parser = nullptr) {
if (op->name() == "pd.feed") { if (op->name() == "pd.feed") {
// NOTE, for now feed op don't need a kernel, so the data type from Op // NOTE, for now feed op don't need a kernel, so the data type from Op
// Result the next op use base program datatype // Result the next op use base program datatype
...@@ -51,40 +53,18 @@ phi::KernelKey GetKernelKey( ...@@ -51,40 +53,18 @@ phi::KernelKey GetKernelKey(
phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED; phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED;
phi::DataType kernel_data_type = phi::DataType::UNDEFINED; phi::DataType kernel_data_type = phi::DataType::UNDEFINED;
paddle::dialect::OpYamlInfoInterface op_info_interface = if (op_info_parser != nullptr) {
op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
std::vector<paddle::dialect::OpInputInfo> input_info;
if (op_info_interface) {
auto op_info_res = op_info_interface.GetOpInfo();
input_info = std::get<0>(op_info_res);
// only suppurt non vector input for now // only suppurt non vector input for now
std::map<std::string, int> input_map; int tensor_input_number = op_info_parser->InputTensorNumber();
int index = 0;
int tensor_input_number = 0;
for (auto& t : input_info) {
// todo filter attribute tensor
input_map[t.name] = index++;
if (!t.is_mutable_attribute) {
tensor_input_number += 1;
}
}
std::map<std::string, std::string> attr_type_map;
auto attr_info = std::get<1>(op_info_res);
for (auto& t : attr_info) {
VLOG(6) << t.name << "\t" << t.type_name;
attr_type_map[t.name] = t.type_name;
}
auto runtime_info = std::get<3>(op_info_res);
auto attr_map = op->attributes(); auto attr_map = op->attributes();
auto data_type_info = runtime_info.kernel_key_dtype; auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype;
if (data_type_info.size() > 0 && data_type_info[0] != "") { if (data_type_info.size() > 0 && data_type_info[0] != "") {
// only support single input and attribute // only support single input and attribute
auto slot_name = data_type_info[0]; auto slot_name = data_type_info[0];
auto& input_map = op_info_parser->Name2Id();
if (input_map.count(slot_name)) { if (input_map.count(slot_name)) {
// parse from input // parse from input
int in_index = input_map.at(slot_name); int in_index = input_map.at(slot_name);
...@@ -95,10 +75,16 @@ phi::KernelKey GetKernelKey( ...@@ -95,10 +75,16 @@ phi::KernelKey GetKernelKey(
.dyn_cast<paddle::dialect::DenseTensorType>(); .dyn_cast<paddle::dialect::DenseTensorType>();
kernel_data_type = TransToPhiDataType(type.dtype()); kernel_data_type = TransToPhiDataType(type.dtype());
} else { } else {
PADDLE_ENFORCE_EQ(attr_type_map.count(slot_name), PADDLE_ENFORCE_EQ(attr_map.count(slot_name),
true, true,
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"[%s] MUST in attr map", slot_name)); "[%s] MUST in attribute map", slot_name));
auto attr_type = op_info_parser->AttrTypeName(slot_name);
PADDLE_ENFORCE_EQ(attr_type,
"paddle::dialect::DataTypeAttribute",
phi::errors::PreconditionNotMet(
"Type of [%s] should be DataType", slot_name));
kernel_data_type = attr_map.at(slot_name) kernel_data_type = attr_map.at(slot_name)
.dyn_cast<paddle::dialect::DataTypeAttribute>() .dyn_cast<paddle::dialect::DataTypeAttribute>()
.data(); .data();
...@@ -140,10 +126,11 @@ phi::KernelKey GetKernelKey( ...@@ -140,10 +126,11 @@ phi::KernelKey GetKernelKey(
paddle::experimental::detail::KernelKeyParser kernel_key_parser; paddle::experimental::detail::KernelKeyParser kernel_key_parser;
for (size_t i = 0; i < op->num_operands(); ++i) { for (size_t i = 0; i < op->num_operands(); ++i) {
// todo filter attribute tensor // NOTE, only op with OpYamlInfo can have TensorArr
if ((input_info.size() > i) && input_info[i].is_mutable_attribute) { if (op_info_parser != nullptr && op_info_parser->IsTensorAttribute(i)) {
continue; continue;
} }
auto input_tmp = op->operand(i); auto input_tmp = op->operand(i);
auto new_input_tmp = map_value_pair.at(input_tmp); auto new_input_tmp = map_value_pair.at(input_tmp);
...@@ -203,13 +190,20 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { ...@@ -203,13 +190,20 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
std::unordered_map<ir::Operation*, ir::Operation*> map_op_pair; std::unordered_map<ir::Operation*, ir::Operation*> map_op_pair;
std::unordered_map<ir::Value, ir::OpResult> map_value_pair; std::unordered_map<ir::Value, ir::OpResult> map_value_pair;
std::string op1_name = paddle::dialect::PhiKernelOp::name(); std::string op_name = paddle::dialect::PhiKernelOp::name();
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
for (auto it = block->begin(); it != block->end(); ++it) { for (auto it = block->begin(); it != block->end(); ++it) {
VLOG(6) << "op name " << (*it)->name(); VLOG(6) << "op name " << (*it)->name();
auto kernel_key = GetKernelKey(*it, cpu_place, map_value_pair); paddle::dialect::OpYamlInfoInterface op_info_interface =
(*it)->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
OpYamlInfoParser* op_info_parser = nullptr;
if (op_info_interface) {
op_info_parser = new OpYamlInfoParser(op_info_interface.GetOpInfo());
}
auto kernel_key =
GetKernelKey(*it, cpu_place, map_value_pair, op_info_parser);
VLOG(6) << "kernel type " << kernel_key; VLOG(6) << "kernel type " << kernel_key;
// create new Op // create new Op
...@@ -256,15 +250,9 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { ...@@ -256,15 +250,9 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
// constuct input // constuct input
std::vector<ir::OpResult> vec_inputs; std::vector<ir::OpResult> vec_inputs;
paddle::dialect::OpYamlInfoInterface op_info_interface =
(*it)->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
std::string kernel_fn_str; std::string kernel_fn_str;
std::vector<paddle::dialect::OpInputInfo> input_info; if (op_info_parser != nullptr) {
if (op_info_interface) { kernel_fn_str = op_info_parser->OpRuntimeInfo().kernel_func[0];
auto op_info_res = op_info_interface.GetOpInfo();
auto runtime_info = std::get<3>(op_info_res);
kernel_fn_str = runtime_info.kernel_func[0];
input_info = std::get<0>(op_info_res);
} }
if ((*it)->num_operands() > 0) { if ((*it)->num_operands() > 0) {
...@@ -284,9 +272,11 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { ...@@ -284,9 +272,11 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
new_in_type.dyn_cast<dialect::AllocatedDenseTensorType>() new_in_type.dyn_cast<dialect::AllocatedDenseTensorType>()
.place(); .place();
if ((i < input_info.size()) && bool need_trans =
(!input_info[i].is_mutable_attribute) && (op_info_parser != nullptr &&
(place != phi::TransToPhiPlace(kernel_key.backend()))) { !op_info_parser->IsTensorAttribute(i)) &&
(place != phi::TransToPhiPlace(kernel_key.backend()));
if (need_trans) {
if (paddle::experimental::NeedTransformPlace( if (paddle::experimental::NeedTransformPlace(
place, kernel.InputAt(i).backend, {})) { place, kernel.InputAt(i).backend, {})) {
VLOG(6) << "need trans from " << place << " to " VLOG(6) << "need trans from " << place << " to "
...@@ -294,19 +284,19 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { ...@@ -294,19 +284,19 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
// build memcopy op // build memcopy op
auto copy_kernel_key = kernel_key; auto copy_kernel_key = kernel_key;
copy_kernel_key.set_backend(phi::Backend::GPU); copy_kernel_key.set_backend(phi::Backend::GPU);
std::unordered_map<std::string, ir::Attribute> op1_attribute{ std::unordered_map<std::string, ir::Attribute> op_attribute{
{"op_name", ir::StrAttribute::get(ctx, "pd.memcpy_h2d")}, {"op_name", ir::StrAttribute::get(ctx, "pd.memcpy_h2d")},
{"kernel_name", ir::StrAttribute::get(ctx, "memcpy_h2d")}, {"kernel_name", ir::StrAttribute::get(ctx, "memcpy_h2d")},
{"kernel_key", {"kernel_key",
dialect::KernelAttribute::get(ctx, copy_kernel_key)}, dialect::KernelAttribute::get(ctx, copy_kernel_key)},
{"dst_place_type", ir::Int32Attribute::get(ctx, 1)}}; {"dst_place_type", ir::Int32Attribute::get(ctx, 1)}};
ir::Operation* op1 = ir::Operation::Create( ir::Operation* op = ir::Operation::Create(
{new_in}, op1_attribute, {new_in_type}, op1_info); {new_in}, op_attribute, {new_in_type}, op_info);
program->block()->push_back(op1); program->block()->push_back(op);
new_in = op1->result(0); new_in = op->result(0);
} }
} }
} else if (new_in_type.isa<ir::VectorType>()) { } else if (new_in_type.isa<ir::VectorType>()) {
...@@ -320,7 +310,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { ...@@ -320,7 +310,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
} }
} }
std::unordered_map<std::string, ir::Attribute> op1_attribute{ std::unordered_map<std::string, ir::Attribute> op_attribute{
{"op_name", ir::StrAttribute::get(ctx, (*it)->name())}, {"op_name", ir::StrAttribute::get(ctx, (*it)->name())},
{"kernel_name", ir::StrAttribute::get(ctx, kernel_fn_str)}, {"kernel_name", ir::StrAttribute::get(ctx, kernel_fn_str)},
{"kernel_key", dialect::KernelAttribute::get(ctx, kernel_key)}}; {"kernel_key", dialect::KernelAttribute::get(ctx, kernel_key)}};
...@@ -328,22 +318,22 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { ...@@ -328,22 +318,22 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
auto op_attr_map = (*it)->attributes(); auto op_attr_map = (*it)->attributes();
for (auto it1 = op_attr_map.begin(); it1 != op_attr_map.end(); ++it1) { for (auto it1 = op_attr_map.begin(); it1 != op_attr_map.end(); ++it1) {
op1_attribute.emplace(it1->first, it1->second); op_attribute.emplace(it1->first, it1->second);
} }
ir::Operation* op1 = ir::Operation::Create( ir::Operation* op = ir::Operation::Create(
vec_inputs, op1_attribute, op_output_types, op1_info); vec_inputs, op_attribute, op_output_types, op_info);
map_op_pair[*it] = op1; map_op_pair[*it] = op;
// only deal with single output // only deal with single output
if ((*it)->num_results() > 0) { if ((*it)->num_results() > 0) {
for (size_t i = 0; i < (*it)->num_results(); ++i) { for (size_t i = 0; i < (*it)->num_results(); ++i) {
map_value_pair[(*it)->result(i)] = op1->result(i); map_value_pair[(*it)->result(i)] = op->result(i);
} }
} }
program->block()->push_back(op1); program->block()->push_back(op);
} }
return program; return program;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/infermeta.h" #include "paddle/fluid/ir/interface/infermeta.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/op_yaml_info_parser.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
...@@ -80,7 +81,9 @@ class PhiKernelAdaptor { ...@@ -80,7 +81,9 @@ class PhiKernelAdaptor {
phi::InferMetaContext ctx; phi::InferMetaContext ctx;
ir::BuildInferMetaContext((*it), name_map, scope_, yaml_info, &ctx); paddle::dialect::OpYamlInfoParser op_yaml_info_parser(yaml_info);
ir::BuildInferMetaContext(
(*it), name_map, scope_, op_yaml_info_parser, &ctx);
infer_meta_impl->infer_meta_(&ctx); infer_meta_impl->infer_meta_(&ctx);
...@@ -96,7 +99,7 @@ class PhiKernelAdaptor { ...@@ -96,7 +99,7 @@ class PhiKernelAdaptor {
phi::KernelContext kernel_ctx(dev_ctx); phi::KernelContext kernel_ctx(dev_ctx);
ir::BuildPhiKernelContext( ir::BuildPhiKernelContext(
(*it), name_map, scope_, yaml_info, &kernel_ctx); (*it), name_map, scope_, op_yaml_info_parser, &kernel_ctx);
kernel_fn(&kernel_ctx); kernel_fn(&kernel_ctx);
auto out_value = (*it)->result(0); auto out_value = (*it)->result(0);
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include "paddle/fluid/ir/dialect/kernel_attribute.h" #include "paddle/fluid/ir/dialect/kernel_attribute.h"
#include "paddle/fluid/ir/dialect/kernel_type.h" #include "paddle/fluid/ir/dialect/kernel_type.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/interface/op_yaml_info_parser.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "glog/logging.h" #include "glog/logging.h"
...@@ -178,70 +179,29 @@ void BuildInferMetaContext( ...@@ -178,70 +179,29 @@ void BuildInferMetaContext(
ir::Operation* op, ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map, const std::unordered_map<ir::Value, std::string>& name_map,
paddle::framework::Scope* scope, paddle::framework::Scope* scope,
const OpInfoTuple& op_yaml_info, const paddle::dialect::OpYamlInfoParser& op_yaml_info,
phi::InferMetaContext* ctx) { phi::InferMetaContext* ctx) {
// inputs include input and mutable attributes // inputs include input and mutable attributes
auto input_info = std::get<0>(op_yaml_info);
std::map<std::string, size_t> input_index_map;
std::map<std::string, std::string> mutable_attr_type_map;
int input_index = 0;
for (auto& t : input_info) {
VLOG(6) << t.name << "\t" << t.type_name;
input_index_map[t.name] = input_index++;
if (t.is_mutable_attribute) {
mutable_attr_type_map[t.name] = t.type_name;
}
}
auto attr_info = std::get<1>(op_yaml_info);
std::map<std::string, std::string> attr_type_map;
for (auto& t : attr_info) {
VLOG(6) << t.name << "\t" << t.type_name;
attr_type_map[t.name] = t.type_name;
}
auto attr_map = op->attributes(); auto attr_map = op->attributes();
auto runtime_info = std::get<3>(op_yaml_info); auto& vec_infer_meta_tensor_params = op_yaml_info.InferMetaTensorParams();
// int input_index = 0; auto& name2id = op_yaml_info.Name2Id();
for (auto& t : vec_infer_meta_tensor_params) {
std::vector<std::string> vec_param_list = runtime_info.infer_meta_param; PADDLE_ENFORCE_EQ(
name2id.count(t),
for (size_t input_index = 0; input_index < vec_param_list.size(); true,
input_index++) { phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto& t = vec_param_list[input_index]; auto index = op_yaml_info.Name2Id().at(t);
if (input_index_map.count(t)) { ir::Value ptr = op->operand(index);
// get information from input
ir::Value ptr = op->operand(input_index_map[t]);
auto in_var_name = name_map.at(ptr); auto in_var_name = name_map.at(ptr);
if (mutable_attr_type_map.count(t)) {
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t"
<< in_var_name;
if (mutable_attr_type_map[t] == "paddle::dialect::IntArrayAttribute") {
phi::Attribute r1 = phi::TensorRef(
&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else if (mutable_attr_type_map[t] ==
"paddle::dialect::ScalarAttribute") {
phi::Attribute r1 = phi::TensorRef(
&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
mutable_attr_type_map[t]));
}
} else {
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name; VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
auto var = scope->Var(in_var_name); auto var = scope->Var(in_var_name);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
const phi::TensorBase* tensor_in = &(var->Get<phi::DenseTensor>()); const phi::TensorBase* tensor_in = &(var->Get<phi::DenseTensor>());
ctx->EmplaceBackInput(const_cast<phi::TensorBase*>(tensor_in)); ctx->EmplaceBackInput(const_cast<phi::TensorBase*>(tensor_in));
} else if (var->IsType<paddle::framework::TensorRefArray>()) { } else if (var->IsType<paddle::framework::TensorRefArray>()) {
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize> paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize> inputs;
inputs;
auto& tensor_array = var->Get<paddle::framework::TensorRefArray>(); auto& tensor_array = var->Get<paddle::framework::TensorRefArray>();
for (size_t i = 0; i < tensor_array.size(); ++i) { for (size_t i = 0; i < tensor_array.size(); ++i) {
inputs.emplace_back(std::move(phi::MetaTensor(*tensor_array[i]))); inputs.emplace_back(std::move(phi::MetaTensor(*tensor_array[i])));
...@@ -253,35 +213,60 @@ void BuildInferMetaContext( ...@@ -253,35 +213,60 @@ void BuildInferMetaContext(
var->Type())); var->Type()));
} }
} }
auto& vec_infer_meta_attr_params = op_yaml_info.InferMetaAttrParams();
for (auto& t : vec_infer_meta_attr_params) {
if (name2id.count(t)) {
// tensor attribute, get information from input
ir::Value ptr = op->operand(name2id.at(t));
auto in_var_name = name_map.at(ptr);
auto& tensor_attr_type = op_yaml_info.TensorAttrTypeName(t);
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name;
if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
tensor_attr_type));
}
continue;
} }
if (attr_type_map.count(t)) { auto& attr_type_name = op_yaml_info.AttrTypeName(t);
auto type_name = attr_type_map[t];
if (type_name == "paddle::dialect::IntArrayAttribute") { if (attr_type_name == "paddle::dialect::IntArrayAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::IntArrayAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::IntArrayAttribute>().data());
} else if (type_name == "paddle::dialect::DataTypeAttribute") { } else if (attr_type_name == "paddle::dialect::DataTypeAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (type_name == "ir::Int32Attribute") { } else if (attr_type_name == "ir::Int32Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data());
} else if (type_name == "ir::FloatAttribute") { } else if (attr_type_name == "ir::FloatAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data());
} else if (type_name == "ir::BoolAttribute") { } else if (attr_type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (type_name == "paddle::dialect::PlaceAttribute") { } else if (attr_type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
} else if (type_name == "paddle::dialect::ScalarAttribute") { } else if (attr_type_name == "paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data());
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ", PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
type_name)); attr_type_name));
} }
VLOG(6) << "ctx->EmplaceBackAttr: " << t; VLOG(6) << "ctx->EmplaceBackAttr: " << t;
} }
}
// TODO(phlrain): use var type instead of op name // TODO(phlrain): use var type instead of op name
if (op->attributes().count("op_name") && if (op->attributes().count("op_name") &&
...@@ -305,77 +290,30 @@ void BuildPhiKernelContext( ...@@ -305,77 +290,30 @@ void BuildPhiKernelContext(
ir::Operation* op, ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map, const std::unordered_map<ir::Value, std::string>& name_map,
paddle::framework::Scope* scope, paddle::framework::Scope* scope,
const OpInfoTuple& op_yaml_info, const paddle::dialect::OpYamlInfoParser& op_yaml_info,
phi::KernelContext* ctx, phi::KernelContext* ctx,
std::map<std::string, std::vector<int>>* input_map, std::map<std::string, std::vector<int>>* input_map,
std::map<std::string, std::vector<int>>* output_map) { std::map<std::string, std::vector<int>>* output_map) {
// inputs include input and mutable attributes // inputs include input and mutable attributes
auto input_info = std::get<0>(op_yaml_info);
std::map<std::string, size_t> input_index_map;
std::map<std::string, std::string> mutable_attr_type_map;
int input_index = 0;
for (auto& t : input_info) {
VLOG(6) << t.name << "\t" << t.type_name;
input_index_map[t.name] = input_index++;
if (t.is_mutable_attribute) {
mutable_attr_type_map[t.name] = t.type_name;
}
}
auto attr_info = std::get<1>(op_yaml_info);
std::map<std::string, std::string> attr_type_map;
for (auto& t : attr_info) {
VLOG(6) << t.name << "\t" << t.type_name;
attr_type_map[t.name] = t.type_name;
}
auto attr_map = op->attributes(); auto attr_map = op->attributes();
auto runtime_info = std::get<3>(op_yaml_info);
// int input_index = 0;
std::vector<std::string> vec_param_list = runtime_info.kernel_param;
for (auto& t : vec_param_list) {
if (input_index_map.count(t)) {
// get information from input
ir::Value ptr = op->operand(input_index_map[t]);
auto in_var_name = name_map.at(ptr);
if (input_map != nullptr) {
// only deal with single input for now, [todo] need support multi input
// like concat
// TODO(phlrain): OpFuncNode need input_index and output_index,
// construct input_index and output_here, should remove input_index and
// output_index from OpFuncNode Each in_var_name named "inner_var_" +
// index, len("inner_var_") = 10
size_t tmp_id = std::atol(in_var_name.substr(4, 100).c_str());
(*input_map)[std::to_string(input_index_map.at(t))].push_back(tmp_id);
}
if (mutable_attr_type_map.count(t)) { auto& vec_kernel_fn_tensor_params = op_yaml_info.KernelFnTensorParams();
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t"
<< in_var_name;
if (mutable_attr_type_map[t] == "paddle::dialect::IntArrayAttribute") {
phi::Attribute r1 = phi::TensorRef(
&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else if (mutable_attr_type_map[t] ==
"paddle::dialect::ScalarAttribute") {
phi::Attribute r1 = phi::TensorRef(
&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
mutable_attr_type_map[t]));
}
} else { auto& name2id = op_yaml_info.Name2Id();
for (auto& t : vec_kernel_fn_tensor_params) {
PADDLE_ENFORCE_EQ(
name2id.count(t),
true,
phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.Name2Id().at(t);
ir::Value ptr = op->operand(index);
auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name; VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(scope->FindLocalVar(in_var_name),
scope->FindLocalVar(in_var_name), phi::errors::PreconditionNotMet(
phi::errors::PreconditionNotMet("can not find var[%s] in scope", "can not find var[%s] in scope", in_var_name));
in_var_name));
auto var = scope->Var(in_var_name); auto var = scope->Var(in_var_name);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
...@@ -387,7 +325,7 @@ void BuildPhiKernelContext( ...@@ -387,7 +325,7 @@ void BuildPhiKernelContext(
for (size_t i = 0; i < tensor_array.size(); ++i) { for (size_t i = 0; i < tensor_array.size(); ++i) {
inputs.emplace_back(tensor_array[i]); inputs.emplace_back(tensor_array[i]);
} }
std::cerr << "is tensor ref " << std::endl;
ctx->EmplaceBackInputs(std::move(inputs)); ctx->EmplaceBackInputs(std::move(inputs));
} else if (var->IsType<paddle::framework::FeedList>()) { } else if (var->IsType<paddle::framework::FeedList>()) {
auto feed_list = var->Get<paddle::framework::FeedList>(); auto feed_list = var->Get<paddle::framework::FeedList>();
...@@ -398,35 +336,70 @@ void BuildPhiKernelContext( ...@@ -398,35 +336,70 @@ void BuildPhiKernelContext(
var->Type())); var->Type()));
} }
} }
auto& vec_kernel_fn_attr_params = op_yaml_info.KernelFnAttrParams();
for (auto& t : vec_kernel_fn_attr_params) {
if (name2id.count(t)) {
// tensor attribute, get information from input
ir::Value ptr = op->operand(name2id.at(t));
auto in_var_name = name_map.at(ptr);
if (input_map != nullptr) {
// only deal with single input for now, [todo] need support multi input
// like concat
// TODO(phlrain): OpFuncNode need input_index and output_index,
// construct input_index and output_here, should remove input_index and
// output_index from OpFuncNode Each in_var_name named "inner_var_" +
// index, len("inner_var_") = 10
size_t tmp_id = std::atol(in_var_name.substr(4, 100).c_str());
(*input_map)[std::to_string(name2id.at(t))].push_back(tmp_id);
} }
if (attr_type_map.count(t)) { auto& tensor_attr_type = op_yaml_info.TensorAttrTypeName(t);
auto type_name = attr_type_map[t]; VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name;
if (type_name == "paddle::dialect::IntArrayAttribute") { if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
tensor_attr_type));
}
continue;
}
auto& attr_type_name = op_yaml_info.AttrTypeName(t);
if (attr_type_name == "paddle::dialect::IntArrayAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::IntArrayAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::IntArrayAttribute>().data());
} else if (type_name == "paddle::dialect::DataTypeAttribute") { } else if (attr_type_name == "paddle::dialect::DataTypeAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (type_name == "ir::Int32Attribute") { } else if (attr_type_name == "ir::Int32Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data());
} else if (type_name == "ir::FloatAttribute") { } else if (attr_type_name == "ir::FloatAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data());
} else if (type_name == "ir::BoolAttribute") { } else if (attr_type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (type_name == "paddle::dialect::PlaceAttribute") { } else if (attr_type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
} else if (type_name == "paddle::dialect::ScalarAttribute") { } else if (attr_type_name == "paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data());
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ", PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
type_name)); attr_type_name));
} }
VLOG(6) << "ctx->EmplaceBackAttr: " << t; VLOG(6) << "ctx->EmplaceBackAttr: " << t;
} }
}
// TODO(phlrain): use var type instead of op name // TODO(phlrain): use var type instead of op name
if (op->attributes().count("op_name") && if (op->attributes().count("op_name") &&
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include "paddle/fluid/ir/dialect/kernel_attribute.h" #include "paddle/fluid/ir/dialect/kernel_attribute.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/interface/op_yaml_info_parser.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "glog/logging.h" #include "glog/logging.h"
...@@ -47,14 +48,14 @@ void BuildInferMetaContext( ...@@ -47,14 +48,14 @@ void BuildInferMetaContext(
ir::Operation* op, ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map, const std::unordered_map<ir::Value, std::string>& name_map,
paddle::framework::Scope* scope, paddle::framework::Scope* scope,
const OpInfoTuple& op_yaml_info, const paddle::dialect::OpYamlInfoParser& op_yaml_info,
phi::InferMetaContext* ctx); phi::InferMetaContext* ctx);
void BuildPhiKernelContext( void BuildPhiKernelContext(
ir::Operation* op, ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map, const std::unordered_map<ir::Value, std::string>& name_map,
paddle::framework::Scope* scope, paddle::framework::Scope* scope,
const OpInfoTuple& op_yaml_info, const paddle::dialect::OpYamlInfoParser& op_yaml_info,
phi::KernelContext* ctx, phi::KernelContext* ctx,
std::map<std::string, std::vector<int>>* input_map = nullptr, std::map<std::string, std::vector<int>>* input_map = nullptr,
std::map<std::string, std::vector<int>>* output_map = nullptr); std::map<std::string, std::vector<int>>* output_map = nullptr);
......
...@@ -118,7 +118,7 @@ void TransDataBackend(const phi::SelectedRows* tensor, ...@@ -118,7 +118,7 @@ void TransDataBackend(const phi::SelectedRows* tensor,
Backend target_backend, Backend target_backend,
phi::SelectedRows* out); phi::SelectedRows* out);
inline bool NeedTransformPlace(const phi::Place& input, inline bool NeedTransformPlace(const phi::Place& src_place,
const Backend& target, const Backend& target,
const TransformFlag& transform_flag) { const TransformFlag& transform_flag) {
// NOTE(dev): The default value of TransformFlag is True, if it is set with // NOTE(dev): The default value of TransformFlag is True, if it is set with
...@@ -128,9 +128,9 @@ inline bool NeedTransformPlace(const phi::Place& input, ...@@ -128,9 +128,9 @@ inline bool NeedTransformPlace(const phi::Place& input,
if (!transform_flag.need_trans_backend()) { if (!transform_flag.need_trans_backend()) {
return false; return false;
} }
bool ret = input.GetType() == AllocationType::GPUPINNED || bool ret = src_place.GetType() == AllocationType::GPUPINNED ||
(target != Backend::ALL_BACKEND && (target != Backend::ALL_BACKEND &&
phi::TransToPhiBackend(input) != phi::TransToPhiBackend(src_place) !=
(target != Backend::GPUDNN ? target : Backend::GPU)); (target != Backend::GPUDNN ? target : Backend::GPU));
return ret; return ret;
} }
......
...@@ -60,14 +60,14 @@ TEST(ir_op_info_test, op_op_info_test) { ...@@ -60,14 +60,14 @@ TEST(ir_op_info_test, op_op_info_test) {
auto kernel_fn_tensor_param = op_yaml_info_parser.KernelFnTensorParams(); auto kernel_fn_tensor_param = op_yaml_info_parser.KernelFnTensorParams();
auto kernel_fn_attr_param = op_yaml_info_parser.KernelFnAttrParams(); auto kernel_fn_attr_param = op_yaml_info_parser.KernelFnAttrParams();
EXPECT_EQ(infer_meta_tensor_param.size(), 1u); EXPECT_EQ(infer_meta_tensor_param.size(), 0u);
EXPECT_EQ(infer_meta_attr_param.size(), 1u); EXPECT_EQ(infer_meta_attr_param.size(), 2u);
EXPECT_EQ(kernel_fn_tensor_param.size(), 3u); EXPECT_EQ(kernel_fn_tensor_param.size(), 0u);
EXPECT_EQ(kernel_fn_attr_param.size(), 2u); EXPECT_EQ(kernel_fn_attr_param.size(), 5u);
EXPECT_EQ((op_yaml_info_parser.AttrTypeName("seed") == "ir::Int32Attribute"), EXPECT_EQ((op_yaml_info_parser.AttrTypeName("seed") == "ir::Int32Attribute"),
true); true);
EXPECT_EQ(op_yaml_info_parser.IsTensorArrtibute(0), true); EXPECT_EQ(op_yaml_info_parser.IsTensorAttribute(0), true);
EXPECT_EQ(op_yaml_info_parser.InputTensorNumber(), 0u); EXPECT_EQ(op_yaml_info_parser.InputTensorNumber(), 0u);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册