未验证 提交 eef38db1 编写于 作者: H hong 提交者: GitHub

Refactor build attribute (#54968)

* update

* refactor build context

* fix bug

* polish code

* change func name
上级 7353e9e9
......@@ -969,7 +969,12 @@ void BuildOpFuncList(
VLOG(6) << "op name" << op_func_node.phi_op_name_;
dialect::OpYamlInfoParser op_yaml_info_parser(impl->get_op_info_());
::ir::BuildInferMetaContext((*it),
::ir::BuildPhiContext<
phi::InferMetaContext,
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it),
value_2_name_map,
scope,
op_yaml_info_parser,
......@@ -990,7 +995,11 @@ void BuildOpFuncList(
true,
"not found kernel for [%s]",
kernel_name);
::ir::BuildPhiKernelContext((*it),
::ir::BuildPhiContext<phi::KernelContext,
const phi::TensorBase*,
phi::TensorBase*,
paddle::small_vector<const phi::TensorBase*>,
true>((*it),
value_2_name_map,
scope,
op_yaml_info_parser,
......
......@@ -39,21 +39,21 @@ size_t OpYamlInfoParser::InputTensorNumber() const {
const std::string& OpYamlInfoParser::AttrTypeName(
const std::string& name) const {
auto it = map_attr_info_.find(name);
auto it = attr_info_.find(name);
PADDLE_ENFORCE_NE(
it,
map_attr_info_.end(),
attr_info_.end(),
phi::errors::NotFound("Not found [%s] in attribute map", name));
return it->second.type_name;
}
const std::string& OpYamlInfoParser::TensorAttrTypeName(
const std::string& name) const {
auto it = map_input_info_.find(name);
auto it = input_info_.find(name);
PADDLE_ENFORCE_NE(it,
map_input_info_.end(),
input_info_.end(),
phi::errors::NotFound("Not found [%s] in input map", name));
PADDLE_ENFORCE_EQ(
......@@ -63,18 +63,21 @@ const std::string& OpYamlInfoParser::TensorAttrTypeName(
return it->second.type_name;
}
const std::vector<std::string>& OpYamlInfoParser::InferMetaTensorParams()
const {
return vec_infer_meta_tensor_params_;
}
const std::vector<std::string>& OpYamlInfoParser::InferMetaAttrParams() const {
return vec_infer_meta_attr_params_;
}
const std::vector<std::string>& OpYamlInfoParser::KernelFnTensorParams() const {
return vec_kernel_fn_tensor_params_;
const std::vector<std::string>& OpYamlInfoParser::TensorParams(
bool is_kernel) const {
if (is_kernel) {
return kernel_fn_tensor_params_;
} else {
return infer_meta_tensor_params_;
}
}
const std::vector<std::string>& OpYamlInfoParser::KernelFnAttrParams() const {
return vec_kernel_fn_attr_params_;
const std::vector<std::string>& OpYamlInfoParser::AttrParams(
bool is_kernel) const {
if (is_kernel) {
return kernel_fn_attr_params_;
} else {
return infer_meta_attr_params_;
}
}
const OpRunTimeInfo& OpYamlInfoParser::OpRuntimeInfo() const {
......@@ -82,7 +85,7 @@ const OpRunTimeInfo& OpYamlInfoParser::OpRuntimeInfo() const {
}
const std::map<std::string, int>& OpYamlInfoParser::Name2Id() const {
return map_name2id_;
return name2id_;
}
void OpYamlInfoParser::parse() {
......@@ -91,43 +94,41 @@ void OpYamlInfoParser::parse() {
int start_index = 0;
for (size_t i = 0; i < input_info.size(); ++i) {
map_name2id_[input_info[i].name] = start_index++;
name2id_[input_info[i].name] = start_index++;
if (!input_info[i].is_mutable_attribute) {
input_tensor_number_++;
}
map_input_info_[input_info[i].name] = input_info[i];
input_info_[input_info[i].name] = input_info[i];
}
auto attribute_info = std::get<1>(op_info_tuple_);
for (size_t i = 0; i < attribute_info.size(); ++i) {
map_attr_info_[attribute_info[i].name] = attribute_info[i];
attr_info_[attribute_info[i].name] = attribute_info[i];
}
auto output_info = std::get<2>(op_info_tuple_);
for (size_t i = 0; i < output_info.size(); ++i) {
map_output_info_[output_info[i].name] = output_info[i];
output_info_[output_info[i].name] = output_info[i];
}
auto runtime_info = std::get<3>(op_info_tuple_);
for (auto& name : runtime_info.infer_meta_param) {
if (map_name2id_.count(name) &&
!map_input_info_[name].is_mutable_attribute) {
vec_infer_meta_tensor_params_.push_back(name);
if (name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
infer_meta_tensor_params_.push_back(name);
} else {
vec_infer_meta_attr_params_.push_back(name);
infer_meta_attr_params_.push_back(name);
}
}
for (auto& name : runtime_info.kernel_param) {
if (map_name2id_.count(name) &&
!map_input_info_[name].is_mutable_attribute) {
vec_kernel_fn_tensor_params_.push_back(name);
if (name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
kernel_fn_tensor_params_.push_back(name);
} else {
vec_kernel_fn_attr_params_.push_back(name);
kernel_fn_attr_params_.push_back(name);
}
}
}
......
......@@ -31,10 +31,8 @@ class OpYamlInfoParser {
const std::string& AttrTypeName(const std::string& name) const;
const std::string& TensorAttrTypeName(const std::string& name) const;
const std::vector<std::string>& InferMetaTensorParams() const;
const std::vector<std::string>& InferMetaAttrParams() const;
const std::vector<std::string>& KernelFnTensorParams() const;
const std::vector<std::string>& KernelFnAttrParams() const;
const std::vector<std::string>& TensorParams(bool is_kernel = false) const;
const std::vector<std::string>& AttrParams(bool is_kernel = false) const;
const OpRunTimeInfo& OpRuntimeInfo() const;
const std::map<std::string, int>& Name2Id() const;
......@@ -46,16 +44,16 @@ class OpYamlInfoParser {
OpInfoTuple op_info_tuple_;
std::map<std::string, int> map_name2id_;
std::map<std::string, int> name2id_;
std::map<std::string, OpInputInfo> map_input_info_;
std::map<std::string, OpAttributeInfo> map_attr_info_;
std::map<std::string, OpOutputInfo> map_output_info_;
std::map<std::string, OpInputInfo> input_info_;
std::map<std::string, OpAttributeInfo> attr_info_;
std::map<std::string, OpOutputInfo> output_info_;
std::vector<std::string> vec_infer_meta_tensor_params_;
std::vector<std::string> vec_infer_meta_attr_params_;
std::vector<std::string> vec_kernel_fn_tensor_params_;
std::vector<std::string> vec_kernel_fn_attr_params_;
std::vector<std::string> infer_meta_tensor_params_;
std::vector<std::string> infer_meta_attr_params_;
std::vector<std::string> kernel_fn_tensor_params_;
std::vector<std::string> kernel_fn_attr_params_;
int input_tensor_number_{0};
};
......
......@@ -82,8 +82,12 @@ class PhiKernelAdaptor {
phi::InferMetaContext ctx;
paddle::dialect::OpYamlInfoParser op_yaml_info_parser(yaml_info);
ir::BuildInferMetaContext(
(*it), name_map, scope_, op_yaml_info_parser, &ctx);
ir::BuildPhiContext<
phi::InferMetaContext,
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it), name_map, scope_, op_yaml_info_parser, &ctx);
infer_meta_impl->infer_meta_(&ctx);
......@@ -98,7 +102,11 @@ class PhiKernelAdaptor {
phi::KernelContext kernel_ctx(dev_ctx);
ir::BuildPhiKernelContext(
ir::BuildPhiContext<phi::KernelContext,
const phi::TensorBase*,
phi::TensorBase*,
paddle::small_vector<const phi::TensorBase*>,
true>(
(*it), name_map, scope_, op_yaml_info_parser, &kernel_ctx);
kernel_fn(&kernel_ctx);
......
......@@ -175,261 +175,4 @@ void BuildScope(ir::Block* block,
}
}
void BuildInferMetaContext(
ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map,
paddle::framework::Scope* scope,
const paddle::dialect::OpYamlInfoParser& op_yaml_info,
phi::InferMetaContext* ctx) {
// inputs include input and mutable attributes
auto attr_map = op->attributes();
auto& vec_infer_meta_tensor_params = op_yaml_info.InferMetaTensorParams();
auto& name2id = op_yaml_info.Name2Id();
for (auto& t : vec_infer_meta_tensor_params) {
PADDLE_ENFORCE_EQ(
name2id.count(t),
true,
phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.Name2Id().at(t);
ir::Value ptr = op->operand(index);
auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
auto var = scope->Var(in_var_name);
if (var->IsType<phi::DenseTensor>()) {
const phi::TensorBase* tensor_in = &(var->Get<phi::DenseTensor>());
ctx->EmplaceBackInput(const_cast<phi::TensorBase*>(tensor_in));
} else if (var->IsType<paddle::framework::TensorRefArray>()) {
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize> inputs;
auto& tensor_array = var->Get<paddle::framework::TensorRefArray>();
for (size_t i = 0; i < tensor_array.size(); ++i) {
inputs.emplace_back(std::move(phi::MetaTensor(*tensor_array[i])));
}
ctx->EmplaceBackInputs(std::move(inputs));
} else {
PADDLE_THROW(phi::errors::Unimplemented("Not support var type [%d] ",
var->Type()));
}
}
auto& vec_infer_meta_attr_params = op_yaml_info.InferMetaAttrParams();
for (auto& t : vec_infer_meta_attr_params) {
if (name2id.count(t)) {
// tensor attribute, get information from input
ir::Value ptr = op->operand(name2id.at(t));
auto in_var_name = name_map.at(ptr);
auto& tensor_attr_type = op_yaml_info.TensorAttrTypeName(t);
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name;
if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
tensor_attr_type));
}
continue;
}
auto& attr_type_name = op_yaml_info.AttrTypeName(t);
if (attr_type_name == "paddle::dialect::IntArrayAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::IntArrayAttribute>().data());
} else if (attr_type_name == "paddle::dialect::DataTypeAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (attr_type_name == "ir::Int32Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data());
} else if (attr_type_name == "ir::FloatAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data());
} else if (attr_type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (attr_type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
} else if (attr_type_name == "paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data());
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
attr_type_name));
}
VLOG(6) << "ctx->EmplaceBackAttr: " << t;
}
// TODO(phlrain): use var type instead of op name
if (op->attributes().count("op_name") &&
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data() ==
"pd.fetch")) {
// process fetch op
auto fetch_var = scope->Var("fetch");
auto* fetch_list = fetch_var->GetMutable<paddle::framework::FetchList>();
auto* out_tensor = &(PADDLE_GET(phi::DenseTensor, fetch_list->at(0)));
ctx->EmplaceBackOutput(out_tensor);
} else {
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value out_ptr = op->result(i);
auto name = name_map.at(out_ptr);
ctx->EmplaceBackOutput(scope->Var(name)->Get<phi::DenseTensor>());
}
}
}
void BuildPhiKernelContext(
ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map,
paddle::framework::Scope* scope,
const paddle::dialect::OpYamlInfoParser& op_yaml_info,
phi::KernelContext* ctx,
std::map<std::string, std::vector<int>>* input_map,
std::map<std::string, std::vector<int>>* output_map) {
// inputs include input and mutable attributes
auto attr_map = op->attributes();
auto& vec_kernel_fn_tensor_params = op_yaml_info.KernelFnTensorParams();
auto& name2id = op_yaml_info.Name2Id();
for (auto& t : vec_kernel_fn_tensor_params) {
PADDLE_ENFORCE_EQ(
name2id.count(t),
true,
phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.Name2Id().at(t);
ir::Value ptr = op->operand(index);
auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
PADDLE_ENFORCE_NOT_NULL(scope->FindLocalVar(in_var_name),
phi::errors::PreconditionNotMet(
"can not find var[%s] in scope", in_var_name));
auto var = scope->Var(in_var_name);
if (var->IsType<phi::DenseTensor>()) {
const phi::TensorBase* tensor_in = &(var->Get<phi::DenseTensor>());
ctx->EmplaceBackInput(tensor_in);
} else if (var->IsType<paddle::framework::TensorRefArray>()) {
paddle::small_vector<const phi::TensorBase*> inputs;
auto& tensor_array = var->Get<paddle::framework::TensorRefArray>();
for (size_t i = 0; i < tensor_array.size(); ++i) {
inputs.emplace_back(tensor_array[i]);
}
std::cerr << "is tensor ref " << std::endl;
ctx->EmplaceBackInputs(std::move(inputs));
} else if (var->IsType<paddle::framework::FeedList>()) {
auto feed_list = var->Get<paddle::framework::FeedList>();
auto* in_tensor = &(PADDLE_GET(phi::DenseTensor, feed_list.at(0)));
ctx->EmplaceBackOutput(in_tensor);
} else {
PADDLE_THROW(phi::errors::Unimplemented("Not support var type [%d] ",
var->Type()));
}
}
auto& vec_kernel_fn_attr_params = op_yaml_info.KernelFnAttrParams();
for (auto& t : vec_kernel_fn_attr_params) {
if (name2id.count(t)) {
// tensor attribute, get information from input
ir::Value ptr = op->operand(name2id.at(t));
auto in_var_name = name_map.at(ptr);
if (input_map != nullptr) {
// only deal with single input for now, [todo] need support multi input
// like concat
// TODO(phlrain): OpFuncNode need input_index and output_index,
// construct input_index and output_here, should remove input_index and
// output_index from OpFuncNode Each in_var_name named "inner_var_" +
// index, len("inner_var_") = 10
size_t tmp_id = std::atol(in_var_name.substr(4, 100).c_str());
(*input_map)[std::to_string(name2id.at(t))].push_back(tmp_id);
}
auto& tensor_attr_type = op_yaml_info.TensorAttrTypeName(t);
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name;
if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
tensor_attr_type));
}
continue;
}
auto& attr_type_name = op_yaml_info.AttrTypeName(t);
if (attr_type_name == "paddle::dialect::IntArrayAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::IntArrayAttribute>().data());
} else if (attr_type_name == "paddle::dialect::DataTypeAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (attr_type_name == "ir::Int32Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data());
} else if (attr_type_name == "ir::FloatAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data());
} else if (attr_type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (attr_type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
} else if (attr_type_name == "paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data());
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
attr_type_name));
}
VLOG(6) << "ctx->EmplaceBackAttr: " << t;
}
// TODO(phlrain): use var type instead of op name
if (op->attributes().count("op_name") &&
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data() ==
"pd.fetch")) {
// process fetch op
auto fetch_var = scope->Var("fetch");
auto* fetch_list = fetch_var->GetMutable<paddle::framework::FetchList>();
auto* out_tensor = &(PADDLE_GET(phi::DenseTensor, fetch_list->at(0)));
ctx->EmplaceBackOutput(out_tensor);
} else {
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value out_ptr = op->result(i);
auto name = name_map.at(out_ptr);
ctx->EmplaceBackOutput(const_cast<phi::DenseTensor*>(
&(scope->Var(name)->Get<phi::DenseTensor>())));
if (output_map != nullptr) {
// only deal with single input for now, [todo] need support multi input
// like concat
// TODO(phlrain): OpFuncNode need input_index and output_index,
// construct input_index and output_here, should remove input_index and
// output_index from OpFuncNode Each in_var_name named "inner_var_" +
// index, len("inner_var_") = 10
size_t tmp_id = std::atol(name.substr(4, 100).c_str());
(*output_map)["out"].push_back(tmp_id);
}
}
}
}
} // namespace ir
......@@ -44,20 +44,150 @@ void BuildScope(ir::Block* block,
paddle::framework::Scope* scope,
std::unordered_map<ir::Value, std::string>* name_map);
void BuildInferMetaContext(
template <typename Context,
typename InType,
typename OutType,
typename ListType,
bool is_kernel>
void BuildPhiContext(
ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map,
paddle::framework::Scope* scope,
const paddle::dialect::OpYamlInfoParser& op_yaml_info,
phi::InferMetaContext* ctx);
void BuildPhiKernelContext(
ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map,
paddle::framework::Scope* scope,
const paddle::dialect::OpYamlInfoParser& op_yaml_info,
phi::KernelContext* ctx,
Context* ctx,
std::map<std::string, std::vector<int>>* input_map = nullptr,
std::map<std::string, std::vector<int>>* output_map = nullptr);
std::map<std::string, std::vector<int>>* output_map = nullptr) {
// inputs include input and mutable attributes
auto attr_map = op->attributes();
auto& vec_kernel_fn_tensor_params = op_yaml_info.TensorParams(is_kernel);
auto& name2id = op_yaml_info.Name2Id();
for (auto& t : vec_kernel_fn_tensor_params) {
PADDLE_ENFORCE_EQ(
name2id.count(t),
true,
phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.Name2Id().at(t);
ir::Value ptr = op->operand(index);
auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
PADDLE_ENFORCE_NOT_NULL(scope->FindLocalVar(in_var_name),
phi::errors::PreconditionNotMet(
"can not find var[%s] in scope", in_var_name));
auto var = scope->Var(in_var_name);
if (var->IsType<phi::DenseTensor>()) {
const phi::TensorBase* tensor_in = &(var->Get<phi::DenseTensor>());
ctx->EmplaceBackInput(InType(tensor_in));
} else if (var->IsType<paddle::framework::TensorRefArray>()) {
ListType inputs;
auto& tensor_array = var->Get<paddle::framework::TensorRefArray>();
for (size_t i = 0; i < tensor_array.size(); ++i) {
inputs.emplace_back(InType(tensor_array[i]));
}
ctx->EmplaceBackInputs(inputs);
} else {
PADDLE_THROW(phi::errors::Unimplemented("Not support var type [%d] ",
var->Type()));
}
}
auto& vec_kernel_fn_attr_params = op_yaml_info.AttrParams(is_kernel);
for (auto& t : vec_kernel_fn_attr_params) {
if (name2id.count(t)) {
// tensor attribute, get information from input
ir::Value ptr = op->operand(name2id.at(t));
auto in_var_name = name_map.at(ptr);
if (input_map != nullptr) {
// only deal with single input for now, [todo] need support multi input
// like concat
// TODO(phlrain): OpFuncNode need input_index and output_index,
// construct input_index and output_here, should remove input_index and
// output_index from OpFuncNode Each in_var_name named "inner_var_" +
// index, len("inner_var_") = 10
size_t tmp_id = std::atol(in_var_name.substr(4, 100).c_str());
(*input_map)[std::to_string(name2id.at(t))].push_back(tmp_id);
}
auto& tensor_attr_type = op_yaml_info.TensorAttrTypeName(t);
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name;
if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
tensor_attr_type));
}
continue;
}
auto& attr_type_name = op_yaml_info.AttrTypeName(t);
if (attr_type_name == "paddle::dialect::IntArrayAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::IntArrayAttribute>().data());
} else if (attr_type_name == "paddle::dialect::DataTypeAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (attr_type_name == "ir::Int32Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data());
} else if (attr_type_name == "ir::FloatAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data());
} else if (attr_type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (attr_type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
} else if (attr_type_name == "paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data());
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
attr_type_name));
}
VLOG(6) << "ctx->EmplaceBackAttr: " << t;
}
// TODO(phlrain): use var type instead of op name
if (op->attributes().count("op_name") &&
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data() ==
"pd.fetch")) {
// process fetch op
auto fetch_var = scope->Var("fetch");
auto* fetch_list = fetch_var->GetMutable<paddle::framework::FetchList>();
auto* out_tensor = &(PADDLE_GET(phi::DenseTensor, fetch_list->at(0)));
ctx->EmplaceBackOutput(out_tensor);
} else {
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value out_ptr = op->result(i);
auto name = name_map.at(out_ptr);
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(scope->Var(name)->Get<phi::DenseTensor>()))));
if (output_map != nullptr) {
// only deal with single input for now, [todo] need support multi input
// like concat
// TODO(phlrain): OpFuncNode need input_index and output_index,
// construct input_index and output_here, should remove input_index and
// output_index from OpFuncNode Each in_var_name named "inner_var_" +
// index, len("inner_var_") = 10
size_t tmp_id = std::atol(name.substr(4, 100).c_str());
(*output_map)["out"].push_back(tmp_id);
}
}
}
}
} // namespace ir
......@@ -45,6 +45,8 @@ class MetaTensor {
MetaTensor(TensorBase* tensor) : tensor_(tensor) {} // NOLINT
MetaTensor(const TensorBase& tensor) // NOLINT
: tensor_(const_cast<TensorBase*>(&tensor)) {}
MetaTensor(const TensorBase* tensor) // NOLINT
: tensor_(const_cast<TensorBase*>(tensor)) {}
MetaTensor(TensorBase& tensor) : tensor_(&tensor) {} // NOLINT
MetaTensor(MetaTensor&&) = default;
......
......@@ -55,16 +55,6 @@ TEST(ir_op_info_test, op_op_info_test) {
paddle::dialect::OpYamlInfoParser op_yaml_info_parser(op_info_res);
auto infer_meta_tensor_param = op_yaml_info_parser.InferMetaTensorParams();
auto infer_meta_attr_param = op_yaml_info_parser.InferMetaAttrParams();
auto kernel_fn_tensor_param = op_yaml_info_parser.KernelFnTensorParams();
auto kernel_fn_attr_param = op_yaml_info_parser.KernelFnAttrParams();
EXPECT_EQ(infer_meta_tensor_param.size(), 0u);
EXPECT_EQ(infer_meta_attr_param.size(), 2u);
EXPECT_EQ(kernel_fn_tensor_param.size(), 0u);
EXPECT_EQ(kernel_fn_attr_param.size(), 5u);
EXPECT_EQ((op_yaml_info_parser.AttrTypeName("seed") == "ir::Int32Attribute"),
true);
EXPECT_EQ(op_yaml_info_parser.IsTensorAttribute(0), true);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册