提交 21a80312 编写于 作者: H hong19860320

change pool to pool2d to fix the bug of pooling unit test, and refine code

test=develop
上级 b69262cf
...@@ -13,11 +13,7 @@ ...@@ -13,11 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h" #include "paddle/fluid/lite/api/cxx_api.h"
// #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/core/mir/passes.h" #include "paddle/fluid/lite/core/mir/passes.h"
// #endif
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
namespace paddle { namespace paddle {
......
...@@ -21,6 +21,7 @@ namespace mir {} // namespace mir ...@@ -21,6 +21,7 @@ namespace mir {} // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
USE_MIR_PASS(demo); USE_MIR_PASS(demo);
USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(variable_place_inference_pass);
...@@ -28,4 +29,5 @@ USE_MIR_PASS(type_target_transform_pass); ...@@ -28,4 +29,5 @@ USE_MIR_PASS(type_target_transform_pass);
USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(argument_type_display_pass);
#endif
USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(runtime_context_assign_pass);
...@@ -46,18 +46,19 @@ class Optimizer { ...@@ -46,18 +46,19 @@ class Optimizer {
SpecifyKernelPickTactic(kernel_pick_factor); SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass(); InitTargetTypeTransformPass();
// #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
if (passes.empty()) { if (passes.empty()) {
RunPasses(std::vector<std::string>{{ RunPasses(std::vector<std::string>{{
// "static_kernel_pick_pass", // #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
// "variable_place_inference_pass", // "static_kernel_pick_pass", //
// "argument_type_display_pass", // "variable_place_inference_pass", //
// "type_target_transform_pass", // "argument_type_display_pass", //
// "argument_type_display_pass", // "type_target_transform_pass", //
// "variable_place_inference_pass", // "argument_type_display_pass", //
// "argument_type_display_pass", // "variable_place_inference_pass", //
// "io_copy_kernel_pick_pass", // "argument_type_display_pass", //
// "variable_place_inference_pass", // "io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
#endif
"runtime_context_assign_pass", // "runtime_context_assign_pass", //
}}); }});
} else { } else {
......
...@@ -82,8 +82,7 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -82,8 +82,7 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.variance = param_.variance =
scope->FindVar(op_desc.Input("Variance").front())->GetMutable<Tensor>(); scope->FindVar(op_desc.Input("Variance").front())->GetMutable<Tensor>();
param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>(); param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
param_.is_test = true; // TODO(hong19860320) param_.is_test = param_.is_test = op_desc.GetAttr<int>("is_test");
// op_desc.GetAttr<int>("is_test");
param_.use_global_stats = op_desc.GetAttr<bool>("use_global_stats"); param_.use_global_stats = op_desc.GetAttr<bool>("use_global_stats");
if (!param_.is_test) { if (!param_.is_test) {
param_.mean_out = param_.mean_out =
......
...@@ -46,7 +46,7 @@ TEST(batch_norm_op_lite, test) { ...@@ -46,7 +46,7 @@ TEST(batch_norm_op_lite, test) {
desc.SetInput("Mean", {"mean"}); desc.SetInput("Mean", {"mean"});
desc.SetInput("Variance", {"variance"}); desc.SetInput("Variance", {"variance"});
desc.SetOutput("Y", {"y"}); desc.SetOutput("Y", {"y"});
desc.SetAttr("is_test", true); desc.SetAttr("is_test", static_cast<int>(1));
desc.SetAttr("use_global_stats", false); desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f); desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f); desc.SetAttr("momentum", 0.9f);
...@@ -67,72 +67,72 @@ TEST(batch_norm_op_lite, test) { ...@@ -67,72 +67,72 @@ TEST(batch_norm_op_lite, test) {
} }
} }
// TEST(batch_norm_op_lite, test_enable_is_test) { TEST(batch_norm_op_lite, test_enable_is_test) {
// // prepare variables // prepare variables
// Scope scope; Scope scope;
// auto* x = scope.Var("x")->GetMutable<Tensor>(); auto* x = scope.Var("x")->GetMutable<Tensor>();
// auto* scale = scope.Var("scale")->GetMutable<Tensor>(); auto* scale = scope.Var("scale")->GetMutable<Tensor>();
// auto* bias = scope.Var("bias")->GetMutable<Tensor>(); auto* bias = scope.Var("bias")->GetMutable<Tensor>();
// auto* mean = scope.Var("mean")->GetMutable<Tensor>(); auto* mean = scope.Var("mean")->GetMutable<Tensor>();
// auto* variance = scope.Var("variance")->GetMutable<Tensor>(); auto* variance = scope.Var("variance")->GetMutable<Tensor>();
// auto* y = scope.Var("y")->GetMutable<Tensor>(); auto* y = scope.Var("y")->GetMutable<Tensor>();
// auto* mean_out = scope.Var("mean_out")->GetMutable<Tensor>(); auto* mean_out = scope.Var("mean_out")->GetMutable<Tensor>();
// auto* variance_out = scope.Var("variance_out")->GetMutable<Tensor>(); auto* variance_out = scope.Var("variance_out")->GetMutable<Tensor>();
// auto* saved_mean = scope.Var("saved_mean")->GetMutable<Tensor>(); auto* saved_mean = scope.Var("saved_mean")->GetMutable<Tensor>();
// auto* saved_variance = scope.Var("saved_variance")->GetMutable<Tensor>(); auto* saved_variance = scope.Var("saved_variance")->GetMutable<Tensor>();
// x->Resize({2, 32, 10, 20}); x->Resize({2, 32, 10, 20});
// auto x_dims = x->dims(); auto x_dims = x->dims();
// const int64_t channel_size = x_dims[1]; // NCHW const int64_t channel_size = x_dims[1]; // NCHW
// scale->Resize({channel_size}); scale->Resize({channel_size});
// bias->Resize({channel_size}); bias->Resize({channel_size});
// mean->Resize({channel_size}); mean->Resize({channel_size});
// variance->Resize({channel_size}); variance->Resize({channel_size});
// // prepare op desc // prepare op desc
// cpp::OpDesc desc; cpp::OpDesc desc;
// desc.SetType("batch_norm"); desc.SetType("batch_norm");
// desc.SetInput("X", {"x"}); desc.SetInput("X", {"x"});
// desc.SetInput("Scale", {"scale"}); desc.SetInput("Scale", {"scale"});
// desc.SetInput("Bias", {"bias"}); desc.SetInput("Bias", {"bias"});
// desc.SetInput("Mean", {"mean"}); desc.SetInput("Mean", {"mean"});
// desc.SetInput("Variance", {"variance"}); desc.SetInput("Variance", {"variance"});
// desc.SetOutput("Y", {"y"}); desc.SetOutput("Y", {"y"});
// desc.SetOutput("MeanOut", {"mean_out"}); desc.SetOutput("MeanOut", {"mean_out"});
// desc.SetOutput("VarianceOut", {"variance_out"}); desc.SetOutput("VarianceOut", {"variance_out"});
// desc.SetOutput("SavedMean", {"saved_mean"}); desc.SetOutput("SavedMean", {"saved_mean"});
// desc.SetOutput("SavedVariance", {"saved_variance"}); desc.SetOutput("SavedVariance", {"saved_variance"});
// desc.SetAttr("is_test", false); desc.SetAttr("is_test", static_cast<int>(0));
// desc.SetAttr("use_global_stats", false); desc.SetAttr("use_global_stats", false);
// desc.SetAttr("epsilon", 1e-5f); desc.SetAttr("epsilon", 1e-5f);
// desc.SetAttr("momentum", 0.9f); desc.SetAttr("momentum", 0.9f);
// desc.SetAttr("data_layout", std::string("NCHW")); desc.SetAttr("data_layout", std::string("NCHW"));
// BatchNormOp batch_norm("batch_norm"); BatchNormOp batch_norm("batch_norm");
// batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
// batch_norm.Attach(desc, &scope); batch_norm.Attach(desc, &scope);
// batch_norm.CheckShape(); batch_norm.CheckShape();
// batch_norm.InferShape(); batch_norm.InferShape();
// // check output dims // check output dims
// auto y_dims = y->dims(); auto y_dims = y->dims();
// CHECK_EQ(y_dims.size(), x_dims.size()); CHECK_EQ(y_dims.size(), x_dims.size());
// for (size_t i = 0; i < y_dims.size(); i++) { for (size_t i = 0; i < y_dims.size(); i++) {
// CHECK_EQ(y_dims[i], x_dims[i]); CHECK_EQ(y_dims[i], x_dims[i]);
// } }
// auto mean_out_dims = mean_out->dims(); auto mean_out_dims = mean_out->dims();
// auto variance_out_dims = variance_out->dims(); auto variance_out_dims = variance_out->dims();
// auto saved_mean_dims = saved_mean->dims(); auto saved_mean_dims = saved_mean->dims();
// auto saved_variance_dims = saved_variance->dims(); auto saved_variance_dims = saved_variance->dims();
// CHECK_EQ(mean_out_dims.size(), 1UL); CHECK_EQ(mean_out_dims.size(), 1UL);
// CHECK_EQ(variance_out_dims.size(), 1UL); CHECK_EQ(variance_out_dims.size(), 1UL);
// CHECK_EQ(saved_mean_dims.size(), 1UL); CHECK_EQ(saved_mean_dims.size(), 1UL);
// CHECK_EQ(saved_variance_dims.size(), 1UL); CHECK_EQ(saved_variance_dims.size(), 1UL);
// CHECK_EQ(mean_out_dims[0], channel_size); CHECK_EQ(mean_out_dims[0], channel_size);
// CHECK_EQ(variance_out_dims[0], channel_size); CHECK_EQ(variance_out_dims[0], channel_size);
// CHECK_EQ(saved_mean_dims[0], channel_size); CHECK_EQ(saved_mean_dims[0], channel_size);
// CHECK_EQ(saved_variance_dims[0], channel_size); CHECK_EQ(saved_variance_dims[0], channel_size);
// } }
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
......
...@@ -71,7 +71,7 @@ class PoolOpLite : public OpLite { ...@@ -71,7 +71,7 @@ class PoolOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "pool"; } std::string DebugString() const override { return "pool2d"; }
private: private:
mutable PoolParam param_; mutable PoolParam param_;
......
...@@ -38,7 +38,7 @@ TEST(pool_op_lite, test) { ...@@ -38,7 +38,7 @@ TEST(pool_op_lite, test) {
// prepare op desc // prepare op desc
cpp::OpDesc desc; cpp::OpDesc desc;
desc.SetType("pool"); desc.SetType("pool2d");
desc.SetInput("X", {"x"}); desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"output"}); desc.SetOutput("Out", {"output"});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册