提交 21a80312 编写于 作者: H hong19860320

change pool to pool2d to fix the bug of pooling unit test, and refine code

test=develop
上级 b69262cf
......@@ -13,11 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
// #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/core/mir/passes.h"
// #endif
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
......
......@@ -21,6 +21,7 @@ namespace mir {} // namespace mir
} // namespace lite
} // namespace paddle
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
USE_MIR_PASS(demo);
USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass);
......@@ -28,4 +29,5 @@ USE_MIR_PASS(type_target_transform_pass);
USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
#endif
USE_MIR_PASS(runtime_context_assign_pass);
......@@ -46,18 +46,19 @@ class Optimizer {
SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass();
// #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
if (passes.empty()) {
RunPasses(std::vector<std::string>{{
// "static_kernel_pick_pass", //
// "variable_place_inference_pass", //
// "argument_type_display_pass", //
// "type_target_transform_pass", //
// "argument_type_display_pass", //
// "variable_place_inference_pass", //
// "argument_type_display_pass", //
// "io_copy_kernel_pick_pass", //
// "variable_place_inference_pass", //
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_target_transform_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
#endif
"runtime_context_assign_pass", //
}});
} else {
......
......@@ -82,8 +82,7 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.variance =
scope->FindVar(op_desc.Input("Variance").front())->GetMutable<Tensor>();
param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
param_.is_test = true; // TODO(hong19860320) param_.is_test =
// op_desc.GetAttr<int>("is_test");
param_.is_test = op_desc.GetAttr<int>("is_test");
param_.use_global_stats = op_desc.GetAttr<bool>("use_global_stats");
if (!param_.is_test) {
param_.mean_out =
......
......@@ -46,7 +46,7 @@ TEST(batch_norm_op_lite, test) {
desc.SetInput("Mean", {"mean"});
desc.SetInput("Variance", {"variance"});
desc.SetOutput("Y", {"y"});
desc.SetAttr("is_test", true);
desc.SetAttr("is_test", static_cast<int>(1));
desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f);
......@@ -67,72 +67,72 @@ TEST(batch_norm_op_lite, test) {
}
}
// TEST(batch_norm_op_lite, test_enable_is_test) {
// // prepare variables
// Scope scope;
// auto* x = scope.Var("x")->GetMutable<Tensor>();
// auto* scale = scope.Var("scale")->GetMutable<Tensor>();
// auto* bias = scope.Var("bias")->GetMutable<Tensor>();
// auto* mean = scope.Var("mean")->GetMutable<Tensor>();
// auto* variance = scope.Var("variance")->GetMutable<Tensor>();
// auto* y = scope.Var("y")->GetMutable<Tensor>();
// auto* mean_out = scope.Var("mean_out")->GetMutable<Tensor>();
// auto* variance_out = scope.Var("variance_out")->GetMutable<Tensor>();
// auto* saved_mean = scope.Var("saved_mean")->GetMutable<Tensor>();
// auto* saved_variance = scope.Var("saved_variance")->GetMutable<Tensor>();
// x->Resize({2, 32, 10, 20});
// auto x_dims = x->dims();
// const int64_t channel_size = x_dims[1]; // NCHW
// scale->Resize({channel_size});
// bias->Resize({channel_size});
// mean->Resize({channel_size});
// variance->Resize({channel_size});
TEST(batch_norm_op_lite, test_enable_is_test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* scale = scope.Var("scale")->GetMutable<Tensor>();
auto* bias = scope.Var("bias")->GetMutable<Tensor>();
auto* mean = scope.Var("mean")->GetMutable<Tensor>();
auto* variance = scope.Var("variance")->GetMutable<Tensor>();
auto* y = scope.Var("y")->GetMutable<Tensor>();
auto* mean_out = scope.Var("mean_out")->GetMutable<Tensor>();
auto* variance_out = scope.Var("variance_out")->GetMutable<Tensor>();
auto* saved_mean = scope.Var("saved_mean")->GetMutable<Tensor>();
auto* saved_variance = scope.Var("saved_variance")->GetMutable<Tensor>();
x->Resize({2, 32, 10, 20});
auto x_dims = x->dims();
const int64_t channel_size = x_dims[1]; // NCHW
scale->Resize({channel_size});
bias->Resize({channel_size});
mean->Resize({channel_size});
variance->Resize({channel_size});
// // prepare op desc
// cpp::OpDesc desc;
// desc.SetType("batch_norm");
// desc.SetInput("X", {"x"});
// desc.SetInput("Scale", {"scale"});
// desc.SetInput("Bias", {"bias"});
// desc.SetInput("Mean", {"mean"});
// desc.SetInput("Variance", {"variance"});
// desc.SetOutput("Y", {"y"});
// desc.SetOutput("MeanOut", {"mean_out"});
// desc.SetOutput("VarianceOut", {"variance_out"});
// desc.SetOutput("SavedMean", {"saved_mean"});
// desc.SetOutput("SavedVariance", {"saved_variance"});
// desc.SetAttr("is_test", false);
// desc.SetAttr("use_global_stats", false);
// desc.SetAttr("epsilon", 1e-5f);
// desc.SetAttr("momentum", 0.9f);
// desc.SetAttr("data_layout", std::string("NCHW"));
// prepare op desc
cpp::OpDesc desc;
desc.SetType("batch_norm");
desc.SetInput("X", {"x"});
desc.SetInput("Scale", {"scale"});
desc.SetInput("Bias", {"bias"});
desc.SetInput("Mean", {"mean"});
desc.SetInput("Variance", {"variance"});
desc.SetOutput("Y", {"y"});
desc.SetOutput("MeanOut", {"mean_out"});
desc.SetOutput("VarianceOut", {"variance_out"});
desc.SetOutput("SavedMean", {"saved_mean"});
desc.SetOutput("SavedVariance", {"saved_variance"});
desc.SetAttr("is_test", static_cast<int>(0));
desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f);
desc.SetAttr("data_layout", std::string("NCHW"));
// BatchNormOp batch_norm("batch_norm");
BatchNormOp batch_norm("batch_norm");
// batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
// batch_norm.Attach(desc, &scope);
// batch_norm.CheckShape();
// batch_norm.InferShape();
batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
batch_norm.Attach(desc, &scope);
batch_norm.CheckShape();
batch_norm.InferShape();
// // check output dims
// auto y_dims = y->dims();
// CHECK_EQ(y_dims.size(), x_dims.size());
// for (size_t i = 0; i < y_dims.size(); i++) {
// CHECK_EQ(y_dims[i], x_dims[i]);
// }
// auto mean_out_dims = mean_out->dims();
// auto variance_out_dims = variance_out->dims();
// auto saved_mean_dims = saved_mean->dims();
// auto saved_variance_dims = saved_variance->dims();
// CHECK_EQ(mean_out_dims.size(), 1UL);
// CHECK_EQ(variance_out_dims.size(), 1UL);
// CHECK_EQ(saved_mean_dims.size(), 1UL);
// CHECK_EQ(saved_variance_dims.size(), 1UL);
// CHECK_EQ(mean_out_dims[0], channel_size);
// CHECK_EQ(variance_out_dims[0], channel_size);
// CHECK_EQ(saved_mean_dims[0], channel_size);
// CHECK_EQ(saved_variance_dims[0], channel_size);
// }
// check output dims
auto y_dims = y->dims();
CHECK_EQ(y_dims.size(), x_dims.size());
for (size_t i = 0; i < y_dims.size(); i++) {
CHECK_EQ(y_dims[i], x_dims[i]);
}
auto mean_out_dims = mean_out->dims();
auto variance_out_dims = variance_out->dims();
auto saved_mean_dims = saved_mean->dims();
auto saved_variance_dims = saved_variance->dims();
CHECK_EQ(mean_out_dims.size(), 1UL);
CHECK_EQ(variance_out_dims.size(), 1UL);
CHECK_EQ(saved_mean_dims.size(), 1UL);
CHECK_EQ(saved_variance_dims.size(), 1UL);
CHECK_EQ(mean_out_dims[0], channel_size);
CHECK_EQ(variance_out_dims[0], channel_size);
CHECK_EQ(saved_mean_dims[0], channel_size);
CHECK_EQ(saved_variance_dims[0], channel_size);
}
} // namespace operators
} // namespace lite
......
......@@ -71,7 +71,7 @@ class PoolOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "pool"; }
std::string DebugString() const override { return "pool2d"; }
private:
mutable PoolParam param_;
......
......@@ -38,7 +38,7 @@ TEST(pool_op_lite, test) {
// prepare op desc
cpp::OpDesc desc;
desc.SetType("pool");
desc.SetType("pool2d");
desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"output"});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册