未验证 提交 28959ee2 编写于 作者: H hong 提交者: GitHub

[NewIR]Fix new IR warprnn op bug (#55161)

* add ir output check in OpTest

* add ir grad check in op test

* fix legacy name converter bug

* add more unittest

* fix

* fix warprnn op bug

* add whit list

* polish code

* polish code

---------
Co-authored-by: Nkangguangli <kangguangli@hotmail.com>
上级 2484545e
...@@ -28,6 +28,9 @@ namespace translator { ...@@ -28,6 +28,9 @@ namespace translator {
using MutableAttributeInfo = std::vector<std::string>; using MutableAttributeInfo = std::vector<std::string>;
static constexpr char kPhiGradSuffix[] = "_grad";
static constexpr char kFluidVarGradSuffix[] = "@GRAD";
class OpNameNormalizer { class OpNameNormalizer {
private: private:
OpNameNormalizer(); // Disallow instantiation outside of the class. OpNameNormalizer(); // Disallow instantiation outside of the class.
...@@ -76,11 +79,11 @@ class OpNameNormalizer { ...@@ -76,11 +79,11 @@ class OpNameNormalizer {
std::string GetLegacyArgName(const std::string& op_type, std::string GetLegacyArgName(const std::string& op_type,
const std::string& arg_name) { const std::string& arg_name) {
bool is_grad_op = (op_type.find("grad") != std::string::npos); bool is_grad_op = (op_type.find(kPhiGradSuffix) != std::string::npos);
bool is_grad_arg = (arg_name.find("grad") != std::string::npos); bool is_grad_arg = (arg_name.find(kPhiGradSuffix) != std::string::npos);
if (is_grad_op && is_grad_arg) { if (is_grad_op && is_grad_arg) {
std::string target = "_grad"; std::string target = kPhiGradSuffix;
std::string data = "@GRAD"; std::string data = kFluidVarGradSuffix;
size_t first_grad_pos = arg_name.find(target); size_t first_grad_pos = arg_name.find(target);
size_t type_pos = op_type.find(target); size_t type_pos = op_type.find(target);
...@@ -95,9 +98,7 @@ class OpNameNormalizer { ...@@ -95,9 +98,7 @@ class OpNameNormalizer {
return legacy_name; return legacy_name;
} else if (is_grad_op && !is_grad_arg) { } else if (is_grad_op && !is_grad_arg) {
// backwward op using forward args: like trace_grad using forward input // backwward op using forward args: like trace_grad using forward input
std::string target = "_grad"; size_t type_pos = op_type.find(kPhiGradSuffix);
size_t type_pos = op_type.find(target);
std::string legacy_name = std::string legacy_name =
this->GetLegacyArgName(op_type.substr(0, type_pos), arg_name); this->GetLegacyArgName(op_type.substr(0, type_pos), arg_name);
......
...@@ -2835,6 +2835,12 @@ ...@@ -2835,6 +2835,12 @@
outputs : outputs :
{warpctcgrad : WarpCTCGrad, loss : Loss} {warpctcgrad : WarpCTCGrad, loss : Loss}
- op : warprnnt
inputs :
{input : input, label : label, input_lengths : input_lengths, label_lengths : label_lengths}
outputs :
{loss : loss, warprnntgrad : warprnntgrad}
- op : where - op : where
backward : where_grad backward : where_grad
inputs : inputs :
......
...@@ -57,3 +57,4 @@ test_triangular_solve_op ...@@ -57,3 +57,4 @@ test_triangular_solve_op
test_trunc_op test_trunc_op
test_unfold_op test_unfold_op
test_unpool3d_op test_unpool3d_op
test_warprnnt_op
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册