From 858e90323101236027aaec8e5685e97b0fb4f201 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 28 Dec 2018 16:11:18 +0800 Subject: [PATCH] Add unittest for operator test=develop --- paddle/fluid/framework/operator.h | 2 +- paddle/fluid/framework/operator_test.cc | 24 +++++++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 87bb28c0c55..51de4c9dfb5 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -70,7 +70,7 @@ inline std::string GradVarName(const std::string& var_name) { } inline std::string OriginVarName(const std::string& grad_var_name) { - std::size_t pos = grad_var_name.find_last_of(kGradVarSuffix); + std::size_t pos = grad_var_name.rfind(kGradVarSuffix); if (pos == std::string::npos) { return grad_var_name; } else { diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index 1623dfca6f2..3bbbda6424c 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -289,11 +289,29 @@ TEST(OpKernel, multi_inputs) { op->Run(scope, cpu_place); } -TEST(Functions, all) { +TEST(VarNameTest, all) { std::string var_name("X"); std::string grad_var_name = paddle::framework::GradVarName(var_name); - ASSERT_EQ(grad_var_name.c_str(), "X@GRAD"); + ASSERT_EQ(grad_var_name, "X@GRAD"); std::string original_var_name = paddle::framework::OriginVarName(grad_var_name); - ASSERT_EQ(original_var_name.c_str(), "X"); + ASSERT_EQ(original_var_name, "X"); + original_var_name = paddle::framework::OriginVarName(original_var_name); + ASSERT_EQ(original_var_name, "X"); + + std::string var_name_2("XYZ"); + grad_var_name = paddle::framework::GradVarName(var_name_2); + ASSERT_EQ(grad_var_name, "XYZ@GRAD"); + original_var_name = paddle::framework::OriginVarName(grad_var_name); + ASSERT_EQ(original_var_name, "XYZ"); + original_var_name = paddle::framework::OriginVarName(original_var_name); + ASSERT_EQ(original_var_name, "XYZ"); + + std::string var_name_3(""); + grad_var_name = paddle::framework::GradVarName(var_name_3); + ASSERT_EQ(grad_var_name, "@GRAD"); + original_var_name = paddle::framework::OriginVarName(grad_var_name); + ASSERT_EQ(original_var_name, ""); + original_var_name = paddle::framework::OriginVarName(original_var_name); + ASSERT_EQ(original_var_name, ""); } -- GitLab