提交 de2a72a4 编写于 作者: S sangoly

add elementwise_add_activation fuse pass & op

上级 09b35192
......@@ -7,7 +7,8 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory(fusion)
cc_library(mir_passes
SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc
conv_elementwise_add_activation_fuse_pass.cc
elementwise_add_activation_fuse_pass.cc
conv_bn_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
......@@ -82,7 +83,11 @@ lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz
add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
lite_cc_test(test_lite_conv_elementwise_add_relu_fuse
SRCS conv_elementwise_add_relu_fuse_pass_test.cc
lite_cc_test(test_lite_conv_elementwise_add_activation_fuse
SRCS conv_elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels})
lite_cc_test(test_lite_elementwise_add_activation_fuse
SRCS elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels})
......@@ -12,22 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvElementwiseAddReLUFusePass::Apply(
void ConvElementwiseAddActivationFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseAddReLUFuser fuser("conv2d");
fusion::ConvElementwiseAddActivationFuser fuser("conv2d", "relu");
fuser(graph.get());
fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d");
fusion::ConvElementwiseAddActivationFuser depthwise_fuser("depthwise_conv2d",
"relu");
depthwise_fuser(graph.get());
}
......@@ -35,5 +36,5 @@ void ConvElementwiseAddReLUFusePass::Apply(
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass,
paddle::lite::mir::ConvElementwiseAddReLUFusePass);
REGISTER_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ConvElementwiseAddActivationFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvElementwiseAddActivationFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
......@@ -135,11 +135,11 @@ TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) {
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
const int num_nodes = graph->nodes().size();
auto* fuser = new ConvElementwiseAddReLUFusePass;
auto* fuser = new ConvElementwiseAddActivationFusePass;
fuser->Apply(graph);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ +
1UL * 2 /* fused fc node*/);
ASSERT_EQ(graph->nodes().size(),
num_nodes - 5UL * 2 /*nodes removed */ + 1UL * 2 /* fused nodes*/);
}
} // namespace fusion
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ElementwiseAddActivationFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::ElementwiseAddActivationFuser fuser("relu");
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass);
......@@ -22,7 +22,7 @@ namespace paddle {
namespace lite {
namespace mir {
class ConvElementwiseAddReLUFusePass : public ProgramPass {
class ElementwiseAddActivationFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* add_1 = main_block->AppendOp();
auto* add_2 = main_block->AppendOp();
auto* relu_1 = main_block->AppendOp();
auto* relu_2 = main_block->AppendOp();
main_block->Var("x_1");
main_block->Var("y_1");
main_block->Var("add_out_1");
main_block->Var("relu_out_1");
main_block->Var("y_2");
main_block->Var("add_out_2");
main_block->Var("out");
scope->Var("x_1")->GetMutable<lite::Tensor>();
scope->Var("y_1")->GetMutable<lite::Tensor>();
scope->Var("add_out_1")->GetMutable<lite::Tensor>();
scope->Var("relu_out_1")->GetMutable<lite::Tensor>();
scope->Var("y_2")->GetMutable<lite::Tensor>();
scope->Var("add_out_2")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>();
add_1->SetType("elementwise_add");
add_1->SetInput("X", {"x_1"});
add_1->SetInput("Y", {"y_1"});
add_1->SetOutput("Out", {"add_out_1"});
add_1->SetAttr("axis", 1);
relu_1->SetType("relu");
relu_1->SetInput("X", {"add_out_1"});
relu_1->SetOutput("Out", {"relu_out_1"});
add_2->SetType("elementwise_add");
add_2->SetInput("X", {"relu_out_1"});
add_2->SetInput("Y", {"y_2"});
add_2->SetOutput("Out", {"add_out_2"});
add_2->SetAttr("axis", 1);
relu_2->SetType("relu");
relu_2->SetInput("X", {"add_out_2"});
relu_2->SetOutput("Out", {"out"});
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(elementwise_add_activation_fuse_pass, graph_test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
ASSERT_EQ(graph->nodes().size(),
7UL /*vars*/ + 4UL /*ops*/ + 1UL /* SSAGraph tmp node*/);
}
TEST(elementwise_add_activation_fuse_pass, fuse_test_op) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
const int num_nodes = graph->nodes().size();
auto* fuser = new ElementwiseAddActivationFusePass;
fuser->Apply(graph);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(),
num_nodes - 3UL * 2 /*nodes removed */ + 1UL * 2 /* fused nodes*/);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(elementwise_add);
USE_LITE_OP(fusion_elementwise_add_activation);
USE_LITE_OP(relu);
cc_library(fuse_fc
SRCS fc_fuser.cc
DEPS pattern_matcher_high_api)
cc_library(fuse_conv_elementwise_add_relu
SRCS conv_elementwise_add_relu_fuser.cc
cc_library(fuse_conv_elementwise_add_activation
SRCS conv_elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api)
cc_library(fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api)
cc_library(fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
fuse_conv_elementwise_add_relu
fuse_conv_elementwise_add_activation
fuse_conv_bn
fuse_elementwise_add_activation
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
#include <memory>
#include <vector>
......@@ -21,7 +21,7 @@ namespace lite {
namespace mir {
namespace fusion {
void ConvElementwiseAddReLUFuser::BuildPattern() {
void ConvElementwiseAddActivationFuser::BuildPattern() {
// create input nodes.
auto* input =
VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput();
......@@ -36,7 +36,8 @@ void ConvElementwiseAddReLUFuser::BuildPattern() {
auto* add = OpNode("add", "elementwise_add")
->assert_is_op("elementwise_add")
->AsIntermediate();
auto* relu = OpNode("relu", "relu")->assert_is_op("relu")->AsIntermediate();
auto* act =
OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate();
// create intermediate nodes
auto* conv2d_out = VarNode("conv2d_out")
......@@ -45,22 +46,23 @@ void ConvElementwiseAddReLUFuser::BuildPattern() {
->AsIntermediate();
auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("relu", "X")
->assert_is_op_input(act_type_, "X")
->AsIntermediate();
// create output node
auto* out = VarNode("output")->assert_is_op_output("relu", "Out")->AsOutput();
auto* out =
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology.
std::vector<PMNode*> conv2d_inputs{filter, input};
std::vector<PMNode*> add_inputs{conv2d_out, bias};
conv2d_inputs >> *conv2d >> *conv2d_out;
add_inputs >> *add >> *add_out;
*add_out >> *relu >> *out;
*add_out >> *act >> *out;
}
void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
void ConvElementwiseAddActivationFuser::InsertNewNode(
SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op;
......@@ -76,7 +78,8 @@ void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc ConvElementwiseAddActivationFuser::GenOpDesc(
const key2nodes_t& matched) {
auto* desc = matched.at("conv2d")->stmt()->op_info();
cpp::OpDesc op_desc;
......@@ -98,6 +101,7 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetAttr("paddings", desc->GetAttr<std::vector<int>>("paddings"));
op_desc.SetAttr("groups", desc->GetAttr<int>("groups"));
op_desc.SetAttr("dilations", desc->GetAttr<std::vector<int>>("dilations"));
// TODO(sangoly): support other activation types
op_desc.SetAttr("fuse_relu", true);
return op_desc;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ConvElementwiseAddActivationFuser : public FuseBase {
public:
explicit ConvElementwiseAddActivationFuser(const std::string& conv_type,
const std::string& act_type) {
CHECK(act_type == "relu") << "Only relu activation be supported now";
conv_type_ = conv_type;
act_type_ = act_type;
}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string conv_type_;
std::string act_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ElementwiseAddActivationFuser::BuildPattern() {
// create input nodes.
auto* x = VarNode("x")->assert_is_op_input("elementwise_add", "X")->AsInput();
auto* y = VarNode("y")->assert_is_op_input("elementwise_add", "Y")->AsInput();
// create op nodes
auto* add = OpNode("add", "elementwise_add")
->assert_is_op("elementwise_add")
->AsIntermediate();
auto* act =
OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate();
// create intermediate nodes
auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input(act_type_, "X")
->AsIntermediate();
// create output node
auto* out =
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology.
std::vector<PMNode*> add_inputs{x, y};
add_inputs >> *add >> *add_out;
*add_out >> *act >> *out;
}
void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto op =
LiteOpRegistry::Global().Create("fusion_elementwise_add_activation");
auto old_op = matched.at("add")->stmt()->op;
auto* scope = old_op->scope();
auto& valid_places = old_op->valid_places();
op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("y"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc ElementwiseAddActivationFuser::GenOpDesc(
const key2nodes_t& matched) {
auto* desc = matched.at("add")->stmt()->op_info();
cpp::OpDesc op_desc;
op_desc.SetType("fusion_elementwise_add_activation");
op_desc.SetInput("X", {matched.at("x")->arg()->name});
op_desc.SetInput("Y", {matched.at("y")->arg()->name});
op_desc.SetOutput("Out", {matched.at("output")->arg()->name});
op_desc.SetAttr("axis", desc->GetAttr<int>("axis"));
op_desc.SetAttr("act_type", act_type_);
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -23,16 +23,16 @@ namespace lite {
namespace mir {
namespace fusion {
class ConvElementwiseAddReLUFuser : public FuseBase {
class ElementwiseAddActivationFuser : public FuseBase {
public:
explicit ConvElementwiseAddReLUFuser(const std::string& conv_type)
: conv_type_(conv_type) {}
explicit ElementwiseAddActivationFuser(const std::string& act_type)
: act_type_(act_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string conv_type_;
std::string act_type_;
};
} // namespace fusion
......
......@@ -34,4 +34,5 @@ USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(graph_visualze);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass);
USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass);
......@@ -49,8 +49,11 @@ class Optimizer {
if (passes.empty()) {
RunPasses(std::vector<std::string>{{
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_add_act_fuse_pass", //
"lite_conv_elementwise_add_activation_fuse_pass", //
"lite_fc_fuse_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass", //
#endif
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
......
......@@ -14,6 +14,7 @@ cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS})
cc_library(activation_ops_lite SRCS activation_ops.cc DEPS ${op_DEPS})
cc_library(elementwise_ops_lite SRCS elementwise_ops.cc DEPS ${op_DEPS})
cc_library(fusion_elementwise_activation_ops_lite SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops_lite ${op_DEPS})
cc_library(mean_op_lite SRCS mean_op.cc DEPS ${op_DEPS})
cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
#cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS})
......@@ -36,6 +37,7 @@ set(ops_lite
fetch_op_lite
io_copy_op_lite
elementwise_ops_lite
fusion_elementwise_activation_ops_lite
mean_op_lite
fill_constant_op_lite
activation_ops_lite
......@@ -56,3 +58,6 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite)
lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite)
lite_cc_test(test_fusion_elementwise_activation_ops_lite
SRCS fusion_elementwise_activation_ops_test.cc
DEPS fusion_elementwise_activation_ops_lite memory_lite)
......@@ -12,31 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/operators/elementwise_ops.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class ElementwiseOp : public OpLite {
public:
explicit ElementwiseOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
bool ElementwiseOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.Out);
return true;
}
}
bool InferShape() const override {
bool ElementwiseOp::InferShape() const {
CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size());
param_.Out->Resize(param_.X->dims());
return true;
}
}
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto X_name = opdesc.Input("X").front();
auto Y_name = opdesc.Input("Y").front();
auto Out_name = opdesc.Output("Out").front();
......@@ -46,36 +42,25 @@ class ElementwiseOp : public OpLite {
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.axis = opdesc.GetAttr<int>("axis");
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "elementwise_op"; }
private:
mutable operators::ElementwiseParam param_;
};
}
#ifdef LITE_WITH_X86
class ElementwiseGradExplicitOp : public OpLite {
public:
explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
bool ElementwiseGradExplicitOp::CheckShape() const {
CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.X_grad);
CHECK_OR_FALSE(param_.Y_grad);
CHECK_OR_FALSE(param_.Out_grad);
return true;
}
}
bool InferShape() const override {
bool ElementwiseGradExplicitOp::InferShape() const {
param_.X_grad->Resize(param_.Out_grad->dims());
param_.Y_grad->Resize(param_.Y->dims());
return true;
}
}
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
bool ElementwiseGradExplicitOp::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
CHECK_EQ(opdesc.InputArgumentNames().size(), 1UL);
auto Out_name = opdesc.Input(framework::GradVarName("Out")).front();
auto X_name = opdesc.Output(framework::GradVarName("X")).front();
......@@ -87,17 +72,7 @@ class ElementwiseGradExplicitOp : public OpLite {
param_.axis = opdesc.GetAttr<int>("axis");
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "elementwise_grad_explicit_op";
}
private:
mutable operators::ElementwiseGradParam param_;
};
}
#endif
} // namespace operators
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class ElementwiseOp : public OpLite {
public:
explicit ElementwiseOp(const std::string& op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "elementwise_op"; }
private:
mutable operators::ElementwiseParam param_;
};
#ifdef LITE_WITH_X86
class ElementwiseGradExplicitOp : public OpLite {
public:
explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "elementwise_grad_explicit_op";
}
private:
mutable operators::ElementwiseGradParam param_;
};
#endif
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h"
#include <string>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
ElementwiseOp::AttachImpl(opdesc, scope);
param_.act_type = opdesc.GetAttr<std::string>("act_type");
// TODO(sangoly): support more activation types.
CHECK(param_.act_type == "relu") << "Only relu activation be supported now";
return true;
}
#ifdef LITE_WITH_X86
bool FusionElementwiseActivationGradExplicitOp::AttachImpl(
const cpp::OpDesc& opdesc, lite::Scope* scope) {
ElementwiseGradExplicitOp::AttachImpl(opdesc, scope);
param_.act_type = opdesc.GetAttr<std::string>("act_type");
// TODO(sangoly): support more activation types.
CHECK(param_.act_type == "relu") << "Only relu activation be supported now";
return true;
}
#endif
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fusion_elementwise_sub_activation,
paddle::lite::operators::FusionElementwiseActivationOp);
#ifdef LITE_WITH_X86
REGISTER_LITE_OP(
fusion_elementwise_sub_activation_grad,
paddle::lite::operators::FusionElementwiseActivationGradExplicitOp);
#endif
REGISTER_LITE_OP(fusion_elementwise_add_activation,
paddle::lite::operators::FusionElementwiseActivationOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/operators/elementwise_ops.h"
namespace paddle {
namespace lite {
namespace operators {
class FusionElementwiseActivationOp : public ElementwiseOp {
public:
explicit FusionElementwiseActivationOp(const std::string& type)
: ElementwiseOp(type) {}
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
std::string DebugString() const override {
return "fusion_elementwise_activation_op";
}
private:
mutable operators::FusionElementwiseActivationParam param_;
};
#ifdef LITE_WITH_X86
class FusionElementwiseActivationGradExplicitOp
: public ElementwiseGradExplicitOp {
public:
explicit FusionElementwiseActivationGradExplicitOp(const std::string& type)
: ElementwiseGradExplicitOp(type) {}
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
std::string DebugString() const override {
return "fusion_elementwise_activation_grad_explicit_op";
}
private:
mutable operators::FusionElementwiseActivationGradParam param_;
};
#endif
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(fusion_elementwise_activation_op_lite, test) {
// prepare variables
lite::Scope scope;
auto* x = scope.Var("x")->GetMutable<lite::Tensor>();
auto* y = scope.Var("y")->GetMutable<lite::Tensor>();
auto* out = scope.Var("out")->GetMutable<lite::Tensor>();
x->Resize(lite::DDim(std::vector<int64_t>({10, 20})));
y->Resize(lite::DDim(std::vector<int64_t>({10, 20})));
out->Resize(lite::DDim(std::vector<int64_t>{10, 20}));
// set data
for (int i = 0; i < 10 * 20; i++) {
x->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
y->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
out->mutable_data<float>()[i] = 0.;
}
// prepare op desc
cpp::OpDesc desc;
desc.SetType("fusion_elementwise_add_activation");
desc.SetInput("X", {"x"});
desc.SetInput("Y", {"y"});
desc.SetOutput("Out", {"out"});
desc.SetAttr("axis", static_cast<int>(1));
desc.SetAttr("act_type", std::string("relu"));
FusionElementwiseActivationOp fuse_op("fusion_elementwise_add_activation");
fuse_op.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}});
fuse_op.Attach(desc, &scope);
}
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -219,6 +219,14 @@ struct ElementwiseGradParam {
int axis{-1}; // for broadcasting.
};
struct FusionElementwiseActivationParam : public ElementwiseParam {
std::string act_type;
};
struct FusionElementwiseActivationGradParam : public ElementwiseGradParam {
std::string act_type;
};
/// ----------------------- activation operators ----------------------
struct ActivationParam {
const lite::Tensor* X{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册