diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_info.h b/paddle/fluid/ir_adaptor/translator/op_compat_info.h index b07e31c8a13eed56ef765ce7c84464b21a6e3eb1..1bb2f4d8131b4af00c27c3e8b74f2b7ea4c8bdc3 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_info.h +++ b/paddle/fluid/ir_adaptor/translator/op_compat_info.h @@ -28,6 +28,9 @@ namespace translator { using MutableAttributeInfo = std::vector; +static constexpr char kPhiGradSuffix[] = "_grad"; +static constexpr char kFluidVarGradSuffix[] = "@GRAD"; + class OpNameNormalizer { private: OpNameNormalizer(); // Disallow instantiation outside of the class. @@ -76,11 +79,11 @@ class OpNameNormalizer { std::string GetLegacyArgName(const std::string& op_type, const std::string& arg_name) { - bool is_grad_op = (op_type.find("grad") != std::string::npos); - bool is_grad_arg = (arg_name.find("grad") != std::string::npos); + bool is_grad_op = (op_type.find(kPhiGradSuffix) != std::string::npos); + bool is_grad_arg = (arg_name.find(kPhiGradSuffix) != std::string::npos); if (is_grad_op && is_grad_arg) { - std::string target = "_grad"; - std::string data = "@GRAD"; + std::string target = kPhiGradSuffix; + std::string data = kFluidVarGradSuffix; size_t first_grad_pos = arg_name.find(target); size_t type_pos = op_type.find(target); @@ -95,9 +98,7 @@ class OpNameNormalizer { return legacy_name; } else if (is_grad_op && !is_grad_arg) { // backwward op using forward args: like trace_grad using forward input - std::string target = "_grad"; - - size_t type_pos = op_type.find(target); + size_t type_pos = op_type.find(kPhiGradSuffix); std::string legacy_name = this->GetLegacyArgName(op_type.substr(0, type_pos), arg_name); diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 8e8cca4296e7d05e79f4837886eb63fde45b0e78..109fc6946ce2c61a18311a307ebad0fe08eccccf 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2835,6 +2835,12 @@ outputs : {warpctcgrad : WarpCTCGrad, loss : Loss} +- op : warprnnt + inputs : + {input : input, label : label, input_lengths : input_lengths, label_lengths : label_lengths} + outputs : + {loss : loss, warprnntgrad : warprnntgrad} + - op : where backward : where_grad inputs : diff --git a/test/white_list/new_ir_op_test_white_list b/test/white_list/new_ir_op_test_white_list index 0c61420afc90f1580c8817b5e732a99959d45579..4504eaae470616dd7dcbb7906184d65149cb8e50 100644 --- a/test/white_list/new_ir_op_test_white_list +++ b/test/white_list/new_ir_op_test_white_list @@ -57,3 +57,4 @@ test_triangular_solve_op test_trunc_op test_unfold_op test_unpool3d_op +test_warprnnt_op