提交 ef5c2be4 编写于 作者: R rensilin

update_test_create_pregrams

Change-Id: Ibd81678790415b2b756a8a1b8cce04369e6e6947
上级 31555637
...@@ -101,11 +101,11 @@ application_args = [ ...@@ -101,11 +101,11 @@ application_args = [
Libs(module='baidu/third-party/openmpi', libs=['libmpi.so', 'libmpi_cxx.so', 'libopen-pal.so', 'libopen-rte.so']), Libs(module='baidu/third-party/openmpi', libs=['libmpi.so', 'libmpi_cxx.so', 'libopen-pal.so', 'libopen-rte.so']),
] ]
StaticLibrary("feed_trainer", Sources(custom_trainer_src), application_args) StaticLibrary("feed_trainer", Sources(custom_trainer_src), *application_args)
Application('feed_trainer', Sources('paddle/fluid/train/custom_trainer/feed/main.cc'), WholeArchives("$OUT/lib/libfeed_trainer.a"), application_args) Application('feed_trainer', Sources('paddle/fluid/train/custom_trainer/feed/main.cc'), WholeArchives("$OUT/lib/libfeed_trainer.a"), *application_args)
#feed unit test #feed unit test
# bug: shared librarys can not be found when run on server # bug: shared librarys can not be found when run on server
UTApplication('unit_test', UTOnServer(False), Sources(UT_FILE('main.cc'), GLOB(UT_FILE('test_*.cc'))), WholeArchives("$OUT/lib/libfeed_trainer.a"), application_args) UTApplication('unit_test', UTOnServer(False), Sources(UT_FILE('main.cc'), GLOB(UT_FILE('test_*.cc'))), WholeArchives("$OUT/lib/libfeed_trainer.a"), *application_args)
...@@ -131,7 +131,6 @@ public: ...@@ -131,7 +131,6 @@ public:
return read_all(file_list, data_channel); return read_all(file_list, data_channel);
} }
virtual int read_all(const std::vector<std::string>& file_list, ::paddle::framework::Channel<DataItem> data_channel) { virtual int read_all(const std::vector<std::string>& file_list, ::paddle::framework::Channel<DataItem> data_channel) {
DataItem data_item;
const int file_list_size = file_list.size(); const int file_list_size = file_list.size();
std::atomic<bool> is_failed(false); std::atomic<bool> is_failed(false);
......
...@@ -77,21 +77,35 @@ TEST_F(CreateProgramsTest, example_network) { ...@@ -77,21 +77,35 @@ TEST_F(CreateProgramsTest, example_network) {
std::string input_name = "cvm_input"; std::string input_name = "cvm_input";
ASSERT_TRUE(model_desc["inputs"]); ASSERT_TRUE(model_desc["inputs"]);
ASSERT_TRUE(model_desc["inputs"][input_name]); ASSERT_EQ(1, model_desc["inputs"].size());
ASSERT_TRUE(model_desc["loss_name"]); ASSERT_TRUE(model_desc["inputs"][0]["name"]);
ASSERT_TRUE(model_desc["label_name"]); ASSERT_TRUE(model_desc["inputs"][0]["shape"]);
ASSERT_TRUE(model_desc["ctr_output_name"]); ASSERT_EQ(input_name, model_desc["inputs"][0]["name"].as<std::string>());
auto loss_name = model_desc["loss_name"].as<std::string>(); std::vector<int> input_shape = model_desc["inputs"][0]["shape"].as<std::vector<int>>(std::vector<int>());
auto label_name = model_desc["label_name"].as<std::string>();
auto ctr_output_name = model_desc["ctr_output_name"].as<std::string>();
std::vector<int> input_shape = model_desc["inputs"][input_name].as<std::vector<int>>(std::vector<int>());
ASSERT_EQ(2, input_shape.size()); ASSERT_EQ(2, input_shape.size());
ASSERT_EQ(-1, input_shape[0]); ASSERT_EQ(-1, input_shape[0]);
ASSERT_EQ(4488, input_shape[1]); ASSERT_EQ(4488, input_shape[1]);
ASSERT_TRUE(model_desc["loss_all"]);
auto loss_all_name = model_desc["loss_all"].as<std::string>();
ASSERT_TRUE(model_desc["outputs"]);
ASSERT_EQ(1, model_desc["outputs"].size());
ASSERT_TRUE(model_desc["outputs"][0]["name"]);
ASSERT_TRUE(model_desc["outputs"][0]["shape"]);
ASSERT_TRUE(model_desc["outputs"][0]["label_name"]);
ASSERT_TRUE(model_desc["outputs"][0]["loss_name"]);
auto ctr_output_label_name = model_desc["outputs"][0]["label_name"].as<std::string>();
auto ctr_output_loss_name = model_desc["outputs"][0]["loss_name"].as<std::string>();
auto ctr_output_name = model_desc["outputs"][0]["name"].as<std::string>();
std::vector<int> output_shape = model_desc["outputs"][0]["shape"].as<std::vector<int>>(std::vector<int>());
ASSERT_EQ(2, output_shape.size());
ASSERT_EQ(-1, output_shape[0]);
ASSERT_EQ(1, output_shape[1]);
auto input_var = executor->mutable_var<::paddle::framework::LoDTensor>(input_name); auto input_var = executor->mutable_var<::paddle::framework::LoDTensor>(input_name);
auto label_var = executor->mutable_var<::paddle::framework::LoDTensor>(label_name); auto label_var = executor->mutable_var<::paddle::framework::LoDTensor>(ctr_output_label_name);
ASSERT_NE(nullptr, input_var); ASSERT_NE(nullptr, input_var);
ASSERT_NE(nullptr, label_var); ASSERT_NE(nullptr, label_var);
...@@ -109,14 +123,18 @@ TEST_F(CreateProgramsTest, example_network) { ...@@ -109,14 +123,18 @@ TEST_F(CreateProgramsTest, example_network) {
ASSERT_EQ(0, executor->run()); ASSERT_EQ(0, executor->run());
auto loss_var = executor->var<::paddle::framework::LoDTensor>(loss_name); auto loss_var = executor->var<::paddle::framework::LoDTensor>(ctr_output_loss_name);
auto loss = loss_var.data<float>()[0]; auto loss = loss_var.data<float>()[0];
auto loss_all_var = executor->var<::paddle::framework::LoDTensor>(loss_all_name);
auto loss_all = loss_all_var.data<float>()[0];
auto ctr_output_var = executor->var<::paddle::framework::LoDTensor>(ctr_output_name); auto ctr_output_var = executor->var<::paddle::framework::LoDTensor>(ctr_output_name);
auto ctr_output = ctr_output_var.data<float>()[0]; auto ctr_output = ctr_output_var.data<float>()[0];
std::cout << "loss: " << loss << std::endl; std::cout << "loss: " << loss << std::endl;
std::cout << "ctr_output: " << ctr_output << std::endl; std::cout << "ctr_output: " << ctr_output << std::endl;
ASSERT_NEAR(loss, loss_all, 1e-9);
} }
} // namespace feed } // namespace feed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册