未验证 提交 2a344823 编写于 作者: W Wilber 提交者: GitHub

add eltwise_activate fuse. test=develop (#3367)

* add eltwise_activate_fuse. test=develop
上级 06f77998
......@@ -37,7 +37,7 @@ USE_MIR_PASS(identity_dropout_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass);
USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass);
USE_MIR_PASS(type_layout_cast_pass);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/backends/cuda/math/elementwise.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -62,6 +63,52 @@ __global__ void elementwise_relu_kernel(const size_t total,
}
}
template <typename Dtype>
__global__ void elementwise_abs_kernel(const size_t total,
const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < total) {
int idx = tid / post % n;
Dtype temp;
#if __CUDA_ARCH__ >= 350
temp = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type);
#else
temp = binary_calc(x_data[tid], y_data[idx], type);
#endif
out_data[tid] = temp > 0 ? temp : -temp;
}
}
template <typename Dtype>
__global__ void elementwise_tanh_kernel(const size_t total,
const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < total) {
int idx = tid / post % n;
Dtype temp;
#if __CUDA_ARCH__ >= 350
temp = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type);
#else
temp = binary_calc(x_data[tid], y_data[idx], type);
#endif
out_data[tid] = tanh(temp);
}
}
template <typename Dtype>
__global__ void elementwise_add_kernel(const size_t total,
const Dtype* x_data,
......@@ -135,19 +182,30 @@ void elementwise(const Dtype* x_data,
}
template <typename Dtype>
void elementwise_relu(const Dtype* x_data,
void elementwise_act(const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
std::string act,
BinaryOperation type,
cudaStream_t stream) {
int num = pre * n * post;
int thread = 256;
int block = (num + thread - 1) / thread;
if (act == "relu") {
elementwise_relu_kernel<<<block, thread, 0, stream>>>(
num, x_data, y_data, out_data, pre, n, post, type);
} else if (act == "tanh") {
elementwise_tanh_kernel<<<block, thread, 0, stream>>>(
num, x_data, y_data, out_data, pre, n, post, type);
} else if (act == "abs") {
elementwise_abs_kernel<<<block, thread, 0, stream>>>(
num, x_data, y_data, out_data, pre, n, post, type);
} else {
LOG(FATAL) << "not supported activate type: " << act;
}
}
template void elementwise(const float*,
......@@ -159,14 +217,15 @@ template void elementwise(const float*,
BinaryOperation,
cudaStream_t);
template void elementwise_relu(const float*,
const float*,
float*,
int,
int,
int,
BinaryOperation,
cudaStream_t);
template void elementwise_act(const float* x_data,
const float* y_data,
float* out_data,
int pre,
int n,
int post,
std::string act,
BinaryOperation type,
cudaStream_t stream);
template <typename Dtype>
void elementwise_add(int num,
......
......@@ -15,6 +15,7 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <string>
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
......@@ -33,12 +34,13 @@ void elementwise(const Dtype* x_data,
cudaStream_t stream);
template <typename Dtype>
void elementwise_relu(const Dtype* x_data,
void elementwise_act(const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
std::string act,
BinaryOperation type,
cudaStream_t stream);
......
......@@ -22,20 +22,31 @@ namespace paddle {
namespace lite {
namespace mir {
void ElementwiseAddActivationFusePass::Apply(
void ElementwiseActivationFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::ElementwiseAddActivationFuser fuser("relu");
// initialze fuser params
std::vector<std::string> elt_types{
"elementwise_add", "elementwise_sub", "elementwise_mul"};
std::vector<std::string> act_types{"relu", "abs", "tanh"};
// start fuse using params
for (auto elt_type : elt_types) {
for (auto act_type : act_types) {
fusion::ElementwiseActivationFuser fuser(elt_type, act_type);
fuser(graph.get());
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass)
REGISTER_MIR_PASS(lite_elementwise_activation_fuse_pass,
paddle::lite::mir::ElementwiseActivationFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)})
.ExcludeTargets({TARGET(kBM)})
.ExcludeTargets({TARGET(kX86)})
.BindKernel("fusion_elementwise_add_activation");
.BindKernel("fusion_elementwise_add_activation")
.BindKernel("fusion_elementwise_sub_activation");
......@@ -22,7 +22,7 @@ namespace paddle {
namespace lite {
namespace mir {
class ElementwiseAddActivationFusePass : public ProgramPass {
class ElementwiseActivationFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
......
......@@ -21,21 +21,21 @@ namespace lite {
namespace mir {
namespace fusion {
void ElementwiseAddActivationFuser::BuildPattern() {
void ElementwiseActivationFuser::BuildPattern() {
// create input nodes.
auto* x = VarNode("x")->assert_is_op_input("elementwise_add", "X")->AsInput();
auto* y = VarNode("y")->assert_is_op_input("elementwise_add", "Y")->AsInput();
auto* x = VarNode("x")->assert_is_op_input(eltwise_type_, "X")->AsInput();
auto* y = VarNode("y")->assert_is_op_input(eltwise_type_, "Y")->AsInput();
// create op nodes
auto* add = OpNode("add", "elementwise_add")
->assert_is_op("elementwise_add")
auto* elt = OpNode("elt", eltwise_type_)
->assert_is_op(eltwise_type_)
->AsIntermediate();
auto* act =
OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate();
// create intermediate nodes
auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out")
auto* elt_out = VarNode("add_out")
->assert_is_op_output(eltwise_type_, "Out")
->assert_is_op_input(act_type_, "X")
->AsIntermediate();
......@@ -44,21 +44,29 @@ void ElementwiseAddActivationFuser::BuildPattern() {
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology.
std::vector<PMNode*> add_inputs{x, y};
add_inputs >> *add >> *add_out;
*add_out >> *act >> *out;
std::vector<PMNode*> elt_inputs{x, y};
elt_inputs >> *elt >> *elt_out;
*elt_out >> *act >> *out;
}
void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
void ElementwiseActivationFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto op =
LiteOpRegistry::Global().Create("fusion_elementwise_add_activation");
auto old_op = matched.at("add")->stmt()->op();
std::shared_ptr<lite::OpLite> op;
if (eltwise_type_ == "elementwise_add") {
op = LiteOpRegistry::Global().Create("fusion_elementwise_add_activation");
} else if (eltwise_type_ == "elementwise_sub") {
op = LiteOpRegistry::Global().Create("fusion_elementwise_sub_activation");
} else if (eltwise_type_ == "elementwise_mul") {
op = LiteOpRegistry::Global().Create("fusion_elementwise_mul_activation");
} else {
LOG(FATAL) << "not supported elementwise_type: " << eltwise_type_;
}
auto old_op = matched.at("elt")->stmt()->op();
auto* scope = old_op->scope();
auto& valid_places = old_op->valid_places();
op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
......@@ -66,12 +74,20 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc ElementwiseAddActivationFuser::GenOpDesc(
const key2nodes_t& matched) {
auto* desc = matched.at("add")->stmt()->op_info();
cpp::OpDesc ElementwiseActivationFuser::GenOpDesc(const key2nodes_t& matched) {
auto* desc = matched.at("elt")->stmt()->op_info();
cpp::OpDesc op_desc;
if (eltwise_type_ == "elementwise_add") {
op_desc.SetType("fusion_elementwise_add_activation");
} else if (eltwise_type_ == "elementwise_sub") {
op_desc.SetType("fusion_elementwise_sub_activation");
} else if (eltwise_type_ == "elementwise_mul") {
op_desc.SetType("fusion_elementwise_mul_activation");
} else {
LOG(FATAL) << "not supported elementwise_type: " << eltwise_type_;
}
op_desc.SetInput("X", {matched.at("x")->arg()->name});
op_desc.SetInput("Y", {matched.at("y")->arg()->name});
op_desc.SetOutput("Out", {matched.at("output")->arg()->name});
......
......@@ -23,15 +23,23 @@ namespace lite {
namespace mir {
namespace fusion {
class ElementwiseAddActivationFuser : public FuseBase {
// Detect elementwise and activation ops, and then merge into
// fusion_eltsiwise_act op.
// Example:
// elementwise_add + relu fuse.
// fusion::ElementwiseActivationFuser fuser("elementwise_add", "relu");
// fuser(graph.get());
class ElementwiseActivationFuser : public FuseBase {
public:
explicit ElementwiseAddActivationFuser(const std::string& act_type)
: act_type_(act_type) {}
explicit ElementwiseActivationFuser(const std::string& eltwise_type,
const std::string& act_type)
: eltwise_type_(eltwise_type), act_type_(act_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string eltwise_type_;
std::string act_type_;
};
......
......@@ -74,7 +74,7 @@ class Optimizer {
"lite_scale_activation_fuse_pass", //
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
(defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass", //
"lite_elementwise_activation_fuse_pass", //
#endif
"__xpu__resnet_fuse_pass",
"__xpu__multi_encoder_fuse_pass",
......
......@@ -70,7 +70,7 @@ inline bool is_broadcast(const DDim& x_dims,
return true;
}
#define ELEMENTWISE_COMPUTE(OP, WITH_RELU) \
#define ELEMENTWISE_COMPUTE(OP) \
auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
......@@ -85,25 +85,66 @@ inline bool is_broadcast(const DDim& x_dims,
int pre = 1; \
int n = pixel_num; \
int post = 1; \
if (WITH_RELU) { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise_relu( \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise_relu( \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
}
#define ELEMENTWISE_COMPUTE_ACT(OP) \
auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
const lite::Tensor* x = param.X; \
const lite::Tensor* y = param.Y; \
lite::Tensor* out = param.Out; \
int axis = param.axis; \
auto* x_data = x->data<float>(); \
auto* y_data = y->data<float>(); \
auto out_data = out->mutable_data<float>(TARGET(kCUDA)); \
int pixel_num = x->numel(); \
int pre = 1; \
int n = pixel_num; \
int post = 1; \
auto act = param.act_type; \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise_act( \
x_data, y_data, out_data, pre, n, post, act, OP, stream); \
} else { \
lite::cuda::math::elementwise_act( \
x_data, y_data, out_data, 1, pixel_num, 1, act, OP, stream); \
}
#define ELEMENTWISE_COMPUTE_NHWC(OP) \
std::map<int, int> pos_map = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; \
auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
const lite::Tensor* x = param.X; \
const lite::Tensor* y = param.Y; \
lite::Tensor* out = param.Out; \
int axis = param.axis; \
if (axis < 0) axis = x->dims().size() - y->dims().size(); \
CHECK(axis >= 0) << "invalid axis of elementwise op"; \
axis = pos_map[axis]; \
auto* x_data = x->data<float>(); \
auto* y_data = y->data<float>(); \
auto out_data = out->mutable_data<float>(TARGET(kCUDA)); \
int pixel_num = x->numel(); \
int pre = 1; \
int n = pixel_num; \
int post = 1; \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
}
#define ELEMENTWISE_COMPUTE_NHWC(OP, WITH_RELU) \
#define ELEMENTWISE_COMPUTE_ACT_NHWC(OP) \
std::map<int, int> pos_map = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; \
auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
......@@ -122,80 +163,83 @@ inline bool is_broadcast(const DDim& x_dims,
int pre = 1; \
int n = pixel_num; \
int post = 1; \
if (WITH_RELU) { \
auto act = param.act_type; \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise_relu( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise_relu( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
lite::cuda::math::elementwise_act( \
x_data, y_data, out_data, pre, n, post, act, OP, stream); \
} else { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
lite::cuda::math::elementwise_act( \
x_data, y_data, out_data, 1, pixel_num, 1, act, OP, stream); \
}
void ElementwiseAddCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, false)
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, false)
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseSubCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kSUB, false)
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kSUB)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseSubComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kSUB, false)
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kSUB)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, false)
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, false)
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddReluCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, true)
void ElementwiseAddActivationCompute::Run() {
ELEMENTWISE_COMPUTE_ACT(lite::cuda::math::BinaryOperation::kADD)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddReluComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, true)
void ElementwiseAddActivationComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_ACT_NHWC(lite::cuda::math::BinaryOperation::kADD)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulReluCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, true)
void ElementwiseSubActivationCompute::Run() {
ELEMENTWISE_COMPUTE_ACT(lite::cuda::math::BinaryOperation::kSUB)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulReluComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, true)
void ElementwiseSubActivationComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_ACT_NHWC(lite::cuda::math::BinaryOperation::kSUB)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulActivationCompute::Run() {
ELEMENTWISE_COMPUTE_ACT(lite::cuda::math::BinaryOperation::kMUL)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulActivationComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_ACT_NHWC(lite::cuda::math::BinaryOperation::kMUL)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
......@@ -298,22 +342,57 @@ REGISTER_LITE_KERNEL(elementwise_mul,
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_add_activation,
REGISTER_LITE_KERNEL(
fusion_elementwise_add_activation,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseAddActivationCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(
fusion_elementwise_add_activation,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseAddActivationComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
fusion_elementwise_sub_activation,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseAddReluCompute,
paddle::lite::kernels::cuda::ElementwiseSubActivationCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_add_activation,
REGISTER_LITE_KERNEL(
fusion_elementwise_sub_activation,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseAddReluComputeNHWC,
paddle::lite::kernels::cuda::ElementwiseSubActivationComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
......@@ -329,22 +408,24 @@ REGISTER_LITE_KERNEL(fusion_elementwise_add_activation,
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
REGISTER_LITE_KERNEL(
fusion_elementwise_mul_activation,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseMulReluCompute,
paddle::lite::kernels::cuda::ElementwiseMulActivationCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
REGISTER_LITE_KERNEL(
fusion_elementwise_mul_activation,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseMulReluComputeNHWC,
paddle::lite::kernels::cuda::ElementwiseMulActivationComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
......
......@@ -74,40 +74,58 @@ class ElementwiseMulComputeNHWC
virtual ~ElementwiseMulComputeNHWC() = default;
};
class ElementwiseAddReluCompute
class ElementwiseAddActivationCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseAddReluCompute() = default;
virtual ~ElementwiseAddActivationCompute() = default;
};
class ElementwiseAddReluComputeNHWC
class ElementwiseAddActivationComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseAddReluComputeNHWC() = default;
virtual ~ElementwiseAddActivationComputeNHWC() = default;
};
class ElementwiseMulReluCompute
class ElementwiseSubActivationCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseMulReluCompute() = default;
virtual ~ElementwiseSubActivationCompute() = default;
};
class ElementwiseMulReluComputeNHWC
class ElementwiseSubActivationComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseMulReluComputeNHWC() = default;
virtual ~ElementwiseSubActivationComputeNHWC() = default;
};
class ElementwiseMulActivationCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseMulActivationCompute() = default;
};
class ElementwiseMulActivationComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseMulActivationComputeNHWC() = default;
};
} // namespace cuda
......
......@@ -44,8 +44,6 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc,
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.axis = opdesc.GetAttr<int>("axis");
param_.act_type = opdesc.GetAttr<std::string>("act_type");
// TODO(sangoly): support more activation types.
CHECK(param_.act_type == "relu") << "Only relu activation be supported now";
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册