提交 c68dd38e 编写于 作者: E eclipsess

conv add bn relu

上级 8f97ff2f
......@@ -138,9 +138,21 @@ class OpKernelBase {
* @p para 这个参数为 kernel 运算时所需要用到参数组成的一个结构体,
* 所有结构体存在与: paddle-mobile/src/operators/op_param.h
* */
#ifdef PADDLE_MOBILE_MALI_GPU
OpKernelBase() { acl_op_ = nullptr; }
void *GetAclOp() const { return acl_op_; }
void SetAclOp(void *op, void *ob) const {
reinterpret_cast<OpKernelBase<Dtype, P> *>(ob)->acl_op_ = op;
}
#endif
virtual void Compute(const P &para) const = 0;
virtual bool Init(const P &para) const { return true; };
virtual ~OpKernelBase() = default;
private:
#ifdef PADDLE_MOBILE_MALI_GPU
void *acl_op_;
#endif
};
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/program/program-optimize/node.h"
#include <algorithm>
#include "framework/operator.h"
namespace paddle_mobile {
......@@ -92,7 +93,8 @@ int Node::Depth(int begin) {
Node &Node::Folder(
int size, std::string type,
std::map<std::string, std::pair<std::string, std::string>> change,
std::map<std::string, std::vector<std::pair<std::string, std::string>>>
change,
std::vector<std::shared_ptr<Node>> *removed_nodes) {
std::shared_ptr<framework::OpDesc> op_desc =
std::make_shared<framework::OpDesc>();
......@@ -109,12 +111,15 @@ Node &Node::Folder(
void Node::Folder(
std::shared_ptr<framework::OpDesc> op_desc,
std::vector<std::shared_ptr<Node>> *outputs, int index,
std::map<std::string, std::pair<std::string, std::string>> *change,
std::map<std::string, std::vector<std::pair<std::string, std::string>>>
*change,
Node *begin_node, std::vector<std::shared_ptr<Node>> *removed_nodes) {
if (change->find(this->type_) != change->end()) {
auto change_pair = (*change)[this->type_];
op_desc->GetInputs()[change_pair.second] =
this->op_desc_->GetInputs()[change_pair.first];
auto change_pairs = (*change)[this->type_];
for (const auto &change_pair : change_pairs) {
op_desc->GetInputs()[change_pair.second] =
this->op_desc_->GetInputs()[change_pair.first];
}
}
for (auto &attr_pair : this->op_desc_->attrs_) {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <cinttypes>
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "common/log.h"
#include "framework/program/op_desc.h"
......@@ -43,7 +44,8 @@ class Node {
int Depth(int begin = 0);
Node &Folder(
int size, std::string type,
std::map<std::string, std::pair<std::string, std::string>> change_map,
std::map<std::string, std::vector<std::pair<std::string, std::string>>>
change,
std::vector<std::shared_ptr<Node>> *removed_nodes);
std::vector<std::shared_ptr<framework::OpDesc>> OpDescs(int size);
std::shared_ptr<framework::OpDesc> OpDescOfNode() { return op_desc_; }
......@@ -56,7 +58,8 @@ class Node {
void Folder(
std::shared_ptr<framework::OpDesc> op_desc,
std::vector<std::shared_ptr<Node>> *outputs, int index,
std::map<std::string, std::pair<std::string, std::string>> *change,
std::map<std::string, std::vector<std::pair<std::string, std::string>>>
*change,
Node *begin_node, std::vector<std::shared_ptr<Node>> *removed_nodes);
std::shared_ptr<framework::OpDesc> op_desc_;
#ifdef PADDLE_MOBILE_DEBUG
......
......@@ -50,6 +50,8 @@ USE_OP_CPU(feed);
REGISTER_OPERATOR_CPU(feed, ops::FeedOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
USE_OP_MALI_GPU(feed);
REGISTER_OPERATOR_MALI_GPU(feed, ops::FeedOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
......
......@@ -50,6 +50,8 @@ USE_OP_CPU(fetch);
REGISTER_OPERATOR_CPU(fetch, ops::FetchOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
USE_OP_MALI_GPU(fetch);
REGISTER_OPERATOR_MALI_GPU(fetch, ops::FetchOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
......
......@@ -54,6 +54,8 @@ USE_OP_CPU(conv_add);
REGISTER_OPERATOR_CPU(conv_add, ops::FusionConvAddOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
USE_OP_MALI_GPU(conv_add);
REGISTER_OPERATOR_MALI_GPU(conv_add, ops::FusionConvAddOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
......
......@@ -40,7 +40,7 @@ class FusionConvAddMatcher : public framework::FusionOpMatcher {
vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node->OpDescs(node_.Depth());
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Y"}}}, removed_nodes);
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes);
}
std::string Type() { return G_OP_TYPE_CONV_ADD; }
......@@ -68,11 +68,13 @@ class FusionConvAddOp : public framework::OperatorWithKernel<
};
#ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_REGISTER
static framework::FusionOpRegistrar convadd_registrar(
new FusionConvAddMatcher());
#define CONV_ADD_REGISTER
#endif
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
......
......@@ -36,7 +36,7 @@ class FusionConvAddReluOpMatcher : public framework::FusionOpMatcher {
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Y"}}}, removed_nodes);
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_RELU; }
};
......@@ -65,11 +65,11 @@ class FusionConvAddReluOp : public framework::OperatorWithKernel<
#ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_RELU_REGISTER
#define CONV_ADD_RELU_REGISTER
//#ifndef CONV_ADD_RELU_REGISTER
//#define CONV_ADD_RELU_REGISTER
// static framework::FusionOpRegistrar fusion_conv_add_relu_registrar(new
// FusionConvAddReluOpMatcher());
#endif
//#endif
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
......
......@@ -38,7 +38,7 @@ class FusionFcMatcher : public framework::FusionOpMatcher {
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}, removed_nodes);
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Z"}}}}, removed_nodes);
}
std::string Type() { return G_OP_TYPE_FC; }
......@@ -66,17 +66,21 @@ class FusionFcOp
};
#ifdef PADDLE_MOBILE_CPU
#ifndef CONV_CPU_REGISTER
#define CONV_CPU_REGISTER
static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher());
#endif
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#ifndef CONV_CPU_REGISTER
#define CONV_CPU_REGISTER
static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher());
#endif
#endif
#ifdef PADDLE_MOBILE_FPGA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册