未验证 提交 a2903920 编写于 作者: H hong 提交者: GitHub

[NewIR]Fix null value and support some attribute (#55100)

* suport optional input in new_ir

* polish code

* add coverate test

* update

* update

* add unitest

* remove reduplicate code

* set test timeout
上级 df6c74c3
...@@ -954,7 +954,9 @@ void BuildOpFuncList( ...@@ -954,7 +954,9 @@ void BuildOpFuncList(
auto op_name = attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().data(); auto op_name = attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().data();
op_func_node.phi_op_name_ = op_name; op_func_node.phi_op_name_ = op_name;
if (op_name == "builtin.combine" || op_name == "pd.feed") { if (op_name == "builtin.combine" || op_name == "pd.feed" ||
op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter") {
VLOG(6) << "skip process " << op_name; VLOG(6) << "skip process " << op_name;
continue; continue;
} }
......
...@@ -61,7 +61,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -61,7 +61,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
execution_config.create_local_scope = false; execution_config.create_local_scope = false;
execution_config.skip_gc_vars = job->SkipGcVars(); execution_config.skip_gc_vars = job->SkipGcVars();
if (FLAGS_enable_new_ir_in_executor) { if (FLAGS_enable_new_ir_in_executor && platform::is_cpu_place(place)) {
VLOG(6) << "begin to translate" << std::endl; VLOG(6) << "begin to translate" << std::endl;
auto base_program = paddle::TranslateLegacyProgramToProgram(*program); auto base_program = paddle::TranslateLegacyProgramToProgram(*program);
auto kernel_program = auto kernel_program =
......
...@@ -48,8 +48,7 @@ void BuildScope(ir::Block* block, ...@@ -48,8 +48,7 @@ void BuildScope(ir::Block* block,
std::unordered_map<ir::Value, std::string>* name_map) { std::unordered_map<ir::Value, std::string>* name_map) {
std::unordered_map<ir::Value, int> map_test; std::unordered_map<ir::Value, int> map_test;
// int count = name_map->size(); int count = name_map->size();
int count = 0;
for (auto it = block->begin(); it != block->end(); ++it) { for (auto it = block->begin(); it != block->end(); ++it) {
size_t input_num = (*it)->num_operands(); size_t input_num = (*it)->num_operands();
auto attr_map = (*it)->attributes(); auto attr_map = (*it)->attributes();
...@@ -69,6 +68,35 @@ void BuildScope(ir::Block* block, ...@@ -69,6 +68,35 @@ void BuildScope(ir::Block* block,
continue; continue;
} }
if (op_name == "builtin.set_parameter") {
auto param_name = (*it)
->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
auto in_ptr = (*it)->operand(0);
// change opreand name to param_name
auto orig_name = name_map->at(in_ptr);
(*name_map)[in_ptr] = param_name;
scope->Rename(orig_name, param_name);
continue;
}
if (op_name == "builtin.get_parameter") {
auto param_name = (*it)
->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
auto out_ptr = (*it)->result(0);
name_map->emplace(out_ptr, param_name);
continue;
}
if (op_name == "pd.feed") { if (op_name == "pd.feed") {
auto ptr = (*it)->result(0); auto ptr = (*it)->result(0);
std::string name = "inner_var_" + std::to_string(count++); std::string name = "inner_var_" + std::to_string(count++);
...@@ -123,14 +151,14 @@ void BuildScope(ir::Block* block, ...@@ -123,14 +151,14 @@ void BuildScope(ir::Block* block,
if (input_num > 0) { if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) { for (size_t i = 0; i < input_num; ++i) {
auto ptr = (*it)->operand(i); auto ptr = (*it)->operand(i);
std::string name; if (ptr) {
if (name_map->find(ptr) != name_map->end()) { PADDLE_ENFORCE_NE(
name = name_map->at(ptr); name_map->find(ptr),
} else { name_map->end(),
PADDLE_THROW(phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op", "input should in name map, [%d] 'th input of [%s] op",
i, i,
op_name)); op_name));
} }
} }
} }
...@@ -149,7 +177,6 @@ void BuildScope(ir::Block* block, ...@@ -149,7 +177,6 @@ void BuildScope(ir::Block* block,
} }
auto var = scope->Var(name); auto var = scope->Var(name);
// Only support DenseTensor or Vector<DenseTensor> // Only support DenseTensor or Vector<DenseTensor>
if (!ptr.type()) { if (!ptr.type()) {
var->GetMutable<phi::DenseTensor>(); var->GetMutable<phi::DenseTensor>();
} else if (ptr.type() } else if (ptr.type()
......
...@@ -71,6 +71,12 @@ void BuildPhiContext( ...@@ -71,6 +71,12 @@ void BuildPhiContext(
phi::errors::NotFound("param [%s] MUST in name2id map", t)); phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.Name2Id().at(t); auto index = op_yaml_info.Name2Id().at(t);
ir::Value ptr = op->operand(index); ir::Value ptr = op->operand(index);
if (!ptr) {
phi::DenseTensor* ptr = nullptr;
OutType in_ptr(ptr);
ctx->EmplaceBackInput(in_ptr);
continue;
}
auto in_var_name = name_map.at(ptr); auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name; VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
...@@ -142,10 +148,14 @@ void BuildPhiContext( ...@@ -142,10 +148,14 @@ void BuildPhiContext(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (attr_type_name == "ir::Int32Attribute") { } else if (attr_type_name == "ir::Int32Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data());
} else if (attr_type_name == "ir::Int64Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int64Attribute>().data());
} else if (attr_type_name == "ir::FloatAttribute") { } else if (attr_type_name == "ir::FloatAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::FloatAttribute>().data());
} else if (attr_type_name == "ir::BoolAttribute") { } else if (attr_type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (attr_type_name == "ir::StrAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::StrAttribute>().data());
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int32Attribute>") { } else if (attr_type_name == "ir::ArrayAttribute<ir::Int32Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data(); auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::vector<int32_t> vec_res; std::vector<int32_t> vec_res;
...@@ -160,6 +170,44 @@ void BuildPhiContext( ...@@ -160,6 +170,44 @@ void BuildPhiContext(
array_list[i].dyn_cast<ir::Int32Attribute>().data()); array_list[i].dyn_cast<ir::Int32Attribute>().data());
} }
} }
} else if (attr_type_name == "ir::ArrayAttribute<ir::FloatAttribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::vector<float> vec_res;
if (array_list.size() > 0) {
if (array_list[0].isa<ir::FloatAttribute>()) {
for (size_t i = 0; i < array_list.size(); ++i) {
vec_res.push_back(
array_list[i].dyn_cast<ir::FloatAttribute>().data());
}
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
attr_type_name));
}
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") {
std::cerr << "int64 array" << std::endl;
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::cerr << "len " << array_list.size() << std::endl;
std::vector<int64_t> vec_res;
if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ(
array_list[0].isa<ir::Int64Attribute>(),
true,
phi::errors::PreconditionNotMet(
"Element in array list MUST be ir::Int64Attribute "));
std::cerr << "int 64" << std::endl;
for (size_t i = 0; i < array_list.size(); ++i) {
std::cerr << "i " << i << "\t"
<< array_list[i].dyn_cast<ir::Int64Attribute>().data()
<< std::endl;
vec_res.push_back(
array_list[i].dyn_cast<ir::Int64Attribute>().data());
}
}
ctx->EmplaceBackAttr(vec_res); ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "paddle::dialect::PlaceAttribute") { } else if (attr_type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
......
...@@ -132,6 +132,9 @@ phi::KernelKey GetKernelKey( ...@@ -132,6 +132,9 @@ phi::KernelKey GetKernelKey(
} }
auto input_tmp = op->operand(i); auto input_tmp = op->operand(i);
if (!input_tmp) {
continue;
}
auto new_input_tmp = map_value_pair.at(input_tmp); auto new_input_tmp = map_value_pair.at(input_tmp);
auto input_type = new_input_tmp.type(); auto input_type = new_input_tmp.type();
...@@ -264,6 +267,15 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { ...@@ -264,6 +267,15 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
if ((*it)->num_operands() > 0) { if ((*it)->num_operands() > 0) {
for (size_t i = 0; i < (*it)->num_operands(); ++i) { for (size_t i = 0; i < (*it)->num_operands(); ++i) {
auto cur_in = (*it)->operand(i); auto cur_in = (*it)->operand(i);
if (!cur_in) {
vec_inputs.push_back(ir::OpResult());
continue;
}
PADDLE_ENFORCE_EQ(
map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair", i, (*it)->name()));
auto new_in = map_value_pair.at(cur_in); auto new_in = map_value_pair.at(cur_in);
auto new_in_type = new_in.type(); auto new_in_type = new_in.type();
......
...@@ -132,10 +132,6 @@ ...@@ -132,10 +132,6 @@
{ axis : dim, keepdim : keep_dim} { axis : dim, keepdim : keep_dim}
outputs: outputs:
out : Out out : Out
int_array:
axis :
data_type : int
support_tensor : true
manual_signature : [all] manual_signature : [all]
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
...@@ -163,10 +159,6 @@ ...@@ -163,10 +159,6 @@
{ axis : dim, keepdim : keep_dim } { axis : dim, keepdim : keep_dim }
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
int_array:
axis :
data_type : int
support_tensor : true
get_expected_kernel_type : get_expected_kernel_type :
amax_grad : GetReduceGradExpectedKernelType amax_grad : GetReduceGradExpectedKernelType
manual_signature : [amax] manual_signature : [amax]
...@@ -181,10 +173,6 @@ ...@@ -181,10 +173,6 @@
{ axis : dim, keepdim : keep_dim } { axis : dim, keepdim : keep_dim }
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
int_array:
axis :
data_type : int
support_tensor : true
get_expected_kernel_type : get_expected_kernel_type :
amin_grad : GetReduceGradExpectedKernelType amin_grad : GetReduceGradExpectedKernelType
manual_signature : [amin] manual_signature : [amin]
...@@ -207,10 +195,6 @@ ...@@ -207,10 +195,6 @@
{ axis : dim, keepdim : keep_dim } { axis : dim, keepdim : keep_dim }
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
int_array:
axis :
data_type : int
support_tensor : true
get_expected_kernel_type : get_expected_kernel_type :
any : GetReduceOpUseInputPlaceExpectedKernelType any : GetReduceOpUseInputPlaceExpectedKernelType
manual_signature : [any] manual_signature : [any]
......
...@@ -1314,6 +1314,16 @@ foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS}) ...@@ -1314,6 +1314,16 @@ foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
FLAGS_new_executor_static_build=true) FLAGS_new_executor_static_build=true)
endforeach() endforeach()
set(NEW_IR_COVERAGE_TESTS test_label_smooth_op test_instance_norm_op_v2)
foreach(NEW_IR_COVERAGE_TEST ${NEW_IR_COVERAGE_TESTS})
py_test_modules(
${NEW_IR_COVERAGE_TEST}_new_ir MODULES ${NEW_IR_COVERAGE_TEST} ENVS
FLAGS_enable_new_ir_in_executor=true)
endforeach()
set_tests_properties(test_instance_norm_op_v2_new_ir PROPERTIES TIMEOUT 120)
set_tests_properties(test_decoupled_py_reader_static_build PROPERTIES TIMEOUT set_tests_properties(test_decoupled_py_reader_static_build PROPERTIES TIMEOUT
120) 120)
set_tests_properties(test_fuse_bn_act_pass_static_build PROPERTIES TIMEOUT 120) set_tests_properties(test_fuse_bn_act_pass_static_build PROPERTIES TIMEOUT 120)
......
...@@ -380,4 +380,5 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -380,4 +380,5 @@ class TestPrimForwardAndBackward(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -153,4 +153,5 @@ class TestNumelAPI(unittest.TestCase): ...@@ -153,4 +153,5 @@ class TestNumelAPI(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -378,4 +378,5 @@ class TestRandomValue(unittest.TestCase): ...@@ -378,4 +378,5 @@ class TestRandomValue(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
...@@ -475,4 +475,5 @@ class TestYolov3LossStatic(unittest.TestCase): ...@@ -475,4 +475,5 @@ class TestYolov3LossStatic(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册