diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 87bb28c0c55efcf709e31ed3d43144bb528ce0f8..51de4c9dfb583b8e8a7fe8d92d5be07d32609c67 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -70,7 +70,7 @@ inline std::string GradVarName(const std::string& var_name) { } inline std::string OriginVarName(const std::string& grad_var_name) { - std::size_t pos = grad_var_name.find_last_of(kGradVarSuffix); + std::size_t pos = grad_var_name.rfind(kGradVarSuffix); if (pos == std::string::npos) { return grad_var_name; } else { diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index 1623dfca6f2e3575afc70f4feabb63d514f9c518..3bbbda6424ca110eabe31f232207c7562d4c8fc5 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -289,11 +289,29 @@ TEST(OpKernel, multi_inputs) { op->Run(scope, cpu_place); } -TEST(Functions, all) { +TEST(VarNameTest, all) { std::string var_name("X"); std::string grad_var_name = paddle::framework::GradVarName(var_name); - ASSERT_EQ(grad_var_name.c_str(), "X@GRAD"); + ASSERT_EQ(grad_var_name, "X@GRAD"); std::string original_var_name = paddle::framework::OriginVarName(grad_var_name); - ASSERT_EQ(original_var_name.c_str(), "X"); + ASSERT_EQ(original_var_name, "X"); + original_var_name = paddle::framework::OriginVarName(original_var_name); + ASSERT_EQ(original_var_name, "X"); + + std::string var_name_2("XYZ"); + grad_var_name = paddle::framework::GradVarName(var_name_2); + ASSERT_EQ(grad_var_name, "XYZ@GRAD"); + original_var_name = paddle::framework::OriginVarName(grad_var_name); + ASSERT_EQ(original_var_name, "XYZ"); + original_var_name = paddle::framework::OriginVarName(original_var_name); + ASSERT_EQ(original_var_name, "XYZ"); + + std::string var_name_3(""); + grad_var_name = paddle::framework::GradVarName(var_name_3); + ASSERT_EQ(grad_var_name, "@GRAD"); + original_var_name = paddle::framework::OriginVarName(grad_var_name); + ASSERT_EQ(original_var_name, ""); + original_var_name = paddle::framework::OriginVarName(original_var_name); + ASSERT_EQ(original_var_name, ""); }