未验证 提交 2194e4c1 编写于 作者: H hong 提交者: GitHub

[NewIR]fix new ir edit distance bug (#55294)

* fix edit distance bug

* add op define kernel data type

* fix bug

* update

* add header

* add op test to cmake
上级 6f7ceca0
...@@ -165,6 +165,7 @@ void HandleForSpecialOp(ir::Operation* op, ...@@ -165,6 +165,7 @@ void HandleForSpecialOp(ir::Operation* op,
auto feed_list = feed_var->Get<paddle::framework::FeedList>(); auto feed_list = feed_var->Get<paddle::framework::FeedList>();
auto& in_tensor = (PADDLE_GET(phi::DenseTensor, feed_list.at(index))); auto& in_tensor = (PADDLE_GET(phi::DenseTensor, feed_list.at(index)));
out_tensor->ShareDataWith(in_tensor); out_tensor->ShareDataWith(in_tensor);
out_tensor->set_lod(in_tensor.lod());
} }
if (op_name == "builtin.combine") { if (op_name == "builtin.combine") {
......
...@@ -32,12 +32,23 @@ ...@@ -32,12 +32,23 @@
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
const int init_on_gpu_threashold = 1000; const int init_on_gpu_threashold = 1000;
std::unordered_map<std::string, phi::DataType> Str2PhiDataType = {
{"DataType::FLOAT16", phi::DataType::FLOAT16},
{"DataType::BFLOAT16", phi::DataType::BFLOAT16},
{"DataType::FLOAT32", phi::DataType::FLOAT32},
{"DataType::FLOAT64", phi::DataType::FLOAT64},
{"DataType::INT16", phi::DataType::INT16},
{"DataType::INT32", phi::DataType::INT32},
{"DataType::INT64", phi::DataType::INT64},
{"DataType::INT8", phi::DataType::INT8},
{"DataType::BOOL", phi::DataType::BOOL},
};
phi::KernelKey GetKernelKey( phi::KernelKey GetKernelKey(
ir::Operation* op, ir::Operation* op,
const phi::Place& place, const phi::Place& place,
...@@ -67,7 +78,10 @@ phi::KernelKey GetKernelKey( ...@@ -67,7 +78,10 @@ phi::KernelKey GetKernelKey(
auto slot_name = data_type_info[0]; auto slot_name = data_type_info[0];
auto& input_map = op_info_parser->InputName2Id(); auto& input_map = op_info_parser->InputName2Id();
if (input_map.count(slot_name)) { auto find_it = Str2PhiDataType.find(slot_name);
if (find_it != Str2PhiDataType.end()) {
kernel_data_type = find_it->second;
} else if (input_map.count(slot_name)) {
// parse from input // parse from input
int in_index = input_map.at(slot_name); int in_index = input_map.at(slot_name);
......
...@@ -1316,7 +1316,8 @@ foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS}) ...@@ -1316,7 +1316,8 @@ foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
FLAGS_new_executor_static_build=true) FLAGS_new_executor_static_build=true)
endforeach() endforeach()
set(NEW_IR_COVERAGE_TESTS test_label_smooth_op test_instance_norm_op_v2) set(NEW_IR_COVERAGE_TESTS test_label_smooth_op test_instance_norm_op_v2
test_edit_distance_op)
foreach(NEW_IR_COVERAGE_TEST ${NEW_IR_COVERAGE_TESTS}) foreach(NEW_IR_COVERAGE_TEST ${NEW_IR_COVERAGE_TESTS})
py_test_modules( py_test_modules(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册