未验证 提交 af514209 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Kernel] Add cuda fp16 kernel (#3903)

上级 62c6d5d5
...@@ -134,6 +134,16 @@ template void SequencePadding(float* pad_data, ...@@ -134,6 +134,16 @@ template void SequencePadding(float* pad_data,
int step_width, int step_width,
cudaStream_t* stream); cudaStream_t* stream);
template void SequencePadding(half* pad_data,
const half* seq_data,
const half* pad_value_data,
bool is_constant_pad,
const size_t* seq_offsets_data,
int seq_num,
int pad_seq_len,
int step_width,
cudaStream_t* stream);
template void SequenceUnpadding(float* seq_data, template void SequenceUnpadding(float* seq_data,
const float* pad_data, const float* pad_data,
const size_t* seq_offsets_data, const size_t* seq_offsets_data,
...@@ -142,6 +152,14 @@ template void SequenceUnpadding(float* seq_data, ...@@ -142,6 +152,14 @@ template void SequenceUnpadding(float* seq_data,
int step_width, int step_width,
cudaStream_t* stream); cudaStream_t* stream);
template void SequenceUnpadding(half* seq_data,
const half* pad_data,
const size_t* seq_offsets_data,
int seq_num,
int pad_seq_len,
int step_width,
cudaStream_t* stream);
} // namespace math } // namespace math
} // namespace cuda } // namespace cuda
} // namespace lite } // namespace lite
......
...@@ -57,18 +57,18 @@ void SequenceMaskCompute<T, Ptype>::Run() { ...@@ -57,18 +57,18 @@ void SequenceMaskCompute<T, Ptype>::Run() {
} }
if (maxlen < 0) { if (maxlen < 0) {
maxlen = maxlen = thrust::reduce(
thrust::reduce(x_data, x_data + x->numel(), 0, thrust::maximum<T>()); x_data, x_data + x->numel(), 0, thrust::maximum<int64_t>());
} }
auto y_dim = x->dims().Vectorize(); auto y_dim = x->dims().Vectorize();
y_dim.push_back(maxlen); y_dim.push_back(maxlen);
y->Resize(y_dim); y->Resize(y_dim);
const int count = y->numel(); const int count = y->numel();
auto* dst_data = y->template mutable_data<float>(TARGET(kCUDA)); auto* dst_data = y->template mutable_data<T>(TARGET(kCUDA));
if (param.out_dtype == 5) { if (param.out_dtype == 5) {
SequenceMaskKernel< SequenceMaskKernel<
float><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>( T><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
dst_data, x_data, count, maxlen); dst_data, x_data, count, maxlen);
} else { } else {
LOG(FATAL) << "not supported out_dtype: " << param.out_dtype; LOG(FATAL) << "not supported out_dtype: " << param.out_dtype;
...@@ -84,8 +84,19 @@ void SequenceMaskCompute<T, Ptype>::Run() { ...@@ -84,8 +84,19 @@ void SequenceMaskCompute<T, Ptype>::Run() {
using SeqMaskFp32 = using SeqMaskFp32 =
paddle::lite::kernels::cuda::SequenceMaskCompute<float, PRECISION(kFloat)>; paddle::lite::kernels::cuda::SequenceMaskCompute<float, PRECISION(kFloat)>;
using SeqMaskFp16 =
paddle::lite::kernels::cuda::SequenceMaskCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(sequence_mask, kCUDA, kFloat, kNCHW, SeqMaskFp32, def) REGISTER_LITE_KERNEL(sequence_mask, kCUDA, kFloat, kNCHW, SeqMaskFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
.BindInput("MaxLenTensor", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("MaxLenTensor",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt32))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(sequence_mask, kCUDA, kFP16, kNCHW, SeqMaskFp16, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
.BindInput("MaxLenTensor",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt32))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/cuda/sequence_mask_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
...@@ -21,8 +23,7 @@ ...@@ -21,8 +23,7 @@
#include "lite/api/test_helper.h" #include "lite/api/test_helper.h"
#include "lite/backends/cuda/cuda_utils.h" #include "lite/backends/cuda/cuda_utils.h"
#include "lite/kernels/cuda/sequence_mask_compute.h" #include "lite/utils/float16.h"
// #include "lite/utils/float16.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -70,7 +71,10 @@ class SequenceMaskTest : public ::testing::Test { ...@@ -70,7 +71,10 @@ class SequenceMaskTest : public ::testing::Test {
x_gpu_.dims()); x_gpu_.dims());
} }
void InitHalfInput() {} void InitHalfInput() {
x_gpu_.Assign<int64_t, lite::DDim, TARGET(kCUDA)>(x_ref_.data<int64_t>(),
x_gpu_.dims());
}
void RunBaseLine(const lite::Tensor* x, lite::Tensor* out) { void RunBaseLine(const lite::Tensor* x, lite::Tensor* out) {
auto* out_data = out->mutable_data<float>(); auto* out_data = out->mutable_data<float>();
...@@ -125,6 +129,41 @@ TEST_F(SequenceMaskTest, fp32) { ...@@ -125,6 +129,41 @@ TEST_F(SequenceMaskTest, fp32) {
} }
} }
TEST_F(SequenceMaskTest, TestFP16) {
InitHalfInput();
SequenceMaskCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
const half* out_gpu_data = out_gpu_.data<half>();
half* out_cpu_data = out_cpu_.mutable_data<half>();
CopySync<TARGET(kCUDA)>(out_cpu_data,
out_gpu_data,
sizeof(half) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = static_cast<float>(lite::float16(out_cpu_data[i]));
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 1e-2);
}
}
} // namespace cuda } // namespace cuda
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -85,9 +85,22 @@ void SequencePadCompute<T, Ptype>::Run() { ...@@ -85,9 +85,22 @@ void SequencePadCompute<T, Ptype>::Run() {
using SeqPadFp32 = using SeqPadFp32 =
paddle::lite::kernels::cuda::SequencePadCompute<float, PRECISION(kFloat)>; paddle::lite::kernels::cuda::SequencePadCompute<float, PRECISION(kFloat)>;
using SeqPadFp16 =
paddle::lite::kernels::cuda::SequencePadCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(sequence_pad, kCUDA, kFloat, kNCHW, SeqPadFp32, def) REGISTER_LITE_KERNEL(sequence_pad, kCUDA, kFloat, kNCHW, SeqPadFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("PadValue", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("PadValue", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Length", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Length",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
.Finalize();
REGISTER_LITE_KERNEL(sequence_pad, kCUDA, kFP16, kNCHW, SeqPadFp16, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("PadValue",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Length",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
.Finalize(); .Finalize();
...@@ -52,11 +52,11 @@ class SequencePadTest : public ::testing::Test { ...@@ -52,11 +52,11 @@ class SequencePadTest : public ::testing::Test {
length_ref_.Resize( length_ref_.Resize(
lite::DDim({static_cast<int64_t>(x_lod_[0].size() - 1)})); lite::DDim({static_cast<int64_t>(x_lod_[0].size() - 1)}));
length_gpu_.Resize(length_ref_.dims()); length_gpu_.Resize(length_ref_.dims());
length_cpu_.Resize(length_ref_.dims());
auto x_ref_data = x_ref_.mutable_data<float>(); auto x_ref_data = x_ref_.mutable_data<float>();
auto pad_value_ref_data = pad_value_ref_.mutable_data<float>(); auto pad_value_ref_data = pad_value_ref_.mutable_data<float>();
// prepare input
for (int64_t i = 0; i < x_ref_.numel(); i++) { for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i); x_ref_data[i] = static_cast<float>(i);
} }
...@@ -92,7 +92,23 @@ class SequencePadTest : public ::testing::Test { ...@@ -92,7 +92,23 @@ class SequencePadTest : public ::testing::Test {
pad_value_ref_.data<float>(), pad_value_gpu_.dims()); pad_value_ref_.data<float>(), pad_value_gpu_.dims());
} }
void InitHalfInput() {} void InitHalfInput() {
x_half_.Resize(lite::DDim(x_shape_));
auto x_half_data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < x_half_.numel(); i++) {
x_half_data[i] = half(lite::float16(x_ref_.data<float>()[i]));
}
x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, x_gpu_.dims());
x_gpu_.set_lod(x_ref_.lod());
pad_value_half_.Resize(pad_value_ref_.dims());
auto pad_value_half_data = pad_value_half_.mutable_data<half>();
for (int64_t i = 0; i < pad_value_half_.numel(); i++) {
pad_value_half_data[i] =
half(lite::float16(pad_value_ref_.data<float>()[i]));
}
pad_value_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(
pad_value_half_data, pad_value_gpu_.dims());
}
void RunBaseLine(const lite::Tensor* x, void RunBaseLine(const lite::Tensor* x,
const lite::Tensor* pad_value, const lite::Tensor* pad_value,
...@@ -119,6 +135,7 @@ class SequencePadTest : public ::testing::Test { ...@@ -119,6 +135,7 @@ class SequencePadTest : public ::testing::Test {
lite::Tensor x_ref_, pad_value_ref_, out_ref_, length_ref_; lite::Tensor x_ref_, pad_value_ref_, out_ref_, length_ref_;
lite::Tensor x_gpu_, pad_value_gpu_, out_gpu_, length_gpu_; lite::Tensor x_gpu_, pad_value_gpu_, out_gpu_, length_gpu_;
lite::Tensor x_half_, pad_value_half_;
lite::Tensor out_cpu_, length_cpu_; lite::Tensor out_cpu_, length_cpu_;
operators::SequencePadParam param_; operators::SequencePadParam param_;
...@@ -165,6 +182,51 @@ TEST_F(SequencePadTest, fp32) { ...@@ -165,6 +182,51 @@ TEST_F(SequencePadTest, fp32) {
} }
} }
TEST_F(SequencePadTest, TestFP16) {
InitHalfInput();
SequencePadCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
const half* out_gpu_data = out_gpu_.data<half>();
half* out_cpu_data = out_cpu_.mutable_data<half>();
const int64_t* length_gpu_data = length_gpu_.data<int64_t>();
int64_t* length_cpu_data = length_cpu_.mutable_data<int64_t>();
CopySync<TARGET(kCUDA)>(out_cpu_data,
out_gpu_data,
sizeof(half) * out_gpu_.numel(),
IoDirection::DtoH);
CopySync<TARGET(kCUDA)>(length_cpu_data,
length_gpu_data,
sizeof(int64_t) * length_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = static_cast<float>(lite::float16(out_cpu_data[i]));
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 1e-2);
}
for (int i = 0; i < length_gpu_.numel(); ++i) {
EXPECT_NEAR(
length_cpu_.data<int64_t>()[i], length_ref_.data<int64_t>()[i], 1e-5);
}
}
} // namespace cuda } // namespace cuda
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -74,8 +74,19 @@ void SequenceUnpadCompute<T, Ptype>::Run() { ...@@ -74,8 +74,19 @@ void SequenceUnpadCompute<T, Ptype>::Run() {
using SeqUnadFp32 = using SeqUnadFp32 =
paddle::lite::kernels::cuda::SequenceUnpadCompute<float, PRECISION(kFloat)>; paddle::lite::kernels::cuda::SequenceUnpadCompute<float, PRECISION(kFloat)>;
using SeqUnadFp16 =
paddle::lite::kernels::cuda::SequenceUnpadCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(sequence_unpad, kCUDA, kFloat, kNCHW, SeqUnadFp32, def) REGISTER_LITE_KERNEL(sequence_unpad, kCUDA, kFloat, kNCHW, SeqUnadFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Length", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("Length",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(sequence_unpad, kCUDA, kFP16, kNCHW, SeqUnadFp16, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("Length",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
...@@ -88,7 +88,16 @@ class SequenceUnpadTest : public ::testing::Test { ...@@ -88,7 +88,16 @@ class SequenceUnpadTest : public ::testing::Test {
length_ref_.data<int64_t>(), length_gpu_.dims()); length_ref_.data<int64_t>(), length_gpu_.dims());
} }
void InitHalfInput() {} void InitHalfInput() {
x_half_.Resize(lite::DDim(x_shape_));
auto x_half_data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < x_half_.numel(); i++) {
x_half_data[i] = half(lite::float16(x_ref_.data<float>()[i]));
}
x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, x_gpu_.dims());
length_gpu_.Assign<int64_t, lite::DDim, TARGET(kCUDA)>(
length_ref_.data<int64_t>(), length_gpu_.dims());
}
void RunBaseLine(const lite::Tensor* X, void RunBaseLine(const lite::Tensor* X,
const lite::Tensor* Length, const lite::Tensor* Length,
...@@ -109,6 +118,7 @@ class SequenceUnpadTest : public ::testing::Test { ...@@ -109,6 +118,7 @@ class SequenceUnpadTest : public ::testing::Test {
lite::Tensor x_ref_, out_ref_, length_ref_; lite::Tensor x_ref_, out_ref_, length_ref_;
lite::Tensor x_gpu_, out_gpu_, length_gpu_; lite::Tensor x_gpu_, out_gpu_, length_gpu_;
lite::Tensor x_half_;
lite::Tensor out_cpu_, length_cpu_; lite::Tensor out_cpu_, length_cpu_;
operators::SequencePadParam param_; operators::SequencePadParam param_;
...@@ -147,6 +157,41 @@ TEST_F(SequenceUnpadTest, fp32) { ...@@ -147,6 +157,41 @@ TEST_F(SequenceUnpadTest, fp32) {
} }
} }
TEST_F(SequenceUnpadTest, TestFP16) {
InitHalfInput();
SequenceUnpadCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
const half* out_gpu_data = out_gpu_.data<half>();
half* out_cpu_data = out_cpu_.mutable_data<half>();
CopySync<TARGET(kCUDA)>(out_cpu_data,
out_gpu_data,
sizeof(half) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = static_cast<float>(lite::float16(out_cpu_data[i]));
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 1e-2);
}
}
} // namespace cuda } // namespace cuda
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册