提交 33437fe4 编写于 作者: H hong19860320

enable mobilenetv1, fix the bugs of conv, pool, relu and split

test=develop
上级 5f833603
......@@ -14,9 +14,9 @@
#include "paddle/fluid/lite/api/cxx_api.h"
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
// #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/core/mir/passes.h"
#endif
// #endif
#include "paddle/fluid/lite/core/op_registry.h"
......@@ -24,6 +24,9 @@ namespace paddle {
namespace lite {
void Run(const char* model_dir) {
#ifdef LITE_WITH_ARM
DeviceInfo::Init();
#endif
lite::ExecutorLite predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
......@@ -32,9 +35,9 @@ void Run(const char* model_dir) {
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({3, 224, 224})));
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 3 * 224 * 224; i++) {
for (int i = 0; i < input_tensor->dims().production(); i++) {
data[i] = i;
}
......@@ -65,7 +68,7 @@ USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
USE_LITE_OP(con2d);
USE_LITE_OP(conv2d);
// USE_LITE_OP(batch_norm);
USE_LITE_OP(relu);
USE_LITE_OP(depthwise_conv2d);
......@@ -81,10 +84,10 @@ USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(con2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_con2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
......
......@@ -46,24 +46,24 @@ class Optimizer {
SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass();
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
// #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
if (passes.empty()) {
RunPasses(std::vector<std::string>{{
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_target_transform_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
// "static_kernel_pick_pass", //
// "variable_place_inference_pass", //
// "argument_type_display_pass", //
// "type_target_transform_pass", //
// "argument_type_display_pass", //
// "variable_place_inference_pass", //
// "argument_type_display_pass", //
// "io_copy_kernel_pick_pass", //
// "variable_place_inference_pass", //
"runtime_context_assign_pass", //
}});
} else {
RunPasses(passes);
}
#endif
// #endif
exec_scope_ = program.exec_scope();
}
......
......@@ -102,7 +102,7 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
......@@ -110,5 +110,5 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -26,7 +26,7 @@ void ElementwiseAddCompute::Run() {
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
int n = param.X->dims().production();
lite::arm::math::elementwise_add(x_data, y_data, out_data, n);
// lite::arm::math::elementwise_add(x_data, y_data, out_data, n);
}
} // namespace arm
......
......@@ -163,7 +163,7 @@ PrecisionType PoolCompute::precision() const { return PRECISION(kFloat); }
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(pool, kARM, kFloat, kNCHW,
REGISTER_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::PoolCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
......
......@@ -45,4 +45,6 @@ class ReluCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(relu, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ReluCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -40,11 +40,11 @@ class ConvOpLite : public OpLite {
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto input = op_desc.Input("Input").front();
auto filter = op_desc.Input("Filter").front();
auto out = op_desc.Output("Out").front();
auto output = op_desc.Output("Output").front();
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.filter = scope->FindVar(filter)->GetMutable<lite::Tensor>();
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
CHECK(scope->FindVar(output));
param_.output = scope->FindVar(output)->GetMutable<lite::Tensor>();
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups");
......@@ -53,21 +53,26 @@ class ConvOpLite : public OpLite {
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_var = scope->FindVar(op_desc.Input("Bias").front());
auto bias_arguments = op_desc.Input("Bias");
if (bias_arguments.size() > 0) {
auto bias_var = scope->FindVar(bias_arguments.front());
if (bias_var != nullptr) {
param_.bias =
const_cast<lite::Tensor *>(&(bias_var->Get<lite::Tensor>()));
}
}
}
if (std::find(input_arg_names.begin(), input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
auto residual_data_var =
scope->FindVar(op_desc.Input("ResidualData").front());
auto res_data_arguments = op_desc.Input("ResidualData");
if (res_data_arguments.size() > 0) {
auto residual_data_var = scope->FindVar(res_data_arguments.front());
if (residual_data_var != nullptr) {
param_.residualData = const_cast<lite::Tensor *>(
&(residual_data_var->Get<lite::Tensor>()));
}
}
}
return true;
}
......
......@@ -53,10 +53,18 @@ class PoolOpLite : public OpLite {
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
if (op_desc.HasAttr("exclusive")) {
param_.exclusive = op_desc.GetAttr<bool>("exclusive");
}
if (op_desc.HasAttr("adaptive")) {
param_.adaptive = op_desc.GetAttr<bool>("adaptive");
}
if (op_desc.HasAttr("ceil_mode")) {
param_.ceil_mode = op_desc.GetAttr<bool>("ceil_mode");
}
if (op_desc.HasAttr("use_quantizer")) {
param_.use_quantizer = op_desc.GetAttr<bool>("use_quantizer");
}
// param_.data_format = op_desc.GetAttr<bool>("data_format");
return true;
}
......
......@@ -32,12 +32,11 @@ bool ReluOp::InferShape() const {
bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("Input").front())->Get<lite::Tensor>());
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
CHECK(param_.input);
CHECK(param_.output);
kernel_->SetParam(param_);
return true;
}
......
......@@ -37,7 +37,7 @@ bool SplitOp::InferShape() const {
const auto &sections = param_.sections;
const int outs_number = outs.size();
std::vector<lite::DDimHvy> outs_dims;
std::vector<lite::DDim> outs_dims;
outs_dims.reserve(outs_number);
if (num > 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册