提交 f904623c 编写于 作者: myq406450149's avatar myq406450149

add gpu kernel mul pool relu scale softmax dropout bilinear_interp and can run in tx2

上级 4ac51a6b
......@@ -6,7 +6,7 @@ set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs7 "30 35 50 52")
set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75")
set(paddle_known_gpu_archs10 "30 35 50 52 60 61 62 70 75")
######################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled)
......
......@@ -34,6 +34,14 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS
${CUDA_TOOLKIT_ROOT_DIR}
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
)
if (${CUDA_VERSION} GREATER_EQUAL 10.0)
find_library(CUBLAS_LIBRARY NAMES libcublas.so PATHS ${CUDNN_CHECK_LIBRARY_DIRS} NO_DEFAULT_PATH)
set(CUBLAS_LIBRARIES ${CUBLAS_LIBRARY})
else()
set(CUBLAS_LIBRARIES ${CUDA_CUBLAS_LIBRARIES})
endif()
set(CUDNN_LIB_NAME "libcudnn.so")
if(WIN32)
......
......@@ -146,8 +146,11 @@ set(GPU_COMMON_FLAGS
-Wno-error=unused-local-typedefs
-Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array
-gencode arch=compute_62,code=sm_62
)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64")
if(NOT LITE_WITH_CUDA)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64")
endif()
endif(NOT WIN32)
if (APPLE)
......
......@@ -507,7 +507,7 @@ function(nv_test TARGET_NAME)
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest
gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY})
gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY} ${CUBLAS_LIBRARIES} )
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest gflags glog)
common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME})
......
......@@ -30,10 +30,8 @@ namespace cuda {
* Some basic methods.
*/
struct BlasBase {
/*
BlasBase() { CUBLAS_CHECK(cublasCreate(&handle_)); }
~BlasBase() { CUBLAS_CHECK(cublasDestroy(handle_)); }
*/
void SetStream(cudaStream_t stream) {
CUBLAS_CHECK(cublasSetStream(handle_, stream));
......
......@@ -13,25 +13,59 @@
// limitations under the License.
#include "iostream"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
/*
template <typename T>
__global__ void scale_kernel(int num, const T* in, T* out, const float scale) {
__global__ void scale_kernel(int num, const T* in, T* out, const float scale,
const float bias) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
#if __CUDA_ARCH__ >= 350
out[tid] = __ldg(in + tid) * scale;
out[tid] = __ldg(in + tid) * scale + bias;
#else
out[tid] = in[tid] * scale;
#endif
}
}
*/
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void scale_kernel(int count,
const T* in_data,
T* out_data,
const T* scale_data,
const T* bias_data,
const int scale_dim,
const int inner_dim) {
CUDA_KERNEL_LOOP(tid, count) {
int scale_id = (tid / inner_dim) % scale_dim;
T scale = scale_data[scale_id];
if (bias_data == nullptr) {
out_data[tid] = scale * in_data[tid];
} else {
out_data[tid] = scale * in_data[tid] + bias_data[scale_id];
}
}
}
template <typename T>
__global__ void scale_kernel(
int count, const T* in_data, T* out_data, const T scale, const T bias) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
// if (tid < count){
// out_data[tid] = scale * in_data[tid] + bias;
//}
CUDA_KERNEL_LOOP(tid, count) { out_data[tid] = scale * in_data[tid] + bias; }
}
__global__ void fp32_scale_nhwc4_kernel(int num,
const float4* in,
......@@ -114,21 +148,25 @@ void fp32_scale_nhwc(int num,
}
template <typename T>
void scale(int num, const T* in, T* out, float scale, cudaStream_t stream) {
void scale(int num, const T* in, T* out, T scale, cudaStream_t stream, T bias) {
int thread = 256;
int block = (num + thread - 1) / thread;
scale_kernel<<<block, thread, 0, stream>>>(num, in, out, scale);
scale_kernel<<<block, thread, 0, stream>>>(num, in, out, scale, bias);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
template <typename T>
void scale(int num, const T* in, T* out, float scale) {
void scale(int num, const T* in, T* out, T scale, T bias) {
int thread = 256;
int block = (num + thread - 1) / thread;
scale_kernel<<<block, thread>>>(num, in, out, scale);
scale_kernel<<<block, thread>>>(num, in, out, scale, bias);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
template void scale(int num, const float*, float*, float, cudaStream_t);
template void scale(int num, const float*, float*, float);
template void scale(int num, const float*, float*, float, cudaStream_t, float);
template void scale(int num, const float*, float*, float, float);
} // namespace math
} // namespace cuda
......
......@@ -32,10 +32,11 @@ void fp32_scale_nhwc(int num,
cudaStream_t stream);
template <typename T>
void scale(int num, const T* in, T* out, float scale, cudaStream_t stream);
void scale(
int num, const T* in, T* out, T scale, cudaStream_t stream, T bias = 0);
template <typename T>
void scale(int num, const T* in, T* out, float scale);
void scale(int num, const T* in, T* out, T scale, T bias = 0);
} // namespace math
} // namespace cuda
......
......@@ -54,11 +54,6 @@ void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc,
valid_kernels_ = op_->CreateKernels(valid_places);
}
std::ostream &mir::operator<<(std::ostream &os, const mir::Node::Stmt &other) {
os << "Statement " << other.op_type() << " " << other.place().DebugString();
return os;
}
mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) {
auto &x = AsArg();
x.name = name;
......
......@@ -74,7 +74,11 @@ class Node {
KernelBase& picked_kernel();
friend std::ostream& operator<<(std::ostream& os, const Stmt& other);
friend std::ostream& operator<<(std::ostream& os, const Stmt& other) {
os << "Statement " << other.op_type() << " "
<< other.place().DebugString();
return os;
}
// Description.
std::string desc;
......
......@@ -7,6 +7,7 @@ message(STATUS "compile with lite CUDA kernels")
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda} cuda_transpose)
add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
......@@ -16,15 +17,23 @@ add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute.
add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps})
add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose)
add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_cuda CUDA basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} cuda_scale)
add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} cuda_scale)
add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps})
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda)
nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
......@@ -111,6 +111,28 @@ REGISTER_LITE_KERNEL(
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ConvCompute,
def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
conv2d,
kCUDA,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/dropout_compute.h"
#include <string>
#include "lite/backends/cuda/math/scale.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void DropoutCompute::Run() {
auto& param = Param<operators::DropoutParam>();
const float* x_data = param.x->data<float>();
float* out_data = param.output->mutable_data<float>(TARGET(kCUDA));
int num = param.x->dims().production();
const float prob_data = param.dropout_prob;
float scale = 1.0f;
if (param.dropout_implementation == "downgrade_in_infer") {
scale = 1.0f - prob_data;
}
lite::cuda::math::scale(num, x_data, out_data, scale, 0);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(dropout,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::DropoutCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Mask", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class DropoutCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~DropoutCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/dropout_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename dtype>
void dropout_compute_ref(const operators::DropoutParam& param) {
const float* x_data = param.x->data<float>();
float* output_data = param.output->mutable_data<float>();
int num = param.x->dims().production();
const float prob_data = param.dropout_prob;
if (param.dropout_implementation.compare(
std::string({"downgrade_in_infer"})) == 0) {
float scale = 1.0 - prob_data;
for (int i = 0; i < num; i++) {
output_data[i] = x_data[i] * scale;
}
} else {
for (int i = 0; i < num; i++) {
output_data[i] = x_data[i];
}
}
}
TEST(dropout_cuda, normal) {
DropoutCompute dropout_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::DropoutParam param;
lite::Tensor x;
lite::Tensor x_cpu;
lite::Tensor x_ref;
lite::Tensor output;
lite::Tensor output_cpu;
lite::Tensor output_ref;
for (auto n : {1, 3, 4}) {
for (auto c : {1, 3, 4, 256}) {
for (auto h : {1, 3, 4, 6}) {
for (auto w : {1, 3, 4, 6}) {
for (auto prob : {0.2f, 0.8f})
for (auto impl : {std::string({"downgrade_in_infer"})}) {
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
x_cpu.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
x_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_cpu.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
auto* output_data = output.mutable_data<float>(TARGET(kCUDA));
auto* output_cpu_data = output_cpu.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_cpu_data[i] = i;
x_ref_data[i] = i;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data,
x_cpu.dims());
param.x = &x;
param.output = &output;
param.dropout_prob = prob;
param.dropout_implementation = impl;
dropout_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
dropout_kernel.SetContext(std::move(ctx));
dropout_kernel.Launch();
CopySync<TARGET(kCUDA)>(output_cpu_data,
output_data,
sizeof(float) * output.numel(),
IoDirection::DtoH);
param.x = &x_ref;
param.output = &output_ref;
dropout_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_cpu_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -29,12 +29,7 @@ void ElementwiseAddCompute::Run() {
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims() == y->dims());
const int n = x->dims()[0];
const int c = x->dims()[1];
const int h = x->dims()[2];
const int w = x->dims()[3];
CHECK(x->dims().production() == y->dims().production());
auto* x_data = x->data<float>();
auto* y_data = y->data<float>();
......@@ -57,12 +52,7 @@ void ElementwiseAddComputeNHWC::Run() {
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims() == y->dims());
const int n = x->dims()[0];
const int c = x->dims()[1];
const int h = x->dims()[2];
const int w = x->dims()[3];
CHECK(x->dims().production() == y->dims().production());
auto* x_data = x->data<float>();
auto* y_data = y->data<float>();
......@@ -85,7 +75,7 @@ void ElementwiseAddComputeInt8::Run() {
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims() == y->dims());
CHECK(x->dims().production() == y->dims().production());
const int c = x->dims()[3];
......
......@@ -33,19 +33,36 @@ void mul_compute(const lite::cuda::Blas<float>& blas,
int y_h,
int y_w,
T* out) {
float alpha = 1.0;
float beta = 0.0;
/*
blas.sgemm(CUBLAS_OP_N,
CUBLAS_OP_N,
x_h,
y_w,
x_w,
nullptr,
&alpha,
x,
x_w,
y,
y_w,
nullptr,
&beta,
out,
x_h);
*/
blas.sgemm(CUBLAS_OP_N,
CUBLAS_OP_N,
y_w,
x_h,
y_h,
&alpha,
y,
y_w,
x,
x_w,
&beta,
out,
y_w);
}
class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
......@@ -56,23 +73,29 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
CHECK(ctx_) << "running context should be set first";
auto& context = this->ctx_->template As<CUDAContext>();
CHECK(context.cublas_fp32()) << "blas should init first";
/*
auto& blas = *context.cublas_fp32();
CHECK(param.x->target() == TARGET(kCUDA));
auto* x = param.x->data<float>();
int x_h = param.x->dims()[0];
int x_w = param.x->dims()[1];
auto* y = param.y->data<float>();
int y_h = param.y->dims()[0];
int y_w = param.y->dims()[1];
*/
auto& param = this->Param<param_t>();
const auto* x_data = param.x->data<float>();
const auto* y_data = param.y->data<float>();
auto* out_data = param.output->mutable_data<float>(TARGET(kCUDA));
const auto& param = Param<operators::MulParam>();
param.output->mutable_data<float>(TARGET(kCUDA));
LOG(INFO) << "mul output memory size " << param.output->data_size();
int x_h = static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production());
int x_w = static_cast<int>(
param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size())
.production());
int y_h = static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production());
int y_w = static_cast<int>(
param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production());
CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
LOG(INFO) << x_h << " " << x_w << " " << y_h << " " << y_w;
// mul_compute<float>(blas, x, x_h, x_w, y, y_h, y_w, out);
mul_compute<float>(blas, x_data, x_h, x_w, y_data, y_h, y_w, out_data);
}
virtual ~MulCompute() = default;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/mul_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
TEST(mul_compute, normal) {
MulCompute mul_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
Tensor x, y, out, x_cpu, y_cpu, out_cpu;
int x_h = 2, x_w_y_h = 3, y_w = 4;
out.Resize({x_h, y_w});
x_cpu.Resize({x_h, x_w_y_h});
y_cpu.Resize({x_w_y_h, y_w});
out_cpu.Resize({x_h, y_w});
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* y_cpu_data = y_cpu.mutable_data<float>();
float* out_cpu_data = out_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i + 1.0;
}
for (int i = 0; i < y_cpu.numel(); i++) {
y_cpu_data[i] = i + 1.0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
y.Assign<float, lite::DDim, TARGET(kCUDA)>(y_cpu_data, y_cpu.dims());
operators::MulParam param;
param.x = &x;
param.y = &y;
param.output = &out;
mul_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
mul_kernel.SetContext(std::move(ctx));
mul_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
for (int i = 0; i < out_cpu.numel(); i++) {
LOG(INFO) << out_cpu_data[i];
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/pool_compute.h"
#include "lite/utils/macros.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
using DDim = lite::DDim;
#define MAX_VAL(a, b) (((a) > (b)) ? (a) : (b))
#define MIN_VAL(a, b) (((a) < (b)) ? (a) : (b))
__global__ void max_pool_kernel(const float* input,
float* output,
const int spatial_in,
const int spatial_out,
const int in_h,
const int in_w,
const int out_h,
const int out_w,
const int pad_h,
const int pad_w,
const int win_h,
const int win_w,
const int stride_h,
const int stride_w,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int nc_id = gid / spatial_out;
const int w_id = gid % spatial_out % out_w;
const int h_id = gid % spatial_out / out_w;
const int w_s = w_id * stride_w - pad_w;
const int iw_s = MAX_VAL(w_s, 0);
const int iw_e = MIN_VAL(w_s + win_w, in_w);
const int w_loop = iw_e - iw_s;
const int h_s = h_id * stride_h - pad_h;
const int ih_s = MAX_VAL(h_s, 0);
const int ih_e = MIN_VAL(h_s + win_h, in_h);
const int h_loop = ih_e - ih_s;
const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s;
float max_val = -FLT_MAX;
for (int i = 0; i < h_loop; ++i) {
for (int j = 0; j < w_loop; ++j) {
max_val = MAX_VAL(max_val, *(in_p + j));
}
in_p += in_w;
}
max_val = max_val == -FLT_MAX ? 0.f : max_val;
output[nc_id * spatial_out + h_id * out_w + w_id] = max_val;
}
}
__global__ void adaptive_max_pool_kernel(const float* input,
float* output,
const int spatial_in,
const int spatial_out,
const int in_h,
const int in_w,
const int out_h,
const int out_w,
const int pad_h,
const int pad_w,
const int win_h,
const int win_w,
const int stride_h,
const int stride_w,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int nc_id = gid / spatial_out;
const int w_id = gid % spatial_out % out_w;
const int h_id = gid % spatial_out / out_w;
const int iw_s = floor(static_cast<double>(w_id * in_w) / out_w);
const int iw_e = ceil(static_cast<double>((w_id + 1) * in_w) / out_w);
const int w_loop = iw_e - iw_s;
const int ih_s = floor(static_cast<double>(h_id * in_h) / out_h);
const int ih_e = ceil(static_cast<double>((h_id + 1) * in_h) / out_h);
const int h_loop = ih_e - ih_s;
const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s;
float max_val = -FLT_MAX;
for (int i = 0; i < h_loop; ++i) {
for (int j = 0; j < w_loop; ++j) {
max_val = MAX_VAL(max_val, *(in_p + j));
}
in_p += in_w;
}
output[nc_id * spatial_out + h_id * out_w + w_id] = max_val;
}
}
__global__ void avg_pool_kernel(const float* input,
float* output,
const int spatial_in,
const int spatial_out,
const int in_h,
const int in_w,
const int out_h,
const int out_w,
const int pad_h,
const int pad_w,
const int win_h,
const int win_w,
const int stride_h,
const int stride_w,
bool exclusive,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int nc_id = gid / spatial_out;
const int w_id = gid % spatial_out % out_w;
const int h_id = gid % spatial_out / out_w;
const int w_s = w_id * stride_w - pad_w;
const int iw_s = MAX_VAL(w_s, 0);
const int iw_e = MIN_VAL(w_s + win_w, in_w);
const int w_loop = iw_e - iw_s;
const int h_s = h_id * stride_h - pad_h;
const int ih_s = MAX_VAL(h_s, 0);
const int ih_e = MIN_VAL(h_s + win_h, in_h);
const int h_loop = ih_e - ih_s;
const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s;
float sum_val = 0.f;
for (int i = 0; i < h_loop; ++i) {
for (int j = 0; j < w_loop; ++j) {
sum_val += *(in_p + j);
}
in_p += in_w;
}
int pool_size = exclusive ? h_loop * w_loop : win_w * win_h;
pool_size = pool_size == 0 ? 1 : pool_size;
output[nc_id * spatial_out + h_id * out_w + w_id] = sum_val / pool_size;
}
}
__global__ void adaptive_avg_pool_kernel(const float* input,
float* output,
const int spatial_in,
const int spatial_out,
const int in_h,
const int in_w,
const int out_h,
const int out_w,
const int pad_h,
const int pad_w,
const int win_h,
const int win_w,
const int stride_h,
const int stride_w,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int nc_id = gid / spatial_out;
const int w_id = gid % spatial_out % out_w;
const int h_id = gid % spatial_out / out_w;
const int iw_s = floor(static_cast<double>(w_id * in_w) / out_w);
const int iw_e = ceil(static_cast<double>((w_id + 1) * in_w) / out_w);
const int w_loop = iw_e - iw_s;
const int ih_s = floor(static_cast<double>(h_id * in_h) / out_h);
const int ih_e = ceil(static_cast<double>((h_id + 1) * in_h) / out_h);
const int h_loop = ih_e - ih_s;
const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s;
float sum_val = 0.f;
for (int i = 0; i < h_loop; ++i) {
for (int j = 0; j < w_loop; ++j) {
sum_val += *(in_p + j);
}
in_p += in_w;
}
int pool_size = h_loop * w_loop;
pool_size = pool_size == 0 ? 1 : pool_size;
output[nc_id * spatial_out + h_id * out_w + w_id] = sum_val / pool_size;
}
}
__global__ void global_max_pool_kernel(const float* input,
float* output,
const int in_h,
const int in_w,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int spatial_in = in_h * in_w;
const float* in_p = input + gid * spatial_in;
int i = 0;
float max_val = -0.f;
// unroll 8
for (; i < spatial_in - 7; i += 8) {
max_val = MAX_VAL(max_val, *(in_p + 0));
max_val = MAX_VAL(max_val, *(in_p + 1));
max_val = MAX_VAL(max_val, *(in_p + 2));
max_val = MAX_VAL(max_val, *(in_p + 3));
max_val = MAX_VAL(max_val, *(in_p + 4));
max_val = MAX_VAL(max_val, *(in_p + 5));
max_val = MAX_VAL(max_val, *(in_p + 6));
max_val = MAX_VAL(max_val, *(in_p + 7));
in_p += 8;
}
for (; i < spatial_in; i++) {
max_val = MAX_VAL(max_val, *in_p);
in_p++;
}
output[gid] = max_val;
}
}
__global__ void global_avg_pool_kernel(const float* input,
float* output,
const int in_h,
const int in_w,
const int total_threads) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid < total_threads) {
const int spatial_in = in_h * in_w;
const float* in_p = input + gid * spatial_in;
int i = 0;
float sum_val = 0.f;
// unroll 8
for (; i < spatial_in - 7; i += 8) {
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
sum_val += *in_p++;
}
for (; i < spatial_in; i++) {
sum_val += *in_p++;
}
output[gid] = sum_val / spatial_in;
}
}
void PoolCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
bool exclusive = param.exclusive;
bool adaptive = param.adaptive;
auto x_dims = param.x->dims();
auto out_dims = param.output->dims();
const int in_h = x_dims[2];
const int in_w = x_dims[3];
const int out_h = out_dims[2];
const int out_w = out_dims[3];
const int spatial_in = in_h * in_w;
const int spatial_out = out_h * out_w;
const int win_h = param.ksize[0];
const int win_w = param.ksize[1];
const int stride_h = param.strides[0];
const int stride_w = param.strides[1];
const int pad_h = param.paddings[0];
const int pad_w = param.paddings[1];
const int total_threads = out_dims.production();
const int threads = 512;
const int blocks = (total_threads + threads - 1) / threads;
auto input_data = param.x->data<float>();
auto output_data = param.output->mutable_data<float>(TARGET(kCUDA));
if (param.global_pooling) {
if (param.pooling_type == "max") {
global_max_pool_kernel<<<blocks, threads, 0, stream>>>(
input_data, output_data, in_h, in_w, total_threads);
} else {
global_avg_pool_kernel<<<blocks, threads, 0, stream>>>(
input_data, output_data, in_h, in_w, total_threads);
}
} else {
if (!adaptive) {
if (param.pooling_type == "max") {
max_pool_kernel<<<blocks, threads, 0, stream>>>(input_data,
output_data,
spatial_in,
spatial_out,
in_h,
in_w,
out_h,
out_w,
pad_h,
pad_w,
win_h,
win_w,
stride_h,
stride_w,
total_threads);
} else {
avg_pool_kernel<<<blocks, threads, 0, stream>>>(input_data,
output_data,
spatial_in,
spatial_out,
in_h,
in_w,
out_h,
out_w,
pad_h,
pad_w,
win_h,
win_w,
stride_h,
stride_w,
exclusive,
total_threads);
}
} else {
if (param.pooling_type == "max") {
adaptive_max_pool_kernel<<<blocks, threads, 0, stream>>>(input_data,
output_data,
spatial_in,
spatial_out,
in_h,
in_w,
out_h,
out_w,
pad_h,
pad_w,
win_h,
win_w,
stride_h,
stride_w,
total_threads);
} else {
adaptive_avg_pool_kernel<<<blocks, threads, 0, stream>>>(input_data,
output_data,
spatial_in,
spatial_out,
in_h,
in_w,
out_h,
out_w,
pad_h,
pad_w,
win_h,
win_w,
stride_h,
stride_w,
total_threads);
}
}
}
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(FATAL) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
pool2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::PoolCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class PoolCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public:
using param_t = operators::PoolParam;
void Run() override;
virtual ~PoolCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/pool_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
using DDim = lite::DDim;
static int PoolOutputSize(
int input_size, int filter_size, int padding, int stride, bool ceil_mode) {
int output_size;
if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1;
} else {
output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
}
return output_size;
}
static std::vector<int64_t> compute_output_shape(operators::PoolParam* param_) {
const auto x_dims = param_->x->dims();
std::vector<int>& ksize = param_->ksize;
if (param_->global_pooling) {
ksize.resize(static_cast<size_t>(x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) {
param_->paddings[i] = 0;
ksize[i] = static_cast<int>(x_dims[i + 2]);
}
}
std::vector<int64_t> output_shape({x_dims[0], x_dims[1]});
if (param_->adaptive) {
output_shape.insert(
output_shape.end(), param_->ksize.begin(), param_->ksize.end());
} else {
for (size_t i = 0; i < param_->ksize.size(); ++i) {
output_shape.push_back(PoolOutputSize(x_dims[i + 2],
param_->ksize[i],
param_->paddings[i],
param_->strides[i],
param_->ceil_mode));
}
}
return output_shape;
}
static void pool_compute_ref(const operators::PoolParam& param) {
auto& in_dims = param.x->dims();
auto& out_dims = param.output->dims();
const float* src_ptr = param.x->data<const float>();
float* dst_ptr = param.output->mutable_data<float>();
std::vector<int> ksize = param.ksize;
std::vector<int> strides = param.strides;
std::vector<int> paddings = param.paddings;
std::string pooling_type = param.pooling_type;
bool global_pooling = param.global_pooling;
bool exclusive = param.exclusive;
std::string data_format = param.data_format;
int in_n = in_dims[0];
int in_c = in_dims[1];
int in_h = in_dims[2];
int in_w = in_dims[3];
int size_in_n = in_c * in_h * in_w;
int size_in_c = in_h * in_w;
int out_h = out_dims[2];
int out_w = out_dims[3];
int size_out_n = in_c * out_h * out_w;
int size_out_c = out_h * out_w;
int window_h = ksize[0];
int window_w = ksize[1];
int stride_h = strides[0];
int stride_w = strides[1];
int pad_h = paddings[0];
int pad_w = paddings[1];
if (global_pooling == true) {
for (int n = 0; n < in_n; ++n) {
for (int c = 0; c < in_c; ++c) {
const float* src = src_ptr + n * size_in_n + c * size_in_c;
float res = src[0];
if (pooling_type == "max") {
for (int i = 1; i < size_in_c; ++i) {
float cur_val = src[i];
res = cur_val > res ? cur_val : res;
}
} else if (pooling_type == "avg") {
for (int i = 1; i < size_in_c; ++i) {
float cur_val = src[i];
res += cur_val;
}
res /= size_in_c;
}
dst_ptr[n * size_out_n + c] = res;
}
}
} else {
for (int n = 0; n < in_n; ++n) {
for (int c = 0; c < in_c; ++c) {
for (int h = 0; h < out_h; ++h) {
int sh = h * stride_h;
int eh = sh + window_h;
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h;
eh = (eh - pad_h) > in_h ? in_h : eh - pad_h;
for (int w = 0; w < out_w; ++w) {
int sw = w * stride_w;
int ew = sw + window_w;
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w;
ew = (ew - pad_w) > in_w ? in_w : ew - pad_w;
int pooling_size = (ew - sw) * (eh - sh);
if (pooling_size == 0) {
dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = 0.f;
continue;
}
float res = 0.f;
for (int kh = sh; kh < eh; ++kh) {
for (int kw = sw; kw < ew; ++kw) {
int src_idx = n * size_in_n + c * size_in_c + kh * in_w + kw;
if (kh == sh && kw == sw) {
res = src_ptr[src_idx];
} else {
if (pooling_type == "max") {
res = res >= src_ptr[src_idx] ? res : src_ptr[src_idx];
}
if (pooling_type == "avg") {
res += src_ptr[src_idx];
}
}
}
}
if (pooling_type == "avg") {
if (exclusive) {
res /= pooling_size;
} else {
res /= window_h * window_w;
}
}
dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = res;
}
}
}
}
}
}
TEST(pool_cuda, compute) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
PoolCompute pool;
operators::PoolParam param;
pool.SetContext(std::move(ctx));
lite::Tensor x;
lite::Tensor x_cpu;
lite::Tensor output;
lite::Tensor output_cpu;
lite::Tensor output_ref;
for (auto pooling_type : {"max", "avg"}) {
for (auto ceil_mode : {true, false}) {
for (auto global_pooling : {true, false}) {
for (auto exclusive : {true, false}) {
for (auto ksize : {2, 3}) {
for (auto stride : {1, 2}) {
for (auto pad : {0, 1}) {
for (auto n : {1, 2}) {
for (auto c : {1, 3, 256}) {
for (auto h : {2, 3, 4, 6, 13}) {
for (auto w : {2, 3, 4, 6, 13}) {
VLOG(3) << "n:" << n << " c:" << c << " h:" << h
<< " w:" << w << " ksize:" << ksize
<< " stride:" << stride << " pad:" << pad
<< " exclusive:" << exclusive
<< " global_pooling:" << global_pooling
<< " ceil_mode: " << ceil_mode
<< " pooling_type:" << pooling_type;
// init x, output
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
x_cpu.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_cpu_data = x_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.dims().production(); ++i) {
float sign = i % 3 == 0 ? -0.03 : 0.05f;
x_cpu_data[i] = sign * (i % 128);
}
x.Assign<float, DDim, TARGET(kCUDA)>(x_cpu_data,
x_cpu.dims());
// fill param
param.x = &x;
param.output = &output;
param.pooling_type = pooling_type;
if (global_pooling) {
param.ksize = {h, w};
} else {
param.ksize = {ksize, ksize};
}
param.global_pooling = global_pooling;
param.strides = {stride, stride};
param.paddings = {pad, pad};
param.exclusive = exclusive;
param.ceil_mode = ceil_mode;
param.adaptive = false;
param.use_quantizer = false;
const std::vector<int64_t>& output_shape =
compute_output_shape(&param);
if (output_shape[2] * output_shape[3] == 0) continue;
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
output_cpu.Resize(DDim(output_shape));
auto* output_data =
output.mutable_data<float>(TARGET(kCUDA));
auto* output_ref_data =
output_ref.mutable_data<float>();
auto* output_cpu_data =
output_cpu.mutable_data<float>();
// compute
pool.SetParam(param);
pool.Launch();
// compute ref
param.x = &x_cpu;
param.output = &output_ref;
pool_compute_ref(param);
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(output_cpu_data,
output_data,
sizeof(float) * output.numel(),
IoDirection::DtoH);
// compare
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(
output_cpu_data[i], output_ref_data[i], 1e-4);
}
VLOG(3) << "compare pass";
}
}
}
}
}
}
}
}
}
}
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/relu_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
__global__ void ReluKernel(const int num, const T* input, T* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
#if __CUDA_ARCH__ >= 350
output[index] = __ldg(input + index) >= 0 ? __ldg(input + index) : 0;
#else
output[index] = input[index] >= 0 ? input[index] : 0;
#endif
}
}
void ReluCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int num = static_cast<int>(param.X->numel());
auto input = param.X->data<float>();
auto output = param.Out->mutable_data<float>(TARGET(kCUDA));
int threads = 1024;
int blocks = (num + threads - 1) / threads;
ReluKernel<<<blocks, threads, 0, stream>>>(num, input, output);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
relu, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ReluCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class ReluCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~ReluCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/relu_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
TEST(relu, normal) {
ReluCompute relu_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam param;
Tensor x, y, x_cpu, y_cpu;
int h = 256, w = 256;
y.Resize({h, w});
x_cpu.Resize({h, w});
y_cpu.Resize({h, w});
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* y_cpu_data = x_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i - 5.0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.X = &x;
param.Out = &y;
relu_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
relu_kernel.SetContext(std::move(ctx));
relu_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
// for (int i = 0; i < y.numel(); i++) {
// LOG(INFO) << y_cpu_data[i];
// }
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/kernels/cuda/scale_compute.h"
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void ScaleCompute::Run() {
auto& param = Param<operators::ScaleParam>();
const float* x_data = param.x->data<float>();
float* output_data = param.output->mutable_data<float>();
DDim x_dims = param.x->dims();
bool bias_after_scale = param.bias_after_scale;
float scale = param.scale;
float bias = param.bias;
if (!bias_after_scale) {
bias *= scale;
}
lite::cuda::math::scale(
x_dims.production(), x_data, output_data, scale, bias);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
scale, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ScaleCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/cuda/math/scale.h"
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class ScaleCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ScaleCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/softmax_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
extern __shared__ char tile[];
template <typename dtype>
__global__ void sharemem_softmax_kernel(int total_size,
const dtype* in_data,
dtype* out_data,
int inner_num,
int outer_num,
int axis_size) {
dtype* data = reinterpret_cast<dtype*>(tile) + threadIdx.x;
//! compute thread index and real data index
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_size) {
int idx_inner = idx % inner_num;
int idx_outer = (idx / inner_num) * axis_size;
int blocksize = blockDim.x;
int real_index = idx_outer * inner_num + idx_inner;
int loop_idx = real_index;
//! read all data to sharemem in softmax channel
#pragma unroll
for (int i = 0; i < axis_size; ++i) {
data[i * blocksize] = in_data[loop_idx];
loop_idx += inner_num;
}
//! get maximum value in softmax channel
dtype max_data = data[0];
#pragma unroll
for (int i = 1; i < axis_size; ++i) {
dtype dt = data[i * blocksize];
if (max_data < dt) {
max_data = dt;
}
}
//! subtract then summarize
dtype sum = 0;
#pragma unroll
for (int i = 0; i < axis_size; ++i) {
dtype* dt = data + i * blocksize;
*dt = expf(*dt - max_data);
sum += *dt;
}
//! write back result
loop_idx = real_index;
#pragma unroll
for (int i = 0; i < axis_size; ++i) {
out_data[loop_idx] = data[i * blocksize] / sum;
loop_idx += inner_num;
}
}
}
//! general kernel for softmax
template <typename dtype>
__global__ void softmax_max_kernel(int total_size,
const dtype* in_data,
dtype* out_data,
dtype min_data,
int inner_num,
int outer_num,
int axis_size) {
//! compute data index
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_size) {
int idx_inner = idx % inner_num;
int idx_outer = (idx / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
//! get maximum data across softmax axis
dtype max_data = min_data;
for (int i = 0; i < axis_size; ++i) {
max_data =
in_data[real_index] > max_data ? in_data[real_index] : max_data;
real_index += inner_num;
}
out_data[idx] = max_data;
}
}
template <typename dtype>
__global__ void softmax_sub_exp_sum_kernel(int total_size,
const dtype* in_data,
dtype* out_data,
const dtype* max_data,
dtype* sum_data,
int inner_num,
int outer_num,
int axis_size) {
//! compute data index
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_size) {
int idx_inner = idx % inner_num;
int idx_outer = (idx / inner_num) * axis_size;
dtype max_data_cur = max_data[idx];
dtype sum_data_cur = 0;
int real_index = idx_outer * inner_num + idx_inner;
//! compute exp and summarize across the softmax axis
for (int i = 0; i < axis_size; ++i) {
dtype sub_data = in_data[real_index] - max_data_cur;
sub_data = expf(sub_data);
sum_data_cur += sub_data;
out_data[real_index] = sub_data;
real_index += inner_num;
}
sum_data[idx] = sum_data_cur;
}
}
template <typename dtype>
__global__ void softmax_divid_output_kernel(int total_size,
dtype* io_data,
const dtype* sum_data,
int inner_num,
int outer_num,
int axis_size) {
//! compute data index
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_size) {
int idx_inner = idx % inner_num;
int idx_outer = (idx / inner_num) * axis_size;
dtype sum_data_cur = 1.f / sum_data[idx];
int real_index = idx_outer * inner_num + idx_inner;
//! compute final result
for (int i = 0; i < axis_size; ++i) {
io_data[real_index] = io_data[real_index] * sum_data_cur;
real_index += inner_num;
}
}
}
void SoftmaxCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
auto x_dims = param.x->dims();
auto x_rank = x_dims.size();
int axis = param.axis;
if (axis < 0) {
axis += x_rank;
}
int outer_num = x_dims.Slice(0, axis).production();
int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int total_threads = inner_num * outer_num;
int axis_size = x_dims[axis];
int device_id;
const int threads = 512;
const int blocks = (total_threads + threads - 1) / threads;
cudaGetDevice(&device_id);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, device_id);
size_t sharedmem_size = deviceProp.sharedMemPerBlock;
int max_dimsize = sharedmem_size / sizeof(float) / threads;
auto input_data = param.x->data<float>();
auto output_data = param.output->mutable_data<float>(TARGET(kCUDA));
if (axis_size <= max_dimsize) {
int use_sharemem_size = axis_size * threads * sizeof(float);
sharemem_softmax_kernel<<<blocks, threads, use_sharemem_size, stream>>>(
total_threads,
input_data,
output_data,
inner_num,
outer_num,
axis_size);
} else {
//! re_alloc device memory
Tensor tmax_data;
Tensor tsum_data;
tmax_data.Resize({1, 1, 1, outer_num * inner_num});
tsum_data.Resize({1, 1, 1, outer_num * inner_num});
auto max_data = tmax_data.mutable_data<float>(TARGET(kCUDA));
auto sum_data = tsum_data.mutable_data<float>(TARGET(kCUDA));
//! firstly, get maximum data
float min_data = std::numeric_limits<float>::min();
softmax_max_kernel<float><<<blocks, threads, 0, stream>>>(total_threads,
input_data,
max_data,
min_data,
inner_num,
outer_num,
axis_size);
//! then, compute exp and sum data
softmax_sub_exp_sum_kernel<float><<<blocks, threads, 0, stream>>>(
total_threads,
input_data,
output_data,
max_data,
sum_data,
inner_num,
outer_num,
axis_size);
//! last, compute divided output
softmax_divid_output_kernel<float><<<blocks, threads, 0, stream>>>(
total_threads, output_data, sum_data, inner_num, outer_num, axis_size);
}
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(softmax,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SoftmaxCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("axis",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SoftmaxCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public:
using param_t = operators::SoftmaxParam;
void Run() override;
virtual ~SoftmaxCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册