未验证 提交 2194e4c1 编写于 作者: H hong 提交者: GitHub

[NewIR]fix new ir edit distance bug (#55294)

* fix edit distance bug

* add op define kernel data type

* fix bug

* update

* add header

* add op test to cmake
上级 6f7ceca0
......@@ -165,6 +165,7 @@ void HandleForSpecialOp(ir::Operation* op,
auto feed_list = feed_var->Get<paddle::framework::FeedList>();
auto& in_tensor = (PADDLE_GET(phi::DenseTensor, feed_list.at(index)));
out_tensor->ShareDataWith(in_tensor);
out_tensor->set_lod(in_tensor.lod());
}
if (op_name == "builtin.combine") {
......
......@@ -32,12 +32,23 @@
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle {
namespace dialect {
const int init_on_gpu_threashold = 1000;
std::unordered_map<std::string, phi::DataType> Str2PhiDataType = {
{"DataType::FLOAT16", phi::DataType::FLOAT16},
{"DataType::BFLOAT16", phi::DataType::BFLOAT16},
{"DataType::FLOAT32", phi::DataType::FLOAT32},
{"DataType::FLOAT64", phi::DataType::FLOAT64},
{"DataType::INT16", phi::DataType::INT16},
{"DataType::INT32", phi::DataType::INT32},
{"DataType::INT64", phi::DataType::INT64},
{"DataType::INT8", phi::DataType::INT8},
{"DataType::BOOL", phi::DataType::BOOL},
};
phi::KernelKey GetKernelKey(
ir::Operation* op,
const phi::Place& place,
......@@ -67,7 +78,10 @@ phi::KernelKey GetKernelKey(
auto slot_name = data_type_info[0];
auto& input_map = op_info_parser->InputName2Id();
if (input_map.count(slot_name)) {
auto find_it = Str2PhiDataType.find(slot_name);
if (find_it != Str2PhiDataType.end()) {
kernel_data_type = find_it->second;
} else if (input_map.count(slot_name)) {
// parse from input
int in_index = input_map.at(slot_name);
......
......@@ -1316,7 +1316,8 @@ foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
FLAGS_new_executor_static_build=true)
endforeach()
set(NEW_IR_COVERAGE_TESTS test_label_smooth_op test_instance_norm_op_v2)
set(NEW_IR_COVERAGE_TESTS test_label_smooth_op test_instance_norm_op_v2
test_edit_distance_op)
foreach(NEW_IR_COVERAGE_TEST ${NEW_IR_COVERAGE_TESTS})
py_test_modules(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册