diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 1e6f34a62129e2ca0a717ceb489d98b56b78d47a..26e8d28fcd430e7642e07c4375f90c67c10cbaba 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -6,7 +6,7 @@ set(paddle_known_gpu_archs "30 35 50 52 60 61 70") set(paddle_known_gpu_archs7 "30 35 50 52") set(paddle_known_gpu_archs8 "30 35 50 52 60 61") set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70") -set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75") +set(paddle_known_gpu_archs10 "30 35 50 52 60 61 62 70 75") ###################################################################################### # A function for automatic detection of GPUs installed (if autodetection is enabled) diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 3775d6cc2bdaa617f225b4cff9a03092bd9a19cc..a3a62c6b2ae2b939aa015a909b44fa492e6e5fb1 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -34,6 +34,14 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 ) + +if (${CUDA_VERSION} GREATER_EQUAL 10.0) + find_library(CUBLAS_LIBRARY NAMES libcublas.so PATHS ${CUDNN_CHECK_LIBRARY_DIRS} NO_DEFAULT_PATH) + set(CUBLAS_LIBRARIES ${CUBLAS_LIBRARY}) +else() + set(CUBLAS_LIBRARIES ${CUDA_CUBLAS_LIBRARIES}) +endif() + set(CUDNN_LIB_NAME "libcudnn.so") if(WIN32) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index 36b533aa4f7815896fb48c33fefad892b8d0d29c..903c70fbbff285bc90697281f9703b544fd00186 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -146,8 +146,11 @@ set(GPU_COMMON_FLAGS -Wno-error=unused-local-typedefs -Wno-error=unused-function # Warnings in Numpy Header. -Wno-error=array-bounds # Warnings in Eigen::array + -gencode arch=compute_62,code=sm_62 ) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64") +if(NOT LITE_WITH_CUDA) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64") +endif() endif(NOT WIN32) if (APPLE) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index f1b73f1d882afc3a2352514906e5958c70eee542..415eb451a986cd7e59829b9a8f2c744ecf464bd6 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -507,7 +507,7 @@ function(nv_test TARGET_NAME) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest -gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY}) +gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY} ${CUBLAS_LIBRARIES} ) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest gflags glog) common_link(${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME}) diff --git a/lite/backends/cuda/blas.h b/lite/backends/cuda/blas.h index f73bb576b8dd5ecad178ba69a9208b2286c050ab..058b961f3678197a3a6719a3337e0decac78564f 100644 --- a/lite/backends/cuda/blas.h +++ b/lite/backends/cuda/blas.h @@ -30,10 +30,8 @@ namespace cuda { * Some basic methods. */ struct BlasBase { - /* BlasBase() { CUBLAS_CHECK(cublasCreate(&handle_)); } ~BlasBase() { CUBLAS_CHECK(cublasDestroy(handle_)); } - */ void SetStream(cudaStream_t stream) { CUBLAS_CHECK(cublasSetStream(handle_, stream)); diff --git a/lite/backends/cuda/math/scale.cu b/lite/backends/cuda/math/scale.cu index 9ab8f91779ebcc5259a99cf5a415ce13d4cfcebb..0e51fec0f232a6ceae3d4e5a36d9c3088ae29502 100644 --- a/lite/backends/cuda/math/scale.cu +++ b/lite/backends/cuda/math/scale.cu @@ -13,25 +13,59 @@ // limitations under the License. #include "iostream" +#include "lite/backends/cuda/cuda_utils.h" #include "lite/backends/cuda/math/scale.h" #include "lite/backends/cuda/math/utils.h" - namespace paddle { namespace lite { namespace cuda { namespace math { - +/* template -__global__ void scale_kernel(int num, const T* in, T* out, const float scale) { +__global__ void scale_kernel(int num, const T* in, T* out, const float scale, +const float bias) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < num) { #if __CUDA_ARCH__ >= 350 - out[tid] = __ldg(in + tid) * scale; + out[tid] = __ldg(in + tid) * scale + bias; #else out[tid] = in[tid] * scale; #endif } } +*/ +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void scale_kernel(int count, + const T* in_data, + T* out_data, + const T* scale_data, + const T* bias_data, + const int scale_dim, + const int inner_dim) { + CUDA_KERNEL_LOOP(tid, count) { + int scale_id = (tid / inner_dim) % scale_dim; + T scale = scale_data[scale_id]; + if (bias_data == nullptr) { + out_data[tid] = scale * in_data[tid]; + } else { + out_data[tid] = scale * in_data[tid] + bias_data[scale_id]; + } + } +} + +template +__global__ void scale_kernel( + int count, const T* in_data, T* out_data, const T scale, const T bias) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + // if (tid < count){ + // out_data[tid] = scale * in_data[tid] + bias; + //} + CUDA_KERNEL_LOOP(tid, count) { out_data[tid] = scale * in_data[tid] + bias; } +} __global__ void fp32_scale_nhwc4_kernel(int num, const float4* in, @@ -114,21 +148,25 @@ void fp32_scale_nhwc(int num, } template -void scale(int num, const T* in, T* out, float scale, cudaStream_t stream) { +void scale(int num, const T* in, T* out, T scale, cudaStream_t stream, T bias) { int thread = 256; int block = (num + thread - 1) / thread; - scale_kernel<<>>(num, in, out, scale); + scale_kernel<<>>(num, in, out, scale, bias); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) std::cout << cudaGetErrorString(error); } template -void scale(int num, const T* in, T* out, float scale) { +void scale(int num, const T* in, T* out, T scale, T bias) { int thread = 256; int block = (num + thread - 1) / thread; - scale_kernel<<>>(num, in, out, scale); + scale_kernel<<>>(num, in, out, scale, bias); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) std::cout << cudaGetErrorString(error); } -template void scale(int num, const float*, float*, float, cudaStream_t); -template void scale(int num, const float*, float*, float); +template void scale(int num, const float*, float*, float, cudaStream_t, float); +template void scale(int num, const float*, float*, float, float); } // namespace math } // namespace cuda diff --git a/lite/backends/cuda/math/scale.h b/lite/backends/cuda/math/scale.h index 83af600ba8c68a236fdb2a5c9f8521199b46f633..52ed1d38ae79ce11cac50a9abef0f57e6de1352c 100644 --- a/lite/backends/cuda/math/scale.h +++ b/lite/backends/cuda/math/scale.h @@ -32,10 +32,11 @@ void fp32_scale_nhwc(int num, cudaStream_t stream); template -void scale(int num, const T* in, T* out, float scale, cudaStream_t stream); +void scale( + int num, const T* in, T* out, T scale, cudaStream_t stream, T bias = 0); template -void scale(int num, const T* in, T* out, float scale); +void scale(int num, const T* in, T* out, T scale, T bias = 0); } // namespace math } // namespace cuda diff --git a/lite/core/mir/node.cc b/lite/core/mir/node.cc index 61d3d317e7b7bbbfc4064cfbe0f2503f8fbe7a31..4a90e530a46c4d42d2ba032da1828973dfc1bcef 100644 --- a/lite/core/mir/node.cc +++ b/lite/core/mir/node.cc @@ -54,11 +54,6 @@ void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc, valid_kernels_ = op_->CreateKernels(valid_places); } -std::ostream &mir::operator<<(std::ostream &os, const mir::Node::Stmt &other) { - os << "Statement " << other.op_type() << " " << other.place().DebugString(); - return os; -} - mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) { auto &x = AsArg(); x.name = name; diff --git a/lite/core/mir/node.h b/lite/core/mir/node.h index 9c7d441ca3811d39b8ba9f5b49746c9a31c1d449..60fa1fb1ebe49e1be38a7d84cb82545389ea4aac 100644 --- a/lite/core/mir/node.h +++ b/lite/core/mir/node.h @@ -74,7 +74,11 @@ class Node { KernelBase& picked_kernel(); - friend std::ostream& operator<<(std::ostream& os, const Stmt& other); + friend std::ostream& operator<<(std::ostream& os, const Stmt& other) { + os << "Statement " << other.op_type() << " " + << other.place().DebugString(); + return os; + } // Description. std::string desc; diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 348a55db117245582a8f13c5abf9161a8c880940..d855ee8e36b8babc40e4820ccd2b19d0b1008d34 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -7,6 +7,7 @@ message(STATUS "compile with lite CUDA kernels") add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps}) add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda} cuda_transpose) add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps}) @@ -16,15 +17,23 @@ add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute. add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps}) add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose) add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(scale_compute_cuda CUDA basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} cuda_scale) +add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} cuda_scale) +add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps}) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda) nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda) nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda) nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) +nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda) nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda) nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda) #nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda) +nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) +nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda ) +nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda ) nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) diff --git a/lite/kernels/cuda/conv_compute.cc b/lite/kernels/cuda/conv_compute.cc index cee7d9593c576365d62f271d42c0bc4ad356650f..eea81602ddf94158250aecf01fe5e95193bf58c1 100644 --- a/lite/kernels/cuda/conv_compute.cc +++ b/lite/kernels/cuda/conv_compute.cc @@ -111,6 +111,28 @@ REGISTER_LITE_KERNEL( DATALAYOUT(kNCHW))}) .Finalize(); +REGISTER_LITE_KERNEL(depthwise_conv2d, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ConvCompute, + def) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Bias", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + REGISTER_LITE_KERNEL( conv2d, kCUDA, diff --git a/lite/kernels/cuda/dropout_compute.cc b/lite/kernels/cuda/dropout_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e3a3a62432f3bc5f2e62112b2b220abc17ee2bd --- /dev/null +++ b/lite/kernels/cuda/dropout_compute.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/dropout_compute.h" +#include +#include "lite/backends/cuda/math/scale.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +void DropoutCompute::Run() { + auto& param = Param(); + const float* x_data = param.x->data(); + float* out_data = param.output->mutable_data(TARGET(kCUDA)); + int num = param.x->dims().production(); + const float prob_data = param.dropout_prob; + float scale = 1.0f; + if (param.dropout_implementation == "downgrade_in_infer") { + scale = 1.0f - prob_data; + } + lite::cuda::math::scale(num, x_data, out_data, scale, 0); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(dropout, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::DropoutCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Mask", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/dropout_compute.h b/lite/kernels/cuda/dropout_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..aec0d1966bc8b368484b5c810f133a8e9a6fb410 --- /dev/null +++ b/lite/kernels/cuda/dropout_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class DropoutCompute : public KernelLite { + public: + void Run() override; + + virtual ~DropoutCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/dropout_compute_test.cc b/lite/kernels/cuda/dropout_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e6ed54330c0a109091934ebe48ed341afcae96f9 --- /dev/null +++ b/lite/kernels/cuda/dropout_compute_test.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/dropout_compute.h" +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +void dropout_compute_ref(const operators::DropoutParam& param) { + const float* x_data = param.x->data(); + float* output_data = param.output->mutable_data(); + int num = param.x->dims().production(); + const float prob_data = param.dropout_prob; + if (param.dropout_implementation.compare( + std::string({"downgrade_in_infer"})) == 0) { + float scale = 1.0 - prob_data; + for (int i = 0; i < num; i++) { + output_data[i] = x_data[i] * scale; + } + } else { + for (int i = 0; i < num; i++) { + output_data[i] = x_data[i]; + } + } +} + +TEST(dropout_cuda, normal) { + DropoutCompute dropout_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::DropoutParam param; + lite::Tensor x; + lite::Tensor x_cpu; + lite::Tensor x_ref; + lite::Tensor output; + lite::Tensor output_cpu; + lite::Tensor output_ref; + + for (auto n : {1, 3, 4}) { + for (auto c : {1, 3, 4, 256}) { + for (auto h : {1, 3, 4, 6}) { + for (auto w : {1, 3, 4, 6}) { + for (auto prob : {0.2f, 0.8f}) + for (auto impl : {std::string({"downgrade_in_infer"})}) { + x.Resize(DDim(std::vector({n, c, h, w}))); + x_cpu.Resize(DDim(std::vector({n, c, h, w}))); + x_ref.Resize(DDim(std::vector({n, c, h, w}))); + output.Resize(DDim(std::vector({n, c, h, w}))); + output_cpu.Resize(DDim(std::vector({n, c, h, w}))); + output_ref.Resize(DDim(std::vector({n, c, h, w}))); + + auto* x_cpu_data = x_cpu.mutable_data(); + auto* x_ref_data = x_ref.mutable_data(); + auto* output_data = output.mutable_data(TARGET(kCUDA)); + auto* output_cpu_data = output_cpu.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + + for (int i = 0; i < x.dims().production(); i++) { + x_cpu_data[i] = i; + x_ref_data[i] = i; + } + + x.Assign(x_cpu_data, + x_cpu.dims()); + + param.x = &x; + param.output = &output; + param.dropout_prob = prob; + param.dropout_implementation = impl; + dropout_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + dropout_kernel.SetContext(std::move(ctx)); + dropout_kernel.Launch(); + + CopySync(output_cpu_data, + output_data, + sizeof(float) * output.numel(), + IoDirection::DtoH); + + param.x = &x_ref; + param.output = &output_ref; + dropout_compute_ref(param); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_cpu_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/elementwise_add_compute.cu b/lite/kernels/cuda/elementwise_add_compute.cu index 390dacc7bcc96b65f9de5c4bb4c2aaea090e977e..4bacf532a2b67168679449200b1af721b7a282c8 100644 --- a/lite/kernels/cuda/elementwise_add_compute.cu +++ b/lite/kernels/cuda/elementwise_add_compute.cu @@ -29,12 +29,7 @@ void ElementwiseAddCompute::Run() { const lite::Tensor* y = param.Y; lite::Tensor* out = param.Out; - CHECK(x->dims() == y->dims()); - - const int n = x->dims()[0]; - const int c = x->dims()[1]; - const int h = x->dims()[2]; - const int w = x->dims()[3]; + CHECK(x->dims().production() == y->dims().production()); auto* x_data = x->data(); auto* y_data = y->data(); @@ -57,12 +52,7 @@ void ElementwiseAddComputeNHWC::Run() { const lite::Tensor* y = param.Y; lite::Tensor* out = param.Out; - CHECK(x->dims() == y->dims()); - - const int n = x->dims()[0]; - const int c = x->dims()[1]; - const int h = x->dims()[2]; - const int w = x->dims()[3]; + CHECK(x->dims().production() == y->dims().production()); auto* x_data = x->data(); auto* y_data = y->data(); @@ -85,7 +75,7 @@ void ElementwiseAddComputeInt8::Run() { const lite::Tensor* y = param.Y; lite::Tensor* out = param.Out; - CHECK(x->dims() == y->dims()); + CHECK(x->dims().production() == y->dims().production()); const int c = x->dims()[3]; diff --git a/lite/kernels/cuda/mul_compute.h b/lite/kernels/cuda/mul_compute.h index 4a542104d6743b52758cbecfb11c025628e46333..c2fc4364ef77742858b143734d2ecf4d13e201e9 100644 --- a/lite/kernels/cuda/mul_compute.h +++ b/lite/kernels/cuda/mul_compute.h @@ -33,19 +33,36 @@ void mul_compute(const lite::cuda::Blas& blas, int y_h, int y_w, T* out) { + float alpha = 1.0; + float beta = 0.0; + /* blas.sgemm(CUBLAS_OP_N, CUBLAS_OP_N, x_h, y_w, x_w, - nullptr, + &alpha, x, x_w, y, y_w, - nullptr, + &beta, out, x_h); + */ + blas.sgemm(CUBLAS_OP_N, + CUBLAS_OP_N, + y_w, + x_h, + y_h, + &alpha, + y, + y_w, + x, + x_w, + &beta, + out, + y_w); } class MulCompute : public KernelLite { @@ -56,23 +73,29 @@ class MulCompute : public KernelLite { CHECK(ctx_) << "running context should be set first"; auto& context = this->ctx_->template As(); CHECK(context.cublas_fp32()) << "blas should init first"; - /* auto& blas = *context.cublas_fp32(); - CHECK(param.x->target() == TARGET(kCUDA)); - auto* x = param.x->data(); - int x_h = param.x->dims()[0]; - int x_w = param.x->dims()[1]; - auto* y = param.y->data(); - int y_h = param.y->dims()[0]; - int y_w = param.y->dims()[1]; - */ + auto& param = this->Param(); + const auto* x_data = param.x->data(); + const auto* y_data = param.y->data(); + auto* out_data = param.output->mutable_data(TARGET(kCUDA)); - const auto& param = Param(); - param.output->mutable_data(TARGET(kCUDA)); - LOG(INFO) << "mul output memory size " << param.output->data_size(); + int x_h = static_cast( + param.x->dims().Slice(0, param.x_num_col_dims).production()); + int x_w = static_cast( + param.x->dims() + .Slice(param.x_num_col_dims, param.x->dims().size()) + .production()); + int y_h = static_cast( + param.y->dims().Slice(0, param.y_num_col_dims).production()); + int y_w = static_cast( + param.y->dims() + .Slice(param.y_num_col_dims, param.y->dims().size()) + .production()); + CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h"; + LOG(INFO) << x_h << " " << x_w << " " << y_h << " " << y_w; - // mul_compute(blas, x, x_h, x_w, y, y_h, y_w, out); + mul_compute(blas, x_data, x_h, x_w, y_data, y_h, y_w, out_data); } virtual ~MulCompute() = default; diff --git a/lite/kernels/cuda/mul_compute_test.cc b/lite/kernels/cuda/mul_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1c1d63e7dcd46f84cd128fc5b855da2098e179d --- /dev/null +++ b/lite/kernels/cuda/mul_compute_test.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/mul_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +TEST(mul_compute, normal) { + MulCompute mul_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + Tensor x, y, out, x_cpu, y_cpu, out_cpu; + int x_h = 2, x_w_y_h = 3, y_w = 4; + out.Resize({x_h, y_w}); + x_cpu.Resize({x_h, x_w_y_h}); + y_cpu.Resize({x_w_y_h, y_w}); + out_cpu.Resize({x_h, y_w}); + + auto* out_data = out.mutable_data(TARGET(kCUDA)); + float* x_cpu_data = x_cpu.mutable_data(); + float* y_cpu_data = y_cpu.mutable_data(); + float* out_cpu_data = out_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i + 1.0; + } + for (int i = 0; i < y_cpu.numel(); i++) { + y_cpu_data[i] = i + 1.0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + y.Assign(y_cpu_data, y_cpu.dims()); + + operators::MulParam param; + param.x = &x; + param.y = &y; + param.output = &out; + mul_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + mul_kernel.SetContext(std::move(ctx)); + mul_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + for (int i = 0; i < out_cpu.numel(); i++) { + LOG(INFO) << out_cpu_data[i]; + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/pool_compute.cu b/lite/kernels/cuda/pool_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..a2483a2c759e8acc5f5944fd316c83bb49530d36 --- /dev/null +++ b/lite/kernels/cuda/pool_compute.cu @@ -0,0 +1,375 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/pool_compute.h" +#include "lite/utils/macros.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; +using DDim = lite::DDim; + +#define MAX_VAL(a, b) (((a) > (b)) ? (a) : (b)) +#define MIN_VAL(a, b) (((a) < (b)) ? (a) : (b)) + +__global__ void max_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int w_s = w_id * stride_w - pad_w; + const int iw_s = MAX_VAL(w_s, 0); + const int iw_e = MIN_VAL(w_s + win_w, in_w); + const int w_loop = iw_e - iw_s; + const int h_s = h_id * stride_h - pad_h; + const int ih_s = MAX_VAL(h_s, 0); + const int ih_e = MIN_VAL(h_s + win_h, in_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float max_val = -FLT_MAX; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + max_val = MAX_VAL(max_val, *(in_p + j)); + } + in_p += in_w; + } + max_val = max_val == -FLT_MAX ? 0.f : max_val; + output[nc_id * spatial_out + h_id * out_w + w_id] = max_val; + } +} + +__global__ void adaptive_max_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int iw_s = floor(static_cast(w_id * in_w) / out_w); + const int iw_e = ceil(static_cast((w_id + 1) * in_w) / out_w); + const int w_loop = iw_e - iw_s; + const int ih_s = floor(static_cast(h_id * in_h) / out_h); + const int ih_e = ceil(static_cast((h_id + 1) * in_h) / out_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float max_val = -FLT_MAX; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + max_val = MAX_VAL(max_val, *(in_p + j)); + } + in_p += in_w; + } + output[nc_id * spatial_out + h_id * out_w + w_id] = max_val; + } +} + +__global__ void avg_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + bool exclusive, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int w_s = w_id * stride_w - pad_w; + const int iw_s = MAX_VAL(w_s, 0); + const int iw_e = MIN_VAL(w_s + win_w, in_w); + const int w_loop = iw_e - iw_s; + const int h_s = h_id * stride_h - pad_h; + const int ih_s = MAX_VAL(h_s, 0); + const int ih_e = MIN_VAL(h_s + win_h, in_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float sum_val = 0.f; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + sum_val += *(in_p + j); + } + in_p += in_w; + } + int pool_size = exclusive ? h_loop * w_loop : win_w * win_h; + pool_size = pool_size == 0 ? 1 : pool_size; + output[nc_id * spatial_out + h_id * out_w + w_id] = sum_val / pool_size; + } +} + +__global__ void adaptive_avg_pool_kernel(const float* input, + float* output, + const int spatial_in, + const int spatial_out, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int pad_h, + const int pad_w, + const int win_h, + const int win_w, + const int stride_h, + const int stride_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int nc_id = gid / spatial_out; + const int w_id = gid % spatial_out % out_w; + const int h_id = gid % spatial_out / out_w; + const int iw_s = floor(static_cast(w_id * in_w) / out_w); + const int iw_e = ceil(static_cast((w_id + 1) * in_w) / out_w); + const int w_loop = iw_e - iw_s; + const int ih_s = floor(static_cast(h_id * in_h) / out_h); + const int ih_e = ceil(static_cast((h_id + 1) * in_h) / out_h); + const int h_loop = ih_e - ih_s; + const float* in_p = input + nc_id * spatial_in + ih_s * in_w + iw_s; + float sum_val = 0.f; + for (int i = 0; i < h_loop; ++i) { + for (int j = 0; j < w_loop; ++j) { + sum_val += *(in_p + j); + } + in_p += in_w; + } + int pool_size = h_loop * w_loop; + pool_size = pool_size == 0 ? 1 : pool_size; + output[nc_id * spatial_out + h_id * out_w + w_id] = sum_val / pool_size; + } +} + +__global__ void global_max_pool_kernel(const float* input, + float* output, + const int in_h, + const int in_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int spatial_in = in_h * in_w; + const float* in_p = input + gid * spatial_in; + int i = 0; + float max_val = -0.f; + // unroll 8 + for (; i < spatial_in - 7; i += 8) { + max_val = MAX_VAL(max_val, *(in_p + 0)); + max_val = MAX_VAL(max_val, *(in_p + 1)); + max_val = MAX_VAL(max_val, *(in_p + 2)); + max_val = MAX_VAL(max_val, *(in_p + 3)); + max_val = MAX_VAL(max_val, *(in_p + 4)); + max_val = MAX_VAL(max_val, *(in_p + 5)); + max_val = MAX_VAL(max_val, *(in_p + 6)); + max_val = MAX_VAL(max_val, *(in_p + 7)); + in_p += 8; + } + for (; i < spatial_in; i++) { + max_val = MAX_VAL(max_val, *in_p); + in_p++; + } + output[gid] = max_val; + } +} + +__global__ void global_avg_pool_kernel(const float* input, + float* output, + const int in_h, + const int in_w, + const int total_threads) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid < total_threads) { + const int spatial_in = in_h * in_w; + const float* in_p = input + gid * spatial_in; + int i = 0; + float sum_val = 0.f; + // unroll 8 + for (; i < spatial_in - 7; i += 8) { + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + sum_val += *in_p++; + } + for (; i < spatial_in; i++) { + sum_val += *in_p++; + } + output[gid] = sum_val / spatial_in; + } +} + +void PoolCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + bool exclusive = param.exclusive; + bool adaptive = param.adaptive; + auto x_dims = param.x->dims(); + auto out_dims = param.output->dims(); + const int in_h = x_dims[2]; + const int in_w = x_dims[3]; + const int out_h = out_dims[2]; + const int out_w = out_dims[3]; + const int spatial_in = in_h * in_w; + const int spatial_out = out_h * out_w; + const int win_h = param.ksize[0]; + const int win_w = param.ksize[1]; + const int stride_h = param.strides[0]; + const int stride_w = param.strides[1]; + const int pad_h = param.paddings[0]; + const int pad_w = param.paddings[1]; + const int total_threads = out_dims.production(); + const int threads = 512; + const int blocks = (total_threads + threads - 1) / threads; + auto input_data = param.x->data(); + auto output_data = param.output->mutable_data(TARGET(kCUDA)); + if (param.global_pooling) { + if (param.pooling_type == "max") { + global_max_pool_kernel<<>>( + input_data, output_data, in_h, in_w, total_threads); + } else { + global_avg_pool_kernel<<>>( + input_data, output_data, in_h, in_w, total_threads); + } + } else { + if (!adaptive) { + if (param.pooling_type == "max") { + max_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + total_threads); + } else { + avg_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + exclusive, + total_threads); + } + } else { + if (param.pooling_type == "max") { + adaptive_max_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + total_threads); + } else { + adaptive_avg_pool_kernel<<>>(input_data, + output_data, + spatial_in, + spatial_out, + in_h, + in_w, + out_h, + out_w, + pad_h, + pad_w, + win_h, + win_w, + stride_h, + stride_w, + total_threads); + } + } + } + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(FATAL) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + pool2d, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::PoolCompute, def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/pool_compute.h b/lite/kernels/cuda/pool_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..55b346bfaf4ac139c8d22bff2ac64f0e78bc6023 --- /dev/null +++ b/lite/kernels/cuda/pool_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class PoolCompute + : public KernelLite { + public: + using param_t = operators::PoolParam; + + void Run() override; + virtual ~PoolCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/pool_compute_test.cc b/lite/kernels/cuda/pool_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fafd1ef0c8d449c84c417023fbb81e8d7c3bb43f --- /dev/null +++ b/lite/kernels/cuda/pool_compute_test.cc @@ -0,0 +1,283 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/pool_compute.h" +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using Tensor = lite::Tensor; +using DDim = lite::DDim; + +static int PoolOutputSize( + int input_size, int filter_size, int padding, int stride, bool ceil_mode) { + int output_size; + if (!ceil_mode) { + output_size = (input_size - filter_size + 2 * padding) / stride + 1; + } else { + output_size = + (input_size - filter_size + 2 * padding + stride - 1) / stride + 1; + } + return output_size; +} + +static std::vector compute_output_shape(operators::PoolParam* param_) { + const auto x_dims = param_->x->dims(); + std::vector& ksize = param_->ksize; + if (param_->global_pooling) { + ksize.resize(static_cast(x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) { + param_->paddings[i] = 0; + ksize[i] = static_cast(x_dims[i + 2]); + } + } + + std::vector output_shape({x_dims[0], x_dims[1]}); + if (param_->adaptive) { + output_shape.insert( + output_shape.end(), param_->ksize.begin(), param_->ksize.end()); + } else { + for (size_t i = 0; i < param_->ksize.size(); ++i) { + output_shape.push_back(PoolOutputSize(x_dims[i + 2], + param_->ksize[i], + param_->paddings[i], + param_->strides[i], + param_->ceil_mode)); + } + } + return output_shape; +} + +static void pool_compute_ref(const operators::PoolParam& param) { + auto& in_dims = param.x->dims(); + auto& out_dims = param.output->dims(); + + const float* src_ptr = param.x->data(); + float* dst_ptr = param.output->mutable_data(); + + std::vector ksize = param.ksize; + std::vector strides = param.strides; + std::vector paddings = param.paddings; + + std::string pooling_type = param.pooling_type; + bool global_pooling = param.global_pooling; + bool exclusive = param.exclusive; + std::string data_format = param.data_format; + + int in_n = in_dims[0]; + int in_c = in_dims[1]; + int in_h = in_dims[2]; + int in_w = in_dims[3]; + int size_in_n = in_c * in_h * in_w; + int size_in_c = in_h * in_w; + + int out_h = out_dims[2]; + int out_w = out_dims[3]; + int size_out_n = in_c * out_h * out_w; + int size_out_c = out_h * out_w; + + int window_h = ksize[0]; + int window_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + if (global_pooling == true) { + for (int n = 0; n < in_n; ++n) { + for (int c = 0; c < in_c; ++c) { + const float* src = src_ptr + n * size_in_n + c * size_in_c; + float res = src[0]; + if (pooling_type == "max") { + for (int i = 1; i < size_in_c; ++i) { + float cur_val = src[i]; + res = cur_val > res ? cur_val : res; + } + } else if (pooling_type == "avg") { + for (int i = 1; i < size_in_c; ++i) { + float cur_val = src[i]; + res += cur_val; + } + res /= size_in_c; + } + dst_ptr[n * size_out_n + c] = res; + } + } + } else { + for (int n = 0; n < in_n; ++n) { + for (int c = 0; c < in_c; ++c) { + for (int h = 0; h < out_h; ++h) { + int sh = h * stride_h; + int eh = sh + window_h; + sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; + eh = (eh - pad_h) > in_h ? in_h : eh - pad_h; + for (int w = 0; w < out_w; ++w) { + int sw = w * stride_w; + int ew = sw + window_w; + sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; + ew = (ew - pad_w) > in_w ? in_w : ew - pad_w; + int pooling_size = (ew - sw) * (eh - sh); + if (pooling_size == 0) { + dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = 0.f; + continue; + } + float res = 0.f; + for (int kh = sh; kh < eh; ++kh) { + for (int kw = sw; kw < ew; ++kw) { + int src_idx = n * size_in_n + c * size_in_c + kh * in_w + kw; + if (kh == sh && kw == sw) { + res = src_ptr[src_idx]; + } else { + if (pooling_type == "max") { + res = res >= src_ptr[src_idx] ? res : src_ptr[src_idx]; + } + if (pooling_type == "avg") { + res += src_ptr[src_idx]; + } + } + } + } + if (pooling_type == "avg") { + if (exclusive) { + res /= pooling_size; + } else { + res /= window_h * window_w; + } + } + dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = res; + } + } + } + } + } +} + +TEST(pool_cuda, compute) { + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + PoolCompute pool; + operators::PoolParam param; + pool.SetContext(std::move(ctx)); + + lite::Tensor x; + lite::Tensor x_cpu; + lite::Tensor output; + lite::Tensor output_cpu; + lite::Tensor output_ref; + for (auto pooling_type : {"max", "avg"}) { + for (auto ceil_mode : {true, false}) { + for (auto global_pooling : {true, false}) { + for (auto exclusive : {true, false}) { + for (auto ksize : {2, 3}) { + for (auto stride : {1, 2}) { + for (auto pad : {0, 1}) { + for (auto n : {1, 2}) { + for (auto c : {1, 3, 256}) { + for (auto h : {2, 3, 4, 6, 13}) { + for (auto w : {2, 3, 4, 6, 13}) { + VLOG(3) << "n:" << n << " c:" << c << " h:" << h + << " w:" << w << " ksize:" << ksize + << " stride:" << stride << " pad:" << pad + << " exclusive:" << exclusive + << " global_pooling:" << global_pooling + << " ceil_mode: " << ceil_mode + << " pooling_type:" << pooling_type; + + // init x, output + x.Resize(DDim(std::vector({n, c, h, w}))); + x_cpu.Resize(DDim(std::vector({n, c, h, w}))); + auto* x_cpu_data = x_cpu.mutable_data(); + for (int i = 0; i < x_cpu.dims().production(); ++i) { + float sign = i % 3 == 0 ? -0.03 : 0.05f; + x_cpu_data[i] = sign * (i % 128); + } + x.Assign(x_cpu_data, + x_cpu.dims()); + // fill param + param.x = &x; + param.output = &output; + param.pooling_type = pooling_type; + if (global_pooling) { + param.ksize = {h, w}; + } else { + param.ksize = {ksize, ksize}; + } + param.global_pooling = global_pooling; + param.strides = {stride, stride}; + param.paddings = {pad, pad}; + param.exclusive = exclusive; + param.ceil_mode = ceil_mode; + param.adaptive = false; + param.use_quantizer = false; + + const std::vector& output_shape = + compute_output_shape(¶m); + if (output_shape[2] * output_shape[3] == 0) continue; + output.Resize(DDim(output_shape)); + output_ref.Resize(DDim(output_shape)); + output_cpu.Resize(DDim(output_shape)); + auto* output_data = + output.mutable_data(TARGET(kCUDA)); + auto* output_ref_data = + output_ref.mutable_data(); + auto* output_cpu_data = + output_cpu.mutable_data(); + + // compute + pool.SetParam(param); + pool.Launch(); + + // compute ref + param.x = &x_cpu; + param.output = &output_ref; + pool_compute_ref(param); + + cudaDeviceSynchronize(); + CopySync(output_cpu_data, + output_data, + sizeof(float) * output.numel(), + IoDirection::DtoH); + // compare + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR( + output_cpu_data[i], output_ref_data[i], 1e-4); + } + VLOG(3) << "compare pass"; + } + } + } + } + } + } + } + } + } + } + } +} +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/relu_compute.cu b/lite/kernels/cuda/relu_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..7c6623a4fe3bc68408a90c7ed2a2e9e35d7061fb --- /dev/null +++ b/lite/kernels/cuda/relu_compute.cu @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/relu_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__global__ void ReluKernel(const int num, const T* input, T* output) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { +#if __CUDA_ARCH__ >= 350 + output[index] = __ldg(input + index) >= 0 ? __ldg(input + index) : 0; +#else + output[index] = input[index] >= 0 ? input[index] : 0; +#endif + } +} + +void ReluCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + int num = static_cast(param.X->numel()); + auto input = param.X->data(); + auto output = param.Out->mutable_data(TARGET(kCUDA)); + + int threads = 1024; + int blocks = (num + threads - 1) / threads; + ReluKernel<<>>(num, input, output); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + relu, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ReluCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/relu_compute.h b/lite/kernels/cuda/relu_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..b0fd500ff4369fc4a4ca256153aa5f0d21cf1e8e --- /dev/null +++ b/lite/kernels/cuda/relu_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class ReluCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + virtual ~ReluCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/relu_compute_test.cc b/lite/kernels/cuda/relu_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..39144bfda13a9eac4ac7ad65d3d426d528fc2beb --- /dev/null +++ b/lite/kernels/cuda/relu_compute_test.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/relu_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +TEST(relu, normal) { + ReluCompute relu_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ActivationParam param; + + Tensor x, y, x_cpu, y_cpu; + int h = 256, w = 256; + y.Resize({h, w}); + x_cpu.Resize({h, w}); + y_cpu.Resize({h, w}); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + float* x_cpu_data = x_cpu.mutable_data(); + float* y_cpu_data = x_cpu.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); i++) { + x_cpu_data[i] = i - 5.0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + + param.X = &x; + param.Out = &y; + relu_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + relu_kernel.SetContext(std::move(ctx)); + relu_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + // for (int i = 0; i < y.numel(); i++) { + // LOG(INFO) << y_cpu_data[i]; + // } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/scale_compute.cc b/lite/kernels/cuda/scale_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bf7414d8c85383a834159678cdd5a09e0b434d9 --- /dev/null +++ b/lite/kernels/cuda/scale_compute.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/kernels/cuda/scale_compute.h" +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +void ScaleCompute::Run() { + auto& param = Param(); + const float* x_data = param.x->data(); + float* output_data = param.output->mutable_data(); + DDim x_dims = param.x->dims(); + bool bias_after_scale = param.bias_after_scale; + float scale = param.scale; + float bias = param.bias; + if (!bias_after_scale) { + bias *= scale; + } + lite::cuda::math::scale( + x_dims.production(), x_data, output_data, scale, bias); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + scale, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::ScaleCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/scale_compute.h b/lite/kernels/cuda/scale_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..0dd082122a7e16762c790c8f360e2e0d7939496c --- /dev/null +++ b/lite/kernels/cuda/scale_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/backends/cuda/math/scale.h" +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class ScaleCompute : public KernelLite { + public: + void Run() override; + + virtual ~ScaleCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/softmax_compute.cu b/lite/kernels/cuda/softmax_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..d8d2987524cd2e8f9c38aba4da3ff61a80bf53ce --- /dev/null +++ b/lite/kernels/cuda/softmax_compute.cu @@ -0,0 +1,246 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/softmax_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; + +extern __shared__ char tile[]; +template +__global__ void sharemem_softmax_kernel(int total_size, + const dtype* in_data, + dtype* out_data, + int inner_num, + int outer_num, + int axis_size) { + dtype* data = reinterpret_cast(tile) + threadIdx.x; + //! compute thread index and real data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + int blocksize = blockDim.x; + int real_index = idx_outer * inner_num + idx_inner; + int loop_idx = real_index; +//! read all data to sharemem in softmax channel +#pragma unroll + for (int i = 0; i < axis_size; ++i) { + data[i * blocksize] = in_data[loop_idx]; + loop_idx += inner_num; + } + //! get maximum value in softmax channel + dtype max_data = data[0]; +#pragma unroll + for (int i = 1; i < axis_size; ++i) { + dtype dt = data[i * blocksize]; + if (max_data < dt) { + max_data = dt; + } + } + //! subtract then summarize + dtype sum = 0; +#pragma unroll + for (int i = 0; i < axis_size; ++i) { + dtype* dt = data + i * blocksize; + *dt = expf(*dt - max_data); + sum += *dt; + } + //! write back result + loop_idx = real_index; +#pragma unroll + for (int i = 0; i < axis_size; ++i) { + out_data[loop_idx] = data[i * blocksize] / sum; + loop_idx += inner_num; + } + } +} + +//! general kernel for softmax +template +__global__ void softmax_max_kernel(int total_size, + const dtype* in_data, + dtype* out_data, + dtype min_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + //! get maximum data across softmax axis + dtype max_data = min_data; + for (int i = 0; i < axis_size; ++i) { + max_data = + in_data[real_index] > max_data ? in_data[real_index] : max_data; + real_index += inner_num; + } + out_data[idx] = max_data; + } +} + +template +__global__ void softmax_sub_exp_sum_kernel(int total_size, + const dtype* in_data, + dtype* out_data, + const dtype* max_data, + dtype* sum_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + + dtype max_data_cur = max_data[idx]; + dtype sum_data_cur = 0; + int real_index = idx_outer * inner_num + idx_inner; + //! compute exp and summarize across the softmax axis + for (int i = 0; i < axis_size; ++i) { + dtype sub_data = in_data[real_index] - max_data_cur; + sub_data = expf(sub_data); + sum_data_cur += sub_data; + out_data[real_index] = sub_data; + real_index += inner_num; + } + sum_data[idx] = sum_data_cur; + } +} + +template +__global__ void softmax_divid_output_kernel(int total_size, + dtype* io_data, + const dtype* sum_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + dtype sum_data_cur = 1.f / sum_data[idx]; + int real_index = idx_outer * inner_num + idx_inner; + //! compute final result + for (int i = 0; i < axis_size; ++i) { + io_data[real_index] = io_data[real_index] * sum_data_cur; + real_index += inner_num; + } + } +} + +void SoftmaxCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + auto x_dims = param.x->dims(); + auto x_rank = x_dims.size(); + int axis = param.axis; + if (axis < 0) { + axis += x_rank; + } + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int total_threads = inner_num * outer_num; + int axis_size = x_dims[axis]; + + int device_id; + const int threads = 512; + const int blocks = (total_threads + threads - 1) / threads; + cudaGetDevice(&device_id); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, device_id); + size_t sharedmem_size = deviceProp.sharedMemPerBlock; + int max_dimsize = sharedmem_size / sizeof(float) / threads; + + auto input_data = param.x->data(); + auto output_data = param.output->mutable_data(TARGET(kCUDA)); + if (axis_size <= max_dimsize) { + int use_sharemem_size = axis_size * threads * sizeof(float); + sharemem_softmax_kernel<<>>( + total_threads, + input_data, + output_data, + inner_num, + outer_num, + axis_size); + } else { + //! re_alloc device memory + Tensor tmax_data; + Tensor tsum_data; + tmax_data.Resize({1, 1, 1, outer_num * inner_num}); + tsum_data.Resize({1, 1, 1, outer_num * inner_num}); + auto max_data = tmax_data.mutable_data(TARGET(kCUDA)); + auto sum_data = tsum_data.mutable_data(TARGET(kCUDA)); + //! firstly, get maximum data + float min_data = std::numeric_limits::min(); + softmax_max_kernel<<>>(total_threads, + input_data, + max_data, + min_data, + inner_num, + outer_num, + axis_size); + //! then, compute exp and sum data + softmax_sub_exp_sum_kernel<<>>( + total_threads, + input_data, + output_data, + max_data, + sum_data, + inner_num, + outer_num, + axis_size); + //! last, compute divided output + softmax_divid_output_kernel<<>>( + total_threads, output_data, sum_data, inner_num, outer_num, axis_size); + } + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(softmax, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SoftmaxCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("axis", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/cuda/softmax_compute.h b/lite/kernels/cuda/softmax_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..4acde4ab072390dd139c3e4e715f9ad288dc4ef8 --- /dev/null +++ b/lite/kernels/cuda/softmax_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SoftmaxCompute + : public KernelLite { + public: + using param_t = operators::SoftmaxParam; + + void Run() override; + virtual ~SoftmaxCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle