未验证 提交 28959ee2 编写于 作者: H hong 提交者: GitHub

[NewIR]Fix new IR warprnn op bug (#55161)

* add ir output check in OpTest

* add ir grad check in op test

* fix legacy name converter bug

* add more unittest

* fix

* fix warprnn op bug

* add whit list

* polish code

* polish code

---------
Co-authored-by: Nkangguangli <kangguangli@hotmail.com>
上级 2484545e
......@@ -28,6 +28,9 @@ namespace translator {
using MutableAttributeInfo = std::vector<std::string>;
static constexpr char kPhiGradSuffix[] = "_grad";
static constexpr char kFluidVarGradSuffix[] = "@GRAD";
class OpNameNormalizer {
private:
OpNameNormalizer(); // Disallow instantiation outside of the class.
......@@ -76,11 +79,11 @@ class OpNameNormalizer {
std::string GetLegacyArgName(const std::string& op_type,
const std::string& arg_name) {
bool is_grad_op = (op_type.find("grad") != std::string::npos);
bool is_grad_arg = (arg_name.find("grad") != std::string::npos);
bool is_grad_op = (op_type.find(kPhiGradSuffix) != std::string::npos);
bool is_grad_arg = (arg_name.find(kPhiGradSuffix) != std::string::npos);
if (is_grad_op && is_grad_arg) {
std::string target = "_grad";
std::string data = "@GRAD";
std::string target = kPhiGradSuffix;
std::string data = kFluidVarGradSuffix;
size_t first_grad_pos = arg_name.find(target);
size_t type_pos = op_type.find(target);
......@@ -95,9 +98,7 @@ class OpNameNormalizer {
return legacy_name;
} else if (is_grad_op && !is_grad_arg) {
// backwward op using forward args: like trace_grad using forward input
std::string target = "_grad";
size_t type_pos = op_type.find(target);
size_t type_pos = op_type.find(kPhiGradSuffix);
std::string legacy_name =
this->GetLegacyArgName(op_type.substr(0, type_pos), arg_name);
......
......@@ -2835,6 +2835,12 @@
outputs :
{warpctcgrad : WarpCTCGrad, loss : Loss}
- op : warprnnt
inputs :
{input : input, label : label, input_lengths : input_lengths, label_lengths : label_lengths}
outputs :
{loss : loss, warprnntgrad : warprnntgrad}
- op : where
backward : where_grad
inputs :
......
......@@ -57,3 +57,4 @@ test_triangular_solve_op
test_trunc_op
test_unfold_op
test_unpool3d_op
test_warprnnt_op
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册