未验证 提交 11cbd50e 编写于 作者: W Wilber 提交者: GitHub

update transpose cuda kernel. test=develop (#3879)

上级 f35f8dac
......@@ -174,24 +174,9 @@ void Transpose<T>::transpose(T* dst,
TransposeCUDAImpl<T>(src_dims, axes, src, dst, &Y_dims_, &strides_, stream);
}
// template <typename T>
// void Transpose<T>::transpose(T* dst,
// const T* src,
// const std::vector<int>& src_dims,
// const std::vector<int>& axes,
// cudaStream_t* stream) {
// std::vector<int64_t> _src_dims(src_dims.size(), 0);
// std::transform(
// src_dims.begin(),
// src_dims.end(),
// _src_dims.begin(),
// [](int data) -> int64_t { return static_cast<int64_t>(data); });
// TransposeCUDAImpl<T>(_src_dims, axes, src, dst, &Y_dims_, &strides_,
// stream);
//}
template class Transpose<int8_t>;
template class Transpose<float>;
template class Transpose<half>;
} // namespace math
} // namespace cuda
......
......@@ -60,6 +60,8 @@ class AssignValueTest : public ::testing::Test {
void device_init() {
ctx.reset(new KernelContext);
cudaStreamCreate(&stream);
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
param.shape = shape;
param.dtype = dtype;
param.fp32_values = fp32_values;
......@@ -113,8 +115,6 @@ class AssignValueTest : public ::testing::Test {
TEST_F(AssignValueTest, fp32) {
float_data_init();
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
AssignValueCompute kernel;
kernel.SetParam(param);
kernel.SetContext(std::move(ctx));
......
......@@ -57,6 +57,8 @@ class SequenceMaskTest : public ::testing::Test {
void device_init() {
ctx.reset(new KernelContext);
cudaStreamCreate(&stream);
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
param.X = &X_gpu;
param.Y = &Out_gpu;
param.maxlen = maxlen;
......@@ -94,8 +96,6 @@ class SequenceMaskTest : public ::testing::Test {
TEST_F(SequenceMaskTest, fp32) {
float_data_init();
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
SequenceMaskCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param);
kernel.SetContext(std::move(ctx));
......
......@@ -74,6 +74,8 @@ class SequencePadTest : public ::testing::Test {
void device_init() {
ctx.reset(new KernelContext);
cudaStreamCreate(&stream);
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
param.X = &X_gpu;
param.PadValue = &PadValue_gpu;
param.Length = &Length_gpu;
......@@ -125,8 +127,6 @@ class SequencePadTest : public ::testing::Test {
TEST_F(SequencePadTest, fp32) {
float_data_init();
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
SequencePadCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param);
kernel.SetContext(std::move(ctx));
......
......@@ -74,6 +74,8 @@ class SequenceUnpadTest : public ::testing::Test {
void device_init() {
ctx.reset(new KernelContext);
cudaStreamCreate(&stream);
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
param.X = &X_gpu;
param.Length = &Length_gpu;
param.Out = &Out_gpu;
......@@ -116,8 +118,6 @@ class SequenceUnpadTest : public ::testing::Test {
TEST_F(SequenceUnpadTest, fp32) {
float_data_init();
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
SequenceUnpadCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param);
kernel.SetContext(std::move(ctx));
......
......@@ -13,17 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "lite/kernels/cuda/transpose_compute.h"
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/transpose_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void TransposeCompute::Run() {
auto& param = this->Param<param_t>();
template <typename T, PrecisionType Ptype>
void TransposeCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
......@@ -31,8 +34,8 @@ void TransposeCompute::Run() {
lite::Tensor* Out = param.output;
std::vector<int> axes = param.axis;
const float* in = X->data<float>();
float* out = Out->mutable_data<float>(TARGET(kCUDA));
const T* in = X->template data<T>();
T* out = Out->mutable_data<T>(TARGET(kCUDA));
int ndim = X->dims().size();
std::vector<int64_t> dims = X->dims().data();
......@@ -65,34 +68,31 @@ void TransposeCompute::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(transpose,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::TransposeCompute,
def)
using TransFp32 =
paddle::lite::kernels::cuda::TransposeCompute<float, PRECISION(kFloat)>;
using TransFp16 =
paddle::lite::kernels::cuda::TransposeCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(transpose, kCUDA, kFloat, kNCHW, TransFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(transpose2,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::TransposeCompute,
def)
REGISTER_LITE_KERNEL(transpose2, kCUDA, kFloat, kNCHW, TransFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// REGISTER_LITE_KERNEL(transpose2,
// kCUDA,
// kFloat,
// kNCHW,
// paddle::lite::kernels::cuda::TransposeCompute,
// def)
// .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .Finalize();
REGISTER_LITE_KERNEL(transpose, kCUDA, kFP16, kNCHW, TransFp16, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
REGISTER_LITE_KERNEL(transpose2, kCUDA, kFP16, kNCHW, TransFp16, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("XShape",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
......@@ -21,7 +21,8 @@ namespace lite {
namespace kernels {
namespace cuda {
class TransposeCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
template <typename Dtype, PrecisionType Ptype>
class TransposeCompute : public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::TransposeParam;
......@@ -29,7 +30,7 @@ class TransposeCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
virtual ~TransposeCompute() = default;
private:
lite::cuda::math::Transpose<float> trans;
lite::cuda::math::Transpose<Dtype> trans;
};
} // namespace cuda
......
......@@ -13,11 +13,16 @@
// limitations under the License.
#include "lite/kernels/cuda/transpose_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -89,7 +94,7 @@ void nhwc2nchw_ref(lite::Tensor* input,
}
}
void transpose_ref(lite::Tensor* input,
void transpose_ref(const lite::Tensor* input,
lite::Tensor* output,
const std::vector<int> axes) {
auto* input_data = input->data<float>();
......@@ -123,7 +128,7 @@ void transpose_ref(lite::Tensor* input,
} // namespace
TEST(transpose_nchw, normal) {
TransposeCompute transpose_kernel;
TransposeCompute<float, PRECISION(kFloat)> transpose_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
......@@ -177,7 +182,7 @@ TEST(transpose_nchw, normal) {
}
TEST(transpose_nhwc, normal) {
TransposeCompute transpose_kernel;
TransposeCompute<float, PRECISION(kFloat)> transpose_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
......@@ -228,54 +233,139 @@ TEST(transpose_nhwc, normal) {
}
}
TEST(transpose, normal) {
TransposeCompute transpose_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
class TransposeTest : public ::testing::Test {
protected:
TransposeTest()
: C(3),
H(128),
W(64),
axes({1, 2, 0}),
x_shape({C, H, W}),
out_shape({H, W, C}) {
X_ref.Resize(lite::DDim(x_shape));
X_gpu.Resize(X_ref.dims());
auto x_ref_data = X_ref.mutable_data<float>();
// prepare input
for (int64_t i = 0; i < X_ref.numel(); i++) {
x_ref_data[i] = static_cast<float>(i);
}
operators::TransposeParam param;
Out_ref.Resize(lite::DDim(out_shape));
Out_gpu.Resize(Out_ref.dims());
Out_cpu.Resize(Out_ref.dims());
cpu_base(&X_ref, &Out_ref);
lite::Tensor x, x_cpu, x_ref;
lite::Tensor out, out_cpu, out_ref;
int C = 3, H = 128, W = 128;
std::vector<int> axes({2, 0, 1});
x.Resize({C, H, W});
out.Resize({W, C, H});
device_init();
}
x_cpu.Resize({C, H, W});
out_cpu.Resize({W, C, H});
void device_init() {
ctx.reset(new KernelContext);
cudaStreamCreate(&stream);
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
param.x = &X_gpu;
param.output = &Out_gpu;
param.axis = axes;
}
x_ref.Resize({C, H, W});
out_ref.Resize({W, C, H});
void float_data_init() {
X_gpu.Assign<float, lite::DDim, TARGET(kCUDA)>(X_ref.data<float>(),
X_gpu.dims());
}
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
void half_data_init() {
X_half.Resize(lite::DDim(X_ref.dims()));
auto x_half_data = X_half.mutable_data<half>();
for (int64_t i = 0; i < X_half.numel(); i++) {
x_half_data[i] = half(lite::float16(X_ref.data<float>()[i]));
}
X_gpu.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, X_gpu.dims());
}
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = i + 1;
x_ref_data[i] = i + 1;
void cpu_base(const lite::Tensor* X, lite::Tensor* Out) {
transpose_ref(X, Out, axes);
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.x = &x;
param.output = &out;
param.axis = axes;
transpose_kernel.SetParam(param);
int C, H, W;
std::vector<int> axes;
std::vector<int64_t> x_shape, out_shape;
lite::Tensor X_ref, Out_ref;
lite::Tensor X_gpu, Out_gpu;
lite::Tensor X_half;
lite::Tensor Out_cpu;
operators::TransposeParam param;
std::unique_ptr<KernelContext> ctx;
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
transpose_kernel.SetContext(std::move(ctx));
transpose_kernel.Launch();
};
TEST_F(TransposeTest, fp32) {
float_data_init();
TransposeCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param);
kernel.SetContext(std::move(ctx));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
transpose_ref(&x_ref, &out_ref, axes);
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(),
Out_gpu.data<float>(),
sizeof(float) * Out_gpu.numel(),
IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) {
EXPECT_NEAR(Out_cpu.data<float>()[i], Out_ref.data<float>()[i], 1e-5);
}
}
TEST_F(TransposeTest, TestFP16) {
half_data_init();
TransposeCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param);
kernel.SetContext(std::move(ctx));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
const half* out_gpu_data = Out_gpu.data<half>();
half* out_cpu_data = Out_cpu.mutable_data<half>();
CopySync<TARGET(kCUDA)>(out_cpu_data,
out_gpu_data,
sizeof(half) * Out_gpu.numel(),
IoDirection::DtoH);
for (int i = 0; i < Out_cpu.numel(); ++i) {
float res = static_cast<float>(lite::float16(out_cpu_data[i]));
float ref = Out_ref.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 1e-2);
}
}
......
......@@ -43,24 +43,9 @@ bool TransposeOp::CheckShape() const {
}
bool TransposeOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
auto x_dims = param_.x->dims();
auto x_rank = x_dims.size();
std::vector<int> axis = param_.axis;
size_t axis_size = axis.size();
// "The input tensor's rank(%d) should be equal to the axis's size(%d)",
// x_rank, axis_size
CHECK_OR_FALSE(x_rank == axis_size);
std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis_size; i++) {
// Each element of Attribute axis should be a unique value
// range from 0 to (dims - 1),
// where the dims is the axis's size
CHECK_OR_FALSE(axis[i] < static_cast<int>(axis_size) &&
++count[axis[i]] == 1);
}
lite::DDim out_dims(x_dims);
for (size_t i = 0; i < axis_size; i++) {
out_dims[i] = x_dims[axis[i]];
......@@ -113,24 +98,9 @@ bool Transpose2Op::CheckShape() const {
}
bool Transpose2Op::InferShapeImpl() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
auto x_dims = param_.x->dims();
auto x_rank = x_dims.size();
std::vector<int> axis = param_.axis;
size_t axis_size = axis.size();
// "The input tensor's rank(%d) should be equal to the axis's size(%d)",
// x_rank, axis_size
CHECK_OR_FALSE(x_rank == axis_size);
std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis_size; i++) {
// Each element of Attribute axis should be a unique value
// range from 0 to (dims - 1),
// where the dims is the axis's size
CHECK_OR_FALSE(axis[i] < static_cast<int>(axis_size) &&
++count[axis[i]] == 1);
}
lite::DDim out_dims(x_dims);
for (size_t i = 0; i < axis_size; i++) {
out_dims[i] = x_dims[axis[i]];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册