提交 ef5c2be4 编写于 作者: R rensilin

update_test_create_pregrams

Change-Id: Ibd81678790415b2b756a8a1b8cce04369e6e6947
上级 31555637
......@@ -101,11 +101,11 @@ application_args = [
Libs(module='baidu/third-party/openmpi', libs=['libmpi.so', 'libmpi_cxx.so', 'libopen-pal.so', 'libopen-rte.so']),
]
StaticLibrary("feed_trainer", Sources(custom_trainer_src), application_args)
StaticLibrary("feed_trainer", Sources(custom_trainer_src), *application_args)
Application('feed_trainer', Sources('paddle/fluid/train/custom_trainer/feed/main.cc'), WholeArchives("$OUT/lib/libfeed_trainer.a"), application_args)
Application('feed_trainer', Sources('paddle/fluid/train/custom_trainer/feed/main.cc'), WholeArchives("$OUT/lib/libfeed_trainer.a"), *application_args)
#feed unit test
# bug: shared librarys can not be found when run on server
UTApplication('unit_test', UTOnServer(False), Sources(UT_FILE('main.cc'), GLOB(UT_FILE('test_*.cc'))), WholeArchives("$OUT/lib/libfeed_trainer.a"), application_args)
UTApplication('unit_test', UTOnServer(False), Sources(UT_FILE('main.cc'), GLOB(UT_FILE('test_*.cc'))), WholeArchives("$OUT/lib/libfeed_trainer.a"), *application_args)
......@@ -131,7 +131,6 @@ public:
return read_all(file_list, data_channel);
}
virtual int read_all(const std::vector<std::string>& file_list, ::paddle::framework::Channel<DataItem> data_channel) {
DataItem data_item;
const int file_list_size = file_list.size();
std::atomic<bool> is_failed(false);
......
......@@ -77,21 +77,35 @@ TEST_F(CreateProgramsTest, example_network) {
std::string input_name = "cvm_input";
ASSERT_TRUE(model_desc["inputs"]);
ASSERT_TRUE(model_desc["inputs"][input_name]);
ASSERT_TRUE(model_desc["loss_name"]);
ASSERT_TRUE(model_desc["label_name"]);
ASSERT_TRUE(model_desc["ctr_output_name"]);
auto loss_name = model_desc["loss_name"].as<std::string>();
auto label_name = model_desc["label_name"].as<std::string>();
auto ctr_output_name = model_desc["ctr_output_name"].as<std::string>();
std::vector<int> input_shape = model_desc["inputs"][input_name].as<std::vector<int>>(std::vector<int>());
ASSERT_EQ(1, model_desc["inputs"].size());
ASSERT_TRUE(model_desc["inputs"][0]["name"]);
ASSERT_TRUE(model_desc["inputs"][0]["shape"]);
ASSERT_EQ(input_name, model_desc["inputs"][0]["name"].as<std::string>());
std::vector<int> input_shape = model_desc["inputs"][0]["shape"].as<std::vector<int>>(std::vector<int>());
ASSERT_EQ(2, input_shape.size());
ASSERT_EQ(-1, input_shape[0]);
ASSERT_EQ(4488, input_shape[1]);
ASSERT_TRUE(model_desc["loss_all"]);
auto loss_all_name = model_desc["loss_all"].as<std::string>();
ASSERT_TRUE(model_desc["outputs"]);
ASSERT_EQ(1, model_desc["outputs"].size());
ASSERT_TRUE(model_desc["outputs"][0]["name"]);
ASSERT_TRUE(model_desc["outputs"][0]["shape"]);
ASSERT_TRUE(model_desc["outputs"][0]["label_name"]);
ASSERT_TRUE(model_desc["outputs"][0]["loss_name"]);
auto ctr_output_label_name = model_desc["outputs"][0]["label_name"].as<std::string>();
auto ctr_output_loss_name = model_desc["outputs"][0]["loss_name"].as<std::string>();
auto ctr_output_name = model_desc["outputs"][0]["name"].as<std::string>();
std::vector<int> output_shape = model_desc["outputs"][0]["shape"].as<std::vector<int>>(std::vector<int>());
ASSERT_EQ(2, output_shape.size());
ASSERT_EQ(-1, output_shape[0]);
ASSERT_EQ(1, output_shape[1]);
auto input_var = executor->mutable_var<::paddle::framework::LoDTensor>(input_name);
auto label_var = executor->mutable_var<::paddle::framework::LoDTensor>(label_name);
auto label_var = executor->mutable_var<::paddle::framework::LoDTensor>(ctr_output_label_name);
ASSERT_NE(nullptr, input_var);
ASSERT_NE(nullptr, label_var);
......@@ -109,14 +123,18 @@ TEST_F(CreateProgramsTest, example_network) {
ASSERT_EQ(0, executor->run());
auto loss_var = executor->var<::paddle::framework::LoDTensor>(loss_name);
auto loss_var = executor->var<::paddle::framework::LoDTensor>(ctr_output_loss_name);
auto loss = loss_var.data<float>()[0];
auto loss_all_var = executor->var<::paddle::framework::LoDTensor>(loss_all_name);
auto loss_all = loss_all_var.data<float>()[0];
auto ctr_output_var = executor->var<::paddle::framework::LoDTensor>(ctr_output_name);
auto ctr_output = ctr_output_var.data<float>()[0];
std::cout << "loss: " << loss << std::endl;
std::cout << "ctr_output: " << ctr_output << std::endl;
ASSERT_NEAR(loss, loss_all, 1e-9);
}
} // namespace feed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册