未验证 提交 2a344823 编写于 作者: W Wilber 提交者: GitHub

add eltwise_activate fuse. test=develop (#3367)

* add eltwise_activate_fuse. test=develop
上级 06f77998
...@@ -37,7 +37,7 @@ USE_MIR_PASS(identity_dropout_eliminate_pass); ...@@ -37,7 +37,7 @@ USE_MIR_PASS(identity_dropout_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass); USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass);
USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass); USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_precision_cast_pass);
USE_MIR_PASS(type_layout_cast_pass); USE_MIR_PASS(type_layout_cast_pass);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/backends/cuda/math/elementwise.h" #include "lite/backends/cuda/math/elementwise.h"
#include "lite/utils/cp_logging.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -62,6 +63,52 @@ __global__ void elementwise_relu_kernel(const size_t total, ...@@ -62,6 +63,52 @@ __global__ void elementwise_relu_kernel(const size_t total,
} }
} }
template <typename Dtype>
__global__ void elementwise_abs_kernel(const size_t total,
const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < total) {
int idx = tid / post % n;
Dtype temp;
#if __CUDA_ARCH__ >= 350
temp = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type);
#else
temp = binary_calc(x_data[tid], y_data[idx], type);
#endif
out_data[tid] = temp > 0 ? temp : -temp;
}
}
template <typename Dtype>
__global__ void elementwise_tanh_kernel(const size_t total,
const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < total) {
int idx = tid / post % n;
Dtype temp;
#if __CUDA_ARCH__ >= 350
temp = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type);
#else
temp = binary_calc(x_data[tid], y_data[idx], type);
#endif
out_data[tid] = tanh(temp);
}
}
template <typename Dtype> template <typename Dtype>
__global__ void elementwise_add_kernel(const size_t total, __global__ void elementwise_add_kernel(const size_t total,
const Dtype* x_data, const Dtype* x_data,
...@@ -135,19 +182,30 @@ void elementwise(const Dtype* x_data, ...@@ -135,19 +182,30 @@ void elementwise(const Dtype* x_data,
} }
template <typename Dtype> template <typename Dtype>
void elementwise_relu(const Dtype* x_data, void elementwise_act(const Dtype* x_data,
const Dtype* y_data, const Dtype* y_data,
Dtype* out_data, Dtype* out_data,
int pre, int pre,
int n, int n,
int post, int post,
std::string act,
BinaryOperation type, BinaryOperation type,
cudaStream_t stream) { cudaStream_t stream) {
int num = pre * n * post; int num = pre * n * post;
int thread = 256; int thread = 256;
int block = (num + thread - 1) / thread; int block = (num + thread - 1) / thread;
if (act == "relu") {
elementwise_relu_kernel<<<block, thread, 0, stream>>>( elementwise_relu_kernel<<<block, thread, 0, stream>>>(
num, x_data, y_data, out_data, pre, n, post, type); num, x_data, y_data, out_data, pre, n, post, type);
} else if (act == "tanh") {
elementwise_tanh_kernel<<<block, thread, 0, stream>>>(
num, x_data, y_data, out_data, pre, n, post, type);
} else if (act == "abs") {
elementwise_abs_kernel<<<block, thread, 0, stream>>>(
num, x_data, y_data, out_data, pre, n, post, type);
} else {
LOG(FATAL) << "not supported activate type: " << act;
}
} }
template void elementwise(const float*, template void elementwise(const float*,
...@@ -159,14 +217,15 @@ template void elementwise(const float*, ...@@ -159,14 +217,15 @@ template void elementwise(const float*,
BinaryOperation, BinaryOperation,
cudaStream_t); cudaStream_t);
template void elementwise_relu(const float*, template void elementwise_act(const float* x_data,
const float*, const float* y_data,
float*, float* out_data,
int, int pre,
int, int n,
int, int post,
BinaryOperation, std::string act,
cudaStream_t); BinaryOperation type,
cudaStream_t stream);
template <typename Dtype> template <typename Dtype>
void elementwise_add(int num, void elementwise_add(int num,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <string>
#include "lite/backends/cuda/math/utils.h" #include "lite/backends/cuda/math/utils.h"
namespace paddle { namespace paddle {
...@@ -33,12 +34,13 @@ void elementwise(const Dtype* x_data, ...@@ -33,12 +34,13 @@ void elementwise(const Dtype* x_data,
cudaStream_t stream); cudaStream_t stream);
template <typename Dtype> template <typename Dtype>
void elementwise_relu(const Dtype* x_data, void elementwise_act(const Dtype* x_data,
const Dtype* y_data, const Dtype* y_data,
Dtype* out_data, Dtype* out_data,
int pre, int pre,
int n, int n,
int post, int post,
std::string act,
BinaryOperation type, BinaryOperation type,
cudaStream_t stream); cudaStream_t stream);
......
...@@ -22,20 +22,31 @@ namespace paddle { ...@@ -22,20 +22,31 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
void ElementwiseAddActivationFusePass::Apply( void ElementwiseActivationFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) { const std::unique_ptr<SSAGraph>& graph) {
fusion::ElementwiseAddActivationFuser fuser("relu"); // initialze fuser params
std::vector<std::string> elt_types{
"elementwise_add", "elementwise_sub", "elementwise_mul"};
std::vector<std::string> act_types{"relu", "abs", "tanh"};
// start fuse using params
for (auto elt_type : elt_types) {
for (auto act_type : act_types) {
fusion::ElementwiseActivationFuser fuser(elt_type, act_type);
fuser(graph.get()); fuser(graph.get());
}
}
} }
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, REGISTER_MIR_PASS(lite_elementwise_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass) paddle::lite::mir::ElementwiseActivationFusePass)
.BindTargets({TARGET(kAny)}) .BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)}) .ExcludeTargets({TARGET(kXPU)})
.ExcludeTargets({TARGET(kBM)}) .ExcludeTargets({TARGET(kBM)})
.ExcludeTargets({TARGET(kX86)}) .ExcludeTargets({TARGET(kX86)})
.BindKernel("fusion_elementwise_add_activation"); .BindKernel("fusion_elementwise_add_activation")
.BindKernel("fusion_elementwise_sub_activation");
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
class ElementwiseAddActivationFusePass : public ProgramPass { class ElementwiseActivationFusePass : public ProgramPass {
public: public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override; void Apply(const std::unique_ptr<SSAGraph>& graph) override;
}; };
......
...@@ -21,21 +21,21 @@ namespace lite { ...@@ -21,21 +21,21 @@ namespace lite {
namespace mir { namespace mir {
namespace fusion { namespace fusion {
void ElementwiseAddActivationFuser::BuildPattern() { void ElementwiseActivationFuser::BuildPattern() {
// create input nodes. // create input nodes.
auto* x = VarNode("x")->assert_is_op_input("elementwise_add", "X")->AsInput(); auto* x = VarNode("x")->assert_is_op_input(eltwise_type_, "X")->AsInput();
auto* y = VarNode("y")->assert_is_op_input("elementwise_add", "Y")->AsInput(); auto* y = VarNode("y")->assert_is_op_input(eltwise_type_, "Y")->AsInput();
// create op nodes // create op nodes
auto* add = OpNode("add", "elementwise_add") auto* elt = OpNode("elt", eltwise_type_)
->assert_is_op("elementwise_add") ->assert_is_op(eltwise_type_)
->AsIntermediate(); ->AsIntermediate();
auto* act = auto* act =
OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate(); OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate();
// create intermediate nodes // create intermediate nodes
auto* add_out = VarNode("add_out") auto* elt_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out") ->assert_is_op_output(eltwise_type_, "Out")
->assert_is_op_input(act_type_, "X") ->assert_is_op_input(act_type_, "X")
->AsIntermediate(); ->AsIntermediate();
...@@ -44,21 +44,29 @@ void ElementwiseAddActivationFuser::BuildPattern() { ...@@ -44,21 +44,29 @@ void ElementwiseAddActivationFuser::BuildPattern() {
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology. // create topology.
std::vector<PMNode*> add_inputs{x, y}; std::vector<PMNode*> elt_inputs{x, y};
add_inputs >> *add >> *add_out; elt_inputs >> *elt >> *elt_out;
*add_out >> *act >> *out; *elt_out >> *act >> *out;
} }
void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph, void ElementwiseActivationFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) { const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched); auto op_desc = GenOpDesc(matched);
auto op = std::shared_ptr<lite::OpLite> op;
LiteOpRegistry::Global().Create("fusion_elementwise_add_activation"); if (eltwise_type_ == "elementwise_add") {
auto old_op = matched.at("add")->stmt()->op(); op = LiteOpRegistry::Global().Create("fusion_elementwise_add_activation");
} else if (eltwise_type_ == "elementwise_sub") {
op = LiteOpRegistry::Global().Create("fusion_elementwise_sub_activation");
} else if (eltwise_type_ == "elementwise_mul") {
op = LiteOpRegistry::Global().Create("fusion_elementwise_mul_activation");
} else {
LOG(FATAL) << "not supported elementwise_type: " << eltwise_type_;
}
auto old_op = matched.at("elt")->stmt()->op();
auto* scope = old_op->scope(); auto* scope = old_op->scope();
auto& valid_places = old_op->valid_places(); auto& valid_places = old_op->valid_places();
op->Attach(op_desc, scope); op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(op, valid_places); auto* new_op_node = graph->GraphCreateInstructNode(op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node); IR_NODE_LINK_TO(matched.at("x"), new_op_node);
...@@ -66,12 +74,20 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph, ...@@ -66,12 +74,20 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(new_op_node, matched.at("output")); IR_NODE_LINK_TO(new_op_node, matched.at("output"));
} }
cpp::OpDesc ElementwiseAddActivationFuser::GenOpDesc( cpp::OpDesc ElementwiseActivationFuser::GenOpDesc(const key2nodes_t& matched) {
const key2nodes_t& matched) { auto* desc = matched.at("elt")->stmt()->op_info();
auto* desc = matched.at("add")->stmt()->op_info();
cpp::OpDesc op_desc; cpp::OpDesc op_desc;
if (eltwise_type_ == "elementwise_add") {
op_desc.SetType("fusion_elementwise_add_activation"); op_desc.SetType("fusion_elementwise_add_activation");
} else if (eltwise_type_ == "elementwise_sub") {
op_desc.SetType("fusion_elementwise_sub_activation");
} else if (eltwise_type_ == "elementwise_mul") {
op_desc.SetType("fusion_elementwise_mul_activation");
} else {
LOG(FATAL) << "not supported elementwise_type: " << eltwise_type_;
}
op_desc.SetInput("X", {matched.at("x")->arg()->name}); op_desc.SetInput("X", {matched.at("x")->arg()->name});
op_desc.SetInput("Y", {matched.at("y")->arg()->name}); op_desc.SetInput("Y", {matched.at("y")->arg()->name});
op_desc.SetOutput("Out", {matched.at("output")->arg()->name}); op_desc.SetOutput("Out", {matched.at("output")->arg()->name});
......
...@@ -23,15 +23,23 @@ namespace lite { ...@@ -23,15 +23,23 @@ namespace lite {
namespace mir { namespace mir {
namespace fusion { namespace fusion {
class ElementwiseAddActivationFuser : public FuseBase { // Detect elementwise and activation ops, and then merge into
// fusion_eltsiwise_act op.
// Example:
// elementwise_add + relu fuse.
// fusion::ElementwiseActivationFuser fuser("elementwise_add", "relu");
// fuser(graph.get());
class ElementwiseActivationFuser : public FuseBase {
public: public:
explicit ElementwiseAddActivationFuser(const std::string& act_type) explicit ElementwiseActivationFuser(const std::string& eltwise_type,
: act_type_(act_type) {} const std::string& act_type)
: eltwise_type_(eltwise_type), act_type_(act_type) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private: private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string eltwise_type_;
std::string act_type_; std::string act_type_;
}; };
......
...@@ -74,7 +74,7 @@ class Optimizer { ...@@ -74,7 +74,7 @@ class Optimizer {
"lite_scale_activation_fuse_pass", // "lite_scale_activation_fuse_pass", //
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \ #if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
(defined LITE_WITH_ARM) (defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass", // "lite_elementwise_activation_fuse_pass", //
#endif #endif
"__xpu__resnet_fuse_pass", "__xpu__resnet_fuse_pass",
"__xpu__multi_encoder_fuse_pass", "__xpu__multi_encoder_fuse_pass",
......
...@@ -70,7 +70,7 @@ inline bool is_broadcast(const DDim& x_dims, ...@@ -70,7 +70,7 @@ inline bool is_broadcast(const DDim& x_dims,
return true; return true;
} }
#define ELEMENTWISE_COMPUTE(OP, WITH_RELU) \ #define ELEMENTWISE_COMPUTE(OP) \
auto& param = this->Param<param_t>(); \ auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \ auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \ auto stream = ctx.exec_stream(); \
...@@ -85,25 +85,66 @@ inline bool is_broadcast(const DDim& x_dims, ...@@ -85,25 +85,66 @@ inline bool is_broadcast(const DDim& x_dims,
int pre = 1; \ int pre = 1; \
int n = pixel_num; \ int n = pixel_num; \
int post = 1; \ int post = 1; \
if (WITH_RELU) { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise_relu( \ lite::cuda::math::elementwise( \
x_data, y_data, out_data, pre, n, post, OP, stream); \ x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \ } else { \
lite::cuda::math::elementwise_relu( \ lite::cuda::math::elementwise( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \ }
#define ELEMENTWISE_COMPUTE_ACT(OP) \
auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
const lite::Tensor* x = param.X; \
const lite::Tensor* y = param.Y; \
lite::Tensor* out = param.Out; \
int axis = param.axis; \
auto* x_data = x->data<float>(); \
auto* y_data = y->data<float>(); \
auto out_data = out->mutable_data<float>(TARGET(kCUDA)); \
int pixel_num = x->numel(); \
int pre = 1; \
int n = pixel_num; \
int post = 1; \
auto act = param.act_type; \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise_act( \
x_data, y_data, out_data, pre, n, post, act, OP, stream); \
} else { \ } else { \
lite::cuda::math::elementwise_act( \
x_data, y_data, out_data, 1, pixel_num, 1, act, OP, stream); \
}
#define ELEMENTWISE_COMPUTE_NHWC(OP) \
std::map<int, int> pos_map = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; \
auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
const lite::Tensor* x = param.X; \
const lite::Tensor* y = param.Y; \
lite::Tensor* out = param.Out; \
int axis = param.axis; \
if (axis < 0) axis = x->dims().size() - y->dims().size(); \
CHECK(axis >= 0) << "invalid axis of elementwise op"; \
axis = pos_map[axis]; \
auto* x_data = x->data<float>(); \
auto* y_data = y->data<float>(); \
auto out_data = out->mutable_data<float>(TARGET(kCUDA)); \
int pixel_num = x->numel(); \
int pre = 1; \
int n = pixel_num; \
int post = 1; \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise( \ lite::cuda::math::elementwise( \
x_data, y_data, out_data, pre, n, post, OP, stream); \ x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \ } else { \
lite::cuda::math::elementwise( \ lite::cuda::math::elementwise( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
} }
#define ELEMENTWISE_COMPUTE_NHWC(OP, WITH_RELU) \ #define ELEMENTWISE_COMPUTE_ACT_NHWC(OP) \
std::map<int, int> pos_map = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; \ std::map<int, int> pos_map = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; \
auto& param = this->Param<param_t>(); \ auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \ auto& ctx = this->ctx_->template As<CUDAContext>(); \
...@@ -122,80 +163,83 @@ inline bool is_broadcast(const DDim& x_dims, ...@@ -122,80 +163,83 @@ inline bool is_broadcast(const DDim& x_dims,
int pre = 1; \ int pre = 1; \
int n = pixel_num; \ int n = pixel_num; \
int post = 1; \ int post = 1; \
if (WITH_RELU) { \ auto act = param.act_type; \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise_relu( \ lite::cuda::math::elementwise_act( \
x_data, y_data, out_data, pre, n, post, OP, stream); \ x_data, y_data, out_data, pre, n, post, act, OP, stream); \
} else { \
lite::cuda::math::elementwise_relu( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
} else { \ } else { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ lite::cuda::math::elementwise_act( \
lite::cuda::math::elementwise( \ x_data, y_data, out_data, 1, pixel_num, 1, act, OP, stream); \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
} }
void ElementwiseAddCompute::Run() { void ElementwiseAddCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, false) ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD)
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
void ElementwiseAddComputeNHWC::Run() { void ElementwiseAddComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, false) ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD)
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
void ElementwiseSubCompute::Run() { void ElementwiseSubCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kSUB, false) ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kSUB)
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
void ElementwiseSubComputeNHWC::Run() { void ElementwiseSubComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kSUB, false) ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kSUB)
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
void ElementwiseMulCompute::Run() { void ElementwiseMulCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, false) ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL)
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
void ElementwiseMulComputeNHWC::Run() { void ElementwiseMulComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, false) ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL)
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
void ElementwiseAddReluCompute::Run() { void ElementwiseAddActivationCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, true) ELEMENTWISE_COMPUTE_ACT(lite::cuda::math::BinaryOperation::kADD)
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
void ElementwiseAddReluComputeNHWC::Run() { void ElementwiseAddActivationComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, true) ELEMENTWISE_COMPUTE_ACT_NHWC(lite::cuda::math::BinaryOperation::kADD)
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
void ElementwiseMulReluCompute::Run() { void ElementwiseSubActivationCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, true) ELEMENTWISE_COMPUTE_ACT(lite::cuda::math::BinaryOperation::kSUB)
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
void ElementwiseMulReluComputeNHWC::Run() { void ElementwiseSubActivationComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, true) ELEMENTWISE_COMPUTE_ACT_NHWC(lite::cuda::math::BinaryOperation::kSUB)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulActivationCompute::Run() {
ELEMENTWISE_COMPUTE_ACT(lite::cuda::math::BinaryOperation::kMUL)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulActivationComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_ACT_NHWC(lite::cuda::math::BinaryOperation::kMUL)
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
...@@ -298,22 +342,57 @@ REGISTER_LITE_KERNEL(elementwise_mul, ...@@ -298,22 +342,57 @@ REGISTER_LITE_KERNEL(elementwise_mul,
DATALAYOUT(kNHWC))}) DATALAYOUT(kNHWC))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_add_activation, REGISTER_LITE_KERNEL(
fusion_elementwise_add_activation,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseAddActivationCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(
fusion_elementwise_add_activation,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseAddActivationComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
fusion_elementwise_sub_activation,
kCUDA, kCUDA,
kFloat, kFloat,
kNCHW, kNCHW,
paddle::lite::kernels::cuda::ElementwiseAddReluCompute, paddle::lite::kernels::cuda::ElementwiseSubActivationCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_add_activation, REGISTER_LITE_KERNEL(
fusion_elementwise_sub_activation,
kCUDA, kCUDA,
kFloat, kFloat,
kNHWC, kNHWC,
paddle::lite::kernels::cuda::ElementwiseAddReluComputeNHWC, paddle::lite::kernels::cuda::ElementwiseSubActivationComputeNHWC,
nhwc_format) nhwc_format)
.BindInput("X", .BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA), {LiteType::GetTensorTy(TARGET(kCUDA),
...@@ -329,22 +408,24 @@ REGISTER_LITE_KERNEL(fusion_elementwise_add_activation, ...@@ -329,22 +408,24 @@ REGISTER_LITE_KERNEL(fusion_elementwise_add_activation,
DATALAYOUT(kNHWC))}) DATALAYOUT(kNHWC))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation, REGISTER_LITE_KERNEL(
fusion_elementwise_mul_activation,
kCUDA, kCUDA,
kFloat, kFloat,
kNCHW, kNCHW,
paddle::lite::kernels::cuda::ElementwiseMulReluCompute, paddle::lite::kernels::cuda::ElementwiseMulActivationCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation, REGISTER_LITE_KERNEL(
fusion_elementwise_mul_activation,
kCUDA, kCUDA,
kFloat, kFloat,
kNHWC, kNHWC,
paddle::lite::kernels::cuda::ElementwiseMulReluComputeNHWC, paddle::lite::kernels::cuda::ElementwiseMulActivationComputeNHWC,
nhwc_format) nhwc_format)
.BindInput("X", .BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA), {LiteType::GetTensorTy(TARGET(kCUDA),
......
...@@ -74,40 +74,58 @@ class ElementwiseMulComputeNHWC ...@@ -74,40 +74,58 @@ class ElementwiseMulComputeNHWC
virtual ~ElementwiseMulComputeNHWC() = default; virtual ~ElementwiseMulComputeNHWC() = default;
}; };
class ElementwiseAddReluCompute class ElementwiseAddActivationCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public: public:
using param_t = operators::FusionElementwiseActivationParam; using param_t = operators::FusionElementwiseActivationParam;
void Run() override; void Run() override;
virtual ~ElementwiseAddReluCompute() = default; virtual ~ElementwiseAddActivationCompute() = default;
}; };
class ElementwiseAddReluComputeNHWC class ElementwiseAddActivationComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> { : public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public: public:
using param_t = operators::FusionElementwiseActivationParam; using param_t = operators::FusionElementwiseActivationParam;
void Run() override; void Run() override;
virtual ~ElementwiseAddReluComputeNHWC() = default; virtual ~ElementwiseAddActivationComputeNHWC() = default;
}; };
class ElementwiseMulReluCompute class ElementwiseSubActivationCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public: public:
using param_t = operators::FusionElementwiseActivationParam; using param_t = operators::FusionElementwiseActivationParam;
void Run() override; void Run() override;
virtual ~ElementwiseMulReluCompute() = default; virtual ~ElementwiseSubActivationCompute() = default;
}; };
class ElementwiseMulReluComputeNHWC class ElementwiseSubActivationComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> { : public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public: public:
using param_t = operators::FusionElementwiseActivationParam; using param_t = operators::FusionElementwiseActivationParam;
void Run() override; void Run() override;
virtual ~ElementwiseMulReluComputeNHWC() = default; virtual ~ElementwiseSubActivationComputeNHWC() = default;
};
class ElementwiseMulActivationCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseMulActivationCompute() = default;
};
class ElementwiseMulActivationComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseMulActivationComputeNHWC() = default;
}; };
} // namespace cuda } // namespace cuda
......
...@@ -44,8 +44,6 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc, ...@@ -44,8 +44,6 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc,
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name); param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.axis = opdesc.GetAttr<int>("axis"); param_.axis = opdesc.GetAttr<int>("axis");
param_.act_type = opdesc.GetAttr<std::string>("act_type"); param_.act_type = opdesc.GetAttr<std::string>("act_type");
// TODO(sangoly): support more activation types.
CHECK(param_.act_type == "relu") << "Only relu activation be supported now";
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册