未验证 提交 f7574646 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[cuda] [int8] resnet50 cuda int8 support (#2417)

* init resnet cuda int8 support
test=develop

* refine cuda unit test
test=develop

* add the forgeted file.
test=develop
上级 8a634b71
......@@ -11,6 +11,7 @@ nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps})
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale
cuda_type_trans ${cuda_static_deps})
nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps})
nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps})
nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps})
......@@ -22,6 +23,7 @@ set (
cuda_type_trans
cuda_transpose
cuda_elementwise
cudnn_pool
cuda_gemm
cuda_batched_gemm
)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/cudnn_pool.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/type_trans.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
inline void UpdatePadding(std::vector<int>* paddings,
const bool global_pooling,
const bool adaptive,
const std::vector<int>& data_dims,
const std::vector<int>& strides,
const std::vector<int>& ksize) {
if (paddings->size() == data_dims.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
int copy_pad = *(paddings->begin() + 2 * i);
paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
}
} else {
CHECK(data_dims.size() * 2 == paddings->size())
<< "Paddings size should be the same or twice as the pooling size.";
}
if (global_pooling || adaptive) {
for (auto it = paddings->begin(); it != paddings->end(); it++) {
*it = 0;
}
}
}
inline void UpdateKsize(std::vector<int>* ksize,
const std::vector<int>& data_dims) {
ksize->resize(static_cast<size_t>(data_dims.size()));
for (size_t i = 0; i < ksize->size(); ++i) {
*(ksize->begin() + i) = static_cast<int>(data_dims[i]);
}
}
template <>
bool CudnnPool2DNHWC<PRECISION(kFloat)>::create(
const operators::PoolParam& param, Context<TARGET(kCUDA)>* ctx) {
return true;
}
template <>
bool CudnnPool2DNHWC<PRECISION(kFloat)>::init(const operators::PoolParam& param,
Context<TARGET(kCUDA)>* ctx) {
this->stream_ = ctx->exec_stream();
CUDNN_CHECK(cudnnCreate(&this->handle_));
CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_));
cudnnCreateTensorDescriptor(&this->input_desc_);
cudnnCreateTensorDescriptor(&this->output_desc_);
cudnnCreatePoolingDescriptor(&this->pooling_desc_);
return create(param, ctx);
}
template <>
bool CudnnPool2DNHWC<PRECISION(kFloat)>::run(
const operators::PoolParam& param) {
auto x_dims = param.x->dims();
auto o_dims = param.output->dims();
int batch = x_dims[0];
const float* in_data = param.x->data<float>();
float* out_data = param.output->mutable_data<float>(TARGET(kCUDA));
int ih = x_dims[1];
int iw = x_dims[2]; // nchw
int ic = x_dims[3];
int oh = o_dims[1];
int ow = o_dims[2];
int oc = o_dims[3];
std::vector<int> ksize = param.ksize;
std::vector<int> strides = param.strides;
std::vector<int> paddings = *(param.paddings.get());
std::string pooling_type = param.pooling_type;
bool global_pooling = param.global_pooling;
bool exclusive = param.exclusive;
bool adaptive = param.adaptive;
std::vector<int> data_dims = {ih, iw};
UpdatePadding(&paddings, global_pooling, adaptive, data_dims, strides, ksize);
if (data_dims.size() * 2 == paddings.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
}
}
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
CUDNN_TENSOR_NHWC,
CUDNN_DATA_FLOAT,
batch,
ic,
ih,
iw));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
CUDNN_TENSOR_NHWC,
CUDNN_DATA_FLOAT,
batch,
oc,
oh,
ow));
cudnnPoolingMode_t mode;
if (pooling_type == "max") {
mode = CUDNN_POOLING_MAX;
} else {
mode = exclusive ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
: CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
}
CUDNN_CHECK(cudnnSetPoolingNdDescriptor(this->pooling_desc_,
mode,
CUDNN_NOT_PROPAGATE_NAN,
ksize.size(),
ksize.data(),
paddings.data(),
strides.data()));
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_CHECK(cudnnPoolingForward(this->handle_,
this->pooling_desc_,
&alpha,
this->input_desc_,
in_data,
&beta,
this->output_desc_,
out_data));
return true;
}
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <PrecisionType Ptype_out>
class CudnnPool2DBase {
public:
CudnnPool2DBase()
: handle_(NULL),
input_desc_(NULL),
output_desc_(NULL),
pooling_desc_(NULL) {}
~CudnnPool2DBase() {
if (handle_ != NULL) {
CUDNN_CHECK(cudnnDestroy(handle_));
}
if (input_desc_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc_));
}
if (output_desc_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc_));
}
if (pooling_desc_) {
cudnnDestroyPoolingDescriptor(pooling_desc_);
}
}
protected:
cudaStream_t stream_;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t input_desc_;
cudnnTensorDescriptor_t output_desc_;
cudnnPoolingDescriptor_t pooling_desc_;
};
template <PrecisionType Ptype_out>
class CudnnPool2DNHWC : public CudnnPool2DBase<Ptype_out> {
public:
CudnnPool2DNHWC() : CudnnPool2DBase<Ptype_out>() {}
virtual ~CudnnPool2DNHWC() = default;
virtual bool init(const operators::PoolParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool create(const operators::PoolParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool run(const operators::PoolParam& param);
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -13,13 +13,55 @@
// limitations under the License.
#include "lite/backends/cuda/math/elementwise.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename Dtype>
__global__ void elementwise_kernel(const size_t total,
const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < total) {
int idx = tid / post % n;
#if __CUDA_ARCH__ >= 350
out_data[tid] = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type);
#else
out_data[tid] = binary_calc(x_data[tid], y_data[idx], type);
#endif
}
}
template <typename Dtype>
__global__ void elementwise_relu_kernel(const size_t total,
const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < total) {
int idx = tid / post % n;
Dtype temp;
#if __CUDA_ARCH__ >= 350
temp = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type);
#else
temp = binary_calc(x_data[tid], y_data[idx], type);
#endif
out_data[tid] = temp > 0 ? temp : 0;
}
}
template <typename Dtype>
__global__ void elementwise_add_kernel(const size_t total,
const Dtype* x_data,
......@@ -76,6 +118,56 @@ __global__ void elementwise_add_nhwc4_int8_kernel(const size_t total,
}
}
template <typename Dtype>
void elementwise(const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type,
cudaStream_t stream) {
int num = pre * n * post;
int thread = 256;
int block = (num + thread - 1) / thread;
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x_data, y_data, out_data, pre, n, post, type);
}
template <typename Dtype>
void elementwise_relu(const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type,
cudaStream_t stream) {
int num = pre * n * post;
int thread = 256;
int block = (num + thread - 1) / thread;
elementwise_relu_kernel<<<block, thread, 0, stream>>>(
num, x_data, y_data, out_data, pre, n, post, type);
}
template void elementwise(const float*,
const float*,
float*,
int,
int,
int,
BinaryOperation,
cudaStream_t);
template void elementwise_relu(const float*,
const float*,
float*,
int,
int,
int,
BinaryOperation,
cudaStream_t);
template <typename Dtype>
void elementwise_add(int num,
const Dtype* x_data,
......
......@@ -15,12 +15,33 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename Dtype>
void elementwise(const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type,
cudaStream_t stream);
template <typename Dtype>
void elementwise_relu(const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type,
cudaStream_t stream);
template <typename Dtype>
void elementwise_add(int num,
const Dtype* x_data,
......
......@@ -25,6 +25,24 @@ namespace lite {
namespace cuda {
namespace math {
enum class BinaryOperation {
kADD = 0,
kMUL = 1,
kDIV = 2,
};
template <typename T>
__device__ T binary_calc(T x, T y, BinaryOperation type);
template <>
__device__ __forceinline__ float binary_calc(float x,
float y,
BinaryOperation type) {
if (type == BinaryOperation::kADD) return x + y;
if (type == BinaryOperation::kMUL) return x * y;
if (type == BinaryOperation::kDIV) return x / y;
}
template <typename T>
__device__ T from_float(float x);
......
......@@ -73,7 +73,7 @@ class Optimizer {
"lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA)
"lite_elementwise_add_activation_fuse_pass", //
#endif
"static_kernel_pick_pass", // pick original kernel from graph
......
......@@ -15,14 +15,15 @@ add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${li
add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(concat_compute_cuda CUDA basic SRCS concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute.cu DEPS ${lite_kernel_deps} cuda_elementwise)
add_kernel(elementwise_compute_cuda CUDA basic SRCS elementwise_compute.cu DEPS ${lite_kernel_deps} cuda_elementwise)
add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps})
add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose)
add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_cuda CUDA basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} cuda_scale)
add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} cuda_scale)
add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps})
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS
${lite_kernel_deps} cudnn_pool)
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
......@@ -47,12 +48,13 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(search_group_padding_compute_cuda_test SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
nv_test(elementwise_compute_cuda_test SRCS elementwise_compute_test.cc DEPS elementwise_compute_cuda)
nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda)
nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda)
nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/calib_compute.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
......@@ -58,12 +59,7 @@ void calib_ref(const operators::CalibParam& param, bool to_float = true) {
}
TEST(calib_cuda, int8_to_fp32) {
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create(
"calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW));
ASSERT_FALSE(kernels.empty());
auto calib = std::move(*std::next(kernels.begin(), 1));
LOG(INFO) << "get kernel: " << calib->doc();
CalibComputeInt8ToFp32 calib;
const int n = 64, c = 32, h = 18, w = 18;
Tensor x;
Tensor x_cpu;
......@@ -87,14 +83,14 @@ TEST(calib_cuda, int8_to_fp32) {
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
calib->SetContext(std::move(ctx));
calib.SetContext(std::move(ctx));
operators::CalibParam param;
param.scale = 0.013f;
param.input = &x;
param.output = &output;
calib->SetParam(param);
calib->Launch();
calib.SetParam(param);
calib.Launch();
cudaDeviceSynchronize();
// invoking ref implementation and compare results
param.input = &x_cpu;
......@@ -113,12 +109,7 @@ TEST(calib_cuda, int8_to_fp32) {
}
TEST(calib_cuda, fp32_to_int8) {
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create(
"calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW));
ASSERT_FALSE(kernels.empty());
auto calib = std::move(kernels.front());
LOG(INFO) << "get kernel: " << calib->doc();
CalibComputeFp32ToInt8 calib;
const int n = 64, c = 32, h = 18, w = 18;
Tensor x;
Tensor x_cpu;
......@@ -142,14 +133,14 @@ TEST(calib_cuda, fp32_to_int8) {
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
calib->SetContext(std::move(ctx));
calib.SetContext(std::move(ctx));
operators::CalibParam param;
param.scale = 0.013f;
param.input = &x;
param.output = &output;
calib->SetParam(param);
calib->Launch();
calib.SetParam(param);
calib.Launch();
cudaDeviceSynchronize();
// invoking ref implementation and compare results
param.input = &x_cpu;
......
......@@ -42,7 +42,9 @@ TEST(conv_compute, fp32) {
operators::ConvParam param;
param.activation_param = act_param;
std::vector<int> pads = {1, 1, 1, 1};
std::vector<int> dilations = {1, 1, 1, 1};
param.paddings = std::make_shared<std::vector<int>>(pads);
param.dilations = std::make_shared<std::vector<int>>(dilations);
param.groups = 1;
Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
......@@ -149,6 +151,10 @@ TEST(conv_compute, int8) {
bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data,
filter_cpu.dims());
std::vector<int> pads = {0, 0, 0, 0};
std::vector<int> dilations = {1, 1, 1, 1};
param.paddings = std::make_shared<std::vector<int>>(pads);
param.dilations = std::make_shared<std::vector<int>>(dilations);
param.x = &x;
param.filter = &filter;
param.output = &y;
......@@ -203,12 +209,10 @@ TEST(conv_compute, int8_int8_out) {
std::cout << "input" << std::endl;
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = static_cast<int8_t>(random(-36, 36));
std::cout << float(x_cpu_data[i]) << std::endl;
}
std::cout << "filter" << std::endl;
for (int i = 0; i < filter_cpu.numel(); i++) {
filter_cpu_data[i] = static_cast<int8_t>(random(-10, 10));
std::cout << float(filter_cpu_data[i]) << std::endl;
}
for (int i = 0; i < bias_cpu.numel(); i++) {
bias_cpu_data[i] = i + 1.0;
......@@ -221,6 +225,10 @@ TEST(conv_compute, int8_int8_out) {
bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data,
filter_cpu.dims());
std::vector<int> pads = {0, 0, 0, 0};
std::vector<int> dilations = {1, 1, 1, 1};
param.paddings = std::make_shared<std::vector<int>>(pads);
param.dilations = std::make_shared<std::vector<int>>(dilations);
param.x = &x;
param.filter = &filter;
param.output = &y;
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/backends/cuda/math/elementwise.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/elementwise_add_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void ElementwiseAddCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const lite::Tensor* x = param.X;
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims().production() == y->dims().production());
auto* x_data = x->data<float>();
auto* y_data = y->data<float>();
auto out_data = out->mutable_data<float>(TARGET(kCUDA));
int pixel_num = x->numel();
lite::cuda::math::elementwise_add(
pixel_num, x_data, y_data, out_data, stream);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddComputeNHWC::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const lite::Tensor* x = param.X;
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims().production() == y->dims().production());
auto* x_data = x->data<float>();
auto* y_data = y->data<float>();
auto out_data = out->mutable_data<float>(TARGET(kCUDA));
int pixel_num = x->numel();
lite::cuda::math::elementwise_add(
pixel_num, x_data, y_data, out_data, stream);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddComputeInt8::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const lite::Tensor* x = param.X;
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims().production() == y->dims().production());
const int c = x->dims()[3];
auto* x_data = x->data<float>();
auto* y_data = y->data<float>();
auto out_data = out->mutable_data<int8_t>(TARGET(kCUDA));
int pixel_num = x->numel();
float output_scale = param.output_scale;
if (c % 4 == 0) {
lite::cuda::math::elementwise_add_nhwc4_int8(
pixel_num / 4,
static_cast<const void*>(x_data),
static_cast<const void*>(y_data),
1. / output_scale,
static_cast<void*>(out_data),
stream);
} else {
lite::cuda::math::elementwise_add_int8(
pixel_num, x_data, y_data, 1. / output_scale, out_data, stream);
}
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(elementwise_add,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseAddCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_add,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseAddComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <vector>
#include "lite/backends/cuda/math/elementwise.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/elementwise_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
inline DDim trim_trailing_singular_dims(const DDim& dims) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
}
std::vector<int64_t> trim_dims;
trim_dims.resize(actual_dims_size);
for (int i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i];
}
if (trim_dims.size() == 0) {
return DDim();
}
return DDim(trim_dims);
}
inline bool is_broadcast(const DDim& x_dims,
const DDim& y_dims,
int axis,
int* pre,
int* n,
int* post) {
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
DDim y_dim_trim = trim_trailing_singular_dims(y_dims);
axis = (y_dim_trim.size() == 0) ? x_dims.size() : axis;
if (x_dims.size() == y_dim_trim.size()) {
return false;
}
*pre = 1;
*n = 1;
*post = 1;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dim_trim.size(); ++i) {
CHECK_EQ(x_dims[i + axis], y_dim_trim[i])
<< "Broadcast dimension mismatch.";
(*n) *= y_dim_trim[i];
}
for (int i = axis + y_dim_trim.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
return true;
}
#define ELEMENTWISE_COMPUTE(OP, WITH_RELU) \
auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
const lite::Tensor* x = param.X; \
const lite::Tensor* y = param.Y; \
lite::Tensor* out = param.Out; \
int axis = param.axis; \
auto* x_data = x->data<float>(); \
auto* y_data = y->data<float>(); \
auto out_data = out->mutable_data<float>(TARGET(kCUDA)); \
int pixel_num = x->numel(); \
int pre = 1; \
int n = pixel_num; \
int post = 1; \
if (WITH_RELU) { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise_relu( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise_relu( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
} else { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
}
#define ELEMENTWISE_COMPUTE_NHWC(OP, WITH_RELU) \
std::map<int, int> pos_map = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; \
auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
const lite::Tensor* x = param.X; \
const lite::Tensor* y = param.Y; \
lite::Tensor* out = param.Out; \
int axis = param.axis; \
if (axis < 0) axis = x->dims().size() - y->dims().size(); \
CHECK(axis >= 0) << "invalid axis of elementwise op"; \
axis = pos_map[axis]; \
auto* x_data = x->data<float>(); \
auto* y_data = y->data<float>(); \
auto out_data = out->mutable_data<float>(TARGET(kCUDA)); \
int pixel_num = x->numel(); \
int pre = 1; \
int n = pixel_num; \
int post = 1; \
if (WITH_RELU) { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise_relu( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise_relu( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
} else { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
}
void ElementwiseAddCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddReluCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, true)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddReluComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, true)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulReluCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, true)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulReluComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, true)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(elementwise_add,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseAddCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_add,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseAddComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_mul,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseMulCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_mul,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseMulComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_add_activation,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseAddReluCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_add_activation,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseAddReluComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseMulReluCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseMulReluComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
......@@ -38,13 +38,58 @@ class ElementwiseAddComputeNHWC
virtual ~ElementwiseAddComputeNHWC() = default;
};
class ElementwiseAddComputeInt8
class ElementwiseMulCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override;
virtual ~ElementwiseMulCompute() = default;
};
class ElementwiseMulComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override;
virtual ~ElementwiseAddComputeInt8() = default;
virtual ~ElementwiseMulComputeNHWC() = default;
};
class ElementwiseAddReluCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseAddReluCompute() = default;
};
class ElementwiseAddReluComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseAddReluComputeNHWC() = default;
};
class ElementwiseMulReluCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseMulReluCompute() = default;
};
class ElementwiseMulReluComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseMulReluComputeNHWC() = default;
};
} // namespace cuda
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/elementwise_add_compute.h"
#include "lite/kernels/cuda/elementwise_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
......@@ -31,6 +31,14 @@ static void ElementwiseAddRef(float* x, float* y, float* out, int num) {
}
}
static void ElementwiseBroadcastRef(
float* x, float* y, float* out, int pre, int n, int post) {
for (int i = 0; i < pre * n * post; ++i) {
int idx = (i / post) % n;
out[i] = x[i] + y[idx];
}
}
TEST(elementwise_add, normal) {
ElementwiseAddCompute elementwise_add_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
......@@ -99,38 +107,117 @@ TEST(elementwise_add, normal) {
}
}
TEST(elementwise_add, int8_out) {
ElementwiseAddComputeInt8 elementwise_add_kernel;
TEST(elementwise_add, bias) {
ElementwiseAddCompute elementwise_add_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ElementwiseParam param;
Tensor x, y, out;
Tensor x_cpu, y_cpu, out_cpu;
Tensor x_ref, y_ref, out_ref;
const int n = 1;
const int c = 3;
const int h = 2000;
const int w = 2000;
x.Resize({n, c, h, w});
y.Resize({c, 1, 1});
out.Resize({n, c, h, w});
x_cpu.Resize({n, c, h, w});
y_cpu.Resize({c, 1, 1});
out_cpu.Resize({n, c, h, w});
x_ref.Resize({n, c, h, w});
y_ref.Resize({c, 1, 1});
out_ref.Resize({n, c, h, w});
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* y_cpu_data = y_cpu.mutable_data<float>();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
auto* y_ref_data = y_ref.mutable_data<float>();
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = i + 5.0;
x_ref_data[i] = i + 5.0;
}
for (int i = 0; i < y_cpu.numel(); ++i) {
y_cpu_data[i] = i - 5.0;
y_ref_data[i] = i - 5.0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
y.Assign<float, lite::DDim, TARGET(kCUDA)>(y_cpu_data, y_cpu.dims());
param.X = &x;
param.Y = &y;
param.Out = &out;
param.axis = -1;
elementwise_add_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
elementwise_add_kernel.SetContext(std::move(ctx));
elementwise_add_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
ElementwiseBroadcastRef(x_ref_data, y_ref_data, out_ref_data, n, c, h * w);
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
}
TEST(elementwise_add_nhwc, bias) {
ElementwiseAddComputeNHWC elementwise_add_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ElementwiseParam param;
Tensor x, y, out;
Tensor x_cpu, y_cpu, out_cpu;
Tensor x_ref, y_ref, out_ref;
const int n = 1;
const int h = 36;
const int w = 36;
const int c = 125;
const int c = 3;
const int h = 2000;
const int w = 2000;
x.Resize({n, h, w, c});
y.Resize({n, h, w, c});
y.Resize({c, 1, 1});
out.Resize({n, h, w, c});
x_cpu.Resize({n, h, w, c});
y_cpu.Resize({n, h, w, c});
y_cpu.Resize({c, 1, 1});
out_cpu.Resize({n, h, w, c});
x_ref.Resize({n, h, w, c});
y_ref.Resize({c, 1, 1});
out_ref.Resize({n, h, w, c});
auto* out_data = out.mutable_data<int8_t>(TARGET(kCUDA));
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* y_cpu_data = y_cpu.mutable_data<float>();
auto* out_cpu_data = out_cpu.mutable_data<int8_t>();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
auto* y_ref_data = y_ref.mutable_data<float>();
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = i + 5.0;
x_ref_data[i] = i + 5.0;
}
for (int i = 0; i < y_cpu.numel(); ++i) {
y_cpu_data[i] = i;
y_cpu_data[i] = i - 5.0;
y_ref_data[i] = i - 5.0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
......@@ -139,7 +226,7 @@ TEST(elementwise_add, int8_out) {
param.X = &x;
param.Y = &y;
param.Out = &out;
param.output_scale = 50 / 127.;
param.axis = -1;
elementwise_add_kernel.SetParam(param);
cudaStream_t stream;
......@@ -147,16 +234,15 @@ TEST(elementwise_add, int8_out) {
context.SetExecStream(stream);
elementwise_add_kernel.SetContext(std::move(ctx));
auto start = GetCurrentUS();
for (int i = 0; i < 1000000; i++) {
elementwise_add_kernel.Launch();
}
LOG(INFO) << "time: " << (GetCurrentUS() - start) / 1000000.;
elementwise_add_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(int8_t) * out.numel(), IoDirection::DtoH);
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
ElementwiseBroadcastRef(
x_ref_data, y_ref_data, out_ref_data, n * h * w, c, 1);
for (int i = 0; i < out.numel(); i++) {
// LOG(INFO) << float(out_cpu_data[i]);
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
}
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/cuda/layout_compute.h"
#include <vector>
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/op_registry.h"
......@@ -21,11 +22,32 @@ namespace lite {
namespace kernels {
namespace cuda {
inline DDim trim_singular_dims(const DDim& dims) {
auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
}
std::vector<int64_t> trim_dims;
trim_dims.resize(actual_dims_size);
for (int i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i];
}
if (trim_dims.size() == 0) {
return DDim();
}
return DDim(trim_dims);
}
#define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
DDim input_trim_dim = trim_singular_dims(input_dim); \
if (input_trim_dim.size() == 1) { \
param.y->CopyDataFrom(*param.x); \
return; \
} \
CHECK(input_dim.size() == 4) \
<< "NCHW to NHWC should guarantee that the input dims should be 4"; \
int n = input_dim[0]; \
......@@ -41,6 +63,11 @@ namespace cuda {
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
DDim input_trim_dim = trim_singular_dims(input_dim); \
if (input_trim_dim.size() == 1) { \
param.y->CopyDataFrom(*param.x); \
return; \
} \
CHECK(input_dim.size() == 4) \
<< "NHWC to NCHW should guarantee that the input dims should be 4"; \
int n = input_dim[0]; \
......
......@@ -16,6 +16,7 @@
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include "lite/backends/cuda/blas.h"
namespace paddle {
namespace lite {
......@@ -26,6 +27,7 @@ TEST(mul_compute, normal) {
MulCompute mul_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
context.InitOnce();
Tensor x, y, out, x_cpu, y_cpu, out_cpu;
int x_h = 2, x_w_y_h = 3, y_w = 4;
......
......@@ -358,6 +358,61 @@ void PoolCompute::Run() {
if (error != cudaSuccess) LOG(FATAL) << cudaGetErrorString(error);
}
inline int PoolOutputSize(
int input_size, int filter_size, int padding, int stride, bool ceil_mode) {
int output_size;
if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1;
} else {
output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
}
return output_size;
}
void PoolComputeNHWC::PrepareForRun() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
pool_impl_.reset(new lite::cuda::math::CudnnPool2DNHWC<PRECISION(kFloat)>);
pool_impl_->init(param, &ctx);
}
void PoolComputeNHWC::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const auto x_dims = param.x->dims();
std::vector<int>& ksize = param.ksize;
if (param.global_pooling) {
ksize.resize(static_cast<size_t>(x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) {
(*param.paddings)[i] = 0;
ksize[i] = static_cast<int>(x_dims[i + 1]);
}
}
std::vector<int64_t> output_shape({x_dims[0]});
if (param.adaptive) {
output_shape.insert(
output_shape.end(), param.ksize.begin(), param.ksize.end());
} else {
for (size_t i = 0; i < param.ksize.size(); ++i) {
output_shape.push_back(PoolOutputSize(x_dims[i + 1],
param.ksize[i],
(*param.paddings)[i],
param.strides[i],
param.ceil_mode));
}
}
output_shape.push_back(x_dims[3]);
param.output->Resize(lite::DDim(output_shape));
pool_impl_->run(param);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(FATAL) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
......@@ -374,3 +429,19 @@ REGISTER_LITE_KERNEL(
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(pool2d,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::PoolComputeNHWC,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
......@@ -13,6 +13,9 @@
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "lite/backends/cuda/math/cudnn_pool.h"
#include "lite/core/kernel.h"
namespace paddle {
......@@ -29,6 +32,20 @@ class PoolCompute
virtual ~PoolCompute() = default;
};
class PoolComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::PoolParam;
void PrepareForRun() override;
void Run() override;
virtual ~PoolComputeNHWC() = default;
private:
std::unique_ptr<lite::cuda::math::CudnnPool2DNHWC<PRECISION(kFloat)>>
pool_impl_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
......
......@@ -27,6 +27,71 @@ namespace cuda {
using Tensor = lite::Tensor;
using DDim = lite::DDim;
#define IN(n, c, h, w) \
input_data[w + h * input_w + c * input_h * input_w + \
n * input_c * input_h * input_w]
#define OUT(n, c, h, w) \
output_data[w + h * output_w + c * output_h * output_w + \
n * output_c * output_h * output_w]
template <typename Dtype>
void nchw2nhwc_ref(lite::Tensor* input, lite::Tensor* output) {
auto* input_data = input->data<Dtype>();
auto* output_data = output->mutable_data<Dtype>();
int input_n = input->dims()[0];
int input_c = input->dims()[1];
int input_h = input->dims()[2];
int input_w = input->dims()[3];
int output_c = output->dims()[1];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
for (int n = 0; n < input_n; ++n) {
for (int c = 0; c < input_c; ++c) {
for (int h = 0; h < input_h; ++h) {
for (int w = 0; w < input_w; ++w) {
OUT(n, h, w, c) = IN(n, c, h, w);
}
}
}
}
}
#undef IN
#undef OUT
#define IN(n, h, w, c) \
input_data[c + w * input_c + h * input_w * input_c + \
n * input_h * input_w * input_c]
#define OUT(n, h, w, c) \
output_data[c + w * output_c + h * output_w * output_c + \
n * output_h * output_w * output_c]
template <typename Dtype>
void nhwc2nchw_ref(lite::Tensor* input, lite::Tensor* output) {
auto* input_data = input->data<Dtype>();
auto* output_data = output->mutable_data<Dtype>();
int input_n = input->dims()[0];
int input_h = input->dims()[1];
int input_w = input->dims()[2];
int input_c = input->dims()[3];
int output_h = output->dims()[1];
int output_w = output->dims()[2];
int output_c = output->dims()[3];
for (int n = 0; n < input_n; ++n) {
for (int c = 0; c < input_c; ++c) {
for (int h = 0; h < input_h; ++h) {
for (int w = 0; w < input_w; ++w) {
OUT(n, c, h, w) = IN(n, h, w, c);
}
}
}
}
}
static int PoolOutputSize(int input_size,
int filter_size,
int pad_left,
......@@ -46,7 +111,10 @@ static int PoolOutputSize(int input_size,
return output_size;
}
static std::vector<int64_t> compute_output_shape(operators::PoolParam* param_) {
static std::vector<int64_t> compute_output_shape(operators::PoolParam* param_,
bool is_nchw) {
int axis = 2;
if (!is_nchw) axis = 1;
const auto x_dims = param_->x->dims();
std::vector<int>& ksize = param_->ksize;
if (param_->global_pooling) {
......@@ -59,13 +127,15 @@ static std::vector<int64_t> compute_output_shape(operators::PoolParam* param_) {
}
}
std::vector<int64_t> output_shape({x_dims[0], x_dims[1]});
std::vector<int64_t> output_shape({x_dims[0]});
if (is_nchw) output_shape.push_back(x_dims[1]);
if (param_->adaptive) {
output_shape.insert(
output_shape.end(), param_->ksize.begin(), param_->ksize.end());
} else {
auto paddings = *param_->paddings;
for (size_t i = 0; i < param_->ksize.size(); ++i) {
output_shape.push_back(PoolOutputSize(x_dims[i + 2],
output_shape.push_back(PoolOutputSize(x_dims[i + axis],
param_->ksize[i],
paddings[2 * i],
paddings[2 * i + 1],
......@@ -73,6 +143,7 @@ static std::vector<int64_t> compute_output_shape(operators::PoolParam* param_) {
param_->ceil_mode));
}
}
if (!is_nchw) output_shape.push_back(x_dims[3]);
return output_shape;
}
......@@ -205,15 +276,15 @@ TEST(pool_cuda, compute) {
for (auto pad : {0, 1}) {
for (auto n : {1, 2}) {
for (auto c : {1, 3}) {
for (auto h : {2, 3, 4, 11}) {
for (auto w : {2, 3, 4, 11}) {
VLOG(3) << "n:" << n << " c:" << c << " h:" << h
<< " w:" << w << " ksize:" << ksize
<< " stride:" << stride << " pad:" << pad
<< " exclusive:" << exclusive
<< " global_pooling:" << global_pooling
<< " ceil_mode: " << ceil_mode
<< " pooling_type:" << pooling_type;
for (auto h : {3}) {
for (auto w : {3}) {
LOG(INFO) << "n:" << n << " c:" << c << " h:" << h
<< " w:" << w << " ksize:" << ksize
<< " stride:" << stride << " pad:" << pad
<< " exclusive:" << exclusive
<< " global_pooling:" << global_pooling
<< " ceil_mode: " << ceil_mode
<< " pooling_type:" << pooling_type;
// init x, output
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
......@@ -245,7 +316,7 @@ TEST(pool_cuda, compute) {
param.use_quantizer = false;
const std::vector<int64_t>& output_shape =
compute_output_shape(&param);
compute_output_shape(&param, true);
if (output_shape[2] * output_shape[3] == 0) continue;
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
......@@ -289,6 +360,131 @@ TEST(pool_cuda, compute) {
}
}
}
TEST(pool_cuda, nhwc) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
PoolComputeNHWC pool;
operators::PoolParam param;
pool.SetContext(std::move(ctx));
lite::Tensor x, temp;
lite::Tensor x_cpu;
lite::Tensor output;
lite::Tensor output_cpu, output_temp;
lite::Tensor output_ref;
for (auto pooling_type : {"max", "avg"}) {
for (auto ceil_mode : {false}) {
for (auto global_pooling : {true, false}) {
for (auto exclusive : {false, true}) {
for (auto ksize : {3}) {
for (auto stride : {3}) {
for (auto pad : {1}) {
for (auto n : {1}) {
for (auto c : {3}) {
for (auto h : {8}) {
for (auto w : {8}) {
LOG(INFO) << "n:" << n << " c:" << c << " h:" << h
<< " w:" << w << " ksize:" << ksize
<< " stride:" << stride << " pad:" << pad
<< " exclusive:" << exclusive
<< " global_pooling:" << global_pooling
<< " ceil_mode: " << ceil_mode
<< " pooling_type:" << pooling_type;
// init x, output
x.Resize(DDim(std::vector<int64_t>({n, h, w, c})));
temp.Resize(DDim(std::vector<int64_t>({n, h, w, c})));
x_cpu.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_cpu_data = x_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.dims().production(); ++i) {
float sign = i % 3 == 0 ? -0.03 : 0.05f;
x_cpu_data[i] = sign * (i % 128);
}
nchw2nhwc_ref<float>(&x_cpu, &temp);
auto* temp_cpu_data = temp.mutable_data<float>();
x.Assign<float, DDim, TARGET(kCUDA)>(temp_cpu_data,
temp.dims());
// fill param
param.x = &x;
param.output = &output;
param.pooling_type = pooling_type;
if (global_pooling) {
param.ksize = {h, w};
} else {
param.ksize = {ksize, ksize};
}
param.global_pooling = global_pooling;
param.strides = {stride, stride};
std::vector<int> paddings = {pad, pad, pad, pad};
param.paddings =
std::make_shared<std::vector<int>>(paddings);
param.exclusive = exclusive;
param.ceil_mode = ceil_mode;
param.adaptive = false;
param.use_quantizer = false;
const std::vector<int64_t>& output_shape =
compute_output_shape(&param, false);
if (output_shape[2] * output_shape[3] == 0) continue;
output.Resize(DDim(output_shape));
output_temp.Resize(DDim(output_shape));
output_cpu.Resize(DDim(output_shape));
auto* output_data =
output.mutable_data<float>(TARGET(kCUDA));
auto* output_cpu_data =
output_cpu.mutable_data<float>();
// compute
pool.SetParam(param);
pool.Launch();
// compute ref
param.x = &x_cpu;
// nchw
const std::vector<int64_t>& output_shape_ref =
compute_output_shape(&param, true);
output_ref.Resize(DDim(output_shape_ref));
// auto* output_ref_data =
// output_ref.mutable_data<float>();
param.output = &output_ref;
pool_compute_ref(param);
nchw2nhwc_ref<float>(&output_ref, &output_temp);
auto* output_temp_data =
output_temp.mutable_data<float>();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(output_cpu_data,
output_data,
sizeof(float) * output.numel(),
IoDirection::DtoH);
// compare
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(
output_cpu_data[i], output_temp_data[i], 1e-4);
}
VLOG(3) << "compare pass";
}
}
}
}
}
}
}
}
}
}
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
......
......@@ -57,7 +57,7 @@ void search_aligned_mat_mul_compute_ref(const operators::MatMulParam& param) {
auto x_data = x->data<T>();
auto y_data = y->data<T>();
auto out_data = out->mutable_data<T>();
#pragma omp parallel for
for (int seq = 0; seq < seq_num; seq++) {
auto a = x_data + seq * x_stride;
auto b = y_data + seq * y_stride;
......
......@@ -49,7 +49,6 @@ void search_seq_fc_compute_ref(const operators::SearchSeqFcParam& param) {
auto w_data = w->data<T>();
auto out_data = out->mutable_data<T>();
#pragma omp parallel for
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
auto sum = static_cast<T>(0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册