未验证 提交 91a58fba 编写于 作者: W Wilber 提交者: GitHub

add cuda kernels. test=develop (#3315)

add cuda kernel.

abs, tanh, elementwise_sub
上级 4b0d60e7
......@@ -99,7 +99,8 @@ enum class ActivationType : int {
kTanh = 6,
kSwish = 7,
kExp = 8,
NUM = 9,
kAbs = 9,
NUM = 10,
};
static size_t PrecisionTypeLength(PrecisionType type) {
......
......@@ -29,6 +29,7 @@ enum class BinaryOperation {
kADD = 0,
kMUL = 1,
kDIV = 2,
kSUB = 3,
};
template <typename T>
......@@ -41,6 +42,7 @@ __device__ __forceinline__ float binary_calc(float x,
if (type == BinaryOperation::kADD) return x + y;
if (type == BinaryOperation::kMUL) return x * y;
if (type == BinaryOperation::kDIV) return x / y;
if (type == BinaryOperation::kSUB) return x - y;
}
template <typename T>
......
......@@ -8,6 +8,8 @@ add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_de
add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(abs_compute_cuda CUDA basic SRCS abs_compute.cu DEPS ${lite_kernel_deps})
add_kernel(tanh_compute_cuda CUDA basic SRCS tanh_compute.cu DEPS ${lite_kernel_deps})
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps})
......@@ -45,6 +47,8 @@ lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_
#nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda)
nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
nv_test(abs_compute_cuda_test SRCS abs_compute_test.cc DEPS abs_compute_cuda)
nv_test(tanh_compute_cuda_test SRCS tanh_compute_test.cc DEPS tanh_compute_cuda)
nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
......@@ -61,7 +65,7 @@ nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc
#nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
#nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda)
#nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda sequence_topk_avg_pooling_compute_cuda)
#nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda)
#nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda)
if(LITE_BUILD_EXTRA)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/abs_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
__global__ void AbsKernel(const int num, const T* input, T* output);
template <>
__global__ void AbsKernel<float>(const int num,
const float* input,
float* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
output[index] = fabsf(input[index]);
}
}
template <>
__global__ void AbsKernel<double>(const int num,
const double* input,
double* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
output[index] = fabs(input[index]);
}
}
void AbsCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int num = static_cast<int>(param.X->numel());
auto input = param.X->data<float>();
auto output = param.Out->mutable_data<float>(TARGET(kCUDA));
const int threads = 512;
const int blocks = (num + threads - 1) / threads;
AbsKernel<float><<<blocks, threads, 0, stream>>>(num, input, output);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
abs, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::AbsCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class AbsCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~AbsCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/abs_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
TEST(abs, normal) {
AbsCompute abs_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam param;
Tensor x, y, x_cpu, y_cpu;
int h = 3, w = 3;
y.Resize({h, w});
x_cpu.Resize({h, w});
y_cpu.Resize({h, w});
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* y_cpu_data = y_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i - 1.5;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.X = &x;
param.Out = &y;
abs_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
abs_kernel.SetContext(std::move(ctx));
abs_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_cpu_data[i], std::fabs(x_cpu_data[i]), 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -152,6 +152,18 @@ void ElementwiseAddComputeNHWC::Run() {
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseSubCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kSUB, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseSubComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kSUB, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, false)
cudaError_t error = cudaGetLastError();
......@@ -204,6 +216,17 @@ REGISTER_LITE_KERNEL(elementwise_add,
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_sub,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseSubCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_add,
kCUDA,
kFloat,
......@@ -224,6 +247,26 @@ REGISTER_LITE_KERNEL(elementwise_add,
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_sub,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseSubComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_mul,
kCUDA,
kFloat,
......
......@@ -38,6 +38,24 @@ class ElementwiseAddComputeNHWC
virtual ~ElementwiseAddComputeNHWC() = default;
};
class ElementwiseSubCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override;
virtual ~ElementwiseSubCompute() = default;
};
class ElementwiseSubComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override;
virtual ~ElementwiseSubComputeNHWC() = default;
};
class ElementwiseMulCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/tanh_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
__global__ void TanhKernel(const int num, const T* input, T* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
output[index] = tanh(input[index]);
}
}
void TanhCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int num = static_cast<int>(param.X->numel());
auto input = param.X->data<float>();
auto output = param.Out->mutable_data<float>(TARGET(kCUDA));
const int threads = 512;
const int blocks = (num + threads - 1) / threads;
TanhKernel<float><<<blocks, threads, 0, stream>>>(num, input, output);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
tanh, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::TanhCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class TanhCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~TanhCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/tanh_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
TEST(tanh, fp32) {
TanhCompute tanh_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam param;
Tensor x, y, x_cpu, y_cpu;
int h = 3, w = 3;
y.Resize({h, w});
x_cpu.Resize({h, w});
y_cpu.Resize({h, w});
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* y_cpu_data = y_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i - 1.5;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.X = &x;
param.Out = &y;
tanh_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
tanh_kernel.SetContext(std::move(ctx));
tanh_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_cpu_data[i], tanh(x_cpu_data[i]), 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -71,6 +71,9 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
} else if (opdesc.Type() == "exp") {
// exp
param_.active_type = lite_api::ActivationType::kExp;
} else if (opdesc.Type() == "abs") {
// abs
param_.active_type = lite_api::ActivationType::kAbs;
}
VLOG(4) << "opdesc.Type():" << opdesc.Type();
......@@ -92,6 +95,7 @@ REGISTER_LITE_OP(swish, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(abs, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(sqrt, paddle::lite::operators::ActivationOp);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册