提交 4f3eba94 编写于 作者: R rensilin

add_executor_ut

Change-Id: Ibc3221bdb90032663e7078e63a839dbde4c119d1
上级 8699c196
此差异已折叠。
...@@ -13,26 +13,90 @@ See the License for the specific language governing permissions and ...@@ -13,26 +13,90 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <iostream> #include <iostream>
#include <fstream>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/io/fs.h"
namespace paddle { namespace paddle {
namespace custom_trainer { namespace custom_trainer {
namespace feed { namespace feed {
<<<<<<< HEAD
TEST(testSimpleExecutor, initialize) { TEST(testSimpleExecutor, initialize) {
SimpleExecutor execute; SimpleExecutor execute;
auto context_ptr = std::make_shared<TrainerContext>(); auto context_ptr = std::make_shared<TrainerContext>();
=======
const char test_data_dir[] = "test_data";
const char main_program_path[] = "test_data/main_program";
const char startup_program_path[] = "test_data/startup_program";
class SimpleExecuteTest : public testing::Test
{
public:
static void SetUpTestCase()
{
::paddle::framework::localfs_mkdir(test_data_dir);
{
std::unique_ptr<paddle::framework::ProgramDesc> startup_program(
new paddle::framework::ProgramDesc());
std::ofstream fout(startup_program_path, std::ios::out | std::ios::binary);
ASSERT_TRUE(fout);
fout << startup_program->Proto()->SerializeAsString();
fout.close();
}
{
std::unique_ptr<paddle::framework::ProgramDesc> main_program(
new paddle::framework::ProgramDesc());
auto load_block = main_program->MutableBlock(0);
framework::OpDesc* op = load_block->AppendOp();
op->SetType("mean");
op->SetInput("X", {"x"});
op->SetOutput("Out", {"mean"});
op->CheckAttrs();
std::ofstream fout(main_program_path, std::ios::out | std::ios::binary);
ASSERT_TRUE(fout);
fout << main_program->Proto()->SerializeAsString();
fout.close();
}
}
static void TearDownTestCase()
{
::paddle::framework::localfs_remove(test_data_dir);
}
virtual void SetUp()
{
context_ptr.reset(new TrainerContext());
}
virtual void TearDown()
{
context_ptr = nullptr;
}
std::shared_ptr<TrainerContext> context_ptr;
};
TEST_F(SimpleExecuteTest, initialize) {
SimpleExecute execute;
>>>>>>> add_executor_ut
YAML::Node config = YAML::Load("[1, 2, 3]"); YAML::Node config = YAML::Load("[1, 2, 3]");
ASSERT_NE(0, execute.initialize(config, context_ptr)); ASSERT_NE(0, execute.initialize(config, context_ptr));
config = YAML::Load("{startup_program: ./data/startup_program, main_program: ./data/main_program}"); config = YAML::Load(std::string() + "{startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
ASSERT_EQ(0, execute.initialize(config, context_ptr)); ASSERT_EQ(0, execute.initialize(config, context_ptr));
config = YAML::Load("{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}"); config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
ASSERT_EQ(0, execute.initialize(config, context_ptr)); ASSERT_EQ(0, execute.initialize(config, context_ptr));
} }
<<<<<<< HEAD
float uniform(float min, float max) { float uniform(float min, float max) {
float result = (float)rand() / RAND_MAX; float result = (float)rand() / RAND_MAX;
return min + result * (max - min); return min + result * (max - min);
...@@ -58,21 +122,33 @@ TEST(testSimpleExecutor, run) { ...@@ -58,21 +122,33 @@ TEST(testSimpleExecutor, run) {
SimpleExecutor execute; SimpleExecutor execute;
auto context_ptr = std::make_shared<TrainerContext>(); auto context_ptr = std::make_shared<TrainerContext>();
auto config = YAML::Load("{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}"); auto config = YAML::Load("{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}");
=======
TEST_F(SimpleExecuteTest, run) {
SimpleExecute execute;
auto config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
>>>>>>> add_executor_ut
ASSERT_EQ(0, execute.initialize(config, context_ptr)); ASSERT_EQ(0, execute.initialize(config, context_ptr));
auto x_var = execute.mutable_var<::paddle::framework::LoDTensor>("x"); auto x_var = execute.mutable_var<::paddle::framework::LoDTensor>("x");
auto y_var = execute.mutable_var<::paddle::framework::LoDTensor>("y"); execute.mutable_var<::paddle::framework::LoDTensor>("mean");
ASSERT_NE(nullptr, x_var); ASSERT_NE(nullptr, x_var);
ASSERT_NE(nullptr, y_var);
next_batch(1024, context_ptr->cpu_place, x_var, y_var); int x_len = 10;
x_var->Resize({1, x_len});
auto x_data = x_var->mutable_data<float>(context_ptr->cpu_place);
std::cout << "x: ";
for (int i = 0; i < x_len; ++i) {
x_data[i] = i;
std::cout << i << " ";
}
std::cout << std::endl;
ASSERT_EQ(0, execute.run()); ASSERT_EQ(0, execute.run());
auto loss_var = execute.var<::paddle::framework::LoDTensor>("loss"); auto mean_var = execute.var<::paddle::framework::LoDTensor>("mean");
auto loss = loss_var.data<float>()[0]; auto mean = mean_var.data<float>()[0];
std::cout << "loss: " << loss << std::endl; std::cout << "mean: " << mean << std::endl;
ASSERT_NEAR(4.5, mean, 1e-9);
} }
} // namespace feed } // namespace feed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册