diff --git a/lite/backends/cuda/math/CMakeLists.txt b/lite/backends/cuda/math/CMakeLists.txt index 82acd2d0eab44cf6bad8e5b6a92803ae4afe60b3..fafd74ae7a43d1a769456edfe408c71593d21201 100644 --- a/lite/backends/cuda/math/CMakeLists.txt +++ b/lite/backends/cuda/math/CMakeLists.txt @@ -11,6 +11,7 @@ nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps}) nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale cuda_type_trans ${cuda_static_deps}) nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps}) +nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps}) nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps}) nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps}) @@ -22,6 +23,7 @@ set ( cuda_type_trans cuda_transpose cuda_elementwise + cudnn_pool cuda_gemm cuda_batched_gemm ) diff --git a/lite/backends/cuda/math/cudnn_pool.cc b/lite/backends/cuda/math/cudnn_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..f970fc326b29c4c226e7dc9643e416a3cf24f0eb --- /dev/null +++ b/lite/backends/cuda/math/cudnn_pool.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/math/cudnn_pool.h" +#include "lite/backends/cuda/math/activation.h" +#include "lite/backends/cuda/math/scale.h" +#include "lite/backends/cuda/math/type_trans.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +inline void UpdatePadding(std::vector* paddings, + const bool global_pooling, + const bool adaptive, + const std::vector& data_dims, + const std::vector& strides, + const std::vector& ksize) { + if (paddings->size() == data_dims.size()) { + for (size_t i = 0; i < data_dims.size(); ++i) { + int copy_pad = *(paddings->begin() + 2 * i); + paddings->insert(paddings->begin() + 2 * i + 1, copy_pad); + } + } else { + CHECK(data_dims.size() * 2 == paddings->size()) + << "Paddings size should be the same or twice as the pooling size."; + } + if (global_pooling || adaptive) { + for (auto it = paddings->begin(); it != paddings->end(); it++) { + *it = 0; + } + } +} + +inline void UpdateKsize(std::vector* ksize, + const std::vector& data_dims) { + ksize->resize(static_cast(data_dims.size())); + for (size_t i = 0; i < ksize->size(); ++i) { + *(ksize->begin() + i) = static_cast(data_dims[i]); + } +} + +template <> +bool CudnnPool2DNHWC::create( + const operators::PoolParam& param, Context* ctx) { + return true; +} + +template <> +bool CudnnPool2DNHWC::init(const operators::PoolParam& param, + Context* ctx) { + this->stream_ = ctx->exec_stream(); + CUDNN_CHECK(cudnnCreate(&this->handle_)); + CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_)); + + cudnnCreateTensorDescriptor(&this->input_desc_); + cudnnCreateTensorDescriptor(&this->output_desc_); + cudnnCreatePoolingDescriptor(&this->pooling_desc_); + + return create(param, ctx); +} + +template <> +bool CudnnPool2DNHWC::run( + const operators::PoolParam& param) { + auto x_dims = param.x->dims(); + auto o_dims = param.output->dims(); + int batch = x_dims[0]; + const float* in_data = param.x->data(); + float* out_data = param.output->mutable_data(TARGET(kCUDA)); + + int ih = x_dims[1]; + int iw = x_dims[2]; // nchw + int ic = x_dims[3]; + + int oh = o_dims[1]; + int ow = o_dims[2]; + int oc = o_dims[3]; + + std::vector ksize = param.ksize; + std::vector strides = param.strides; + std::vector paddings = *(param.paddings.get()); + + std::string pooling_type = param.pooling_type; + bool global_pooling = param.global_pooling; + bool exclusive = param.exclusive; + bool adaptive = param.adaptive; + + std::vector data_dims = {ih, iw}; + UpdatePadding(&paddings, global_pooling, adaptive, data_dims, strides, ksize); + + if (data_dims.size() * 2 == paddings.size()) { + for (size_t i = 0; i < data_dims.size(); ++i) { + paddings.erase(paddings.begin() + i + 1); + } + } + + if (global_pooling) { + UpdateKsize(&ksize, data_dims); + } + CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_, + CUDNN_TENSOR_NHWC, + CUDNN_DATA_FLOAT, + batch, + ic, + ih, + iw)); + + CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_, + CUDNN_TENSOR_NHWC, + CUDNN_DATA_FLOAT, + batch, + oc, + oh, + ow)); + cudnnPoolingMode_t mode; + if (pooling_type == "max") { + mode = CUDNN_POOLING_MAX; + } else { + mode = exclusive ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + } + CUDNN_CHECK(cudnnSetPoolingNdDescriptor(this->pooling_desc_, + mode, + CUDNN_NOT_PROPAGATE_NAN, + ksize.size(), + ksize.data(), + paddings.data(), + strides.data())); + float alpha = 1.0f; + float beta = 0.0f; + CUDNN_CHECK(cudnnPoolingForward(this->handle_, + this->pooling_desc_, + &alpha, + this->input_desc_, + in_data, + &beta, + this->output_desc_, + out_data)); + + return true; +} + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/cudnn_pool.h b/lite/backends/cuda/math/cudnn_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..acdc695b500ab41d615cb98c9501efd729c2fe6a --- /dev/null +++ b/lite/backends/cuda/math/cudnn_pool.h @@ -0,0 +1,79 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +class CudnnPool2DBase { + public: + CudnnPool2DBase() + : handle_(NULL), + input_desc_(NULL), + output_desc_(NULL), + pooling_desc_(NULL) {} + + ~CudnnPool2DBase() { + if (handle_ != NULL) { + CUDNN_CHECK(cudnnDestroy(handle_)); + } + if (input_desc_) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc_)); + } + if (output_desc_) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc_)); + } + if (pooling_desc_) { + cudnnDestroyPoolingDescriptor(pooling_desc_); + } + } + + protected: + cudaStream_t stream_; + cudnnHandle_t handle_; + cudnnTensorDescriptor_t input_desc_; + cudnnTensorDescriptor_t output_desc_; + cudnnPoolingDescriptor_t pooling_desc_; +}; + +template +class CudnnPool2DNHWC : public CudnnPool2DBase { + public: + CudnnPool2DNHWC() : CudnnPool2DBase() {} + virtual ~CudnnPool2DNHWC() = default; + virtual bool init(const operators::PoolParam& param, + Context* ctx); + + virtual bool create(const operators::PoolParam& param, + Context* ctx); + + virtual bool run(const operators::PoolParam& param); +}; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/elementwise.cu b/lite/backends/cuda/math/elementwise.cu index 57c9ec022a6e49551fd2d56a9b2036de13bf5a2c..8f0ebd1f97a03f03b568de694b986e9540f07c55 100644 --- a/lite/backends/cuda/math/elementwise.cu +++ b/lite/backends/cuda/math/elementwise.cu @@ -13,13 +13,55 @@ // limitations under the License. #include "lite/backends/cuda/math/elementwise.h" -#include "lite/backends/cuda/math/utils.h" namespace paddle { namespace lite { namespace cuda { namespace math { +template +__global__ void elementwise_kernel(const size_t total, + const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < total) { + int idx = tid / post % n; +#if __CUDA_ARCH__ >= 350 + out_data[tid] = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type); +#else + out_data[tid] = binary_calc(x_data[tid], y_data[idx], type); +#endif + } +} + +template +__global__ void elementwise_relu_kernel(const size_t total, + const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < total) { + int idx = tid / post % n; + Dtype temp; +#if __CUDA_ARCH__ >= 350 + temp = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type); + +#else + temp = binary_calc(x_data[tid], y_data[idx], type); +#endif + out_data[tid] = temp > 0 ? temp : 0; + } +} + template __global__ void elementwise_add_kernel(const size_t total, const Dtype* x_data, @@ -76,6 +118,56 @@ __global__ void elementwise_add_nhwc4_int8_kernel(const size_t total, } } +template +void elementwise(const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type, + cudaStream_t stream) { + int num = pre * n * post; + int thread = 256; + int block = (num + thread - 1) / thread; + elementwise_kernel<<>>( + num, x_data, y_data, out_data, pre, n, post, type); +} + +template +void elementwise_relu(const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type, + cudaStream_t stream) { + int num = pre * n * post; + int thread = 256; + int block = (num + thread - 1) / thread; + elementwise_relu_kernel<<>>( + num, x_data, y_data, out_data, pre, n, post, type); +} + +template void elementwise(const float*, + const float*, + float*, + int, + int, + int, + BinaryOperation, + cudaStream_t); + +template void elementwise_relu(const float*, + const float*, + float*, + int, + int, + int, + BinaryOperation, + cudaStream_t); + template void elementwise_add(int num, const Dtype* x_data, diff --git a/lite/backends/cuda/math/elementwise.h b/lite/backends/cuda/math/elementwise.h index 7fcdf95021ff21379bf94298ed06328dd6d2db09..ce45d0544e5a55a9cdc34bdfacc2b48157f5a198 100644 --- a/lite/backends/cuda/math/elementwise.h +++ b/lite/backends/cuda/math/elementwise.h @@ -15,12 +15,33 @@ #pragma once #include #include +#include "lite/backends/cuda/math/utils.h" namespace paddle { namespace lite { namespace cuda { namespace math { +template +void elementwise(const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type, + cudaStream_t stream); + +template +void elementwise_relu(const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type, + cudaStream_t stream); + template void elementwise_add(int num, const Dtype* x_data, diff --git a/lite/backends/cuda/math/utils.h b/lite/backends/cuda/math/utils.h index b4cd82fd8df6df063d92df709311f3c90e7cf4b6..b6aa9c7d160ad6c8b60b132e4a2bbd7ae1e0b9ff 100644 --- a/lite/backends/cuda/math/utils.h +++ b/lite/backends/cuda/math/utils.h @@ -25,6 +25,24 @@ namespace lite { namespace cuda { namespace math { +enum class BinaryOperation { + kADD = 0, + kMUL = 1, + kDIV = 2, +}; + +template +__device__ T binary_calc(T x, T y, BinaryOperation type); + +template <> +__device__ __forceinline__ float binary_calc(float x, + float y, + BinaryOperation type) { + if (type == BinaryOperation::kADD) return x + y; + if (type == BinaryOperation::kMUL) return x * y; + if (type == BinaryOperation::kDIV) return x / y; +} + template __device__ T from_float(float x); diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index a787e0065ce8d91162c673ece846252809cacfa0..38c9d0e29d5766dec21de76b740c1032ad44da7e 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -73,7 +73,7 @@ class Optimizer { "lite_transpose_softmax_transpose_fuse_pass", // "lite_interpolate_fuse_pass", // "identity_scale_eliminate_pass", // -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) "lite_elementwise_add_activation_fuse_pass", // #endif "static_kernel_pick_pass", // pick original kernel from graph diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index e34d06603abcbc0c9a7205e28427458118c9b386..4bf1cbf5210214befb3620f8b7d70923f41f98f2 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -15,14 +15,15 @@ add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${li add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps}) add_kernel(conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(concat_compute_cuda CUDA basic SRCS concat_compute.cu DEPS ${lite_kernel_deps}) -add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute.cu DEPS ${lite_kernel_deps} cuda_elementwise) +add_kernel(elementwise_compute_cuda CUDA basic SRCS elementwise_compute.cu DEPS ${lite_kernel_deps} cuda_elementwise) add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps}) add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose) add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) add_kernel(scale_compute_cuda CUDA basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} cuda_scale) add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} cuda_scale) add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps}) -add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS +${lite_kernel_deps} cudnn_pool) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) @@ -47,12 +48,13 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) nv_test(search_group_padding_compute_cuda_test SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_cuda) nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda) -nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda) +nv_test(elementwise_compute_cuda_test SRCS elementwise_compute_test.cc DEPS elementwise_compute_cuda) nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda) #nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda) nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda ) nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) +nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda) nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda) nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda) diff --git a/lite/kernels/cuda/calib_compute_cuda_test.cc b/lite/kernels/cuda/calib_compute_cuda_test.cc index 8703d8730a1880b5b93502e5095b1a17d03bee6c..fdb47f7dd3c2e6d8f82e0281b81b24ebe444909a 100644 --- a/lite/kernels/cuda/calib_compute_cuda_test.cc +++ b/lite/kernels/cuda/calib_compute_cuda_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "lite/kernels/cuda/calib_compute.h" #include #include #include @@ -58,12 +59,7 @@ void calib_ref(const operators::CalibParam& param, bool to_float = true) { } TEST(calib_cuda, int8_to_fp32) { - LOG(INFO) << "to get kernel ..."; - auto kernels = KernelRegistry::Global().Create( - "calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW)); - ASSERT_FALSE(kernels.empty()); - auto calib = std::move(*std::next(kernels.begin(), 1)); - LOG(INFO) << "get kernel: " << calib->doc(); + CalibComputeInt8ToFp32 calib; const int n = 64, c = 32, h = 18, w = 18; Tensor x; Tensor x_cpu; @@ -87,14 +83,14 @@ TEST(calib_cuda, int8_to_fp32) { cudaStream_t stream; cudaStreamCreate(&stream); context.SetExecStream(stream); - calib->SetContext(std::move(ctx)); + calib.SetContext(std::move(ctx)); operators::CalibParam param; param.scale = 0.013f; param.input = &x; param.output = &output; - calib->SetParam(param); - calib->Launch(); + calib.SetParam(param); + calib.Launch(); cudaDeviceSynchronize(); // invoking ref implementation and compare results param.input = &x_cpu; @@ -113,12 +109,7 @@ TEST(calib_cuda, int8_to_fp32) { } TEST(calib_cuda, fp32_to_int8) { - LOG(INFO) << "to get kernel ..."; - auto kernels = KernelRegistry::Global().Create( - "calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW)); - ASSERT_FALSE(kernels.empty()); - auto calib = std::move(kernels.front()); - LOG(INFO) << "get kernel: " << calib->doc(); + CalibComputeFp32ToInt8 calib; const int n = 64, c = 32, h = 18, w = 18; Tensor x; Tensor x_cpu; @@ -142,14 +133,14 @@ TEST(calib_cuda, fp32_to_int8) { cudaStream_t stream; cudaStreamCreate(&stream); context.SetExecStream(stream); - calib->SetContext(std::move(ctx)); + calib.SetContext(std::move(ctx)); operators::CalibParam param; param.scale = 0.013f; param.input = &x; param.output = &output; - calib->SetParam(param); - calib->Launch(); + calib.SetParam(param); + calib.Launch(); cudaDeviceSynchronize(); // invoking ref implementation and compare results param.input = &x_cpu; diff --git a/lite/kernels/cuda/conv_compute_test.cc b/lite/kernels/cuda/conv_compute_test.cc index 1216c99051250449a55e12e259bbb5932e1c771c..2ebd7e33baf8e12cfce24661f186382152b6bb89 100644 --- a/lite/kernels/cuda/conv_compute_test.cc +++ b/lite/kernels/cuda/conv_compute_test.cc @@ -42,7 +42,9 @@ TEST(conv_compute, fp32) { operators::ConvParam param; param.activation_param = act_param; std::vector pads = {1, 1, 1, 1}; + std::vector dilations = {1, 1, 1, 1}; param.paddings = std::make_shared>(pads); + param.dilations = std::make_shared>(dilations); param.groups = 1; Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu; @@ -149,6 +151,10 @@ TEST(conv_compute, int8) { bias.Assign(bias_cpu_data, filter_cpu.dims()); + std::vector pads = {0, 0, 0, 0}; + std::vector dilations = {1, 1, 1, 1}; + param.paddings = std::make_shared>(pads); + param.dilations = std::make_shared>(dilations); param.x = &x; param.filter = &filter; param.output = &y; @@ -203,12 +209,10 @@ TEST(conv_compute, int8_int8_out) { std::cout << "input" << std::endl; for (int i = 0; i < x_cpu.numel(); i++) { x_cpu_data[i] = static_cast(random(-36, 36)); - std::cout << float(x_cpu_data[i]) << std::endl; } std::cout << "filter" << std::endl; for (int i = 0; i < filter_cpu.numel(); i++) { filter_cpu_data[i] = static_cast(random(-10, 10)); - std::cout << float(filter_cpu_data[i]) << std::endl; } for (int i = 0; i < bias_cpu.numel(); i++) { bias_cpu_data[i] = i + 1.0; @@ -221,6 +225,10 @@ TEST(conv_compute, int8_int8_out) { bias.Assign(bias_cpu_data, filter_cpu.dims()); + std::vector pads = {0, 0, 0, 0}; + std::vector dilations = {1, 1, 1, 1}; + param.paddings = std::make_shared>(pads); + param.dilations = std::make_shared>(dilations); param.x = &x; param.filter = &filter; param.output = &y; diff --git a/lite/kernels/cuda/elementwise_add_compute.cu b/lite/kernels/cuda/elementwise_add_compute.cu deleted file mode 100644 index 4bacf532a2b67168679449200b1af721b7a282c8..0000000000000000000000000000000000000000 --- a/lite/kernels/cuda/elementwise_add_compute.cu +++ /dev/null @@ -1,139 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "lite/backends/cuda/math/elementwise.h" -#include "lite/core/op_registry.h" -#include "lite/kernels/cuda/elementwise_add_compute.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace cuda { - -void ElementwiseAddCompute::Run() { - auto& param = this->Param(); - auto& ctx = this->ctx_->template As(); - auto stream = ctx.exec_stream(); - - const lite::Tensor* x = param.X; - const lite::Tensor* y = param.Y; - lite::Tensor* out = param.Out; - - CHECK(x->dims().production() == y->dims().production()); - - auto* x_data = x->data(); - auto* y_data = y->data(); - auto out_data = out->mutable_data(TARGET(kCUDA)); - - int pixel_num = x->numel(); - lite::cuda::math::elementwise_add( - pixel_num, x_data, y_data, out_data, stream); - - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); -} - -void ElementwiseAddComputeNHWC::Run() { - auto& param = this->Param(); - auto& ctx = this->ctx_->template As(); - auto stream = ctx.exec_stream(); - - const lite::Tensor* x = param.X; - const lite::Tensor* y = param.Y; - lite::Tensor* out = param.Out; - - CHECK(x->dims().production() == y->dims().production()); - - auto* x_data = x->data(); - auto* y_data = y->data(); - auto out_data = out->mutable_data(TARGET(kCUDA)); - - int pixel_num = x->numel(); - lite::cuda::math::elementwise_add( - pixel_num, x_data, y_data, out_data, stream); - - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); -} - -void ElementwiseAddComputeInt8::Run() { - auto& param = this->Param(); - auto& ctx = this->ctx_->template As(); - auto stream = ctx.exec_stream(); - - const lite::Tensor* x = param.X; - const lite::Tensor* y = param.Y; - lite::Tensor* out = param.Out; - - CHECK(x->dims().production() == y->dims().production()); - - const int c = x->dims()[3]; - - auto* x_data = x->data(); - auto* y_data = y->data(); - auto out_data = out->mutable_data(TARGET(kCUDA)); - - int pixel_num = x->numel(); - float output_scale = param.output_scale; - if (c % 4 == 0) { - lite::cuda::math::elementwise_add_nhwc4_int8( - pixel_num / 4, - static_cast(x_data), - static_cast(y_data), - 1. / output_scale, - static_cast(out_data), - stream); - } else { - lite::cuda::math::elementwise_add_int8( - pixel_num, x_data, y_data, 1. / output_scale, out_data, stream); - } - - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); -} - -} // namespace cuda -} // namespace kernels -} // namespace lite -} // namespace paddle - -REGISTER_LITE_KERNEL(elementwise_add, - kCUDA, - kFloat, - kNCHW, - paddle::lite::kernels::cuda::ElementwiseAddCompute, - def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .Finalize(); - -REGISTER_LITE_KERNEL(elementwise_add, - kCUDA, - kFloat, - kNHWC, - paddle::lite::kernels::cuda::ElementwiseAddComputeNHWC, - nhwc_format) - .BindInput("X", - {LiteType::GetTensorTy(TARGET(kCUDA), - PRECISION(kFloat), - DATALAYOUT(kNHWC))}) - .BindInput("Y", - {LiteType::GetTensorTy(TARGET(kCUDA), - PRECISION(kFloat), - DATALAYOUT(kNHWC))}) - .BindOutput("Out", - {LiteType::GetTensorTy(TARGET(kCUDA), - PRECISION(kFloat), - DATALAYOUT(kNHWC))}) - .Finalize(); diff --git a/lite/kernels/cuda/elementwise_compute.cu b/lite/kernels/cuda/elementwise_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..64759f86f5df85f9855b9c1f186bbc9c039a044c --- /dev/null +++ b/lite/kernels/cuda/elementwise_compute.cu @@ -0,0 +1,318 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "lite/backends/cuda/math/elementwise.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/elementwise_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +inline DDim trim_trailing_singular_dims(const DDim& dims) { + // Remove trailing dimensions of size 1 for y + auto actual_dims_size = dims.size(); + for (; actual_dims_size != 0; --actual_dims_size) { + if (dims[actual_dims_size - 1] != 1) break; + } + + std::vector trim_dims; + trim_dims.resize(actual_dims_size); + for (int i = 0; i < actual_dims_size; ++i) { + trim_dims[i] = dims[i]; + } + if (trim_dims.size() == 0) { + return DDim(); + } + return DDim(trim_dims); +} + +inline bool is_broadcast(const DDim& x_dims, + const DDim& y_dims, + int axis, + int* pre, + int* n, + int* post) { + if (axis < 0) { + axis = x_dims.size() - y_dims.size(); + } + DDim y_dim_trim = trim_trailing_singular_dims(y_dims); + axis = (y_dim_trim.size() == 0) ? x_dims.size() : axis; + if (x_dims.size() == y_dim_trim.size()) { + return false; + } + *pre = 1; + *n = 1; + *post = 1; + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + for (int i = 0; i < y_dim_trim.size(); ++i) { + CHECK_EQ(x_dims[i + axis], y_dim_trim[i]) + << "Broadcast dimension mismatch."; + (*n) *= y_dim_trim[i]; + } + for (int i = axis + y_dim_trim.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + return true; +} + +#define ELEMENTWISE_COMPUTE(OP, WITH_RELU) \ + auto& param = this->Param(); \ + auto& ctx = this->ctx_->template As(); \ + auto stream = ctx.exec_stream(); \ + const lite::Tensor* x = param.X; \ + const lite::Tensor* y = param.Y; \ + lite::Tensor* out = param.Out; \ + int axis = param.axis; \ + auto* x_data = x->data(); \ + auto* y_data = y->data(); \ + auto out_data = out->mutable_data(TARGET(kCUDA)); \ + int pixel_num = x->numel(); \ + int pre = 1; \ + int n = pixel_num; \ + int post = 1; \ + if (WITH_RELU) { \ + if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ + lite::cuda::math::elementwise_relu( \ + x_data, y_data, out_data, pre, n, post, OP, stream); \ + } else { \ + lite::cuda::math::elementwise_relu( \ + x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ + } \ + } else { \ + if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ + lite::cuda::math::elementwise( \ + x_data, y_data, out_data, pre, n, post, OP, stream); \ + } else { \ + lite::cuda::math::elementwise( \ + x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ + } \ + } + +#define ELEMENTWISE_COMPUTE_NHWC(OP, WITH_RELU) \ + std::map pos_map = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; \ + auto& param = this->Param(); \ + auto& ctx = this->ctx_->template As(); \ + auto stream = ctx.exec_stream(); \ + const lite::Tensor* x = param.X; \ + const lite::Tensor* y = param.Y; \ + lite::Tensor* out = param.Out; \ + int axis = param.axis; \ + if (axis < 0) axis = x->dims().size() - y->dims().size(); \ + CHECK(axis >= 0) << "invalid axis of elementwise op"; \ + axis = pos_map[axis]; \ + auto* x_data = x->data(); \ + auto* y_data = y->data(); \ + auto out_data = out->mutable_data(TARGET(kCUDA)); \ + int pixel_num = x->numel(); \ + int pre = 1; \ + int n = pixel_num; \ + int post = 1; \ + if (WITH_RELU) { \ + if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ + lite::cuda::math::elementwise_relu( \ + x_data, y_data, out_data, pre, n, post, OP, stream); \ + } else { \ + lite::cuda::math::elementwise_relu( \ + x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ + } \ + } else { \ + if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ + lite::cuda::math::elementwise( \ + x_data, y_data, out_data, pre, n, post, OP, stream); \ + } else { \ + lite::cuda::math::elementwise( \ + x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ + } \ + } + +void ElementwiseAddCompute::Run() { + ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, false) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseAddComputeNHWC::Run() { + ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, false) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseMulCompute::Run() { + ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, false) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseMulComputeNHWC::Run() { + ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, false) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseAddReluCompute::Run() { + ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, true) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseAddReluComputeNHWC::Run() { + ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, true) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseMulReluCompute::Run() { + ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, true) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseMulReluComputeNHWC::Run() { + ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, true) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(elementwise_add, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ElementwiseAddCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); + +REGISTER_LITE_KERNEL(elementwise_add, + kCUDA, + kFloat, + kNHWC, + paddle::lite::kernels::cuda::ElementwiseAddComputeNHWC, + nhwc_format) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL(elementwise_mul, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ElementwiseMulCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); + +REGISTER_LITE_KERNEL(elementwise_mul, + kCUDA, + kFloat, + kNHWC, + paddle::lite::kernels::cuda::ElementwiseMulComputeNHWC, + nhwc_format) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL(fusion_elementwise_add_activation, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ElementwiseAddReluCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); + +REGISTER_LITE_KERNEL(fusion_elementwise_add_activation, + kCUDA, + kFloat, + kNHWC, + paddle::lite::kernels::cuda::ElementwiseAddReluComputeNHWC, + nhwc_format) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ElementwiseMulReluCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); + +REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation, + kCUDA, + kFloat, + kNHWC, + paddle::lite::kernels::cuda::ElementwiseMulReluComputeNHWC, + nhwc_format) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); diff --git a/lite/kernels/cuda/elementwise_add_compute.h b/lite/kernels/cuda/elementwise_compute.h similarity index 52% rename from lite/kernels/cuda/elementwise_add_compute.h rename to lite/kernels/cuda/elementwise_compute.h index 5c3fecc5d894aeea2bc5260b1815bbfa718eb5c6..986a4db2272d9a6607090babd937747f861f49c7 100644 --- a/lite/kernels/cuda/elementwise_add_compute.h +++ b/lite/kernels/cuda/elementwise_compute.h @@ -38,13 +38,58 @@ class ElementwiseAddComputeNHWC virtual ~ElementwiseAddComputeNHWC() = default; }; -class ElementwiseAddComputeInt8 +class ElementwiseMulCompute + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + + void Run() override; + virtual ~ElementwiseMulCompute() = default; +}; + +class ElementwiseMulComputeNHWC : public KernelLite { public: using param_t = operators::ElementwiseParam; void Run() override; - virtual ~ElementwiseAddComputeInt8() = default; + virtual ~ElementwiseMulComputeNHWC() = default; +}; + +class ElementwiseAddReluCompute + : public KernelLite { + public: + using param_t = operators::FusionElementwiseActivationParam; + + void Run() override; + virtual ~ElementwiseAddReluCompute() = default; +}; + +class ElementwiseAddReluComputeNHWC + : public KernelLite { + public: + using param_t = operators::FusionElementwiseActivationParam; + + void Run() override; + virtual ~ElementwiseAddReluComputeNHWC() = default; +}; + +class ElementwiseMulReluCompute + : public KernelLite { + public: + using param_t = operators::FusionElementwiseActivationParam; + + void Run() override; + virtual ~ElementwiseMulReluCompute() = default; +}; + +class ElementwiseMulReluComputeNHWC + : public KernelLite { + public: + using param_t = operators::FusionElementwiseActivationParam; + + void Run() override; + virtual ~ElementwiseMulReluComputeNHWC() = default; }; } // namespace cuda diff --git a/lite/kernels/cuda/elementwise_add_compute_test.cc b/lite/kernels/cuda/elementwise_compute_test.cc similarity index 55% rename from lite/kernels/cuda/elementwise_add_compute_test.cc rename to lite/kernels/cuda/elementwise_compute_test.cc index cc63f1470b65de37eb73c71701a83146e12778ae..9fd0b7754f2d3209137b5f4862dfe1e90279f3be 100644 --- a/lite/kernels/cuda/elementwise_add_compute_test.cc +++ b/lite/kernels/cuda/elementwise_compute_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/cuda/elementwise_add_compute.h" +#include "lite/kernels/cuda/elementwise_compute.h" #include #include #include @@ -31,6 +31,14 @@ static void ElementwiseAddRef(float* x, float* y, float* out, int num) { } } +static void ElementwiseBroadcastRef( + float* x, float* y, float* out, int pre, int n, int post) { + for (int i = 0; i < pre * n * post; ++i) { + int idx = (i / post) % n; + out[i] = x[i] + y[idx]; + } +} + TEST(elementwise_add, normal) { ElementwiseAddCompute elementwise_add_kernel; std::unique_ptr ctx(new KernelContext); @@ -99,38 +107,117 @@ TEST(elementwise_add, normal) { } } -TEST(elementwise_add, int8_out) { - ElementwiseAddComputeInt8 elementwise_add_kernel; +TEST(elementwise_add, bias) { + ElementwiseAddCompute elementwise_add_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ElementwiseParam param; + Tensor x, y, out; + Tensor x_cpu, y_cpu, out_cpu; + Tensor x_ref, y_ref, out_ref; + + const int n = 1; + const int c = 3; + const int h = 2000; + const int w = 2000; + + x.Resize({n, c, h, w}); + y.Resize({c, 1, 1}); + out.Resize({n, c, h, w}); + x_cpu.Resize({n, c, h, w}); + y_cpu.Resize({c, 1, 1}); + out_cpu.Resize({n, c, h, w}); + x_ref.Resize({n, c, h, w}); + y_ref.Resize({c, 1, 1}); + out_ref.Resize({n, c, h, w}); + + auto* out_data = out.mutable_data(TARGET(kCUDA)); + + auto* x_cpu_data = x_cpu.mutable_data(); + auto* y_cpu_data = y_cpu.mutable_data(); + auto* out_cpu_data = out_cpu.mutable_data(); + + auto* x_ref_data = x_ref.mutable_data(); + auto* y_ref_data = y_ref.mutable_data(); + auto* out_ref_data = out_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 5.0; + x_ref_data[i] = i + 5.0; + } + for (int i = 0; i < y_cpu.numel(); ++i) { + y_cpu_data[i] = i - 5.0; + y_ref_data[i] = i - 5.0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + y.Assign(y_cpu_data, y_cpu.dims()); + + param.X = &x; + param.Y = &y; + param.Out = &out; + param.axis = -1; + elementwise_add_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + elementwise_add_kernel.SetContext(std::move(ctx)); + elementwise_add_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + ElementwiseBroadcastRef(x_ref_data, y_ref_data, out_ref_data, n, c, h * w); + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); + } +} + +TEST(elementwise_add_nhwc, bias) { + ElementwiseAddComputeNHWC elementwise_add_kernel; std::unique_ptr ctx(new KernelContext); auto& context = ctx->As(); operators::ElementwiseParam param; Tensor x, y, out; Tensor x_cpu, y_cpu, out_cpu; + Tensor x_ref, y_ref, out_ref; const int n = 1; - const int h = 36; - const int w = 36; - const int c = 125; + const int c = 3; + const int h = 2000; + const int w = 2000; x.Resize({n, h, w, c}); - y.Resize({n, h, w, c}); + y.Resize({c, 1, 1}); out.Resize({n, h, w, c}); x_cpu.Resize({n, h, w, c}); - y_cpu.Resize({n, h, w, c}); + y_cpu.Resize({c, 1, 1}); out_cpu.Resize({n, h, w, c}); + x_ref.Resize({n, h, w, c}); + y_ref.Resize({c, 1, 1}); + out_ref.Resize({n, h, w, c}); - auto* out_data = out.mutable_data(TARGET(kCUDA)); + auto* out_data = out.mutable_data(TARGET(kCUDA)); auto* x_cpu_data = x_cpu.mutable_data(); auto* y_cpu_data = y_cpu.mutable_data(); - auto* out_cpu_data = out_cpu.mutable_data(); + auto* out_cpu_data = out_cpu.mutable_data(); + + auto* x_ref_data = x_ref.mutable_data(); + auto* y_ref_data = y_ref.mutable_data(); + auto* out_ref_data = out_ref.mutable_data(); for (int i = 0; i < x_cpu.numel(); ++i) { x_cpu_data[i] = i + 5.0; + x_ref_data[i] = i + 5.0; } for (int i = 0; i < y_cpu.numel(); ++i) { - y_cpu_data[i] = i; + y_cpu_data[i] = i - 5.0; + y_ref_data[i] = i - 5.0; } x.Assign(x_cpu_data, x_cpu.dims()); @@ -139,7 +226,7 @@ TEST(elementwise_add, int8_out) { param.X = &x; param.Y = &y; param.Out = &out; - param.output_scale = 50 / 127.; + param.axis = -1; elementwise_add_kernel.SetParam(param); cudaStream_t stream; @@ -147,16 +234,15 @@ TEST(elementwise_add, int8_out) { context.SetExecStream(stream); elementwise_add_kernel.SetContext(std::move(ctx)); - auto start = GetCurrentUS(); - for (int i = 0; i < 1000000; i++) { - elementwise_add_kernel.Launch(); - } - LOG(INFO) << "time: " << (GetCurrentUS() - start) / 1000000.; + elementwise_add_kernel.Launch(); + cudaDeviceSynchronize(); CopySync( - out_cpu_data, out_data, sizeof(int8_t) * out.numel(), IoDirection::DtoH); + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + ElementwiseBroadcastRef( + x_ref_data, y_ref_data, out_ref_data, n * h * w, c, 1); for (int i = 0; i < out.numel(); i++) { - // LOG(INFO) << float(out_cpu_data[i]); + EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); } } diff --git a/lite/kernels/cuda/layout_compute.cc b/lite/kernels/cuda/layout_compute.cc index e2d0ae4f2ef10b29247a2f823988e8098aa33795..6b56d9e1de28cbec57b4b45aff1d1b237b1784b9 100644 --- a/lite/kernels/cuda/layout_compute.cc +++ b/lite/kernels/cuda/layout_compute.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/kernels/cuda/layout_compute.h" +#include #include "lite/backends/cuda/math/transpose.h" #include "lite/core/op_registry.h" @@ -21,11 +22,32 @@ namespace lite { namespace kernels { namespace cuda { +inline DDim trim_singular_dims(const DDim& dims) { + auto actual_dims_size = dims.size(); + for (; actual_dims_size != 0; --actual_dims_size) { + if (dims[actual_dims_size - 1] != 1) break; + } + std::vector trim_dims; + trim_dims.resize(actual_dims_size); + for (int i = 0; i < actual_dims_size; ++i) { + trim_dims[i] = dims[i]; + } + if (trim_dims.size() == 0) { + return DDim(); + } + return DDim(trim_dims); +} + #define NCHWTONHWC(type) \ auto& param = this->template Param(); \ auto& ctx = this->ctx_->template As(); \ auto input = param.x->template data(); \ auto input_dim = param.x->dims(); \ + DDim input_trim_dim = trim_singular_dims(input_dim); \ + if (input_trim_dim.size() == 1) { \ + param.y->CopyDataFrom(*param.x); \ + return; \ + } \ CHECK(input_dim.size() == 4) \ << "NCHW to NHWC should guarantee that the input dims should be 4"; \ int n = input_dim[0]; \ @@ -41,6 +63,11 @@ namespace cuda { auto& ctx = this->ctx_->template As(); \ auto input = param.x->template data(); \ auto input_dim = param.x->dims(); \ + DDim input_trim_dim = trim_singular_dims(input_dim); \ + if (input_trim_dim.size() == 1) { \ + param.y->CopyDataFrom(*param.x); \ + return; \ + } \ CHECK(input_dim.size() == 4) \ << "NHWC to NCHW should guarantee that the input dims should be 4"; \ int n = input_dim[0]; \ diff --git a/lite/kernels/cuda/mul_compute_test.cc b/lite/kernels/cuda/mul_compute_test.cc index d1c1d63e7dcd46f84cd128fc5b855da2098e179d..f521a12e2dddcf854b3982ae37f4da7631f6acf3 100644 --- a/lite/kernels/cuda/mul_compute_test.cc +++ b/lite/kernels/cuda/mul_compute_test.cc @@ -16,6 +16,7 @@ #include #include #include +#include "lite/backends/cuda/blas.h" namespace paddle { namespace lite { @@ -26,6 +27,7 @@ TEST(mul_compute, normal) { MulCompute mul_kernel; std::unique_ptr ctx(new KernelContext); auto& context = ctx->As(); + context.InitOnce(); Tensor x, y, out, x_cpu, y_cpu, out_cpu; int x_h = 2, x_w_y_h = 3, y_w = 4; diff --git a/lite/kernels/cuda/pool_compute.cu b/lite/kernels/cuda/pool_compute.cu index 456a2ce91105c7bc7822d78486633c19b124ff24..d7e3739ddbb59a624e1911b8178e96053dacc0d1 100644 --- a/lite/kernels/cuda/pool_compute.cu +++ b/lite/kernels/cuda/pool_compute.cu @@ -358,6 +358,61 @@ void PoolCompute::Run() { if (error != cudaSuccess) LOG(FATAL) << cudaGetErrorString(error); } +inline int PoolOutputSize( + int input_size, int filter_size, int padding, int stride, bool ceil_mode) { + int output_size; + if (!ceil_mode) { + output_size = (input_size - filter_size + 2 * padding) / stride + 1; + } else { + output_size = + (input_size - filter_size + 2 * padding + stride - 1) / stride + 1; + } + return output_size; +} + +void PoolComputeNHWC::PrepareForRun() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + pool_impl_.reset(new lite::cuda::math::CudnnPool2DNHWC); + pool_impl_->init(param, &ctx); +} + +void PoolComputeNHWC::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + const auto x_dims = param.x->dims(); + std::vector& ksize = param.ksize; + if (param.global_pooling) { + ksize.resize(static_cast(x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) { + (*param.paddings)[i] = 0; + ksize[i] = static_cast(x_dims[i + 1]); + } + } + + std::vector output_shape({x_dims[0]}); + if (param.adaptive) { + output_shape.insert( + output_shape.end(), param.ksize.begin(), param.ksize.end()); + } else { + for (size_t i = 0; i < param.ksize.size(); ++i) { + output_shape.push_back(PoolOutputSize(x_dims[i + 1], + param.ksize[i], + (*param.paddings)[i], + param.strides[i], + param.ceil_mode)); + } + } + output_shape.push_back(x_dims[3]); + param.output->Resize(lite::DDim(output_shape)); + + pool_impl_->run(param); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(FATAL) << cudaGetErrorString(error); +} + } // namespace cuda } // namespace kernels } // namespace lite @@ -374,3 +429,19 @@ REGISTER_LITE_KERNEL( PRECISION(kFloat), DATALAYOUT(kNCHW))}) .Finalize(); + +REGISTER_LITE_KERNEL(pool2d, + kCUDA, + kFloat, + kNHWC, + paddle::lite::kernels::cuda::PoolComputeNHWC, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); diff --git a/lite/kernels/cuda/pool_compute.h b/lite/kernels/cuda/pool_compute.h index 55b346bfaf4ac139c8d22bff2ac64f0e78bc6023..5c3a1bc2b93d3a03a40515fff6f14e604a11c0a1 100644 --- a/lite/kernels/cuda/pool_compute.h +++ b/lite/kernels/cuda/pool_compute.h @@ -13,6 +13,9 @@ // limitations under the License. #pragma once +#include +#include +#include "lite/backends/cuda/math/cudnn_pool.h" #include "lite/core/kernel.h" namespace paddle { @@ -29,6 +32,20 @@ class PoolCompute virtual ~PoolCompute() = default; }; +class PoolComputeNHWC + : public KernelLite { + public: + using param_t = operators::PoolParam; + + void PrepareForRun() override; + void Run() override; + virtual ~PoolComputeNHWC() = default; + + private: + std::unique_ptr> + pool_impl_; +}; + } // namespace cuda } // namespace kernels } // namespace lite diff --git a/lite/kernels/cuda/pool_compute_test.cc b/lite/kernels/cuda/pool_compute_test.cc index 308905c1d01c591a32328f47d909b26af3b6d5d2..0e5aeec8c0133f1f61b469437e3e9a602096133f 100644 --- a/lite/kernels/cuda/pool_compute_test.cc +++ b/lite/kernels/cuda/pool_compute_test.cc @@ -27,6 +27,71 @@ namespace cuda { using Tensor = lite::Tensor; using DDim = lite::DDim; +#define IN(n, c, h, w) \ + input_data[w + h * input_w + c * input_h * input_w + \ + n * input_c * input_h * input_w] +#define OUT(n, c, h, w) \ + output_data[w + h * output_w + c * output_h * output_w + \ + n * output_c * output_h * output_w] + +template +void nchw2nhwc_ref(lite::Tensor* input, lite::Tensor* output) { + auto* input_data = input->data(); + auto* output_data = output->mutable_data(); + + int input_n = input->dims()[0]; + int input_c = input->dims()[1]; + int input_h = input->dims()[2]; + int input_w = input->dims()[3]; + int output_c = output->dims()[1]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + + for (int n = 0; n < input_n; ++n) { + for (int c = 0; c < input_c; ++c) { + for (int h = 0; h < input_h; ++h) { + for (int w = 0; w < input_w; ++w) { + OUT(n, h, w, c) = IN(n, c, h, w); + } + } + } + } +} + +#undef IN +#undef OUT + +#define IN(n, h, w, c) \ + input_data[c + w * input_c + h * input_w * input_c + \ + n * input_h * input_w * input_c] +#define OUT(n, h, w, c) \ + output_data[c + w * output_c + h * output_w * output_c + \ + n * output_h * output_w * output_c] + +template +void nhwc2nchw_ref(lite::Tensor* input, lite::Tensor* output) { + auto* input_data = input->data(); + auto* output_data = output->mutable_data(); + + int input_n = input->dims()[0]; + int input_h = input->dims()[1]; + int input_w = input->dims()[2]; + int input_c = input->dims()[3]; + int output_h = output->dims()[1]; + int output_w = output->dims()[2]; + int output_c = output->dims()[3]; + + for (int n = 0; n < input_n; ++n) { + for (int c = 0; c < input_c; ++c) { + for (int h = 0; h < input_h; ++h) { + for (int w = 0; w < input_w; ++w) { + OUT(n, c, h, w) = IN(n, h, w, c); + } + } + } + } +} + static int PoolOutputSize(int input_size, int filter_size, int pad_left, @@ -46,7 +111,10 @@ static int PoolOutputSize(int input_size, return output_size; } -static std::vector compute_output_shape(operators::PoolParam* param_) { +static std::vector compute_output_shape(operators::PoolParam* param_, + bool is_nchw) { + int axis = 2; + if (!is_nchw) axis = 1; const auto x_dims = param_->x->dims(); std::vector& ksize = param_->ksize; if (param_->global_pooling) { @@ -59,13 +127,15 @@ static std::vector compute_output_shape(operators::PoolParam* param_) { } } - std::vector output_shape({x_dims[0], x_dims[1]}); + std::vector output_shape({x_dims[0]}); + if (is_nchw) output_shape.push_back(x_dims[1]); if (param_->adaptive) { output_shape.insert( output_shape.end(), param_->ksize.begin(), param_->ksize.end()); } else { + auto paddings = *param_->paddings; for (size_t i = 0; i < param_->ksize.size(); ++i) { - output_shape.push_back(PoolOutputSize(x_dims[i + 2], + output_shape.push_back(PoolOutputSize(x_dims[i + axis], param_->ksize[i], paddings[2 * i], paddings[2 * i + 1], @@ -73,6 +143,7 @@ static std::vector compute_output_shape(operators::PoolParam* param_) { param_->ceil_mode)); } } + if (!is_nchw) output_shape.push_back(x_dims[3]); return output_shape; } @@ -205,15 +276,15 @@ TEST(pool_cuda, compute) { for (auto pad : {0, 1}) { for (auto n : {1, 2}) { for (auto c : {1, 3}) { - for (auto h : {2, 3, 4, 11}) { - for (auto w : {2, 3, 4, 11}) { - VLOG(3) << "n:" << n << " c:" << c << " h:" << h - << " w:" << w << " ksize:" << ksize - << " stride:" << stride << " pad:" << pad - << " exclusive:" << exclusive - << " global_pooling:" << global_pooling - << " ceil_mode: " << ceil_mode - << " pooling_type:" << pooling_type; + for (auto h : {3}) { + for (auto w : {3}) { + LOG(INFO) << "n:" << n << " c:" << c << " h:" << h + << " w:" << w << " ksize:" << ksize + << " stride:" << stride << " pad:" << pad + << " exclusive:" << exclusive + << " global_pooling:" << global_pooling + << " ceil_mode: " << ceil_mode + << " pooling_type:" << pooling_type; // init x, output x.Resize(DDim(std::vector({n, c, h, w}))); @@ -245,7 +316,7 @@ TEST(pool_cuda, compute) { param.use_quantizer = false; const std::vector& output_shape = - compute_output_shape(¶m); + compute_output_shape(¶m, true); if (output_shape[2] * output_shape[3] == 0) continue; output.Resize(DDim(output_shape)); output_ref.Resize(DDim(output_shape)); @@ -289,6 +360,131 @@ TEST(pool_cuda, compute) { } } } + +TEST(pool_cuda, nhwc) { + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + PoolComputeNHWC pool; + operators::PoolParam param; + pool.SetContext(std::move(ctx)); + + lite::Tensor x, temp; + lite::Tensor x_cpu; + lite::Tensor output; + lite::Tensor output_cpu, output_temp; + lite::Tensor output_ref; + for (auto pooling_type : {"max", "avg"}) { + for (auto ceil_mode : {false}) { + for (auto global_pooling : {true, false}) { + for (auto exclusive : {false, true}) { + for (auto ksize : {3}) { + for (auto stride : {3}) { + for (auto pad : {1}) { + for (auto n : {1}) { + for (auto c : {3}) { + for (auto h : {8}) { + for (auto w : {8}) { + LOG(INFO) << "n:" << n << " c:" << c << " h:" << h + << " w:" << w << " ksize:" << ksize + << " stride:" << stride << " pad:" << pad + << " exclusive:" << exclusive + << " global_pooling:" << global_pooling + << " ceil_mode: " << ceil_mode + << " pooling_type:" << pooling_type; + + // init x, output + x.Resize(DDim(std::vector({n, h, w, c}))); + temp.Resize(DDim(std::vector({n, h, w, c}))); + x_cpu.Resize(DDim(std::vector({n, c, h, w}))); + + auto* x_cpu_data = x_cpu.mutable_data(); + for (int i = 0; i < x_cpu.dims().production(); ++i) { + float sign = i % 3 == 0 ? -0.03 : 0.05f; + x_cpu_data[i] = sign * (i % 128); + } + + nchw2nhwc_ref(&x_cpu, &temp); + auto* temp_cpu_data = temp.mutable_data(); + + x.Assign(temp_cpu_data, + temp.dims()); + // fill param + param.x = &x; + param.output = &output; + param.pooling_type = pooling_type; + if (global_pooling) { + param.ksize = {h, w}; + } else { + param.ksize = {ksize, ksize}; + } + param.global_pooling = global_pooling; + param.strides = {stride, stride}; + std::vector paddings = {pad, pad, pad, pad}; + param.paddings = + std::make_shared>(paddings); + param.exclusive = exclusive; + param.ceil_mode = ceil_mode; + param.adaptive = false; + param.use_quantizer = false; + + const std::vector& output_shape = + compute_output_shape(¶m, false); + if (output_shape[2] * output_shape[3] == 0) continue; + output.Resize(DDim(output_shape)); + output_temp.Resize(DDim(output_shape)); + output_cpu.Resize(DDim(output_shape)); + + auto* output_data = + output.mutable_data(TARGET(kCUDA)); + auto* output_cpu_data = + output_cpu.mutable_data(); + + // compute + pool.SetParam(param); + pool.Launch(); + + // compute ref + param.x = &x_cpu; + // nchw + const std::vector& output_shape_ref = + compute_output_shape(¶m, true); + + output_ref.Resize(DDim(output_shape_ref)); + // auto* output_ref_data = + // output_ref.mutable_data(); + param.output = &output_ref; + pool_compute_ref(param); + nchw2nhwc_ref(&output_ref, &output_temp); + auto* output_temp_data = + output_temp.mutable_data(); + + cudaDeviceSynchronize(); + CopySync(output_cpu_data, + output_data, + sizeof(float) * output.numel(), + IoDirection::DtoH); + // compare + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR( + output_cpu_data[i], output_temp_data[i], 1e-4); + } + VLOG(3) << "compare pass"; + } + } + } + } + } + } + } + } + } + } + } +} } // namespace cuda } // namespace kernels } // namespace lite diff --git a/lite/kernels/cuda/search_aligned_mat_mul_compute_test.cc b/lite/kernels/cuda/search_aligned_mat_mul_compute_test.cc index 66b3478f64f1144aa09404cb943c3de49e549b0d..f08333b3103973f99d37e39e7e7babeb52b335f1 100644 --- a/lite/kernels/cuda/search_aligned_mat_mul_compute_test.cc +++ b/lite/kernels/cuda/search_aligned_mat_mul_compute_test.cc @@ -57,7 +57,7 @@ void search_aligned_mat_mul_compute_ref(const operators::MatMulParam& param) { auto x_data = x->data(); auto y_data = y->data(); auto out_data = out->mutable_data(); -#pragma omp parallel for + for (int seq = 0; seq < seq_num; seq++) { auto a = x_data + seq * x_stride; auto b = y_data + seq * y_stride; diff --git a/lite/kernels/cuda/search_seq_fc_compute_test.cc b/lite/kernels/cuda/search_seq_fc_compute_test.cc index 0b9beb7b09290e81f17ff2580ff68f4592c9b132..354d1bb5bc3b0f3ee4d102fb2ebce176041ba91b 100644 --- a/lite/kernels/cuda/search_seq_fc_compute_test.cc +++ b/lite/kernels/cuda/search_seq_fc_compute_test.cc @@ -49,7 +49,6 @@ void search_seq_fc_compute_ref(const operators::SearchSeqFcParam& param) { auto w_data = w->data(); auto out_data = out->mutable_data(); -#pragma omp parallel for for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { auto sum = static_cast(0);