diff --git a/src/common/types.h b/src/common/types.h index ae76c953aa573a0bd59df6dae74219f1b9ad5873..227151adbbd5054c5beca30245a97c415f3d7984 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -14,6 +14,10 @@ limitations under the License. */ #pragma once; +#include +#include +#include + namespace paddle_mobile { enum class Precision : int { FP32 = 0 }; @@ -67,4 +71,41 @@ enum PMStatus { PMUnImplError = 0x07, /*!< Unimplement error. */ PMWrongDevice = 0x08 /*!< un-correct device. */ }; + +static const std::string G_OP_TYPE_CONV = "conv2d"; +static const std::string G_OP_TYPE_BATCHNORM = "batch_norm"; +static const std::string G_OP_TYPE_BOX_CODER = "box_coder"; +static const std::string G_OP_TYPE_CONCAT = "concat"; +static const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; +static const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "FusionConvAddRelu"; +static const std::string G_OP_TYPE_FC = "fc"; +static const std::string G_OP_TYPE_LRN = "lrn"; +static const std::string G_OP_TYPE_MUL = "mul"; +static const std::string G_OP_TYPE_MULTICLASS_NMS = "multiclass_nms"; +static const std::string G_OP_TYPE_POOL2D = "pool2d"; +static const std::string G_OP_TYPE_PRIOR_BOX = "prior_box"; +static const std::string G_OP_TYPE_RELU = "relu"; +static const std::string G_OP_TYPE_RESHAPE = "reshape"; +static const std::string G_OP_TYPE_SIGMOID = "sigmoid"; +static const std::string G_OP_TYPE_SOFTMAX = "softmax"; +static const std::string G_OP_TYPE_TRANSPOSE = "transpose"; +static const std::string G_OP_TYPE_SPLIT = "split"; +static const std::string G_OP_TYPE_FEED = "feed"; +static const std::string G_OP_TYPE_FETCH = "fetch"; + +static std::unordered_map< + std::string, std::pair, std::vector>> + op_input_output_key = {{G_OP_TYPE_CONV, {{"Input"}, {"Output"}}}, + {G_OP_TYPE_RELU, {{"X"}, {"Out"}}}, + {G_OP_TYPE_SOFTMAX, {{"X"}, {"Out"}}}, + {G_OP_TYPE_MUL, {{"X"}, {"Out"}}}, + {G_OP_TYPE_ELEMENTWISE_ADD, {{"X", "Y"}, {"Out"}}}, + {G_OP_TYPE_POOL2D, {{"X"}, {"Out"}}}, + {G_OP_TYPE_BATCHNORM, {{"X"}, {"Y"}}}, + {G_OP_TYPE_LRN, {{"X"}, {"Out"}}}, + {G_OP_TYPE_CONCAT, {{"X"}, {"Out"}}}, + {G_OP_TYPE_SPLIT, {{"X"}, {"Out"}}}, + {G_OP_TYPE_FEED, {{"X"}, {"Out"}}}, + {G_OP_TYPE_FETCH, {{"X"}, {"Out"}}}, + {G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}}}; } // namespace paddle_mobile diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index dfdf0af79ac98d0bb79c7da3fdcc872341417b87..808002d4c8f3193744ef68c1db881a787d19b133 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -23,6 +23,7 @@ vector OperatorBase::GetOutKeys() const { auto it = op_input_output_key.find(type_); if (it == op_input_output_key.end()) { DLOG << type_ << " has no outputs"; + return {}; } return it->second.second; } diff --git a/src/framework/operator.h b/src/framework/operator.h index 0d61761775394b3e23825db2883c2d4a2c071f33..6194e5dcfff2c0320318d3e30e5c8204cd71a749 100644 --- a/src/framework/operator.h +++ b/src/framework/operator.h @@ -38,42 +38,46 @@ namespace paddle_mobile { namespace framework { using std::string; using std::vector; -static std::unordered_map< - std::string, std::pair, std::vector>> - op_input_output_key = {{"conv2d", {{"Input"}, {"Output"}}}, - {"relu", {{"X"}, {"Out"}}}, - {"softmax", {{"X"}, {"Out"}}}, - {"mul", {{"X"}, {"Out"}}}, - {"elementwise_add", {{"X", "Y"}, {"Out"}}}, - {"pool2d", {{"X"}, {"Out"}}}, - {"batch_norm", {{"X"}, {"Y"}}}, - {"lrn", {{"X"}, {"Out"}}}, - {"concat", {{"X"}, {"Out"}}}, - {"feed", {{"X"}, {"Out"}}}, - {"fetch", {{"X"}, {"Out"}}}, - {"reshape", {{"X"}, {"Out"}}}}; template class OperatorBase : PaddleMobileObject { public: + /* + * @b op 基类的实例化方法, op 获取到了 输入、参数以及提前分配好的输出 tensor + * */ OperatorBase(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, std::shared_ptr scope); virtual ~OperatorBase() {} void Run() const; - vector GetOutKeys() const; + std::vector GetOutKeys() const; virtual void RunImpl() const = 0; - virtual void InferShape() const = 0; + /* + * @b op 运算所需的输入, 如上一层的输出结果、卷积核 + * */ const VariableNameMap &Inputs() const { return inputs_; } + /* + * @b op 的输出, 内存会提前被分配好, 运算结果会被存到分配好的内存内 + * */ const VariableNameMap &Outputs() const { return outputs_; } + /* + * @b op 类型 + * */ const std::string &Type() const { return type_; } + /* + * @b op 运算所需要用到的参数: 如 conv 运算所需要用到的 stride + * */ const AttributeMap &Attrs() const { return attrs_; } void ClearVariables(const std::vector &var_names) const { if (this->scope_) { this->scope_->EraseVars(var_names); } } + /* + * @b 根据输入形状和参数计算出输出形状 + * */ + virtual void InferShape() const = 0; protected: std::shared_ptr scope_; @@ -86,6 +90,9 @@ class OperatorBase : PaddleMobileObject { void CheckAllInputOutputSet() const; }; +/* + * @b 这个类为所有带有运算的 op 的父类, 这个 op 继承与 OperatorBase + * */ template class OperatorWithKernel : public OperatorBase { public: @@ -98,11 +105,18 @@ class OperatorWithKernel : public OperatorBase { virtual void InferShape() const = 0; }; +/* + * @b 所有kernel的父类 + * */ template class OpKernelBase : PaddleMobileObject { public: + /* + * @b 所有kernel 需实现 Compute 方法 + * @p para 这个参数为 kernel 运算时所需要用到参数组成的一个结构体, + * 所有结构体存在与: paddle-mobile/src/operators/op_param.h + * */ virtual void Compute(const P ¶) const = 0; - virtual ~OpKernelBase() = default; }; @@ -119,8 +133,8 @@ class FusionOpMatcher : PaddleMobileObject { virtual std::string Type() = 0; - virtual void FolderNodes(Node &node) { - node.Folder(node_.Depth(), Type(), {}); + virtual void FolderNodes(Node *node) { + node->Folder(node_.Depth(), Type(), {}); } virtual Node &BeginNode() { return node_; } diff --git a/src/framework/program/program-optimize/node.cpp b/src/framework/program/program-optimize/node.cpp index ac7137a47f3c339946eb0558b224ddab1caf8007..820fa6a443c62c4cfdb38f4d42e6d7805371c2d3 100644 --- a/src/framework/program/program-optimize/node.cpp +++ b/src/framework/program/program-optimize/node.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include +#include "framework/operator.h" #include "framework/program/program-optimize/node.h" namespace paddle_mobile { @@ -73,24 +74,86 @@ void Node::OpDescs(uint index, } void Node::OpDescs(std::vector> *op_desc, - Node *node) { - auto iter = std::find(op_desc->begin(), op_desc->end(), this->op_desc_); + Node *node, bool adding_thread, int thread_num) { + bool can_add_split = false; + if (outputs_.size() > 1) { + can_add_split = true; + if (op_input_output_key[op_desc_->type_].second.size() != 1) { + DLOG << "当前 op desc 输出数不为 1 "; + can_add_split = false; + } + for (const auto &output : outputs_) { + if (op_input_output_key.find(output->op_desc_->type_) != + op_input_output_key.end()) { + auto inputs_and_outputs = op_input_output_key[output->op_desc_->type_]; + auto outputs_of_output = + output->op_desc_->Output(inputs_and_outputs.second[0]); + auto inputs_of_output = + output->op_desc_->Input(inputs_and_outputs.first[0]); + for (int i = 0; i < inputs_of_output.size(); ++i) { + std::string input_of_output = inputs_of_output[i]; + for (int j = 0; j < outputs_of_output.size(); ++j) { + std::string output_of_output = outputs_of_output[j]; + if (input_of_output == output_of_output) { + DLOG << "output的 output 包含 input" << input_of_output; + can_add_split = false; + break; + } + } + } + } else { + DLOG << "找不到 这个 op 类型: " << output->op_desc_->type_; + can_add_split = false; + } + } + } + if (inputs_.size() > 1 && node != inputs_.back()) { return; } else if (inputs_.size() > 1 && node == inputs_.back()) { + adding_thread = false; op_desc->push_back(this->op_desc_); } else { op_desc->push_back(this->op_desc_); } + if (adding_thread) { + Attribute attr; + attr.Set(thread_num); + this->op_desc_->attrs_["thread"] = attr; + } - for (auto &output : outputs_) { - output->OpDescs(op_desc, this); + if (can_add_split) { + adding_thread = true; + std::shared_ptr split_op_desc = + std::make_shared(); + split_op_desc->type_ = G_OP_TYPE_SPLIT; + auto outputs = this->op_desc_->Output( + op_input_output_key[this->op_desc_->Type()].second[0]); + + split_op_desc->inputs_ = { + {op_input_output_key[G_OP_TYPE_SPLIT].first[0], outputs}}; + auto &split_outputs = + split_op_desc->outputs_[op_input_output_key[G_OP_TYPE_SPLIT].second[0]]; + for (const auto &output : outputs_) { + split_outputs.push_back(outputs[0]); + } + DLOG << "add split"; + op_desc->push_back(split_op_desc); + } + + for (int i = 0; i < outputs_.size(); ++i) { + auto &output = outputs_[i]; + if (can_add_split) { + output->OpDescs(op_desc, this, adding_thread, i); + } else { + output->OpDescs(op_desc, this, adding_thread, thread_num); + } } } std::vector> Node::OpDescs() { std::vector> op_descs; - OpDescs(&op_descs, this); + OpDescs(&op_descs, this, false, 0); return op_descs; } diff --git a/src/framework/program/program-optimize/node.h b/src/framework/program/program-optimize/node.h index da9a7ef56941dde301c10d1552d52d6c600b2bfe..5dd1a3acbf5e662901bf7591de5f12cc7f47ef76 100644 --- a/src/framework/program/program-optimize/node.h +++ b/src/framework/program/program-optimize/node.h @@ -42,13 +42,13 @@ class Node : PaddleMobileObject { std::map> change_map); std::vector> OpDescs(uint size); std::vector> OpDescs(); - void OpDescs(std::vector> *op_desc, - Node *node); std::shared_ptr OpDesc() { return op_desc_; } std::string BeginType() { return type_; } void Description(); private: + void OpDescs(std::vector> *op_desc, + Node *node, bool adding_thread, int thread_num); void OpDescs(uint size, std::vector> *op_desc); void To(int index, std::shared_ptr); diff --git a/src/framework/program/program-optimize/program_optimize.cpp b/src/framework/program/program-optimize/program_optimize.cpp index fd7edeed1b60285850879666a1a34fbf4004472f..737fed9bd56bdec92774ba364e035ba581258e57 100644 --- a/src/framework/program/program-optimize/program_optimize.cpp +++ b/src/framework/program/program-optimize/program_optimize.cpp @@ -19,7 +19,7 @@ namespace paddle_mobile { namespace framework { -std::shared_ptr ProgramOptimize::Optimize() {} +// std::shared_ptr ProgramOptimize::Optimize() {} std::shared_ptr ProgramOptimize::FushionOptimize( std::shared_ptr ori_des) { @@ -86,7 +86,7 @@ std::shared_ptr ProgramOptimize::FushionOptimize( // DLOG << " match success " << " fusion node: \n" << // matcher->BeginNode() << "\nsub node: \n" << *sub_node; // DLOG << "match node\n"<< *match_node; - matcher->FolderNodes(*match_node); + matcher->FolderNodes(match_node.get()); // DLOG << " after match node\n"<< *match_node; // match_node->Description(); diff --git a/src/framework/program/program-optimize/program_optimize.h b/src/framework/program/program-optimize/program_optimize.h index 9dc4b19eba3476254e69ecf547472691c908452a..3839fa1e36ba0bbe580dac05af2c7ba6185f9b6c 100644 --- a/src/framework/program/program-optimize/program_optimize.h +++ b/src/framework/program/program-optimize/program_optimize.h @@ -27,7 +27,6 @@ namespace framework { class ProgramOptimize { public: ProgramOptimize() {} - std::shared_ptr Optimize(); std::shared_ptr FushionOptimize( std::shared_ptr ori_des); diff --git a/src/common/io.cpp b/src/io.cpp similarity index 91% rename from src/common/io.cpp rename to src/io.cpp index fc1466237e938a9ded5862d6cc3c7597766197e4..002e73b79648320c229786f8492f4c0e8b299d83 100644 --- a/src/common/io.cpp +++ b/src/io.cpp @@ -15,11 +15,13 @@ limitations under the License. */ #include "io.h" #include #include -#include "common/enforce.h" #include "common/log.h" + +#include "common/enforce.h" #include "framework/framework.pb-c.h" #include "framework/lod_tensor.h" #include "framework/operator.h" +#include "framework/program/program-optimize/program_optimize.h" #include "framework/program/program_desc.h" #include "framework/program/var_desc.h" #include "framework/scope.h" @@ -166,7 +168,7 @@ void Loader::LoadVar(framework::Variable *variable, template const framework::Program Loader::Load( - const std::string &dirname) { + const std::string &dirname, bool optimize) { std::string model_filename = dirname + "/__model__"; PaddleMobile__Framework__Proto__ProgramDesc *c_program; uint8_t *buf = NULL; @@ -203,7 +205,6 @@ const framework::Program Loader::Load( if (var_desc->Persistable() && var_desc->Type() != framework::VARTYPE_TYPE_FEED_MINIBATCH && var_desc->Type() != framework::VARTYPE_TYPE_FETCH_LIST) { - // DLOG << "to load var "; auto dim = var_desc->Tensor_desc().Dims(); auto tensor = var->GetMutable(); tensor->Resize(framework::make_ddim(dim)); @@ -219,8 +220,13 @@ const framework::Program Loader::Load( } } } + // originProgramDesc->Description("program: "); - originProgramDesc->Description("program: "); + if (optimize) { + framework::ProgramOptimize program_optimize; + program.optimizeProgram = + program_optimize.FushionOptimize(originProgramDesc); + } paddle_mobile__framework__proto__program_desc__free_unpacked(c_program, NULL); return program; @@ -231,33 +237,9 @@ template class Loader; #pragma mark - executor template -Executor::Executor(const framework::Program p) : program_(p) { - if (use_optimize_) { - to_predict_program_ = program_.optimizeProgram; - } else { - to_predict_program_ = program_.originProgram; - } - - const std::vector> blocks = - to_predict_program_->Blocks(); - for (int i = 0; i < blocks.size(); ++i) { - std::shared_ptr block_desc = blocks[i]; - std::vector> ops = block_desc->Ops(); - for (int j = 0; j < ops.size(); ++j) { - std::shared_ptr op = ops[j]; - auto op_base = framework::OpRegistry::CreateOp( - op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(), - program_.scope); - op_base->InferShape(); - ops_of_block_[*block_desc.get()].push_back(op_base); - } - } - InitMemory(); -} - -template -Executor::Executor(const framework::Program p, int batch_size) - : program_(p), batch_size_(batch_size) { +Executor::Executor(const framework::Program p, int batch_size, + bool use_optimize) + : program_(p), batch_size_(batch_size), use_optimize_(use_optimize) { if (use_optimize_) { to_predict_program_ = program_.optimizeProgram; } else { @@ -389,7 +371,7 @@ void Executor::InitMemory() { } template -void Executor::predict(const framework::Tensor &t, int block_id) { +void Executor::Predict(const framework::Tensor &t, int block_id) { framework::Variable *g_feed_value = program_.scope->Var("feed"); framework::Tensor *feed_tensor = g_feed_value->GetMutable(); @@ -404,11 +386,11 @@ void Executor::predict(const framework::Tensor &t, int block_id) { } template -std::vector::Ptype> Executor::predict( +std::vector::Ptype> Executor::Predict( const std::vector &input, const std::vector &dims) { framework::Tensor tensor(input, framework::make_ddim(dims)); - predict(tensor, 0); + Predict(tensor, 0); framework::Variable *g_feed_value = program_.scope->Var("col"); auto feed_tensor = g_feed_value->GetMutable(); diff --git a/src/common/io.h b/src/io.h similarity index 84% rename from src/common/io.h rename to src/io.h index 678441a9e05dacf4e1f6a41705c1499c3ea99238..de2d359bf58d1ad328defd2f51e87e2d6bfe6295 100644 --- a/src/common/io.h +++ b/src/io.h @@ -30,7 +30,8 @@ namespace paddle_mobile { template class Loader : PaddleMobileObject { public: - const framework::Program Load(const std::string &dirname); + const framework::Program Load(const std::string &dirname, + bool optimize = true); private: void LoadVar(framework::Variable *variable, @@ -45,13 +46,12 @@ class Executor { Executor() = default; - Executor(const framework::Program p); + Executor(const framework::Program p, int batch_size = 1, + bool use_optimize = true); - Executor(const framework::Program p, int batch_size); + // std::shared_ptr Predict(framework::Tensor &t); - std::shared_ptr predict(framework::Tensor &t); - - std::vector predict(const std::vector &input, + std::vector Predict(const std::vector &input, const std::vector &dims); protected: @@ -61,7 +61,7 @@ class Executor { framework::Program program_; int batch_size_ = 1; std::shared_ptr to_predict_program_; - void predict(const framework::Tensor &t, int block_id); + void Predict(const framework::Tensor &t, int block_id); std::map>>> ops_of_block_; diff --git a/src/operators/conv_op.cpp b/src/operators/conv_op.cpp index 148b0f69f9633f1d82979ab324c5997fb6fcb1c1..bfddcf14acbba016c4e4333e05fcc7dd6eebc509 100644 --- a/src/operators/conv_op.cpp +++ b/src/operators/conv_op.cpp @@ -21,13 +21,6 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -int ConvOutputSize(int input_size, int filter_size, int dilation, int padding, - int stride) { - const int dkernel = dilation * (filter_size - 1) + 1; - int output_size = (input_size + 2 * padding - dkernel) / stride + 1; - return output_size; -} - template void ConvOp::InferShape() const { // std::cout << " begin get dims: " << std::endl; diff --git a/src/operators/conv_op.h b/src/operators/conv_op.h index 1557f2f06eed8237f7b7e9ff44adc233129a49a3..f15f286b606db1403b0e0e609bfc38caac2c5105 100644 --- a/src/operators/conv_op.h +++ b/src/operators/conv_op.h @@ -44,5 +44,12 @@ class ConvOp : public framework::OperatorWithKernel { ConvParam param_; }; +inline int ConvOutputSize(int input_size, int filter_size, int dilation, + int padding, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + 2 * padding - dkernel) / stride + 1; + return output_size; +} + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/depthwise_conv_op.cpp b/src/operators/depthwise_conv_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2538298175c5ea40d7e44338caee853a73c089c4 --- /dev/null +++ b/src/operators/depthwise_conv_op.cpp @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/depthwise_conv_op.h" +#include +#include "framework/data_type.h" +#include "framework/op_proto_maker.h" +#include "framework/op_registry.h" +#include "operators/conv_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void DepthwiseConvOp::InferShape() const { + auto in_dims = param_.Input()->dims(); + auto filter_dims = param_.Filter()->dims(); + const std::vector &strides = param_.Strides(); + std::vector paddings = param_.Paddings(); + int groups = param_.Groups(); + std::vector dilations = param_.Dilations(); + + PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() && + dilations.size() == paddings.size() && + paddings.size() == strides.size()), + "ConvParam is not suitable"); + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); + } + + framework::DDim ddim = framework::make_ddim(output_shape); + param_.Output()->Resize(ddim); +} + +template class DepthwiseConvOp; + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +USE_OP(depthwise_conv2d); +REGISTER_OPERATOR(depthwise_conv2d, ops::DepthwiseConvOp); diff --git a/src/operators/depthwise_conv_op.h b/src/operators/depthwise_conv_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c47fa0ffcacd54a5ddf7280419ca1170173bde1b --- /dev/null +++ b/src/operators/depthwise_conv_op.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/depthwise_conv_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template +class DepthwiseConvOp : public framework::OperatorWithKernel { + public: + DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel(type, inputs, outputs, attrs, + scope), + param_(inputs, outputs, attrs, *scope) {} + + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape() const override; + + void RunImpl() const { + operators::DepthwiseConvKernel kernel; + kernel.Compute(param_); + this->ClearVariables({"Filter", "Input"}); + } + + private: + ConvParam param_; +}; + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/fusion_conv_add_relu_op.h b/src/operators/fusion_conv_add_relu_op.h index 39f11dd708c56c550a41545f5e4bf93b78b7fa51..1fa3399cf22df76b429d89fa89b0cb620257271f 100644 --- a/src/operators/fusion_conv_add_relu_op.h +++ b/src/operators/fusion_conv_add_relu_op.h @@ -23,18 +23,18 @@ namespace operators { class FushionConvAddReluOpMatcher : public framework::FusionOpMatcher { public: FushionConvAddReluOpMatcher() { - node_ = framework::Node("conv2d"); - node_ > std::make_shared("elementwise_add") > - std::make_shared("relu"); + node_ = framework::Node(G_OP_TYPE_CONV); + node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > + std::make_shared(G_OP_TYPE_RELU); } void FolderNodes(framework::Node &node) { std::vector> origin_descs = node.OpDescs(node_.Depth()); - node.Folder(node_.Depth(), Type(), {{"elementwise_add", {"Y", "Z"}}}); + node.Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}); } - - std::string Type() { return "FusionConvAddRelu"; } + std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_RELU; } }; class FusionFcOp { diff --git a/src/operators/fusion_fc_op.h b/src/operators/fusion_fc_op.h index 0ed5a2b4d5e6cfef4a152ba14596c1f591c378b3..fb49fa61b202401871b8c6c18e51b15ab42dc1e4 100644 --- a/src/operators/fusion_fc_op.h +++ b/src/operators/fusion_fc_op.h @@ -28,17 +28,18 @@ using std::vector; class FusionFcMatcher : public framework::FusionOpMatcher { public: FusionFcMatcher() { - node_ = framework::Node("mul"); - node_ > std::make_shared("elementwise_add"); + node_ = framework::Node(G_OP_TYPE_MUL); + node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD); } void FolderNodes(framework::Node &node) { vector> origin_descs = node.OpDescs(node_.Depth()); - node.Folder(node_.Depth(), Type(), {{"elementwise_add", {"Y", "Z"}}}); + node.Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}); } - std::string Type() { return "fc"; } + std::string Type() { return G_OP_TYPE_FC; } }; template diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index 1e2572b984734dcd88be7c1c750fc0f07448e66d..f04b8156c9d3c88520b1c74b60a20f41e7fedc98 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -17,19 +17,6 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -bool IsExpand(const std::vector &filter_dim, - const std::vector &strides, const std::vector &paddings, - const std::vector &dilations) { - bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; - for (size_t j = 0; j < strides.size(); ++j) { - filter_1 = filter_1 && (static_cast(filter_dim[j + 2]) == 1); - strides_1 = strides_1 && (strides[j] == 1); - padding_0 = padding_0 && (paddings[j] == 0); - dilation_1 = dilation_1 && (dilations[j] == 1); - } - return !(filter_1 && strides_1 && padding_0 && dilation_1); -} - template <> void ConvKernel::Compute(const ConvParam ¶m) const { LOG(kLOG_DEBUG) << param; diff --git a/src/operators/kernel/arm/depthwise_conv_kernel.cpp b/src/operators/kernel/arm/depthwise_conv_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..73aa9953cfcbc8efe0ed9d3bf094455cfbb4fe6c --- /dev/null +++ b/src/operators/kernel/arm/depthwise_conv_kernel.cpp @@ -0,0 +1,126 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/kernel/depthwise_conv_kernel.h" +#include "operators/kernel/conv_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +void DepthwiseConvKernel::Compute(const ConvParam ¶m) const { + LOG(kLOG_DEBUG) << param; + + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor *output = param.Output(); + output->mutable_data(); + + int groups = param.Groups(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + DLOG << " compute end get Attrs " << strides[0]; + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + std::vector output_shape_vec(framework::vectorize(output->dims())); + + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + DLOG << " col_shape = " << col_shape; + DLOG << " col_matrix_shape = " << col_matrix_shape; + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + DLOG << " input_shape = " << input_shape; + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + DLOG << " filter.dims() = " << filter.dims(); + + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + DLOG << " in_batch.dims() = " << in_batch.dims(); + DLOG << " out_batch.dims() = " << out_batch.dims(); + + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + DLOG << " out_slice " << out_slice.dims(); + DLOG << " filter_slice " << filter_slice.dims(); + DLOG << " col_matrix " << col_matrix.dims(); + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(0)); + auto filter_ptr = filter_slice.data(); + } + } +} + +template class DepthwiseConvKernel; + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/arm/relu_kernel.cpp b/src/operators/kernel/arm/relu_kernel.cpp index e0badea51e7da4f3119c9303b259259ba8b48e80..586d981175184e2da03f2949390932b888d67f4a 100644 --- a/src/operators/kernel/arm/relu_kernel.cpp +++ b/src/operators/kernel/arm/relu_kernel.cpp @@ -25,6 +25,9 @@ struct ReluFunctor { inline T operator()(T in) const { return in > 0 ? in : 0; } }; +/* + * @b 特化到具体平台的实现, param 从 op 层传入 + * */ template <> void ReluKernel::Compute(const ReluParam ¶m) const { const auto *input_x = param.InputX(); diff --git a/src/operators/kernel/conv_kernel.h b/src/operators/kernel/conv_kernel.h index a756e2d2417cc147cb0559f946a6a70085860ecb..d43a174ffdbf0ca6dbb39e463b8e97652c7b0daf 100644 --- a/src/operators/kernel/conv_kernel.h +++ b/src/operators/kernel/conv_kernel.h @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "framework/operator.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" @@ -23,12 +24,28 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -using namespace framework; +using framework::OpKernelBase; template -class ConvKernel : public framework::OpKernelBase { +class ConvKernel : public OpKernelBase { public: void Compute(const ConvParam ¶m) const; }; + +inline bool IsExpand(const std::vector &filter_dim, + const std::vector &strides, + const std::vector &paddings, + const std::vector &dilations) { + bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; + for (size_t j = 0; j < strides.size(); ++j) { + filter_1 = filter_1 && (static_cast(filter_dim[j + 2]) == 1); + strides_1 = strides_1 && (strides[j] == 1); + padding_0 = padding_0 && (paddings[j] == 0); + dilation_1 = dilation_1 && (dilations[j] == 1); + } + + return !(filter_1 && strides_1 && padding_0 && dilation_1); +} + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/depthwise_conv_kernel.h b/src/operators/kernel/depthwise_conv_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..43ddfb25cd859a7e937577221215d8352b846bff --- /dev/null +++ b/src/operators/kernel/depthwise_conv_kernel.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "framework/operator.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/vol2col.h" +#include "operators/op_param.h" + +#pragma once; + +namespace paddle_mobile { +namespace operators { + +using framework::OpKernelBase; + +template +class DepthwiseConvKernel : public OpKernelBase { + public: + void Compute(const ConvParam ¶m) const; +}; +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 02bda7147aa77648cf6a159bdb11d2f3e42ee304..0ce187c084975c53e433b9428ad14bf11212a5a1 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -696,6 +696,9 @@ class ReshapeParam : public OpParam { bool inplace_; }; +/* + * @b op 层实例化好这个 param 传递给 kernel 层使用 + * */ class ReluParam : public OpParam { public: ReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs, @@ -725,7 +728,6 @@ class FushionFcParam : public OpParam { y_num_col_dims_ = GetAttr("y_num_col_dims", attrs); axis_ = GetAttr("axis", attrs); } - const Tensor *InputX() const { return input_x_; } const Tensor *InputY() const { return input_y_; } diff --git a/src/operators/relu_op.cpp b/src/operators/relu_op.cpp index 5f861579ab47f09b55f8d255103558a5209fedb9..21bcc605282ffc590025e87b609cccc855a631d1 100644 --- a/src/operators/relu_op.cpp +++ b/src/operators/relu_op.cpp @@ -25,6 +25,11 @@ template class ReluOp; } // namespace operators } // namespace paddle_mobile +/* + * @b 每一个 op 都需要注册一下的, + * USE_OP的参数 和 REGISTER_OPERATOR的第一个参数 + * 都是需要和model中类型对应起来的 + * */ namespace ops = paddle_mobile::operators; USE_OP(relu); REGISTER_OPERATOR(relu, ops::ReluOp); diff --git a/src/operators/relu_op.h b/src/operators/relu_op.h index 6c3a614a1a0316e6b487532739f01bf7027557bc..7be8cd249cb22255dff237da6c8653e6237bbc3f 100644 --- a/src/operators/relu_op.h +++ b/src/operators/relu_op.h @@ -28,6 +28,9 @@ using paddle_mobile::framework::Tensor; template class ReluOp : public framework::OperatorWithKernel { public: + /* + * @b op 的实例化方法, 需要调用父类的实例化方法, 以及实例化自己的参数结构体 + * */ ReluOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, std::shared_ptr scope) @@ -35,6 +38,9 @@ class ReluOp : public framework::OperatorWithKernel { scope), param_(inputs, outputs, attrs, *scope) {} + /* + * @b op 进行运算, 调用相应的 kernel 进行运算 + * */ void RunImpl() const { operators::ReluKernel kernel; kernel.Compute(param_); @@ -44,6 +50,10 @@ class ReluOp : public framework::OperatorWithKernel { void InferShape() const override; protected: + /* + * @b Relu kernel 进行运算时所需要用到参数的结构体, + * 结构体定义在: paddle-mobile/src/operators/op_param.h + * */ ReluParam param_; }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c80d34c22e566df0397105bba75022218e5f85f9..37c0de1496bbc272a56abebe43516c2da4250fbf 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -99,3 +99,7 @@ target_link_libraries(test-mobilenet paddle-mobile) # gen test ADD_EXECUTABLE(test-sigmoid operators/test_sigmoid_op.cpp test_include.h) target_link_libraries(test-sigmoid paddle-mobile) + +# gen test +ADD_EXECUTABLE(test-depthwise-conv-op operators/test_depthwise_conv_op.cpp test_helper.h test_include.h executor_for_test.h) +target_link_libraries(test-depthwise-conv-op paddle-mobile) diff --git a/test/executor_for_test.h b/test/executor_for_test.h index c69eba222fbe39d6627a0f03bf1621e7db4d491e..ce3c84e986eb7ef5e9602209cedb3dbabbf06e85 100644 --- a/test/executor_for_test.h +++ b/test/executor_for_test.h @@ -17,9 +17,9 @@ limitations under the License. */ #include #include -#include "common/io.h" #include "common/log.h" #include "framework/op_registry.h" +#include "io.h" #include "operators/conv_op.h" #include "operators/elementwise_add_op.h" #include "operators/pool_op.h" @@ -73,10 +73,11 @@ class Executor4Test : public Executor { } } } + this->InitMemory(); } template - vector> predict(const vector &ts, + vector> Predict(const vector &ts, const vector &input_names, const vector &output_names, const vector &ddims) { @@ -115,7 +116,7 @@ class Executor4Test : public Executor { return output_tensor_sptrs; } - std::shared_ptr predict(const Tensor &t, string input, string output, + std::shared_ptr Predict(const Tensor &t, string input, string output, const DDim &dDim) { auto scope = this->program_.scope; Variable *g_feed_value = scope->Var(input); diff --git a/test/framework/test_load.cpp b/test/framework/test_load.cpp index fe403b55a180d3094446d7254c56d35e1b60edeb..95357547e1b93d3060481b55eaf46c919496785d 100644 --- a/test/framework/test_load.cpp +++ b/test/framework/test_load.cpp @@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "common/io.h" +#include "../test_helper.h" +#include "io.h" int main() { paddle_mobile::Loader loader; - // ../../../test/models/googlenet // ../../../test/models/mobilenet - auto program = loader.Load(std::string("../models/googlenet")); + auto program = loader.Load(g_googlenet); + program.optimizeProgram->Description("program desc: "); return 0; } diff --git a/test/framework/test_optimize.cpp b/test/framework/test_optimize.cpp index 4c4dc6eb3ee8babc812584e766ab2f4eb3580160..f0392cfec02c8ea764cd3d6dc9f50b2415c39e2c 100644 --- a/test/framework/test_optimize.cpp +++ b/test/framework/test_optimize.cpp @@ -12,14 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "common/io.h" +#include "../test_helper.h" #include "framework/program/program-optimize/node.h" #include "framework/program/program-optimize/program_optimize.h" +#include "io.h" int main() { paddle_mobile::Loader loader; // "../../../test/models/googlenet" - auto program = loader.Load("../models/googlenet"); + auto program = loader.Load(g_googlenet); paddle_mobile::framework::ProgramOptimize optimize; // program.originProgram->Description("origin"); auto optimize_program = optimize.FushionOptimize(program.originProgram); diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index d52f080277aceb009c887be7e149df0aafb93b7c..139579e9116651c15764997d962b7d2622532146 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -21,16 +21,16 @@ int main() { // ../../../test/models/googlenet // ../../../test/models/mobilenet auto time1 = time(); - auto program = loader.Load(std::string("../models/googlenet")); + auto program = loader.Load(g_googlenet, false); auto time2 = time(); DLOG << "load cost :" << time_diff(time1, time1) << "ms"; - paddle_mobile::Executor executor(program, 1); + paddle_mobile::Executor executor(program, 1, false); std::vector input; std::vector dims{1, 3, 224, 224}; GetInput(g_test_image_1x3x224x224, &input, dims); auto time3 = time(); - executor.predict(input, dims); + executor.Predict(input, dims); auto time4 = time(); DLOG << "predict cost :" << time_diff(time3, time4) << "ms"; return 0; diff --git a/test/net/test_mobilenet.cpp b/test/net/test_mobilenet.cpp index e686ad85be767c0000a74d7bc11add89d299226e..b5d925227e4ccf8925440ee36c3f9a6e02567f91 100644 --- a/test/net/test_mobilenet.cpp +++ b/test/net/test_mobilenet.cpp @@ -19,10 +19,10 @@ limitations under the License. */ int main() { paddle_mobile::Loader loader; auto time1 = time(); - auto program = loader.Load(g_mobilenet); + auto program = loader.Load(g_mobilenet, false); auto time2 = time(); DLOG << "load cost :" << time_diff(time1, time1) << "ms"; - paddle_mobile::Executor executor(program, 1); + paddle_mobile::Executor executor(program, 1, false); std::vector dims{1, 3, 224, 224}; Tensor input_tensor; @@ -32,7 +32,7 @@ int main() { std::vector input(input_tensor.data(), input_tensor.data() + input_tensor.numel()); auto time3 = time(); - executor.predict(input, dims); + executor.Predict(input, dims); auto time4 = time(); DLOG << "predict cost :" << time_diff(time3, time4) << "ms"; return 0; diff --git a/test/net/test_yolo.cpp b/test/net/test_yolo.cpp index ab61fb250e3083a106bc6967a29e145e89606c74..c82443e23953def917826fe4ec3b2c484b588f59 100644 --- a/test/net/test_yolo.cpp +++ b/test/net/test_yolo.cpp @@ -21,10 +21,10 @@ int main() { // ../../../test/models/googlenet // ../../../test/models/mobilenet auto time1 = time(); - auto program = loader.Load(g_yolo); + auto program = loader.Load(g_yolo, false); auto time2 = time(); DLOG << "load cost :" << time_diff(time1, time1) << "ms"; - paddle_mobile::Executor executor(program, 1); + paddle_mobile::Executor executor(program, 1, false); std::vector dims{1, 3, 227, 227}; Tensor input_tensor; @@ -34,7 +34,7 @@ int main() { std::vector input(input_tensor.data(), input_tensor.data() + input_tensor.numel()); auto time3 = time(); - executor.predict(input, dims); + executor.Predict(input, dims); auto time4 = time(); DLOG << "predict cost :" << time_diff(time3, time4) << "ms"; return 0; diff --git a/test/operators/test_batchnorm_op.cpp b/test/operators/test_batchnorm_op.cpp index ba2e06b80b418b62d2dc445fe87119ed84bfe4f6..38d9f624909fd645c78ae56a5d9efff9fa961795 100644 --- a/test/operators/test_batchnorm_op.cpp +++ b/test/operators/test_batchnorm_op.cpp @@ -128,8 +128,7 @@ int main() { DLOG << "----------**********----------"; DLOG << "begin to run BatchNormOp Test"; paddle_mobile::Loader loader; - auto program = loader.Load(std::string( - "../../test/models/image_classification_resnet.inference.model")); + auto program = loader.Load(std::string(g_resnet)); /// input x (4,10,2,2) paddle_mobile::framework::Tensor inputx1; diff --git a/test/operators/test_box_coder_op.cpp b/test/operators/test_box_coder_op.cpp index b7695c91dfb394645adfddcf1e11b96dd45a3c94..dac0d0b8051ec1790d6982a13ea31ef3f4a64242 100644 --- a/test/operators/test_box_coder_op.cpp +++ b/test/operators/test_box_coder_op.cpp @@ -116,7 +116,7 @@ int main() { DLOG << "----------**********----------"; DLOG << "begin to run BoxCoderOp Test"; paddle_mobile::Loader loader; - auto program = loader.Load(std::string("../../test/models/mobilenet+ssd")); + auto program = loader.Load(std::string(g_mobilenet_ssd)); paddle_mobile::framework::Tensor priorbox; SetupTensor(&priorbox, {1917, 4}, static_cast(0), diff --git a/test/operators/test_concat_op.cpp b/test/operators/test_concat_op.cpp index a9bb072f1e941d15a058825b14fb007507f4d610..7a106b03c44c57fa7ef0f9282434717efd602b5c 100644 --- a/test/operators/test_concat_op.cpp +++ b/test/operators/test_concat_op.cpp @@ -57,7 +57,7 @@ int main() { auto out_ddim = paddle_mobile::framework::make_ddim({3, 100, 2, 2}); out_ddims.push_back(out_ddim); - auto output = executor.predict(input_tensors, input_names, + auto output = executor.Predict(input_tensors, input_names, output_names, out_ddims); auto output0_data = output[0]->data(); diff --git a/test/operators/test_cov_op.cpp b/test/operators/test_cov_op.cpp index 2fe7f3577bef42d26c349e9a24313518c05b9d2b..ba6a9b4800f8b2acb3a5c3b0992128bd4ea0e619 100644 --- a/test/operators/test_cov_op.cpp +++ b/test/operators/test_cov_op.cpp @@ -34,7 +34,7 @@ int main() { // static_cast(1)); auto out_ddim = paddle_mobile::framework::make_ddim({1, 64, 112, 112}); - auto output = executor.predict(input, "data", "conv2d_0.tmp_0", out_ddim); + auto output = executor.Predict(input, "data", "conv2d_0.tmp_0", out_ddim); auto output_ptr = output->data(); for (int j = 0; j < output->numel(); ++j) { diff --git a/test/operators/test_depthwise_conv_op.cpp b/test/operators/test_depthwise_conv_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..648b4c5db9970804a2ca140eef13e2560e36f935 --- /dev/null +++ b/test/operators/test_depthwise_conv_op.cpp @@ -0,0 +1,46 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../executor_for_test.h" +#include "../test_include.h" +#include "operators/depthwise_conv_op.h" + +int main() { + paddle_mobile::Loader loader; + // ../models/image_classification_resnet.inference.model + auto program = loader.Load(g_mobilenet_ssd); + + PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr, + "program file read fail"); + + Executor4Test> + executor(program, "depthwise_conv2d"); + + paddle_mobile::framework::LoDTensor input; + // GetInput(g_test_image_1x3x224x224, &input, {1, 3, 224, 224}); + // use SetupTensor if not has local input image . + SetupTensor(&input, {1, 32, 150, 150}, static_cast(0), + static_cast(1)); + auto input_ptr = input.data(); + auto out_ddim = paddle_mobile::framework::make_ddim({1, 32, 150, 150}); + auto output = executor.Predict(input, "batch_norm_0.tmp_3", + "depthwise_conv2d_0.tmp_0", out_ddim); + + auto output_ptr = output->data(); + for (int j = 0; j < output->numel(); ++j) { + DLOG << " value of output: " << output_ptr[j]; + } + return 0; +} diff --git a/test/operators/test_elementwise_add_op.cpp b/test/operators/test_elementwise_add_op.cpp index 1b4bf457a2ca7d4207ce3f9f0b20d68ee3f463e0..c4997f2eb37730e1af38fbe8aac927e7ee2b6ee0 100644 --- a/test/operators/test_elementwise_add_op.cpp +++ b/test/operators/test_elementwise_add_op.cpp @@ -50,7 +50,7 @@ int main() { auto out_ddim = paddle_mobile::framework::make_ddim({1, 3, 224, 224}); out_ddims.push_back(out_ddim); - auto output = executor.predict(input_tensors, input_names, + auto output = executor.Predict(input_tensors, input_names, output_names, out_ddims); auto output0_data = output[0]->data(); diff --git a/test/operators/test_fushion_fc_op.cpp b/test/operators/test_fushion_fc_op.cpp index 6063772d85a32af7cac166c9682a5c1e2d8ad1de..8dc1b02bec403d13b0b18f3fad58d7686ce403d0 100644 --- a/test/operators/test_fushion_fc_op.cpp +++ b/test/operators/test_fushion_fc_op.cpp @@ -116,7 +116,7 @@ int main() { DLOG << "begin to run Fc Test"; paddle_mobile::Loader loader; // "../../../test/models/googlenet" - auto program = loader.Load("../models/googlenet"); + auto program = loader.Load(g_googlenet); paddle_mobile::framework::ProgramOptimize optimize; // program.originProgram->Description("origin"); auto optimize_program = optimize.FushionOptimize(program.originProgram); diff --git a/test/operators/test_lrn_op.cpp b/test/operators/test_lrn_op.cpp index ba35639fb71668eef8d6b7bae454af5a9120a015..cf5fd4bdf2d45abcf63eb865f1cf333eeb14eafc 100644 --- a/test/operators/test_lrn_op.cpp +++ b/test/operators/test_lrn_op.cpp @@ -46,7 +46,7 @@ int main() { auto out_ddim = paddle_mobile::framework::make_ddim({3, 4, 2, 2}); out_ddims.push_back(out_ddim); - auto output = executor.predict(input_tensors, input_names, + auto output = executor.Predict(input_tensors, input_names, output_names, out_ddims); auto output0_data = output[0]->data(); diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 8acd4a99470b494df3a8931cfb3d140fdc39c4f0..5412e6905b7c12782555c7271c5da17713561469 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -50,7 +50,7 @@ int main() { auto out_ddim = paddle_mobile::framework::make_ddim({3, 3}); out_ddims.push_back(out_ddim); - auto output = executor.predict(input_tensors, input_names, + auto output = executor.Predict(input_tensors, input_names, output_names, out_ddims); auto output0_data = output[0]->data(); diff --git a/test/operators/test_pool_op.cpp b/test/operators/test_pool_op.cpp index 8a1c0a7ccecb7fe392428df7dbe5fb979a6cd510..62dfc20dc12006f86b16997cb6de96123e10ee9c 100644 --- a/test/operators/test_pool_op.cpp +++ b/test/operators/test_pool_op.cpp @@ -14,11 +14,11 @@ limitations under the License. */ #include "../executor_for_test.h" #include "../test_helper.h" -#include "common/io.h" +#include "io.h" int main() { paddle_mobile::Loader loader; - auto program = loader.Load(std::string("../models/googlenet")); + auto program = loader.Load(std::string(g_googlenet)); if (program.originProgram == nullptr) { DLOG << "program read file"; } @@ -32,7 +32,7 @@ int main() { static_cast(1)); auto out_ddim = paddle_mobile::framework::make_ddim({1, 64, 56, 56}); auto output = - executor.predict(input, "conv2d_0.tmp_1", "pool2d_0.tmp_0", out_ddim); + executor.Predict(input, "conv2d_0.tmp_1", "pool2d_0.tmp_0", out_ddim); float *output_ptr = output->data(); for (int j = 0; j < output->numel(); ++j) { diff --git a/test/operators/test_prior_box_op.cpp b/test/operators/test_prior_box_op.cpp index 80ede944936cb5ae31e2ed7e1e70c1257746149a..8c697a9a7982f05b71caa5bb5f4d12e50dc9d418 100644 --- a/test/operators/test_prior_box_op.cpp +++ b/test/operators/test_prior_box_op.cpp @@ -127,7 +127,7 @@ int main() { DLOG << "----------**********----------"; DLOG << "begin to run PriorBoxOp Test"; paddle_mobile::Loader loader; - auto program = loader.Load(std::string("../../test/models/mobilenet+ssd")); + auto program = loader.Load(std::string(g_mobilenet_ssd)); /// input x (1,3,300,300) paddle_mobile::framework::Tensor input_image; diff --git a/test/operators/test_relu_op.cpp b/test/operators/test_relu_op.cpp index fb68b9211136e4272f6774a423f93f8f1087b6e7..50f3b6a20b6244fcb39975c80cc6a6e14dc88d1c 100644 --- a/test/operators/test_relu_op.cpp +++ b/test/operators/test_relu_op.cpp @@ -46,7 +46,7 @@ int main() { auto out_ddim = paddle_mobile::framework::make_ddim({1, 2, 3, 4}); out_ddims.push_back(out_ddim); - auto output = executor.predict(input_tensors, input_names, + auto output = executor.Predict(input_tensors, input_names, output_names, out_ddims); auto output0_data = output[0]->data(); diff --git a/test/operators/test_reshape_op.cpp b/test/operators/test_reshape_op.cpp index b0251e693a736a934b771ee7d381c9b834e58528..5448aac87c23ea90f5b8beec24aee9cc6f437330 100644 --- a/test/operators/test_reshape_op.cpp +++ b/test/operators/test_reshape_op.cpp @@ -14,11 +14,11 @@ limitations under the License. */ #include "../executor_for_test.h" #include "../test_helper.h" -#include "common/io.h" +#include "io.h" int main() { paddle_mobile::Loader loader; - auto program = loader.Load(std::string("../../test/models/mobilenet+ssd")); + auto program = loader.Load(std::string(g_mobilenet_ssd)); if (program.originProgram == nullptr) { DLOG << "program read file"; } @@ -31,7 +31,7 @@ int main() { auto input_ptr = input.data(); auto out_ddim = paddle_mobile::framework::make_ddim({2, 9, 2}); auto output = - executor.predict(input, "transpose_0.tmp_0", "reshape_0.tmp_0", out_ddim); + executor.Predict(input, "transpose_0.tmp_0", "reshape_0.tmp_0", out_ddim); auto *output_ptr = output->data(); DLOG << "input : "; diff --git a/test/operators/test_sigmoid_op.cpp b/test/operators/test_sigmoid_op.cpp index dcd35cd8e468612b71947c283dd37156f33570fa..289eac149fa2d3e05f65624f8a9e5f93e85c6fff 100644 --- a/test/operators/test_sigmoid_op.cpp +++ b/test/operators/test_sigmoid_op.cpp @@ -14,7 +14,7 @@ limitations under the License. */ #include "../../src/operators/kernel/sigmoid_kernel.h" #include "../test_helper.h" -#include "common/io.h" +#include "io.h" int main() { paddle_mobile::framework::Tensor input; diff --git a/test/operators/test_softmax_op.cpp b/test/operators/test_softmax_op.cpp index 094c48adbb691a9f0a2f030f6a224fe2b452372a..58de5300cca0bf367652066851bc4e7e9f75389c 100644 --- a/test/operators/test_softmax_op.cpp +++ b/test/operators/test_softmax_op.cpp @@ -14,11 +14,11 @@ limitations under the License. */ #include "../executor_for_test.h" #include "../test_helper.h" -#include "common/io.h" +#include "io.h" int main() { paddle_mobile::Loader loader; - auto program = loader.Load(std::string("../models/mobilenet")); + auto program = loader.Load(std::string(g_mobilenet)); if (program.originProgram == nullptr) { DLOG << "program read file"; } @@ -30,7 +30,7 @@ int main() { static_cast(1)); auto out_ddim = paddle_mobile::framework::make_ddim({1, 1000}); auto output = - executor.predict(input, "reshape_0.tmp_0", "softmax_0.tmp_0", out_ddim); + executor.Predict(input, "reshape_0.tmp_0", "softmax_0.tmp_0", out_ddim); auto *output_ptr = output->data(); for (int j = 0; j < output->numel(); ++j) { DLOG << " value of output: " << output_ptr[j]; diff --git a/test/operators/test_transpose_op.cpp b/test/operators/test_transpose_op.cpp index 23e3bc3ec475655cc41b0faff119d21c5c904900..4c88df2d83dcfbc44915ced815b50f90ddb33b38 100644 --- a/test/operators/test_transpose_op.cpp +++ b/test/operators/test_transpose_op.cpp @@ -14,11 +14,11 @@ limitations under the License. */ #include "../executor_for_test.h" #include "../test_helper.h" -#include "common/io.h" +#include "io.h" int main() { paddle_mobile::Loader loader; - auto program = loader.Load(std::string("../../test/models/mobilenet+ssd")); + auto program = loader.Load(std::string(g_mobilenet_ssd)); if (program.originProgram == nullptr) { DLOG << "program read file"; } @@ -31,7 +31,7 @@ int main() { auto input_ptr = input.data(); auto out_ddim = paddle_mobile::framework::make_ddim({1, 3, 4, 2}); auto output = - executor.predict(input, "conv2d_22.tmp_1", "transpose_0.tmp_0", out_ddim); + executor.Predict(input, "conv2d_22.tmp_1", "transpose_0.tmp_0", out_ddim); auto *output_ptr = output->data(); DLOG << "input : "; diff --git a/test/test_include.h b/test/test_include.h index 19a9bff8846423ade8d3c8869d4014876d2f11ce..25efbb9f4c00921495a5ab054acdde329c4ef58a 100644 --- a/test/test_include.h +++ b/test/test_include.h @@ -20,7 +20,6 @@ limitations under the License. */ #include "./test_helper.h" #include "common/enforce.h" -#include "common/io.h" #include "common/log.h" #include "framework/lod_tensor.h" #include "framework/operator.h" @@ -30,3 +29,4 @@ limitations under the License. */ #include "framework/scope.h" #include "framework/tensor.h" #include "framework/variable.h" +#include "io.h"