未验证 提交 bc66d2be 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Kernel] Add sequence_mask kernel. (#3868)

上级 973cca29
...@@ -37,6 +37,7 @@ add_kernel(sequence_reverse_compute_cuda CUDA extra SRCS sequence_reverse_comput ...@@ -37,6 +37,7 @@ add_kernel(sequence_reverse_compute_cuda CUDA extra SRCS sequence_reverse_comput
add_kernel(sequence_pad_compute_cuda CUDA extra SRCS sequence_pad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(sequence_pad_compute_cuda CUDA extra SRCS sequence_pad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(sequence_unpad_compute_cuda CUDA extra SRCS sequence_unpad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(sequence_unpad_compute_cuda CUDA extra SRCS sequence_unpad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(sequence_concat_compute_cuda CUDA extra SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_cuda CUDA extra SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_mask_compute_cuda CUDA extra SRCS sequence_mask_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_arithmetic_compute_cuda CUDA extra SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_arithmetic_compute_cuda CUDA extra SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
add_kernel(attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu DEPS ${lite_kernel_deps}) add_kernel(attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu DEPS ${lite_kernel_deps})
...@@ -80,6 +81,7 @@ if(LITE_BUILD_EXTRA) ...@@ -80,6 +81,7 @@ if(LITE_BUILD_EXTRA)
nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda) nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda)
nv_test(sequence_pad_compute_cuda_test SRCS sequence_pad_compute_test.cc DEPS sequence_pad_compute_cuda) nv_test(sequence_pad_compute_cuda_test SRCS sequence_pad_compute_test.cc DEPS sequence_pad_compute_cuda)
nv_test(sequence_unpad_compute_cuda_test SRCS sequence_unpad_compute_test.cc DEPS sequence_unpad_compute_cuda) nv_test(sequence_unpad_compute_cuda_test SRCS sequence_unpad_compute_test.cc DEPS sequence_unpad_compute_cuda)
nv_test(sequence_mask_compute_cuda_test SRCS sequence_mask_compute_test.cc DEPS sequence_mask_compute_cuda)
nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda) nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda)
#nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) #nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
#nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda) #nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sequence_mask_compute.h"
#include <thrust/device_ptr.h>
#include <thrust/reduce.h>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
__global__ void SequenceMaskKernel(T* dst,
const int64_t* src,
int count,
int maxlen) {
CUDA_KERNEL_LOOP(index, count) {
int src_idx = index / maxlen;
int inner_idx = index % maxlen;
dst[index] = static_cast<T>(inner_idx < src[src_idx] ? 1 : 0);
}
}
template <typename T, PrecisionType Ptype>
void SequenceMaskCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const auto* x = param.X;
auto* x_data = x->template data<int64_t>();
auto* y = param.Y;
int maxlen = param.maxlen;
if (param.MaxLenTensor) {
auto* len_tensor_data = param.MaxLenTensor->template data<int32_t>();
int32_t len_data{0};
TargetWrapperCuda::MemcpySync(
&len_data, len_tensor_data, sizeof(int32_t), IoDirection::DtoH);
maxlen = len_data;
}
if (maxlen < 0) {
maxlen =
thrust::reduce(x_data, x_data + x->numel(), 0, thrust::maximum<T>());
}
auto y_dim = x->dims().Vectorize();
y_dim.push_back(maxlen);
y->Resize(y_dim);
const int count = y->numel();
auto* dst_data = y->template mutable_data<float>(TARGET(kCUDA));
if (param.out_dtype == 5) {
SequenceMaskKernel<
float><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
dst_data, x_data, count, maxlen);
} else {
LOG(FATAL) << "not supported out_dtype: " << param.out_dtype;
}
CUDA_POST_KERNEL_CHECK;
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using SeqMaskFp32 =
paddle::lite::kernels::cuda::SequenceMaskCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(sequence_mask, kCUDA, kFloat, kNCHW, SeqMaskFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
.BindInput("MaxLenTensor", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
class SequenceMaskCompute : public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::SequenceMaskParam;
void Run() override;
virtual ~SequenceMaskCompute() = default;
// private:
// lite::Tensor seq_offsets_;
// std::vector<int64_t> seq_len_;
// std::vector<size_t> seq_offsets_vec_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/kernels/cuda/sequence_mask_compute.h"
// #include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SequenceMaskTest : public ::testing::Test {
protected:
SequenceMaskTest()
: maxlen(4),
out_dtype(5),
x_data({3, 2, 1, 0}),
out_shape({static_cast<int64_t>(x_data.size()), maxlen}) {
X_ref.Resize(lite::DDim({static_cast<int64_t>(x_data.size())}));
X_gpu.Resize(X_ref.dims());
auto* x_ref_data = X_ref.mutable_data<int64_t>();
// prepare input
for (size_t i = 0; i < x_data.size(); i++) {
x_ref_data[i] = x_data[i];
}
Out_ref.Resize(lite::DDim(out_shape));
Out_gpu.Resize(Out_ref.dims());
Out_cpu.Resize(Out_ref.dims());
cpu_base(&X_ref, &Out_ref);
device_init();
}
void device_init() {
ctx.reset(new KernelContext);
cudaStreamCreate(&stream);
param.X = &X_gpu;
param.Y = &Out_gpu;
param.maxlen = maxlen;
param.out_dtype = out_dtype;
}
void float_data_init() {
X_gpu.Assign<int64_t, lite::DDim, TARGET(kCUDA)>(X_ref.data<int64_t>(),
X_gpu.dims());
}
void half_data_init() {}
void cpu_base(const lite::Tensor* X, lite::Tensor* Out) {
auto* out_data = Out->mutable_data<float>();
for (size_t i = 0; i < x_data.size(); ++i) {
for (int j = 0; j < maxlen; ++j) {
out_data[i * maxlen + j] = j < x_data[i] ? 1 : 0;
}
}
}
int maxlen, out_dtype;
std::vector<int64_t> x_data, out_shape;
lite::Tensor X_ref, Out_ref;
lite::Tensor X_gpu, Out_gpu;
lite::Tensor Out_cpu;
operators::SequenceMaskParam param;
std::unique_ptr<KernelContext> ctx;
cudaStream_t stream;
};
TEST_F(SequenceMaskTest, fp32) {
float_data_init();
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
SequenceMaskCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param);
kernel.SetContext(std::move(ctx));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(),
Out_gpu.data<float>(),
sizeof(float) * Out_gpu.numel(),
IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) {
EXPECT_NEAR(Out_cpu.data<float>()[i], Out_ref.data<float>()[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -78,6 +78,7 @@ add_operator(shape_op_lite extra SRCS shape_op.cc DEPS ${op_DEPS}) ...@@ -78,6 +78,7 @@ add_operator(shape_op_lite extra SRCS shape_op.cc DEPS ${op_DEPS})
add_operator(sequence_expand_op_lite extra SRCS sequence_expand_op.cc DEPS ${op_DEPS}) add_operator(sequence_expand_op_lite extra SRCS sequence_expand_op.cc DEPS ${op_DEPS})
add_operator(sequence_unpad_op_lite extra SRCS sequence_unpad_op.cc DEPS ${op_DEPS}) add_operator(sequence_unpad_op_lite extra SRCS sequence_unpad_op.cc DEPS ${op_DEPS})
add_operator(sequence_pad_op_lite extra SRCS sequence_pad_op.cc DEPS ${op_DEPS}) add_operator(sequence_pad_op_lite extra SRCS sequence_pad_op.cc DEPS ${op_DEPS})
add_operator(sequence_mask_op_lite extra SRCS sequence_mask_op.cc DEPS ${op_DEPS})
add_operator(im2sequence_op extra SRCS im2sequence_op.cc DEPS ${op_DEPS}) add_operator(im2sequence_op extra SRCS im2sequence_op.cc DEPS ${op_DEPS})
add_operator(gather_op extra SRCS gather_op.cc DEPS ${op_DEPS}) add_operator(gather_op extra SRCS gather_op.cc DEPS ${op_DEPS})
add_operator(anchor_generator_op extra SRCS anchor_generator_op.cc DEPS ${op_DEPS}) add_operator(anchor_generator_op extra SRCS anchor_generator_op.cc DEPS ${op_DEPS})
......
...@@ -1045,6 +1045,14 @@ struct SequenceUnpadParam : ParamBase { ...@@ -1045,6 +1045,14 @@ struct SequenceUnpadParam : ParamBase {
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct SequenceMaskParam : ParamBase {
const lite::Tensor* X{};
const lite::Tensor* MaxLenTensor{nullptr};
lite::Tensor* Y{};
int maxlen{-1};
int out_dtype;
};
struct SequenceExpandAsParam : ParamBase { struct SequenceExpandAsParam : ParamBase {
const lite::Tensor* x{nullptr}; const lite::Tensor* x{nullptr};
const lite::Tensor* y{nullptr}; const lite::Tensor* y{nullptr};
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_mask_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequenceMaskOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Y);
return true;
}
bool SequenceMaskOp::InferShapeImpl() const { return true; }
bool SequenceMaskOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
if (opdesc.HasInput("MaxLenTensor") &&
!opdesc.Input("MaxLenTensor").empty()) {
auto var = scope->FindVar(opdesc.Input("MaxLenTensor").front());
if (var != nullptr) {
param_.MaxLenTensor = var->GetMutable<lite::Tensor>();
}
}
param_.Y =
scope->FindVar(opdesc.Output("Y").front())->GetMutable<lite::Tensor>();
param_.maxlen = opdesc.GetAttr<int>("maxlen");
param_.out_dtype = opdesc.GetAttr<int>("out_dtype");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_mask, paddle::lite::operators::SequenceMaskOp);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
namespace paddle {
namespace lite {
namespace operators {
class SequenceMaskOp : public OpLite {
public:
SequenceMaskOp() {}
explicit SequenceMaskOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_mask"; }
private:
mutable SequenceMaskParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册