提交 a1f1d11e 编写于 作者: R rensilin

exec_in_scope

Change-Id: I4f4e030e86ec28ca6977f311e7998e0350e1ea06
上级 8ab19129
...@@ -15,7 +15,9 @@ def print_help(this_name): ...@@ -15,7 +15,9 @@ def print_help(this_name):
print("Usage: {} <network building filename> [model_dir]\n".format(this_name)) print("Usage: {} <network building filename> [model_dir]\n".format(this_name))
print(" example: {} {}".format(this_name, os.path.join(dirname, 'example.py'))) print(" example: {} {}".format(this_name, os.path.join(dirname, 'example.py')))
def inference(filename):
def inference_warpper(filename):
"""Build inference network(without loss and optimizer) """Build inference network(without loss and optimizer)
Args: Args:
filename: path of file which defined real inference function filename: path of file which defined real inference function
...@@ -24,11 +26,14 @@ def inference(filename): ...@@ -24,11 +26,14 @@ def inference(filename):
and and
Variable: ctr_output Variable: ctr_output
""" """
with open(filename, 'r') as f: with open(filename, 'r') as f:
code = f.read() code = f.read()
compiled = compile(code, filename, 'exec') compiled = compile(code, filename, 'exec')
exec(compiled)
return inference() scope = dict()
exec(compiled, scope)
return scope['inference']()
def main(argv): def main(argv):
"""Create programs """Create programs
...@@ -40,7 +45,7 @@ def main(argv): ...@@ -40,7 +45,7 @@ def main(argv):
exit(1) exit(1)
network_build_file = argv[1] network_build_file = argv[1]
if len(argv) >= 2: if len(argv) > 2:
model_dir = argv[2] model_dir = argv[2]
else: else:
model_dir = './model' model_dir = './model'
...@@ -48,7 +53,7 @@ def main(argv): ...@@ -48,7 +53,7 @@ def main(argv):
main_program = fluid.Program() main_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
inputs, ctr_output = inference(network_build_file) inputs, ctr_output = inference_warpper(network_build_file)
test_program = main_program.clone(for_test=True) test_program = main_program.clone(for_test=True)
......
...@@ -84,29 +84,29 @@ TEST_F(CreateProgramsTest, example_network) { ...@@ -84,29 +84,29 @@ TEST_F(CreateProgramsTest, example_network) {
ASSERT_EQ(-1, input_shape[0]); ASSERT_EQ(-1, input_shape[0]);
ASSERT_EQ(4488, input_shape[1]); ASSERT_EQ(4488, input_shape[1]);
auto input_var = executor->mutable_var<::paddle::framework::LoDTensor>(input_name); auto input_var = executor->mutable_var<::paddle::framework::LoDTensor>(input_name);
auto label_var = executor->mutable_var<::paddle::framework::LoDTensor>(label_name); auto label_var = executor->mutable_var<::paddle::framework::LoDTensor>(label_name);
ASSERT_NE(nullptr, input_var); ASSERT_NE(nullptr, input_var);
ASSERT_NE(nullptr, label_var); ASSERT_NE(nullptr, label_var);
input_var->Resize({1, input_shape[1]}); input_var->Resize({1, input_shape[1]});
auto input_data = input_var->mutable_data<float>(context_ptr->cpu_place); auto input_data = input_var->mutable_data<float>(context_ptr->cpu_place);
ASSERT_NE(nullptr, input_data); ASSERT_NE(nullptr, input_data);
for (int i = 0; i < input_shape[1]; ++i) { for (int i = 0; i < input_shape[1]; ++i) {
input_data[i] = 0.1; input_data[i] = 0.1;
} }
label_var->Resize({1, 1}); label_var->Resize({1, 1});
auto label_data = label_var->mutable_data<float>(context_ptr->cpu_place); auto label_data = label_var->mutable_data<float>(context_ptr->cpu_place);
ASSERT_NE(nullptr, label_data); ASSERT_NE(nullptr, label_data);
label_data[0] = 0.5; label_data[0] = 0.5;
ASSERT_EQ(0, executor->run()); ASSERT_EQ(0, executor->run());
auto loss_var = executor->var<::paddle::framework::LoDTensor>(loss_name); auto loss_var = executor->var<::paddle::framework::LoDTensor>(loss_name);
auto loss = loss_var.data<float>()[0]; auto loss = loss_var.data<float>()[0];
auto ctr_output_var = executor->var<::paddle::framework::LoDTensor>(ctr_output_name); auto ctr_output_var = executor->var<::paddle::framework::LoDTensor>(ctr_output_name);
auto ctr_output = ctr_output_var.data<float>()[0]; auto ctr_output = ctr_output_var.data<float>()[0];
std::cout << "loss: " << loss << std::endl; std::cout << "loss: " << loss << std::endl;
......
...@@ -106,12 +106,12 @@ TEST_F(SimpleExecutorTest, run) { ...@@ -106,12 +106,12 @@ TEST_F(SimpleExecutorTest, run) {
auto config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path)); auto config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path));
ASSERT_EQ(0, executor->initialize(config, context_ptr)); ASSERT_EQ(0, executor->initialize(config, context_ptr));
auto x_var = executor->mutable_var<::paddle::framework::LoDTensor>("x"); auto x_var = executor->mutable_var<::paddle::framework::LoDTensor>("x");
ASSERT_NE(nullptr, x_var); ASSERT_NE(nullptr, x_var);
int x_len = 10; int x_len = 10;
x_var->Resize({1, x_len}); x_var->Resize({1, x_len});
auto x_data = x_var->mutable_data<float>(context_ptr->cpu_place); auto x_data = x_var->mutable_data<float>(context_ptr->cpu_place);
ASSERT_NE(nullptr, x_data); ASSERT_NE(nullptr, x_data);
std::cout << "x: "; std::cout << "x: ";
for (int i = 0; i < x_len; ++i) { for (int i = 0; i < x_len; ++i) {
...@@ -122,7 +122,7 @@ TEST_F(SimpleExecutorTest, run) { ...@@ -122,7 +122,7 @@ TEST_F(SimpleExecutorTest, run) {
ASSERT_EQ(0, executor->run()); ASSERT_EQ(0, executor->run());
auto mean_var = executor->var<::paddle::framework::LoDTensor>("mean"); auto mean_var = executor->var<::paddle::framework::LoDTensor>("mean");
auto mean = mean_var.data<float>()[0]; auto mean = mean_var.data<float>()[0];
std::cout << "mean: " << mean << std::endl; std::cout << "mean: " << mean << std::endl;
ASSERT_NEAR(4.5, mean, 1e-9); ASSERT_NEAR(4.5, mean, 1e-9);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册