未验证 提交 486d6572 编写于 作者: W Wilber 提交者: GitHub

[Code Format] Update code format. (#3890)

上级 45457074
...@@ -33,91 +33,91 @@ namespace cuda { ...@@ -33,91 +33,91 @@ namespace cuda {
class AssignValueTest : public ::testing::Test { class AssignValueTest : public ::testing::Test {
protected: protected:
AssignValueTest() : dtype(5), shape({1}) { AssignValueTest() : dtype_(5), shape_({1}) {
int num = int num = std::accumulate(
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); shape_.begin(), shape_.end(), 1, std::multiplies<int>());
fp32_values.resize(num); fp32_values_.resize(num);
int32_values.resize(num); int32_values_.resize(num);
int64_values.resize(num); int64_values_.resize(num);
bool_values.resize(num); bool_values_.resize(num);
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
fp32_values[i] = i + 5; fp32_values_[i] = i + 5;
int32_values[i] = i; int32_values_[i] = i;
int64_values[i] = i; int64_values_[i] = i;
bool_values[i] = i; bool_values_[i] = i;
} }
std::vector<int64_t> out_shape(shape.size(), 0); std::vector<int64_t> out_shape(shape_.size(), 0);
for (size_t i = 0; i < shape.size(); ++i) out_shape[i] = shape[i]; for (size_t i = 0; i < shape_.size(); ++i) out_shape[i] = shape_[i];
Out_ref.Resize(lite::DDim(out_shape)); out_ref_.Resize(lite::DDim(out_shape));
Out_gpu.Resize(Out_ref.dims()); out_gpu_.Resize(out_ref_.dims());
Out_cpu.Resize(Out_ref.dims()); out_cpu_.Resize(out_ref_.dims());
cpu_base(&Out_ref); RunBaseLine(&out_ref_);
device_init(); InitParamAndContext();
} }
void device_init() { void InitParamAndContext() {
ctx.reset(new KernelContext); ctx_.reset(new KernelContext);
cudaStreamCreate(&stream); cudaStreamCreate(&stream_);
auto& context = ctx->As<CUDAContext>(); auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream); context.SetExecStream(stream_);
param.shape = shape; param_.shape = shape_;
param.dtype = dtype; param_.dtype = dtype_;
param.fp32_values = fp32_values; param_.fp32_values = fp32_values_;
param.int32_values = int32_values; param_.int32_values = int32_values_;
param.int64_values = int64_values; param_.int64_values = int64_values_;
param.bool_values = bool_values; param_.bool_values = bool_values_;
param.Out = &Out_gpu; param_.Out = &out_gpu_;
} }
void float_data_init() {} void InitFloatInput() {}
void half_data_init() {} void InitHalfInput() {}
void cpu_base(lite::Tensor* Out) { void RunBaseLine(lite::Tensor* out) {
if (dtype == static_cast<int>(lite::core::FluidType::INT32)) { if (dtype_ == static_cast<int>(lite::core::FluidType::INT32)) {
for (size_t i = 0; i < int32_values.size(); ++i) { for (size_t i = 0; i < int32_values_.size(); ++i) {
Out->mutable_data<int>()[i] = int32_values[i]; out->mutable_data<int>()[i] = int32_values_[i];
} }
} else if (dtype == static_cast<int>(lite::core::FluidType::FP32)) { } else if (dtype_ == static_cast<int>(lite::core::FluidType::FP32)) {
for (size_t i = 0; i < fp32_values.size(); ++i) { for (size_t i = 0; i < fp32_values_.size(); ++i) {
Out->mutable_data<float>()[i] = fp32_values[i]; out->mutable_data<float>()[i] = fp32_values_[i];
} }
} else if (dtype == static_cast<int>(lite::core::FluidType::INT64)) { } else if (dtype_ == static_cast<int>(lite::core::FluidType::INT64)) {
for (size_t i = 0; i < int64_values.size(); ++i) { for (size_t i = 0; i < int64_values_.size(); ++i) {
Out->mutable_data<int64_t>()[i] = int64_values[i]; out->mutable_data<int64_t>()[i] = int64_values_[i];
} }
} else if (dtype == static_cast<bool>(lite::core::FluidType::BOOL)) { } else if (dtype_ == static_cast<bool>(lite::core::FluidType::BOOL)) {
for (size_t i = 0; i < bool_values.size(); ++i) { for (size_t i = 0; i < bool_values_.size(); ++i) {
Out->mutable_data<bool>()[i] = bool_values[i]; out->mutable_data<bool>()[i] = bool_values_[i];
} }
} else { } else {
LOG(FATAL) << "Unsupported dtype for assign_value_op:" << dtype; LOG(FATAL) << "Unsupported dtype_ for assign_value_op:" << dtype_;
} }
} }
int dtype; int dtype_;
std::vector<int> shape; std::vector<int> shape_;
std::vector<float> fp32_values; std::vector<float> fp32_values_;
std::vector<int> int32_values; std::vector<int> int32_values_;
std::vector<int64_t> int64_values; std::vector<int64_t> int64_values_;
std::vector<int> bool_values; std::vector<int> bool_values_;
lite::Tensor Out_ref; lite::Tensor out_ref_;
lite::Tensor Out_gpu; lite::Tensor out_gpu_;
lite::Tensor Out_cpu; lite::Tensor out_cpu_;
operators::AssignValueParam param; operators::AssignValueParam param_;
std::unique_ptr<KernelContext> ctx; std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream; cudaStream_t stream_;
}; };
TEST_F(AssignValueTest, fp32) { TEST_F(AssignValueTest, fp32) {
float_data_init(); InitFloatInput();
AssignValueCompute kernel; AssignValueCompute kernel;
kernel.SetParam(param); kernel.SetParam(param_);
kernel.SetContext(std::move(ctx)); kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch(); kernel.Launch();
...@@ -135,12 +135,12 @@ TEST_F(AssignValueTest, fp32) { ...@@ -135,12 +135,12 @@ TEST_F(AssignValueTest, fp32) {
<< ", repeats: " << FLAGS_repeats << ", spend " << ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average."; << duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(), CopySync<TARGET(kCUDA)>(out_cpu_.mutable_data<float>(),
Out_gpu.data<float>(), out_gpu_.data<float>(),
sizeof(float) * Out_gpu.numel(), sizeof(float) * out_gpu_.numel(),
IoDirection::DtoH); IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) { for (int i = 0; i < out_gpu_.numel(); ++i) {
EXPECT_NEAR(Out_cpu.data<float>()[i], Out_ref.data<float>()[i], 1e-5); EXPECT_NEAR(out_cpu_.data<float>()[i], out_ref_.data<float>()[i], 1e-5);
} }
} }
......
...@@ -33,7 +33,7 @@ struct FcTypeTraits<float> { ...@@ -33,7 +33,7 @@ struct FcTypeTraits<float> {
}; };
template <typename T> template <typename T>
__global__ void bias_v4(const int num, const T* bias, T* data, int K) { __global__ void AddBiasV4(const int num, const T* bias, T* data, int K) {
CUDA_KERNEL_LOOP(index, num) { CUDA_KERNEL_LOOP(index, num) {
int bias_idx = index % K; int bias_idx = index % K;
const T bias_ptr = bias[bias_idx]; const T bias_ptr = bias[bias_idx];
...@@ -48,7 +48,7 @@ __global__ void bias_v4(const int num, const T* bias, T* data, int K) { ...@@ -48,7 +48,7 @@ __global__ void bias_v4(const int num, const T* bias, T* data, int K) {
} }
template <typename T> template <typename T>
__global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) { __global__ void AddBiasReluV4(const int num, const T* bias, T* data, int K) {
CUDA_KERNEL_LOOP(index, num) { CUDA_KERNEL_LOOP(index, num) {
int bias_idx = index % K; int bias_idx = index % K;
const T bias_ptr = bias[bias_idx]; const T bias_ptr = bias[bias_idx];
...@@ -63,7 +63,7 @@ __global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) { ...@@ -63,7 +63,7 @@ __global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) {
} }
template <typename T> template <typename T>
__global__ void general_bias(const int num, const T* bias, T* data) { __global__ void AddBias(const int num, const T* bias, T* data) {
int offset = blockIdx.x * num; int offset = blockIdx.x * num;
for (int i = threadIdx.x; i < num; i += blockDim.x) { for (int i = threadIdx.x; i < num; i += blockDim.x) {
...@@ -78,7 +78,7 @@ __global__ void general_bias(const int num, const T* bias, T* data) { ...@@ -78,7 +78,7 @@ __global__ void general_bias(const int num, const T* bias, T* data) {
} }
template <typename T> template <typename T>
__global__ void general_relu_bias(const int num, const T* bias, T* data) { __global__ void AddBiasRelu(const int num, const T* bias, T* data) {
int offset = blockIdx.x * num; int offset = blockIdx.x * num;
for (int i = threadIdx.x; i < num; i += blockDim.x) { for (int i = threadIdx.x; i < num; i += blockDim.x) {
...@@ -140,10 +140,10 @@ void FcCompute<T, PType>::Run() { ...@@ -140,10 +140,10 @@ void FcCompute<T, PType>::Run() {
const auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(b_data); const auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(b_data);
auto* data_ptr_v4 = reinterpret_cast<trans_type*>(out_data); auto* data_ptr_v4 = reinterpret_cast<trans_type*>(out_data);
if (activation_type == "relu") { if (activation_type == "relu") {
bias_relu_v4<trans_type><<<blocks, threads, 0, stream>>>( AddBiasReluV4<trans_type><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4); num, bias_ptr_v4, data_ptr_v4, N / 4);
} else if (activation_type == "") { } else if (activation_type == "") {
bias_v4<trans_type><<<blocks, threads, 0, stream>>>( AddBiasV4<trans_type><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4); num, bias_ptr_v4, data_ptr_v4, N / 4);
} else { } else {
LOG(FATAL) << "not supported activation type: " << activation_type; LOG(FATAL) << "not supported activation type: " << activation_type;
...@@ -152,9 +152,9 @@ void FcCompute<T, PType>::Run() { ...@@ -152,9 +152,9 @@ void FcCompute<T, PType>::Run() {
const int threads = 256; const int threads = 256;
const int blocks = M; const int blocks = M;
if (activation_type == "relu") { if (activation_type == "relu") {
general_relu_bias<T><<<blocks, threads, 0, stream>>>(N, b_data, out_data); AddBiasRelu<T><<<blocks, threads, 0, stream>>>(N, b_data, out_data);
} else if (activation_type == "") { } else if (activation_type == "") {
general_bias<T><<<blocks, threads, 0, stream>>>(N, b_data, out_data); AddBias<T><<<blocks, threads, 0, stream>>>(N, b_data, out_data);
} else { } else {
LOG(FATAL) << "not supported activation type: " << activation_type; LOG(FATAL) << "not supported activation type: " << activation_type;
} }
......
...@@ -31,101 +31,101 @@ namespace cuda { ...@@ -31,101 +31,101 @@ namespace cuda {
class FcTest : public ::testing::Test { class FcTest : public ::testing::Test {
protected: protected:
FcTest() FcTest()
: m(128), : m_(128),
k(512), k_(512),
n(64), n_(64),
in_num_col_dims(1), in_num_col_dims_(1),
act_type("relu"), act_type_("relu"),
x_shape({m, k}), x_shape_({m_, k_}),
w_shape({k, n}), w_shape_({k_, n_}),
b_shape({n}), b_shape_({n_}),
out_shape({m, n}) { out_shape_({m_, n_}) {
X_gpu.Resize(lite::DDim(x_shape)); x_ref_.Resize(lite::DDim(x_shape_));
X_ref.Resize(lite::DDim(x_shape)); x_gpu_.Resize(lite::DDim(x_shape_));
W_gpu.Resize(lite::DDim(w_shape)); w_ref_.Resize(lite::DDim(w_shape_));
W_ref.Resize(lite::DDim(w_shape)); w_gpu_.Resize(lite::DDim(w_shape_));
b_gpu.Resize(lite::DDim(b_shape)); b_ref_.Resize(lite::DDim(b_shape_));
b_ref.Resize(lite::DDim(b_shape)); b_gpu_.Resize(lite::DDim(b_shape_));
auto x_ref_data = X_ref.mutable_data<float>(); auto x_ref_data = x_ref_.mutable_data<float>();
auto w_ref_data = W_ref.mutable_data<float>(); auto w_ref_data = w_ref_.mutable_data<float>();
auto b_ref_data = b_ref.mutable_data<float>(); auto b_ref_data = b_ref_.mutable_data<float>();
// prepare input // prepare input
for (int64_t i = 0; i < X_ref.numel(); i++) { for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i % 10 * 0.2); x_ref_data[i] = static_cast<float>(i % 10 * 0.2);
} }
for (int64_t i = 0; i < W_ref.numel(); i++) { for (int64_t i = 0; i < w_ref_.numel(); i++) {
w_ref_data[i] = static_cast<float>(i % 10 * 0.2); w_ref_data[i] = static_cast<float>(i % 10 * 0.2);
} }
for (int64_t i = 0; i < b_ref.numel(); i++) { for (int64_t i = 0; i < b_ref_.numel(); i++) {
b_ref_data[i] = static_cast<float>(i % 10 * 0.2); b_ref_data[i] = static_cast<float>(i % 10 * 0.2);
} }
Out_ref.Resize(lite::DDim(out_shape)); out_ref_.Resize(lite::DDim(out_shape_));
Out_cpu.Resize(Out_ref.dims()); out_cpu_.Resize(out_ref_.dims());
Out_gpu.Resize(Out_ref.dims()); out_gpu_.Resize(out_ref_.dims());
fc_cpu_base(&X_ref, &W_ref, &b_ref, &Out_ref); RunBaseLine(&x_ref_, &w_ref_, &b_ref_, &out_ref_);
device_init(); InitParamAndContext();
} }
void device_init() { void InitParamAndContext() {
ctx.reset(new KernelContext); ctx_.reset(new KernelContext);
cudaStreamCreate(&stream); cudaStreamCreate(&stream_);
auto& context = ctx->As<CUDAContext>(); auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream); context.SetExecStream(stream_);
param.input = &X_gpu; param_.input = &x_gpu_;
param.w = &W_gpu; param_.w = &w_gpu_;
param.bias = &b_gpu; param_.bias = &b_gpu_;
param.in_num_col_dims = in_num_col_dims; param_.in_num_col_dims = in_num_col_dims_;
param.activation_type = act_type; param_.activation_type = act_type_;
param.output = &Out_gpu; param_.output = &out_gpu_;
} }
void float_data_init() { void InitFloatInput() {
X_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(X_ref.data<float>(), x_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(x_ref_.data<float>(),
X_gpu.dims()); x_gpu_.dims());
W_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(W_ref.data<float>(), w_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(w_ref_.data<float>(),
W_gpu.dims()); w_gpu_.dims());
b_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(b_ref.data<float>(), b_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(b_ref_.data<float>(),
b_gpu.dims()); b_gpu_.dims());
} }
void half_data_init() { void InitHalfInput() {
X_half.Resize(lite::DDim(x_shape)); x_half_.Resize(lite::DDim(x_shape_));
auto x_half_data = X_half.mutable_data<half>(); auto x_half_data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < X_half.numel(); i++) { for (int64_t i = 0; i < x_half_.numel(); i++) {
x_half_data[i] = half(lite::float16(X_ref.data<float>()[i])); x_half_data[i] = half(lite::float16(x_ref_.data<float>()[i]));
} }
X_gpu.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, X_gpu.dims()); x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, x_gpu_.dims());
W_half.Resize(W_ref.dims()); w_half_.Resize(w_ref_.dims());
auto w_half_data = W_half.mutable_data<half>(); auto w_half_data = w_half_.mutable_data<half>();
for (int64_t i = 0; i < W_half.numel(); i++) { for (int64_t i = 0; i < w_half_.numel(); i++) {
w_half_data[i] = half(lite::float16(W_ref.data<float>()[i])); w_half_data[i] = half(lite::float16(w_ref_.data<float>()[i]));
} }
W_gpu.Assign<half, lite::DDim, TARGET(kCUDA)>(w_half_data, W_gpu.dims()); w_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(w_half_data, w_gpu_.dims());
b_half.Resize(b_ref.dims()); b_half_.Resize(b_ref_.dims());
auto b_half_data = b_half.mutable_data<half>(); auto b_half_data = b_half_.mutable_data<half>();
for (int64_t i = 0; i < b_half.numel(); i++) { for (int64_t i = 0; i < b_half_.numel(); i++) {
b_half_data[i] = half(lite::float16(b_ref.data<float>()[i])); b_half_data[i] = half(lite::float16(b_ref_.data<float>()[i]));
} }
b_gpu.Assign<half, lite::DDim, TARGET(kCUDA)>(b_half_data, b_gpu.dims()); b_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(b_half_data, b_gpu_.dims());
} }
void fc_cpu_base(const lite::Tensor* X, void RunBaseLine(const lite::Tensor* x,
const lite::Tensor* W, const lite::Tensor* w,
const lite::Tensor* b, const lite::Tensor* b,
lite::Tensor* Out) { lite::Tensor* out) {
const float* data_in = X->data<float>(); const float* data_in = x->data<float>();
const float* bias = b->data<float>(); const float* bias = b->data<float>();
const float* weights = W->data<float>(); const float* weights = w->data<float>();
float* data_out = Out->mutable_data<float>(); float* data_out = out->mutable_data<float>();
int out_rows = X->dims()[0]; int out_rows = x->dims()[0];
int in_cols = X->numel() / out_rows; int in_cols = x->numel() / out_rows;
int out_cols = W->numel() / in_cols; int out_cols = w->numel() / in_cols;
int index_out; int index_out;
for (int i = 0; i < out_rows; i++) { for (int i = 0; i < out_rows; i++) {
for (int j = 0; j < out_cols; j++) { for (int j = 0; j < out_cols; j++) {
...@@ -135,31 +135,31 @@ class FcTest : public ::testing::Test { ...@@ -135,31 +135,31 @@ class FcTest : public ::testing::Test {
data_out[index_out] += data_out[index_out] +=
data_in[i * in_cols + k] * weights[k * out_cols + j]; data_in[i * in_cols + k] * weights[k * out_cols + j];
} }
if (act_type == "relu") { if (act_type_ == "relu") {
data_out[index_out] *= static_cast<int>(data_out[index_out] > 0); data_out[index_out] *= static_cast<int>(data_out[index_out] > 0);
} }
} }
} }
} }
int m, k, n, in_num_col_dims; int m_, k_, n_, in_num_col_dims_;
std::string act_type; std::string act_type_;
std::vector<int64_t> x_shape, w_shape, b_shape, out_shape; std::vector<int64_t> x_shape_, w_shape_, b_shape_, out_shape_;
lite::Tensor X_ref, W_ref, b_ref, Out_ref; lite::Tensor x_ref_, w_ref_, b_ref_, out_ref_;
lite::Tensor X_gpu, W_gpu, b_gpu; lite::Tensor x_gpu_, w_gpu_, b_gpu_;
lite::Tensor X_half, W_half, b_half; lite::Tensor x_half_, w_half_, b_half_;
lite::Tensor Out_cpu, Out_gpu; lite::Tensor out_cpu_, out_gpu_;
operators::FcParam param; operators::FcParam param_;
std::unique_ptr<KernelContext> ctx; std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream; cudaStream_t stream_;
}; };
TEST_F(FcTest, TestFP32) { TEST_F(FcTest, TestFP32) {
float_data_init(); InitFloatInput();
FcCompute<float, PRECISION(kFloat)> kernel; FcCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param); kernel.SetParam(param_);
kernel.SetContext(std::move(ctx)); kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch(); kernel.Launch();
...@@ -177,14 +177,14 @@ TEST_F(FcTest, TestFP32) { ...@@ -177,14 +177,14 @@ TEST_F(FcTest, TestFP32) {
<< ", repeats: " << FLAGS_repeats << ", spend " << ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average."; << duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(), CopySync<TARGET(kCUDA)>(out_cpu_.mutable_data<float>(),
Out_gpu.data<float>(), out_gpu_.data<float>(),
sizeof(float) * Out_gpu.numel(), sizeof(float) * out_gpu_.numel(),
IoDirection::DtoH); IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) { for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = Out_cpu.data<float>()[i]; float res = out_cpu_.data<float>()[i];
float ref = Out_ref.data<float>()[i]; float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / ref, 0.f, 1e-5); EXPECT_NEAR(fabs(res - ref) / ref, 0.f, 1e-5);
} }
} }
......
...@@ -28,11 +28,6 @@ class SequenceMaskCompute : public KernelLite<TARGET(kCUDA), Ptype> { ...@@ -28,11 +28,6 @@ class SequenceMaskCompute : public KernelLite<TARGET(kCUDA), Ptype> {
void Run() override; void Run() override;
virtual ~SequenceMaskCompute() = default; virtual ~SequenceMaskCompute() = default;
// private:
// lite::Tensor seq_offsets_;
// std::vector<int64_t> seq_len_;
// std::vector<size_t> seq_offsets_vec_;
}; };
} // namespace cuda } // namespace cuda
......
...@@ -32,73 +32,73 @@ namespace cuda { ...@@ -32,73 +32,73 @@ namespace cuda {
class SequenceMaskTest : public ::testing::Test { class SequenceMaskTest : public ::testing::Test {
protected: protected:
SequenceMaskTest() SequenceMaskTest()
: maxlen(4), : maxlen_(4),
out_dtype(5), out_dtype_(5),
x_data({3, 2, 1, 0}), x_data_({3, 2, 1, 0}),
out_shape({static_cast<int64_t>(x_data.size()), maxlen}) { out_shape_({static_cast<int64_t>(x_data_.size()), maxlen_}) {
X_ref.Resize(lite::DDim({static_cast<int64_t>(x_data.size())})); x_ref_.Resize(lite::DDim({static_cast<int64_t>(x_data_.size())}));
X_gpu.Resize(X_ref.dims()); x_gpu_.Resize(x_ref_.dims());
auto* x_ref_data = X_ref.mutable_data<int64_t>(); auto* x_ref_data = x_ref_.mutable_data<int64_t>();
// prepare input // prepare input
for (size_t i = 0; i < x_data.size(); i++) { for (size_t i = 0; i < x_data_.size(); i++) {
x_ref_data[i] = x_data[i]; x_ref_data[i] = x_data_[i];
} }
Out_ref.Resize(lite::DDim(out_shape)); out_ref_.Resize(lite::DDim(out_shape_));
Out_gpu.Resize(Out_ref.dims()); out_gpu_.Resize(out_ref_.dims());
Out_cpu.Resize(Out_ref.dims()); out_cpu_.Resize(out_ref_.dims());
cpu_base(&X_ref, &Out_ref); RunBaseLine(&x_ref_, &out_ref_);
device_init(); InitParamAndContext();
} }
void device_init() { void InitParamAndContext() {
ctx.reset(new KernelContext); ctx_.reset(new KernelContext);
cudaStreamCreate(&stream); cudaStreamCreate(&stream_);
auto& context = ctx->As<CUDAContext>(); auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream); context.SetExecStream(stream_);
param.X = &X_gpu; param_.X = &x_gpu_;
param.Y = &Out_gpu; param_.Y = &out_gpu_;
param.maxlen = maxlen; param_.maxlen = maxlen_;
param.out_dtype = out_dtype; param_.out_dtype = out_dtype_;
} }
void float_data_init() { void InitFloatInput() {
X_gpu.Assign<int64_t, lite::DDim, TARGET(kCUDA)>(X_ref.data<int64_t>(), x_gpu_.Assign<int64_t, lite::DDim, TARGET(kCUDA)>(x_ref_.data<int64_t>(),
X_gpu.dims()); x_gpu_.dims());
} }
void half_data_init() {} void InitHalfInput() {}
void cpu_base(const lite::Tensor* X, lite::Tensor* Out) { void RunBaseLine(const lite::Tensor* x, lite::Tensor* out) {
auto* out_data = Out->mutable_data<float>(); auto* out_data = out->mutable_data<float>();
for (size_t i = 0; i < x_data.size(); ++i) { for (size_t i = 0; i < x_data_.size(); ++i) {
for (int j = 0; j < maxlen; ++j) { for (int j = 0; j < maxlen_; ++j) {
out_data[i * maxlen + j] = j < x_data[i] ? 1 : 0; out_data[i * maxlen_ + j] = j < x_data_[i] ? 1 : 0;
} }
} }
} }
int maxlen, out_dtype; int maxlen_, out_dtype_;
std::vector<int64_t> x_data, out_shape; std::vector<int64_t> x_data_, out_shape_;
lite::Tensor X_ref, Out_ref; lite::Tensor x_ref_, out_ref_;
lite::Tensor X_gpu, Out_gpu; lite::Tensor x_gpu_, out_gpu_;
lite::Tensor Out_cpu; lite::Tensor out_cpu_;
operators::SequenceMaskParam param; operators::SequenceMaskParam param_;
std::unique_ptr<KernelContext> ctx; std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream; cudaStream_t stream_;
}; };
TEST_F(SequenceMaskTest, fp32) { TEST_F(SequenceMaskTest, fp32) {
float_data_init(); InitFloatInput();
SequenceMaskCompute<float, PRECISION(kFloat)> kernel; SequenceMaskCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param); kernel.SetParam(param_);
kernel.SetContext(std::move(ctx)); kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch(); kernel.Launch();
...@@ -116,12 +116,12 @@ TEST_F(SequenceMaskTest, fp32) { ...@@ -116,12 +116,12 @@ TEST_F(SequenceMaskTest, fp32) {
<< ", repeats: " << FLAGS_repeats << ", spend " << ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average."; << duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(), CopySync<TARGET(kCUDA)>(out_cpu_.mutable_data<float>(),
Out_gpu.data<float>(), out_gpu_.data<float>(),
sizeof(float) * Out_gpu.numel(), sizeof(float) * out_gpu_.numel(),
IoDirection::DtoH); IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) { for (int i = 0; i < out_gpu_.numel(); ++i) {
EXPECT_NEAR(Out_cpu.data<float>()[i], Out_ref.data<float>()[i], 1e-5); EXPECT_NEAR(out_cpu_.data<float>()[i], out_ref_.data<float>()[i], 1e-5);
} }
} }
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "lite/api/test_helper.h" #include "lite/api/test_helper.h"
#include "lite/backends/cuda/cuda_utils.h" #include "lite/backends/cuda/cuda_utils.h"
// #include "lite/utils/float16.h" #include "lite/utils/float16.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -33,72 +33,73 @@ namespace cuda { ...@@ -33,72 +33,73 @@ namespace cuda {
class SequencePadTest : public ::testing::Test { class SequencePadTest : public ::testing::Test {
protected: protected:
SequencePadTest() SequencePadTest()
: batch(5), : batch_(5),
features(2), features_(2),
padded_length(3), padded_length_(3),
x_lod({{0, 2, 5}}), x_lod_({{0, 2, 5}}),
x_shape({batch, features}), x_shape_({batch_, features_}),
pad_value_shape({features}), pad_value_shape_({features_}),
out_shape({static_cast<int64_t>(x_lod[0].size() - 1), out_shape_({static_cast<int64_t>(x_lod_[0].size() - 1),
padded_length, padded_length_,
features}) { features_}) {
X_ref.Resize(lite::DDim(x_shape)); x_ref_.Resize(lite::DDim(x_shape_));
X_ref.set_lod(x_lod); x_ref_.set_lod(x_lod_);
X_gpu.Resize(X_ref.dims()); x_gpu_.Resize(x_ref_.dims());
PadValue_ref.Resize(lite::DDim(pad_value_shape)); pad_value_ref_.Resize(lite::DDim(pad_value_shape_));
PadValue_gpu.Resize(PadValue_ref.dims()); pad_value_gpu_.Resize(pad_value_ref_.dims());
Length_ref.Resize(lite::DDim({static_cast<int64_t>(x_lod[0].size() - 1)})); length_ref_.Resize(
Length_gpu.Resize(Length_ref.dims()); lite::DDim({static_cast<int64_t>(x_lod_[0].size() - 1)}));
length_gpu_.Resize(length_ref_.dims());
auto x_ref_data = X_ref.mutable_data<float>();
auto pad_value_ref_data = PadValue_ref.mutable_data<float>(); auto x_ref_data = x_ref_.mutable_data<float>();
auto pad_value_ref_data = pad_value_ref_.mutable_data<float>();
// prepare input // prepare input
for (int64_t i = 0; i < X_ref.numel(); i++) { for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i); x_ref_data[i] = static_cast<float>(i);
} }
for (int64_t i = 0; i < PadValue_ref.numel(); i++) { for (int64_t i = 0; i < pad_value_ref_.numel(); i++) {
pad_value_ref_data[i] = static_cast<float>(i); pad_value_ref_data[i] = static_cast<float>(i);
} }
Out_ref.Resize(lite::DDim(out_shape)); out_ref_.Resize(lite::DDim(out_shape_));
Out_gpu.Resize(Out_ref.dims()); out_gpu_.Resize(out_ref_.dims());
Out_cpu.Resize(Out_ref.dims()); out_cpu_.Resize(out_ref_.dims());
cpu_base(&X_ref, &PadValue_ref, &Out_ref, &Length_ref); RunBaseLine(&x_ref_, &pad_value_ref_, &out_ref_, &length_ref_);
device_init(); InitParamAndContext();
} }
void device_init() { void InitParamAndContext() {
ctx.reset(new KernelContext); ctx_.reset(new KernelContext);
cudaStreamCreate(&stream); cudaStreamCreate(&stream_);
auto& context = ctx->As<CUDAContext>(); auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream); context.SetExecStream(stream_);
param.X = &X_gpu; param_.X = &x_gpu_;
param.PadValue = &PadValue_gpu; param_.PadValue = &pad_value_gpu_;
param.Length = &Length_gpu; param_.Length = &length_gpu_;
param.Out = &Out_gpu; param_.Out = &out_gpu_;
param.padded_length = padded_length; param_.padded_length = padded_length_;
} }
void float_data_init() { void InitFloatInput() {
X_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(X_ref.data<float>(), x_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(x_ref_.data<float>(),
X_gpu.dims()); x_gpu_.dims());
X_gpu.set_lod(X_ref.lod()); x_gpu_.set_lod(x_ref_.lod());
PadValue_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>( pad_value_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(
PadValue_ref.data<float>(), PadValue_gpu.dims()); pad_value_ref_.data<float>(), pad_value_gpu_.dims());
} }
void half_data_init() {} void InitHalfInput() {}
void cpu_base(const lite::Tensor* X, void RunBaseLine(const lite::Tensor* x,
const lite::Tensor* PadValue, const lite::Tensor* pad_value,
lite::Tensor* Out, lite::Tensor* out,
lite::Tensor* Length) { lite::Tensor* length) {
auto* length_data = Length->mutable_data<int64_t>(); auto* length_data = length->mutable_data<int64_t>();
auto* out_data = Out->mutable_data<float>(); auto* out_data = out->mutable_data<float>();
length_data[0] = 2; length_data[0] = 2;
length_data[1] = 3; length_data[1] = 3;
...@@ -112,24 +113,24 @@ class SequencePadTest : public ::testing::Test { ...@@ -112,24 +113,24 @@ class SequencePadTest : public ::testing::Test {
} }
} }
int batch, features, padded_length; int batch_, features_, padded_length_;
LoD x_lod; LoD x_lod_;
std::vector<int64_t> x_shape, pad_value_shape, out_shape; std::vector<int64_t> x_shape_, pad_value_shape_, out_shape_;
lite::Tensor X_ref, PadValue_ref, Out_ref, Length_ref; lite::Tensor x_ref_, pad_value_ref_, out_ref_, length_ref_;
lite::Tensor X_gpu, PadValue_gpu, Out_gpu, Length_gpu; lite::Tensor x_gpu_, pad_value_gpu_, out_gpu_, length_gpu_;
lite::Tensor Out_cpu, Length_cpu; lite::Tensor out_cpu_, length_cpu_;
operators::SequencePadParam param; operators::SequencePadParam param_;
std::unique_ptr<KernelContext> ctx; std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream; cudaStream_t stream_;
}; };
TEST_F(SequencePadTest, fp32) { TEST_F(SequencePadTest, fp32) {
float_data_init(); InitFloatInput();
SequencePadCompute<float, PRECISION(kFloat)> kernel; SequencePadCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param); kernel.SetParam(param_);
kernel.SetContext(std::move(ctx)); kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch(); kernel.Launch();
...@@ -147,20 +148,20 @@ TEST_F(SequencePadTest, fp32) { ...@@ -147,20 +148,20 @@ TEST_F(SequencePadTest, fp32) {
<< ", repeats: " << FLAGS_repeats << ", spend " << ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average."; << duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(), CopySync<TARGET(kCUDA)>(out_cpu_.mutable_data<float>(),
Out_gpu.data<float>(), out_gpu_.data<float>(),
sizeof(float) * Out_gpu.numel(), sizeof(float) * out_gpu_.numel(),
IoDirection::DtoH); IoDirection::DtoH);
CopySync<TARGET(kCUDA)>(Length_cpu.mutable_data<int64_t>(), CopySync<TARGET(kCUDA)>(length_cpu_.mutable_data<int64_t>(),
Length_gpu.data<int64_t>(), length_gpu_.data<int64_t>(),
sizeof(int64_t) * Length_gpu.numel(), sizeof(int64_t) * length_gpu_.numel(),
IoDirection::DtoH); IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) { for (int i = 0; i < out_gpu_.numel(); ++i) {
EXPECT_NEAR(Out_cpu.data<float>()[i], Out_ref.data<float>()[i], 1e-5); EXPECT_NEAR(out_cpu_.data<float>()[i], out_ref_.data<float>()[i], 1e-5);
} }
for (int i = 0; i < Length_gpu.numel(); ++i) { for (int i = 0; i < length_gpu_.numel(); ++i) {
EXPECT_NEAR( EXPECT_NEAR(
Length_cpu.data<int64_t>()[i], Length_ref.data<int64_t>()[i], 1e-5); length_cpu_.data<int64_t>()[i], length_ref_.data<int64_t>()[i], 1e-5);
} }
} }
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "lite/api/test_helper.h" #include "lite/api/test_helper.h"
#include "lite/backends/cuda/cuda_utils.h" #include "lite/backends/cuda/cuda_utils.h"
// #include "lite/utils/float16.h" #include "lite/utils/float16.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -33,66 +33,66 @@ namespace cuda { ...@@ -33,66 +33,66 @@ namespace cuda {
class SequenceUnpadTest : public ::testing::Test { class SequenceUnpadTest : public ::testing::Test {
protected: protected:
SequenceUnpadTest() SequenceUnpadTest()
: batch(5), : batch_(5),
features(2), features_(2),
padded_length(3), padded_length_(3),
out_lod({{0, 2, 5}}), out_lod_({{0, 2, 5}}),
x_shape({static_cast<int64_t>(out_lod[0].size() - 1), x_shape_({static_cast<int64_t>(out_lod_[0].size() - 1),
padded_length, padded_length_,
features}), features_}),
out_shape({batch, features}) { out_shape_({batch_, features_}) {
X_ref.Resize(lite::DDim(x_shape)); x_ref_.Resize(lite::DDim(x_shape_));
X_gpu.Resize(X_ref.dims()); x_gpu_.Resize(x_ref_.dims());
Length_ref.Resize( length_ref_.Resize(
lite::DDim({static_cast<int64_t>(out_lod[0].size() - 1)})); lite::DDim({static_cast<int64_t>(out_lod_[0].size() - 1)}));
Length_gpu.Resize(Length_ref.dims()); length_gpu_.Resize(length_ref_.dims());
auto* x_ref_data = X_ref.mutable_data<float>(); auto* x_ref_data = x_ref_.mutable_data<float>();
auto* length_ref_data = Length_ref.mutable_data<int64_t>(); auto* length_ref_data = length_ref_.mutable_data<int64_t>();
// prepare input // prepare input
for (int64_t i = 0; i < X_ref.numel(); i++) { for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i); x_ref_data[i] = static_cast<float>(i);
} }
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) { for (size_t i = 0; i < out_lod_[0].size() - 1; ++i) {
length_ref_data[i] = out_lod[0][i + 1] - out_lod[0][i]; length_ref_data[i] = out_lod_[0][i + 1] - out_lod_[0][i];
} }
Out_ref.Resize(lite::DDim(out_shape)); out_ref_.Resize(lite::DDim(out_shape_));
Out_ref.set_lod(out_lod); out_ref_.set_lod(out_lod_);
Out_gpu.Resize(Out_ref.dims()); out_gpu_.Resize(out_ref_.dims());
Out_gpu.set_lod(Out_ref.lod()); out_gpu_.set_lod(out_ref_.lod());
Out_cpu.Resize(Out_ref.dims()); out_cpu_.Resize(out_ref_.dims());
Out_cpu.set_lod(Out_ref.lod()); out_cpu_.set_lod(out_ref_.lod());
cpu_base(&X_ref, &Length_ref, &Out_ref); RunBaseLine(&x_ref_, &length_ref_, &out_ref_);
device_init(); InitParamAndContext();
} }
void device_init() { void InitParamAndContext() {
ctx.reset(new KernelContext); ctx_.reset(new KernelContext);
cudaStreamCreate(&stream); cudaStreamCreate(&stream_);
auto& context = ctx->As<CUDAContext>(); auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream); context.SetExecStream(stream_);
param.X = &X_gpu; param_.X = &x_gpu_;
param.Length = &Length_gpu; param_.Length = &length_gpu_;
param.Out = &Out_gpu; param_.Out = &out_gpu_;
} }
void float_data_init() { void InitFloatInput() {
X_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(X_ref.data<float>(), x_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(x_ref_.data<float>(),
X_gpu.dims()); x_gpu_.dims());
Length_gpu.Assign<int64_t, lite::DDim, TARGET(kCUDA)>( length_gpu_.Assign<int64_t, lite::DDim, TARGET(kCUDA)>(
Length_ref.data<int64_t>(), Length_gpu.dims()); length_ref_.data<int64_t>(), length_gpu_.dims());
} }
void half_data_init() {} void InitHalfInput() {}
void cpu_base(const lite::Tensor* X, void RunBaseLine(const lite::Tensor* X,
const lite::Tensor* Length, const lite::Tensor* Length,
lite::Tensor* Out) { lite::Tensor* Out) {
auto* out_data = Out->mutable_data<float>(); auto* out_data = Out->mutable_data<float>();
for (size_t i = 0; i < 4; ++i) { for (size_t i = 0; i < 4; ++i) {
...@@ -103,24 +103,24 @@ class SequenceUnpadTest : public ::testing::Test { ...@@ -103,24 +103,24 @@ class SequenceUnpadTest : public ::testing::Test {
} }
} }
int batch, features, padded_length; int batch_, features_, padded_length_;
LoD out_lod; LoD out_lod_;
std::vector<int64_t> x_shape, out_shape; std::vector<int64_t> x_shape_, out_shape_;
lite::Tensor X_ref, Out_ref, Length_ref; lite::Tensor x_ref_, out_ref_, length_ref_;
lite::Tensor X_gpu, Out_gpu, Length_gpu; lite::Tensor x_gpu_, out_gpu_, length_gpu_;
lite::Tensor Out_cpu, Length_cpu; lite::Tensor out_cpu_, length_cpu_;
operators::SequencePadParam param; operators::SequencePadParam param_;
std::unique_ptr<KernelContext> ctx; std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream; cudaStream_t stream_;
}; };
TEST_F(SequenceUnpadTest, fp32) { TEST_F(SequenceUnpadTest, fp32) {
float_data_init(); InitFloatInput();
SequenceUnpadCompute<float, PRECISION(kFloat)> kernel; SequenceUnpadCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param); kernel.SetParam(param_);
kernel.SetContext(std::move(ctx)); kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch(); kernel.Launch();
...@@ -138,12 +138,12 @@ TEST_F(SequenceUnpadTest, fp32) { ...@@ -138,12 +138,12 @@ TEST_F(SequenceUnpadTest, fp32) {
<< ", repeats: " << FLAGS_repeats << ", spend " << ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average."; << duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(), CopySync<TARGET(kCUDA)>(out_cpu_.mutable_data<float>(),
Out_gpu.data<float>(), out_gpu_.data<float>(),
sizeof(float) * Out_gpu.numel(), sizeof(float) * out_gpu_.numel(),
IoDirection::DtoH); IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) { for (int i = 0; i < out_gpu_.numel(); ++i) {
EXPECT_NEAR(Out_cpu.data<float>()[i], Out_ref.data<float>()[i], 1e-5); EXPECT_NEAR(out_cpu_.data<float>()[i], out_ref_.data<float>()[i], 1e-5);
} }
} }
......
...@@ -43,7 +43,7 @@ void TransposeCompute<T, Ptype>::Run() { ...@@ -43,7 +43,7 @@ void TransposeCompute<T, Ptype>::Run() {
// NCHW -> NHWC // NCHW -> NHWC
if (axes.size() == 4 && axes[0] == 0 && axes[1] == 2 && axes[2] == 3 && if (axes.size() == 4 && axes[0] == 0 && axes[1] == 2 && axes[2] == 3 &&
axes[3] == 1) { axes[3] == 1) {
trans.NCHW2NHWC(dims[0], dims[1], dims[2] * dims[3], in, out, &stream); trans_.NCHW2NHWC(dims[0], dims[1], dims[2] * dims[3], in, out, &stream);
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
return; return;
...@@ -52,13 +52,13 @@ void TransposeCompute<T, Ptype>::Run() { ...@@ -52,13 +52,13 @@ void TransposeCompute<T, Ptype>::Run() {
// NHWC -> NCHW // NHWC -> NCHW
if (axes.size() == 4 && axes[0] == 0 && axes[1] == 3 && axes[2] == 1 && if (axes.size() == 4 && axes[0] == 0 && axes[1] == 3 && axes[2] == 1 &&
axes[3] == 2) { axes[3] == 2) {
trans.NHWC2NCHW(dims[0], dims[3], dims[1] * dims[2], in, out, &stream); trans_.NHWC2NCHW(dims[0], dims[3], dims[1] * dims[2], in, out, &stream);
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
return; return;
} }
trans.transpose(out, in, dims, axes, &stream); trans_.transpose(out, in, dims, axes, &stream);
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
......
...@@ -30,7 +30,7 @@ class TransposeCompute : public KernelLite<TARGET(kCUDA), Ptype> { ...@@ -30,7 +30,7 @@ class TransposeCompute : public KernelLite<TARGET(kCUDA), Ptype> {
virtual ~TransposeCompute() = default; virtual ~TransposeCompute() = default;
private: private:
lite::cuda::math::Transpose<Dtype> trans; lite::cuda::math::Transpose<Dtype> trans_;
}; };
} // namespace cuda } // namespace cuda
......
...@@ -36,9 +36,9 @@ namespace { ...@@ -36,9 +36,9 @@ namespace {
#define OUT(n, c, h, w) \ #define OUT(n, c, h, w) \
output_data[w + h * output_w + c * output_h * output_w + \ output_data[w + h * output_w + c * output_h * output_w + \
n * output_c * output_h * output_w] n * output_c * output_h * output_w]
void nchw2nhwc_ref(lite::Tensor* input, void Nchw2nhwcBaseLine(lite::Tensor* input,
lite::Tensor* output, lite::Tensor* output,
const std::vector<int> axies) { const std::vector<int> axies) {
auto* input_data = input->data<float>(); auto* input_data = input->data<float>();
auto* output_data = output->mutable_data<float>(); auto* output_data = output->mutable_data<float>();
...@@ -69,9 +69,9 @@ void nchw2nhwc_ref(lite::Tensor* input, ...@@ -69,9 +69,9 @@ void nchw2nhwc_ref(lite::Tensor* input,
#define OUT(n, h, w, c) \ #define OUT(n, h, w, c) \
output_data[c + w * output_c + h * output_w * output_c + \ output_data[c + w * output_c + h * output_w * output_c + \
n * output_h * output_w * output_c] n * output_h * output_w * output_c]
void nhwc2nchw_ref(lite::Tensor* input, void Nhwc2nchwBaseLine(lite::Tensor* input,
lite::Tensor* output, lite::Tensor* output,
const std::vector<int> axies) { const std::vector<int> axies) {
auto* input_data = input->data<float>(); auto* input_data = input->data<float>();
auto* output_data = output->mutable_data<float>(); auto* output_data = output->mutable_data<float>();
...@@ -94,7 +94,7 @@ void nhwc2nchw_ref(lite::Tensor* input, ...@@ -94,7 +94,7 @@ void nhwc2nchw_ref(lite::Tensor* input,
} }
} }
void transpose_ref(const lite::Tensor* input, void TransBaseLine(const lite::Tensor* input,
lite::Tensor* output, lite::Tensor* output,
const std::vector<int> axes) { const std::vector<int> axes) {
auto* input_data = input->data<float>(); auto* input_data = input->data<float>();
...@@ -173,9 +173,9 @@ TEST(transpose_nchw, normal) { ...@@ -173,9 +173,9 @@ TEST(transpose_nchw, normal) {
auto* out_data = out.mutable_data<float>(TARGET(kCUDA)); auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>( CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
nchw2nhwc_ref(&x_ref, &out_ref, axes); Nchw2nhwcBaseLine(&x_ref, &out_ref, axes);
auto* out_ref_data = out_ref.mutable_data<float>(); auto* out_ref_data = out_ref.mutable_data<float>();
// transpose_ref(&x_ref, &out_ref, axes); // TransBaseLine(&x_ref, &out_ref, axes);
for (int i = 0; i < out.numel(); i++) { for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
} }
...@@ -225,8 +225,8 @@ TEST(transpose_nhwc, normal) { ...@@ -225,8 +225,8 @@ TEST(transpose_nhwc, normal) {
auto* out_data = out.mutable_data<float>(TARGET(kCUDA)); auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>( CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
nhwc2nchw_ref(&x_ref, &out_ref, axes); Nhwc2nchwBaseLine(&x_ref, &out_ref, axes);
// transpose_ref(&x_ref, &out_ref, axes); // TransBaseLine(&x_ref, &out_ref, axes);
auto* out_ref_data = out_ref.mutable_data<float>(); auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < out.numel(); i++) { for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
...@@ -236,77 +236,77 @@ TEST(transpose_nhwc, normal) { ...@@ -236,77 +236,77 @@ TEST(transpose_nhwc, normal) {
class TransposeTest : public ::testing::Test { class TransposeTest : public ::testing::Test {
protected: protected:
TransposeTest() TransposeTest()
: C(3), : C_(3),
H(128), H_(128),
W(64), W_(64),
axes({1, 2, 0}), axes_({1, 2, 0}),
x_shape({C, H, W}), x_shape_({C_, H_, W_}),
out_shape({H, W, C}) { out_shape_({H_, W_, C_}) {
X_ref.Resize(lite::DDim(x_shape)); x_ref_.Resize(lite::DDim(x_shape_));
X_gpu.Resize(X_ref.dims()); x_gpu_.Resize(x_ref_.dims());
auto x_ref_data = X_ref.mutable_data<float>(); auto X_ref__data = x_ref_.mutable_data<float>();
// prepare input // prepare input
for (int64_t i = 0; i < X_ref.numel(); i++) { for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i); X_ref__data[i] = static_cast<float>(i);
} }
Out_ref.Resize(lite::DDim(out_shape)); out_ref_.Resize(lite::DDim(out_shape_));
Out_gpu.Resize(Out_ref.dims()); out_gpu_.Resize(out_ref_.dims());
Out_cpu.Resize(Out_ref.dims()); out_cpu_.Resize(out_ref_.dims());
cpu_base(&X_ref, &Out_ref); RunBaseLine(&x_ref_, &out_ref_);
device_init(); InitParamAndContext();
} }
void device_init() { void InitParamAndContext() {
ctx.reset(new KernelContext); ctx_.reset(new KernelContext);
cudaStreamCreate(&stream); cudaStreamCreate(&stream_);
auto& context = ctx->As<CUDAContext>(); auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream); context.SetExecStream(stream_);
param.x = &X_gpu; param_.x = &x_gpu_;
param.output = &Out_gpu; param_.output = &out_gpu_;
param.axis = axes; param_.axis = axes_;
} }
void float_data_init() { void InitFloatInput() {
X_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(X_ref.data<float>(), x_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(x_ref_.data<float>(),
X_gpu.dims()); x_gpu_.dims());
} }
void half_data_init() { void InitHalfInput() {
X_half.Resize(lite::DDim(X_ref.dims())); x_half_.Resize(lite::DDim(x_ref_.dims()));
auto x_half_data = X_half.mutable_data<half>(); auto X_half__data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < X_half.numel(); i++) { for (int64_t i = 0; i < x_half_.numel(); i++) {
x_half_data[i] = half(lite::float16(X_ref.data<float>()[i])); X_half__data[i] = half(lite::float16(x_ref_.data<float>()[i]));
} }
X_gpu.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, X_gpu.dims()); x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(X_half__data, x_gpu_.dims());
} }
void cpu_base(const lite::Tensor* X, lite::Tensor* Out) { void RunBaseLine(const lite::Tensor* x, lite::Tensor* out) {
transpose_ref(X, Out, axes); TransBaseLine(x, out, axes_);
} }
int C, H, W; int C_, H_, W_;
std::vector<int> axes; std::vector<int> axes_;
std::vector<int64_t> x_shape, out_shape; std::vector<int64_t> x_shape_, out_shape_;
lite::Tensor X_ref, Out_ref; lite::Tensor x_ref_, out_ref_;
lite::Tensor X_gpu, Out_gpu; lite::Tensor x_gpu_, out_gpu_;
lite::Tensor X_half; lite::Tensor x_half_;
lite::Tensor Out_cpu; lite::Tensor out_cpu_;
operators::TransposeParam param; operators::TransposeParam param_;
std::unique_ptr<KernelContext> ctx; std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream; cudaStream_t stream_;
}; };
TEST_F(TransposeTest, fp32) { TEST_F(TransposeTest, fp32) {
float_data_init(); InitFloatInput();
TransposeCompute<float, PRECISION(kFloat)> kernel; TransposeCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param); kernel.SetParam(param_);
kernel.SetContext(std::move(ctx)); kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch(); kernel.Launch();
...@@ -324,20 +324,20 @@ TEST_F(TransposeTest, fp32) { ...@@ -324,20 +324,20 @@ TEST_F(TransposeTest, fp32) {
<< ", repeats: " << FLAGS_repeats << ", spend " << ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average."; << duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(), CopySync<TARGET(kCUDA)>(out_cpu_.mutable_data<float>(),
Out_gpu.data<float>(), out_gpu_.data<float>(),
sizeof(float) * Out_gpu.numel(), sizeof(float) * out_gpu_.numel(),
IoDirection::DtoH); IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) { for (int i = 0; i < out_gpu_.numel(); ++i) {
EXPECT_NEAR(Out_cpu.data<float>()[i], Out_ref.data<float>()[i], 1e-5); EXPECT_NEAR(out_cpu_.data<float>()[i], out_ref_.data<float>()[i], 1e-5);
} }
} }
TEST_F(TransposeTest, TestFP16) { TEST_F(TransposeTest, TestFP16) {
half_data_init(); InitHalfInput();
TransposeCompute<half, PRECISION(kFP16)> kernel; TransposeCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param); kernel.SetParam(param_);
kernel.SetContext(std::move(ctx)); kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch(); kernel.Launch();
...@@ -355,16 +355,16 @@ TEST_F(TransposeTest, TestFP16) { ...@@ -355,16 +355,16 @@ TEST_F(TransposeTest, TestFP16) {
<< ", repeats: " << FLAGS_repeats << ", spend " << ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average."; << duration / FLAGS_repeats << " ms in average.";
const half* out_gpu_data = Out_gpu.data<half>(); const half* Out_gpu__data = out_gpu_.data<half>();
half* out_cpu_data = Out_cpu.mutable_data<half>(); half* Out_cpu__data = out_cpu_.mutable_data<half>();
CopySync<TARGET(kCUDA)>(out_cpu_data, CopySync<TARGET(kCUDA)>(Out_cpu__data,
out_gpu_data, Out_gpu__data,
sizeof(half) * Out_gpu.numel(), sizeof(half) * out_gpu_.numel(),
IoDirection::DtoH); IoDirection::DtoH);
for (int i = 0; i < Out_cpu.numel(); ++i) { for (int i = 0; i < out_cpu_.numel(); ++i) {
float res = static_cast<float>(lite::float16(out_cpu_data[i])); float res = static_cast<float>(lite::float16(Out_cpu__data[i]));
float ref = Out_ref.data<float>()[i]; float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 1e-2); EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 1e-2);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册