提交 4048c261 编写于 作者: Z zhupengyang 提交者: GitHub

[X86][CUDA] add attention_padding_mask op, x86 kernel, cuda kernel and unit tests (#2437)

* [X86] add attention_padding_mask op, x86 kernel and unit test

test=develop

* [CUDA] add attention_padding_mask cuda kernel and unit test

test=develop
上级 4db40afc
......@@ -28,6 +28,7 @@ add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_comput
add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
add_kernel(attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_cuda CUDA basic SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
add_kernel(var_conv_2d_compute_cuda CUDA basic SRCS var_conv_2d_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
......@@ -49,6 +50,7 @@ nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc D
nv_test(search_seq_depadding_compute_cuda_test SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_cuda)
nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda)
nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda)
nv_test(match_matrix_tensor_compute_cuda_test SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_cuda)
nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/attention_padding_mask_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
#define CUDA_NUM_THREADS 256
inline int CUDA_GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void ker_attention_padding_mask(T* out_data,
const T* attn_data,
const int* src_offset,
const int attn_seq_num,
const int attn_seq_len,
const int src_seq_num,
const int src_seq_len,
const T mask,
const int count) {
CUDA_KERNEL_LOOP(tid, count) {
int src_word_id = tid % src_seq_len;
int tmp_tid = tid / src_seq_len;
int attn_seq_id = tmp_tid / attn_seq_len;
int attn_word_id = tmp_tid % attn_seq_len;
int src_seq_id = attn_seq_id % src_seq_num;
int cur_len = src_offset[src_seq_id + 1] - src_offset[src_seq_id];
if (src_word_id >= cur_len) {
out_data[tid] = mask;
} else {
out_data[tid] = attn_data[tid];
}
}
}
void AttentionPaddingMaskCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
auto attn = param.X;
auto src = param.Y;
const int count = attn->numel();
auto attn_offset = attn->lod()[0];
auto src_offset = src->lod()[0];
const int attn_seq_num = attn_offset.size() - 1;
const int attn_seq_len = attn_offset[1];
const int src_seq_num = src_offset.size() - 1;
const int src_seq_len = count / attn->dims()[0];
auto out = param.Out;
out->Resize(attn->dims());
out->set_lod(attn->lod());
auto attn_data = attn->data<float>();
auto out_data = out->mutable_data<float>(TARGET(kCUDA));
std::vector<int> src_offset_cpu(src_offset.size(), 0);
for (int i = 0; i < src_offset.size(); i++) {
src_offset_cpu[i] = src_offset[i];
}
src_offset_cuda.Resize({static_cast<int64_t>(src_offset.size())});
auto src_offset_cuda_data = src_offset_cuda.mutable_data<int>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(src_offset_cuda_data,
src_offset_cpu.data(),
sizeof(int) * src_offset.size(),
IoDirection::HtoD,
stream);
ker_attention_padding_mask<
float><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
out_data,
attn_data,
src_offset_cuda_data,
attn_seq_num,
attn_seq_len,
src_seq_num,
src_seq_len,
param.mask,
count);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(attention_padding_mask,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::AttentionPaddingMaskCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("pad_begin", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class AttentionPaddingMaskCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::AttentionPaddingMaskParam;
void Run() override;
virtual ~AttentionPaddingMaskCompute() = default;
private:
lite::Tensor src_offset_cuda;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/attention_padding_mask_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void attention_padding_mask_ref(
const Tensor& x,
const Tensor& y,
Tensor* out,
Tensor* pad_begin,
const operators::AttentionPaddingMaskParam& param) {
auto attn_offset = x.lod()[0];
auto src_offset = y.lod()[0];
int attn_seq_num = attn_offset.size() - 1;
int src_seq_num = src_offset.size() - 1;
int attn_seq_len = attn_offset[1];
int src_seq_len = x.dims()[1];
CHECK_EQ(attn_seq_num % src_seq_num, 0);
auto count = x.numel();
auto attn_data = x.data<float>();
out->Resize(x.dims());
out->set_lod(x.lod());
auto out_data = out->mutable_data<float>();
memcpy(out_data, attn_data, count * sizeof(float));
for (int i = 0; i < attn_seq_num; ++i) {
for (int j = 0; j < attn_seq_len; ++j) {
auto tmp_out_data = out_data + src_seq_len * (attn_seq_len * i + j);
int src_seq_idx = i % src_seq_num;
int cur_len = src_offset[src_seq_idx + 1] - src_offset[src_seq_idx];
for (int k = cur_len; k < src_seq_len; k++) {
tmp_out_data[k] = param.mask;
}
}
}
}
void prepare_input(Tensor* x, const LoD& lod, int64_t dim2rd) {
std::vector<int64_t> x_dims{static_cast<int64_t>(lod[0].back()), dim2rd};
x->Resize(x_dims);
x->set_lod(lod);
auto x_data = x->mutable_data<float>();
auto x_num = x->numel();
for (int i = 0; i < x_num; i++) {
x_data[i] = (i - x_num) * 1.1;
}
}
int get_max_len(const LoD& lod) {
int max_len = 0;
auto offset = lod[0];
for (int i = 0; i < offset.size() - 1; i++) {
int cur_len = offset[i + 1] - offset[i];
max_len = max_len < cur_len ? cur_len : max_len;
}
return max_len;
}
TEST(attention_padding_mask_cuda, run_test) {
lite::Tensor x, y, x_cpu, y_cpu;
lite::Tensor out, pad_begin, out_cpu, out_ref, pad_begin_ref;
LoD x_lod{{0, 3, 6, 9, 12}}, y_lod{{0, 4, 6}};
prepare_input(&x_cpu, x_lod, get_max_len(y_lod));
prepare_input(&y_cpu, y_lod, 1);
x.Resize(x_cpu.dims());
x.set_lod(x_cpu.lod());
auto x_cpu_data = x_cpu.mutable_data<float>();
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
y.Resize(y_cpu.dims());
y.set_lod(y_cpu.lod());
operators::AttentionPaddingMaskParam param;
param.X = &x;
param.Y = &y;
param.pad_id = 12800001;
param.mask = -90000000.f;
param.Out = &out;
param.pad_begin = &pad_begin;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto context = ctx->As<CUDAContext>();
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
AttentionPaddingMaskCompute attention_padding_mask_kernel;
attention_padding_mask_kernel.SetParam(param);
attention_padding_mask_kernel.SetContext(std::move(ctx));
attention_padding_mask_kernel.Run();
cudaDeviceSynchronize();
auto out_data = out.mutable_data<float>(TARGET(kCUDA));
out_cpu.Resize(out.dims());
auto out_cpu_data = out_cpu.mutable_data<float>();
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
attention_padding_mask_ref(x_cpu, y_cpu, &out_ref, &pad_begin_ref, param);
auto out_ref_data = out_ref.data<float>();
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -47,6 +47,7 @@ add_kernel(search_seq_depadding_compute_x86 X86 basic SRCS search_seq_depadding_
add_kernel(search_grnn_compute_x86 X86 basic SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(var_conv_2d_compute_x86 X86 basic SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps} blas fluid_data_type)
add_kernel(attention_padding_mask_compute_x86 X86 basic SRCS attention_padding_mask_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_arithmetic_compute_x86 X86 basic SRCS sequence_arithmetic_compute.cc DEPS ${lite_kernel_deps})
# for content-dnn specific
......@@ -90,4 +91,5 @@ endif()
lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86)
lite_cc_test(test_match_matrix_compute_x86 SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_x86)
lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86)
lite_cc_test(test_attention_padding_mask_compute_x86 SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_x86)
lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_x86)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/attention_padding_mask_compute.h"
REGISTER_LITE_KERNEL(
attention_padding_mask,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::AttentionPaddingMaskCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("pad_begin", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <random>
#include <string>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
#include "lite/fluid/eigen.h"
#include "lite/operators/attention_padding_mask_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class AttentionPaddingMaskCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::AttentionPaddingMaskParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto src = param.Y;
auto attn = param.X;
auto src_offset = src->lod()[0];
auto attn_offset = attn->lod()[0];
int attn_seq_num = attn_offset.size() - 1;
int src_seq_num = src_offset.size() - 1;
int attn_seq_len = attn_offset[1];
int src_seq_len = attn->numel() / attn->dims()[0];
size_t count = attn->numel();
auto attn_data = attn->data<T>();
auto out = param.Out;
out->Resize(attn->dims());
out->set_lod(attn->lod());
auto out_data = out->mutable_data<T>();
memcpy(out_data, attn_data, count * sizeof(T));
for (int i = 0; i < attn_seq_num; ++i) {
for (int j = 0; j < attn_seq_len; ++j) {
auto tmp_out_data = out_data + src_seq_len * (attn_seq_len * i + j);
int src_seq_idx = i % src_seq_num;
int cur_len = src_offset[src_seq_idx + 1] - src_offset[src_seq_idx];
for (int k = cur_len; k < src_seq_len; k++) {
tmp_out_data[k] = param.mask;
}
}
}
}
virtual ~AttentionPaddingMaskCompute() = default;
private:
lite::Tensor src_offset_;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/attention_padding_mask_compute.cc"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
void attention_padding_mask_ref(
const Tensor& x,
const Tensor& y,
Tensor* out,
Tensor* pad_begin,
const operators::AttentionPaddingMaskParam& param) {
auto attn_offset = x.lod()[0];
auto src_offset = y.lod()[0];
int attn_seq_num = attn_offset.size() - 1;
int src_seq_num = src_offset.size() - 1;
int attn_seq_len = attn_offset[1];
int src_seq_len = x.dims()[1];
CHECK_EQ(attn_seq_num % src_seq_num, 0);
auto count = x.numel();
auto attn_data = x.data<float>();
out->Resize(x.dims());
out->set_lod(x.lod());
auto out_data = out->mutable_data<float>();
memcpy(out_data, attn_data, count * sizeof(float));
for (int i = 0; i < attn_seq_num; ++i) {
for (int j = 0; j < attn_seq_len; ++j) {
auto tmp_out_data = out_data + src_seq_len * (attn_seq_len * i + j);
int src_seq_idx = i % src_seq_num;
int cur_len = src_offset[src_seq_idx + 1] - src_offset[src_seq_idx];
for (int k = cur_len; k < src_seq_len; k++) {
tmp_out_data[k] = param.mask;
}
}
}
}
void prepare_input(Tensor* x, const LoD& lod, int64_t dim2rd) {
std::vector<int64_t> x_dims{static_cast<int64_t>(lod[0].back()), dim2rd};
x->Resize(x_dims);
x->set_lod(lod);
auto x_data = x->mutable_data<float>();
auto x_num = x->numel();
for (int i = 0; i < x_num; i++) {
x_data[i] = (i - x_num) * 1.1;
}
}
int get_max_len(const LoD& lod) {
int max_len = 0;
auto offset = lod[0];
for (int i = 0; i < offset.size() - 1; i++) {
int cur_len = offset[i + 1] - offset[i];
max_len = max_len < cur_len ? cur_len : max_len;
}
return max_len;
}
TEST(attention_padding_mask_x86, retrive_op) {
auto attention_padding_mask =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"attention_padding_mask");
ASSERT_FALSE(attention_padding_mask.empty());
ASSERT_TRUE(attention_padding_mask.front());
}
TEST(attention_padding_mask_x86, init) {
AttentionPaddingMaskCompute<float> attention_padding_mask;
ASSERT_EQ(attention_padding_mask.precision(), PRECISION(kFloat));
ASSERT_EQ(attention_padding_mask.target(), TARGET(kX86));
}
TEST(attention_padding_mask_x86, run_test) {
lite::Tensor x, y;
lite::Tensor out, pad_begin, out_ref, pad_begin_ref;
LoD x_lod{{0, 3, 6, 9, 12}}, y_lod{{0, 4, 6}};
prepare_input(&x, x_lod, get_max_len(y_lod));
prepare_input(&y, y_lod, 1);
operators::AttentionPaddingMaskParam param;
param.X = &x;
param.Y = &y;
param.pad_id = 12800001;
param.mask = -90000000.f;
param.Out = &out;
param.pad_begin = &pad_begin;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
AttentionPaddingMaskCompute<float> attention_padding_mask_kernel;
attention_padding_mask_kernel.SetParam(param);
attention_padding_mask_kernel.SetContext(std::move(ctx));
attention_padding_mask_kernel.Run();
attention_padding_mask_ref(x, y, &out_ref, &pad_begin_ref, param);
auto out_data = out.data<float>();
auto out_ref_data = out_ref.data<float>();
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(attention_padding_mask, kX86, kFloat, kNCHW, def);
......@@ -86,6 +86,7 @@ add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc
add_operator(search_grnn_op_lite extra SRCS search_grnn_op.cc DEPS ${op_DEPS})
add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS})
add_operator(var_conv_2d_op_lite extra SRCS var_conv_2d_op.cc DEPS ${op_DEPS})
add_operator(attention_padding_mask_op_lite extra SRCS attention_padding_mask_op.cc DEPS ${op_DEPS})
add_operator(sequence_arithmetic_op_lite extra SRCS sequence_arithmetic_op.cc DEPS ${op_DEPS})
# for OCR specific
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/attention_padding_mask_op.h"
#include "lite/core/op_registry.h"
#include "lite/core/scope.h"
namespace paddle {
namespace lite {
namespace operators {
bool AttentionPaddingMaskOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.Out);
CHECK_OR_FALSE(param_.pad_begin);
return true;
}
bool AttentionPaddingMaskOp::InferShape() const {
auto src_len = param_.X->lod()[0][1];
CHECK_EQ(src_len, param_.X->dims()[1])
<< "Mismatch source length, expect: " << src_len
<< ", get: " << param_.X->lod()[0][1];
auto att_batch = param_.X->lod()[0].size() - 1;
auto src_batch = param_.Y->lod()[0].size() - 1;
CHECK_EQ(att_batch % src_batch, 0)
<< "Mismatch batch size, bottom0: " << att_batch
<< ", bottom1: " << src_batch;
param_.pad_begin->Resize({static_cast<int64_t>(src_batch)});
param_.Out->Resize(param_.X->dims());
param_.Out->set_lod(param_.X->lod());
return true;
}
bool AttentionPaddingMaskOp::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
param_.X = scope->FindTensor(op_desc.Input("X").front());
param_.Y = scope->FindTensor(op_desc.Input("Y").front());
param_.Out = scope->FindMutableTensor(op_desc.Input("Out").front());
param_.pad_begin =
scope->FindMutableTensor(op_desc.Input("pad_begin").front());
param_.pad_id = op_desc.GetAttr<int>("pad_id");
param_.mask = op_desc.GetAttr<float>("mask");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(attention_padding_mask,
paddle::lite::operators::AttentionPaddingMaskOp);
REGISTER_LITE_OP(search_attention_padding_mask,
paddle::lite::operators::AttentionPaddingMaskOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class AttentionPaddingMaskOp : public OpLite {
public:
AttentionPaddingMaskOp() {}
explicit AttentionPaddingMaskOp(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "attention_padding_mask"; }
private:
mutable AttentionPaddingMaskParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -777,6 +777,15 @@ struct SequenceConcatParam {
lite::Tensor* Out{};
};
struct AttentionPaddingMaskParam {
const lite::Tensor* X{};
const lite::Tensor* Y{};
int pad_id;
float mask;
lite::Tensor* Out{};
lite::Tensor* pad_begin{};
};
struct SequenceArithmeticParam {
const lite::Tensor* X{};
const lite::Tensor* Y{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册