未验证 提交 a2903920 编写于 作者: H hong 提交者: GitHub

[NewIR]Fix null value and support some attribute (#55100)

* suport optional input in new_ir

* polish code

* add coverate test

* update

* update

* add unitest

* remove reduplicate code

* set test timeout
上级 df6c74c3
......@@ -954,7 +954,9 @@ void BuildOpFuncList(
auto op_name = attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().data();
op_func_node.phi_op_name_ = op_name;
if (op_name == "builtin.combine" || op_name == "pd.feed") {
if (op_name == "builtin.combine" || op_name == "pd.feed" ||
op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter") {
VLOG(6) << "skip process " << op_name;
continue;
}
......
......@@ -61,7 +61,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
execution_config.create_local_scope = false;
execution_config.skip_gc_vars = job->SkipGcVars();
if (FLAGS_enable_new_ir_in_executor) {
if (FLAGS_enable_new_ir_in_executor && platform::is_cpu_place(place)) {
VLOG(6) << "begin to translate" << std::endl;
auto base_program = paddle::TranslateLegacyProgramToProgram(*program);
auto kernel_program =
......
......@@ -48,8 +48,7 @@ void BuildScope(ir::Block* block,
std::unordered_map<ir::Value, std::string>* name_map) {
std::unordered_map<ir::Value, int> map_test;
// int count = name_map->size();
int count = 0;
int count = name_map->size();
for (auto it = block->begin(); it != block->end(); ++it) {
size_t input_num = (*it)->num_operands();
auto attr_map = (*it)->attributes();
......@@ -69,6 +68,35 @@ void BuildScope(ir::Block* block,
continue;
}
if (op_name == "builtin.set_parameter") {
auto param_name = (*it)
->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
auto in_ptr = (*it)->operand(0);
// change opreand name to param_name
auto orig_name = name_map->at(in_ptr);
(*name_map)[in_ptr] = param_name;
scope->Rename(orig_name, param_name);
continue;
}
if (op_name == "builtin.get_parameter") {
auto param_name = (*it)
->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
auto out_ptr = (*it)->result(0);
name_map->emplace(out_ptr, param_name);
continue;
}
if (op_name == "pd.feed") {
auto ptr = (*it)->result(0);
std::string name = "inner_var_" + std::to_string(count++);
......@@ -123,14 +151,14 @@ void BuildScope(ir::Block* block,
if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) {
auto ptr = (*it)->operand(i);
std::string name;
if (name_map->find(ptr) != name_map->end()) {
name = name_map->at(ptr);
} else {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
op_name));
if (ptr) {
PADDLE_ENFORCE_NE(
name_map->find(ptr),
name_map->end(),
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
op_name));
}
}
}
......@@ -149,7 +177,6 @@ void BuildScope(ir::Block* block,
}
auto var = scope->Var(name);
// Only support DenseTensor or Vector<DenseTensor>
if (!ptr.type()) {
var->GetMutable<phi::DenseTensor>();
} else if (ptr.type()
......
......@@ -71,6 +71,12 @@ void BuildPhiContext(
phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.Name2Id().at(t);
ir::Value ptr = op->operand(index);
if (!ptr) {
phi::DenseTensor* ptr = nullptr;
OutType in_ptr(ptr);
ctx->EmplaceBackInput(in_ptr);
continue;
}
auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
......@@ -142,10 +148,14 @@ void BuildPhiContext(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (attr_type_name == "ir::Int32Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data());
} else if (attr_type_name == "ir::Int64Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int64Attribute>().data());
} else if (attr_type_name == "ir::FloatAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data());
} else if (attr_type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (attr_type_name == "ir::StrAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::StrAttribute>().data());
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int32Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::vector<int32_t> vec_res;
......@@ -160,6 +170,44 @@ void BuildPhiContext(
array_list[i].dyn_cast<ir::Int32Attribute>().data());
}
}
} else if (attr_type_name == "ir::ArrayAttribute<ir::FloatAttribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::vector<float> vec_res;
if (array_list.size() > 0) {
if (array_list[0].isa<ir::FloatAttribute>()) {
for (size_t i = 0; i < array_list.size(); ++i) {
vec_res.push_back(
array_list[i].dyn_cast<ir::FloatAttribute>().data());
}
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
attr_type_name));
}
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") {
std::cerr << "int64 array" << std::endl;
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::cerr << "len " << array_list.size() << std::endl;
std::vector<int64_t> vec_res;
if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ(
array_list[0].isa<ir::Int64Attribute>(),
true,
phi::errors::PreconditionNotMet(
"Element in array list MUST be ir::Int64Attribute "));
std::cerr << "int 64" << std::endl;
for (size_t i = 0; i < array_list.size(); ++i) {
std::cerr << "i " << i << "\t"
<< array_list[i].dyn_cast<ir::Int64Attribute>().data()
<< std::endl;
vec_res.push_back(
array_list[i].dyn_cast<ir::Int64Attribute>().data());
}
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr(
......
......@@ -132,6 +132,9 @@ phi::KernelKey GetKernelKey(
}
auto input_tmp = op->operand(i);
if (!input_tmp) {
continue;
}
auto new_input_tmp = map_value_pair.at(input_tmp);
auto input_type = new_input_tmp.type();
......@@ -264,6 +267,15 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
if ((*it)->num_operands() > 0) {
for (size_t i = 0; i < (*it)->num_operands(); ++i) {
auto cur_in = (*it)->operand(i);
if (!cur_in) {
vec_inputs.push_back(ir::OpResult());
continue;
}
PADDLE_ENFORCE_EQ(
map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair", i, (*it)->name()));
auto new_in = map_value_pair.at(cur_in);
auto new_in_type = new_in.type();
......
......@@ -132,10 +132,6 @@
{ axis : dim, keepdim : keep_dim}
outputs:
out : Out
int_array:
axis :
data_type : int
support_tensor : true
manual_signature : [all]
extra :
attrs : [bool use_mkldnn = false]
......@@ -163,10 +159,6 @@
{ axis : dim, keepdim : keep_dim }
extra :
attrs : [bool use_mkldnn = false]
int_array:
axis :
data_type : int
support_tensor : true
get_expected_kernel_type :
amax_grad : GetReduceGradExpectedKernelType
manual_signature : [amax]
......@@ -181,10 +173,6 @@
{ axis : dim, keepdim : keep_dim }
extra :
attrs : [bool use_mkldnn = false]
int_array:
axis :
data_type : int
support_tensor : true
get_expected_kernel_type :
amin_grad : GetReduceGradExpectedKernelType
manual_signature : [amin]
......@@ -207,10 +195,6 @@
{ axis : dim, keepdim : keep_dim }
extra :
attrs : [bool use_mkldnn = false]
int_array:
axis :
data_type : int
support_tensor : true
get_expected_kernel_type :
any : GetReduceOpUseInputPlaceExpectedKernelType
manual_signature : [any]
......
......@@ -1314,6 +1314,16 @@ foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
FLAGS_new_executor_static_build=true)
endforeach()
set(NEW_IR_COVERAGE_TESTS test_label_smooth_op test_instance_norm_op_v2)
foreach(NEW_IR_COVERAGE_TEST ${NEW_IR_COVERAGE_TESTS})
py_test_modules(
${NEW_IR_COVERAGE_TEST}_new_ir MODULES ${NEW_IR_COVERAGE_TEST} ENVS
FLAGS_enable_new_ir_in_executor=true)
endforeach()
set_tests_properties(test_instance_norm_op_v2_new_ir PROPERTIES TIMEOUT 120)
set_tests_properties(test_decoupled_py_reader_static_build PROPERTIES TIMEOUT
120)
set_tests_properties(test_fuse_bn_act_pass_static_build PROPERTIES TIMEOUT 120)
......
......@@ -380,4 +380,5 @@ class TestPrimForwardAndBackward(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -153,4 +153,5 @@ class TestNumelAPI(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -378,4 +378,5 @@ class TestRandomValue(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -475,4 +475,5 @@ class TestYolov3LossStatic(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册