提交 2f60e4a7 编写于 作者: R rensilin

fix

Change-Id: I7f5b40e412b045082b7dad1cbc1e346a9e0ca14d
上级 4fc0ae30
...@@ -25,16 +25,11 @@ namespace paddle { ...@@ -25,16 +25,11 @@ namespace paddle {
namespace custom_trainer { namespace custom_trainer {
namespace feed { namespace feed {
<<<<<<< HEAD
TEST(testSimpleExecutor, initialize) {
SimpleExecutor execute;
auto context_ptr = std::make_shared<TrainerContext>();
=======
const char test_data_dir[] = "test_data"; const char test_data_dir[] = "test_data";
const char main_program_path[] = "test_data/main_program"; const char main_program_path[] = "test_data/main_program";
const char startup_program_path[] = "test_data/startup_program"; const char startup_program_path[] = "test_data/startup_program";
class SimpleExecuteTest : public testing::Test class SimpleExecutorTest : public testing::Test
{ {
public: public:
static void SetUpTestCase() static void SetUpTestCase()
...@@ -85,52 +80,23 @@ public: ...@@ -85,52 +80,23 @@ public:
std::shared_ptr<TrainerContext> context_ptr; std::shared_ptr<TrainerContext> context_ptr;
}; };
TEST_F(SimpleExecuteTest, initialize) { TEST_F(SimpleExecutorTest, initialize) {
SimpleExecute execute; SimpleExecutor executor;
>>>>>>> add_executor_ut
YAML::Node config = YAML::Load("[1, 2, 3]"); YAML::Node config = YAML::Load("[1, 2, 3]");
ASSERT_NE(0, execute.initialize(config, context_ptr)); ASSERT_NE(0, executor.initialize(config, context_ptr));
config = YAML::Load(std::string() + "{startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}"); config = YAML::Load(std::string() + "{startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
ASSERT_EQ(0, execute.initialize(config, context_ptr)); ASSERT_EQ(0, executor.initialize(config, context_ptr));
config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}"); config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
ASSERT_EQ(0, execute.initialize(config, context_ptr)); ASSERT_EQ(0, executor.initialize(config, context_ptr));
} }
<<<<<<< HEAD TEST_F(SimpleExecutorTest, run) {
float uniform(float min, float max) { SimpleExecutor executor;
float result = (float)rand() / RAND_MAX;
return min + result * (max - min);
}
void next_batch(int batch_size, const paddle::platform::Place& place, paddle::framework::LoDTensor* x_tensor, paddle::framework::LoDTensor* y_tensor) {
x_tensor->Resize({batch_size, 2});
auto x_data = x_tensor->mutable_data<float>(place);
y_tensor->Resize({batch_size, 1});
auto y_data = y_tensor->mutable_data<float>(place);
for (int i = 0; i < batch_size; ++i) {
x_data[i * 2] = uniform(-2, 2);
x_data[i * 2 + 1] = uniform(-2, 2);
float dis = x_data[i * 2] * x_data[i * 2] + x_data[i * 2 + 1] * x_data[i * 2 + 1];
y_data[i] = dis < 1.0 ? 1.0 : 0.0;
}
}
TEST(testSimpleExecutor, run) {
SimpleExecutor execute;
auto context_ptr = std::make_shared<TrainerContext>();
auto config = YAML::Load("{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}");
=======
TEST_F(SimpleExecuteTest, run) {
SimpleExecute execute;
auto config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}"); auto config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
>>>>>>> add_executor_ut ASSERT_EQ(0, executor.initialize(config, context_ptr));
ASSERT_EQ(0, execute.initialize(config, context_ptr));
auto x_var = execute.mutable_var<::paddle::framework::LoDTensor>("x"); auto x_var = executor.mutable_var<::paddle::framework::LoDTensor>("x");
execute.mutable_var<::paddle::framework::LoDTensor>("mean"); executor.mutable_var<::paddle::framework::LoDTensor>("mean");
ASSERT_NE(nullptr, x_var); ASSERT_NE(nullptr, x_var);
int x_len = 10; int x_len = 10;
...@@ -143,9 +109,9 @@ TEST_F(SimpleExecuteTest, run) { ...@@ -143,9 +109,9 @@ TEST_F(SimpleExecuteTest, run) {
} }
std::cout << std::endl; std::cout << std::endl;
ASSERT_EQ(0, execute.run()); ASSERT_EQ(0, executor.run());
auto mean_var = execute.var<::paddle::framework::LoDTensor>("mean"); auto mean_var = executor.var<::paddle::framework::LoDTensor>("mean");
auto mean = mean_var.data<float>()[0]; auto mean = mean_var.data<float>()[0];
std::cout << "mean: " << mean << std::endl; std::cout << "mean: " << mean << std::endl;
ASSERT_NEAR(4.5, mean, 1e-9); ASSERT_NEAR(4.5, mean, 1e-9);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册