From 21a80312e14aad506712deceb579afcc11e07228 Mon Sep 17 00:00:00 2001 From: hong19860320 <9973393+hong19860320@users.noreply.github.com> Date: Sat, 15 Jun 2019 10:26:14 +0000 Subject: [PATCH] change pool to pool2d to fix the bug of pooling unit test, and refine code test=develop --- paddle/fluid/lite/api/cxx_api_bin.cc | 4 - paddle/fluid/lite/core/mir/passes.h | 2 + paddle/fluid/lite/core/optimizer.h | 21 +-- paddle/fluid/lite/operators/batch_norm_op.cc | 3 +- .../lite/operators/batch_norm_op_test.cc | 126 +++++++++--------- paddle/fluid/lite/operators/pool_op.h | 2 +- paddle/fluid/lite/operators/pool_op_test.cc | 2 +- 7 files changed, 79 insertions(+), 81 deletions(-) diff --git a/paddle/fluid/lite/api/cxx_api_bin.cc b/paddle/fluid/lite/api/cxx_api_bin.cc index b315030ed7e..96cad7cbe07 100644 --- a/paddle/fluid/lite/api/cxx_api_bin.cc +++ b/paddle/fluid/lite/api/cxx_api_bin.cc @@ -13,11 +13,7 @@ // limitations under the License. #include "paddle/fluid/lite/api/cxx_api.h" - -// #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #include "paddle/fluid/lite/core/mir/passes.h" -// #endif - #include "paddle/fluid/lite/core/op_registry.h" namespace paddle { diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index 60e53257ba0..ac7a19bdfc0 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -21,6 +21,7 @@ namespace mir {} // namespace mir } // namespace lite } // namespace paddle +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK USE_MIR_PASS(demo); USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(variable_place_inference_pass); @@ -28,4 +29,5 @@ USE_MIR_PASS(type_target_transform_pass); USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); +#endif USE_MIR_PASS(runtime_context_assign_pass); diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index 63d0ffbc6b8..a57f3c0b7f9 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -46,18 +46,19 @@ class Optimizer { SpecifyKernelPickTactic(kernel_pick_factor); InitTargetTypeTransformPass(); - // #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK if (passes.empty()) { RunPasses(std::vector{{ - // "static_kernel_pick_pass", // - // "variable_place_inference_pass", // - // "argument_type_display_pass", // - // "type_target_transform_pass", // - // "argument_type_display_pass", // - // "variable_place_inference_pass", // - // "argument_type_display_pass", // - // "io_copy_kernel_pick_pass", // - // "variable_place_inference_pass", // +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + "static_kernel_pick_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "type_target_transform_pass", // + "argument_type_display_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "io_copy_kernel_pick_pass", // + "variable_place_inference_pass", // +#endif "runtime_context_assign_pass", // }}); } else { diff --git a/paddle/fluid/lite/operators/batch_norm_op.cc b/paddle/fluid/lite/operators/batch_norm_op.cc index 10b1f2d7c3d..b6ef87732de 100644 --- a/paddle/fluid/lite/operators/batch_norm_op.cc +++ b/paddle/fluid/lite/operators/batch_norm_op.cc @@ -82,8 +82,7 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { param_.variance = scope->FindVar(op_desc.Input("Variance").front())->GetMutable(); param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable(); - param_.is_test = true; // TODO(hong19860320) param_.is_test = - // op_desc.GetAttr("is_test"); + param_.is_test = op_desc.GetAttr("is_test"); param_.use_global_stats = op_desc.GetAttr("use_global_stats"); if (!param_.is_test) { param_.mean_out = diff --git a/paddle/fluid/lite/operators/batch_norm_op_test.cc b/paddle/fluid/lite/operators/batch_norm_op_test.cc index 4072faccd72..9fb02759722 100644 --- a/paddle/fluid/lite/operators/batch_norm_op_test.cc +++ b/paddle/fluid/lite/operators/batch_norm_op_test.cc @@ -46,7 +46,7 @@ TEST(batch_norm_op_lite, test) { desc.SetInput("Mean", {"mean"}); desc.SetInput("Variance", {"variance"}); desc.SetOutput("Y", {"y"}); - desc.SetAttr("is_test", true); + desc.SetAttr("is_test", static_cast(1)); desc.SetAttr("use_global_stats", false); desc.SetAttr("epsilon", 1e-5f); desc.SetAttr("momentum", 0.9f); @@ -67,72 +67,72 @@ TEST(batch_norm_op_lite, test) { } } -// TEST(batch_norm_op_lite, test_enable_is_test) { -// // prepare variables -// Scope scope; -// auto* x = scope.Var("x")->GetMutable(); -// auto* scale = scope.Var("scale")->GetMutable(); -// auto* bias = scope.Var("bias")->GetMutable(); -// auto* mean = scope.Var("mean")->GetMutable(); -// auto* variance = scope.Var("variance")->GetMutable(); -// auto* y = scope.Var("y")->GetMutable(); -// auto* mean_out = scope.Var("mean_out")->GetMutable(); -// auto* variance_out = scope.Var("variance_out")->GetMutable(); -// auto* saved_mean = scope.Var("saved_mean")->GetMutable(); -// auto* saved_variance = scope.Var("saved_variance")->GetMutable(); -// x->Resize({2, 32, 10, 20}); -// auto x_dims = x->dims(); -// const int64_t channel_size = x_dims[1]; // NCHW -// scale->Resize({channel_size}); -// bias->Resize({channel_size}); -// mean->Resize({channel_size}); -// variance->Resize({channel_size}); +TEST(batch_norm_op_lite, test_enable_is_test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* scale = scope.Var("scale")->GetMutable(); + auto* bias = scope.Var("bias")->GetMutable(); + auto* mean = scope.Var("mean")->GetMutable(); + auto* variance = scope.Var("variance")->GetMutable(); + auto* y = scope.Var("y")->GetMutable(); + auto* mean_out = scope.Var("mean_out")->GetMutable(); + auto* variance_out = scope.Var("variance_out")->GetMutable(); + auto* saved_mean = scope.Var("saved_mean")->GetMutable(); + auto* saved_variance = scope.Var("saved_variance")->GetMutable(); + x->Resize({2, 32, 10, 20}); + auto x_dims = x->dims(); + const int64_t channel_size = x_dims[1]; // NCHW + scale->Resize({channel_size}); + bias->Resize({channel_size}); + mean->Resize({channel_size}); + variance->Resize({channel_size}); -// // prepare op desc -// cpp::OpDesc desc; -// desc.SetType("batch_norm"); -// desc.SetInput("X", {"x"}); -// desc.SetInput("Scale", {"scale"}); -// desc.SetInput("Bias", {"bias"}); -// desc.SetInput("Mean", {"mean"}); -// desc.SetInput("Variance", {"variance"}); -// desc.SetOutput("Y", {"y"}); -// desc.SetOutput("MeanOut", {"mean_out"}); -// desc.SetOutput("VarianceOut", {"variance_out"}); -// desc.SetOutput("SavedMean", {"saved_mean"}); -// desc.SetOutput("SavedVariance", {"saved_variance"}); -// desc.SetAttr("is_test", false); -// desc.SetAttr("use_global_stats", false); -// desc.SetAttr("epsilon", 1e-5f); -// desc.SetAttr("momentum", 0.9f); -// desc.SetAttr("data_layout", std::string("NCHW")); + // prepare op desc + cpp::OpDesc desc; + desc.SetType("batch_norm"); + desc.SetInput("X", {"x"}); + desc.SetInput("Scale", {"scale"}); + desc.SetInput("Bias", {"bias"}); + desc.SetInput("Mean", {"mean"}); + desc.SetInput("Variance", {"variance"}); + desc.SetOutput("Y", {"y"}); + desc.SetOutput("MeanOut", {"mean_out"}); + desc.SetOutput("VarianceOut", {"variance_out"}); + desc.SetOutput("SavedMean", {"saved_mean"}); + desc.SetOutput("SavedVariance", {"saved_variance"}); + desc.SetAttr("is_test", static_cast(0)); + desc.SetAttr("use_global_stats", false); + desc.SetAttr("epsilon", 1e-5f); + desc.SetAttr("momentum", 0.9f); + desc.SetAttr("data_layout", std::string("NCHW")); -// BatchNormOp batch_norm("batch_norm"); + BatchNormOp batch_norm("batch_norm"); -// batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); -// batch_norm.Attach(desc, &scope); -// batch_norm.CheckShape(); -// batch_norm.InferShape(); + batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); + batch_norm.Attach(desc, &scope); + batch_norm.CheckShape(); + batch_norm.InferShape(); -// // check output dims -// auto y_dims = y->dims(); -// CHECK_EQ(y_dims.size(), x_dims.size()); -// for (size_t i = 0; i < y_dims.size(); i++) { -// CHECK_EQ(y_dims[i], x_dims[i]); -// } -// auto mean_out_dims = mean_out->dims(); -// auto variance_out_dims = variance_out->dims(); -// auto saved_mean_dims = saved_mean->dims(); -// auto saved_variance_dims = saved_variance->dims(); -// CHECK_EQ(mean_out_dims.size(), 1UL); -// CHECK_EQ(variance_out_dims.size(), 1UL); -// CHECK_EQ(saved_mean_dims.size(), 1UL); -// CHECK_EQ(saved_variance_dims.size(), 1UL); -// CHECK_EQ(mean_out_dims[0], channel_size); -// CHECK_EQ(variance_out_dims[0], channel_size); -// CHECK_EQ(saved_mean_dims[0], channel_size); -// CHECK_EQ(saved_variance_dims[0], channel_size); -// } + // check output dims + auto y_dims = y->dims(); + CHECK_EQ(y_dims.size(), x_dims.size()); + for (size_t i = 0; i < y_dims.size(); i++) { + CHECK_EQ(y_dims[i], x_dims[i]); + } + auto mean_out_dims = mean_out->dims(); + auto variance_out_dims = variance_out->dims(); + auto saved_mean_dims = saved_mean->dims(); + auto saved_variance_dims = saved_variance->dims(); + CHECK_EQ(mean_out_dims.size(), 1UL); + CHECK_EQ(variance_out_dims.size(), 1UL); + CHECK_EQ(saved_mean_dims.size(), 1UL); + CHECK_EQ(saved_variance_dims.size(), 1UL); + CHECK_EQ(mean_out_dims[0], channel_size); + CHECK_EQ(variance_out_dims[0], channel_size); + CHECK_EQ(saved_mean_dims[0], channel_size); + CHECK_EQ(saved_variance_dims[0], channel_size); +} } // namespace operators } // namespace lite diff --git a/paddle/fluid/lite/operators/pool_op.h b/paddle/fluid/lite/operators/pool_op.h index bb9963a70e3..29946ed92a4 100644 --- a/paddle/fluid/lite/operators/pool_op.h +++ b/paddle/fluid/lite/operators/pool_op.h @@ -71,7 +71,7 @@ class PoolOpLite : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } - std::string DebugString() const override { return "pool"; } + std::string DebugString() const override { return "pool2d"; } private: mutable PoolParam param_; diff --git a/paddle/fluid/lite/operators/pool_op_test.cc b/paddle/fluid/lite/operators/pool_op_test.cc index c3df9b47847..e9616ede5a4 100644 --- a/paddle/fluid/lite/operators/pool_op_test.cc +++ b/paddle/fluid/lite/operators/pool_op_test.cc @@ -38,7 +38,7 @@ TEST(pool_op_lite, test) { // prepare op desc cpp::OpDesc desc; - desc.SetType("pool"); + desc.SetType("pool2d"); desc.SetInput("X", {"x"}); desc.SetOutput("Out", {"output"}); -- GitLab