未验证 提交 d2b1e3c2 编写于 作者: W Wang Xin 提交者: GitHub

sequence_mask functionalization (#53478)

* sequence_mask functionalization

* fix sequence_mask test
上级 0603777b
......@@ -12,9 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_ops/sequence_mask_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -101,13 +99,3 @@ REGISTER_OPERATOR(
paddle::operators::SequenceMaskOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
namespace ops = paddle::operators;
PD_REGISTER_STRUCT_KERNEL(sequence_mask,
CPU,
ALL_LAYOUT,
ops::SequenceMaskKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_ops/sequence_mask_op.h"
namespace ops = paddle::operators;
PD_REGISTER_STRUCT_KERNEL(sequence_mask,
GPU,
ALL_LAYOUT,
ops::SequenceMaskKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(__NVCC__) || defined(__HIPCC__)
#include <thrust/device_ptr.h>
#include <thrust/functional.h>
#include <thrust/reduce.h>
#else
#include <algorithm>
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename Tx, typename Ty>
struct SequenceMaskForRangeFunctor {
HOSTDEVICE SequenceMaskForRangeFunctor(const Tx *x, Ty *y, int maxlen)
: x_(x), y_(y), maxlen_(maxlen) {}
HOSTDEVICE void operator()(int y_idx) const {
int x_idx = y_idx / maxlen_;
int j = y_idx % maxlen_;
y_[y_idx] = static_cast<Ty>(j < x_[x_idx] ? 1 : 0);
}
private:
const Tx *x_;
Ty *y_;
int maxlen_;
};
template <typename DeviceContext, typename Tx>
struct SequenceMaskFunctor {
SequenceMaskFunctor(const DeviceContext &ctx,
const Tx *x,
phi::DenseTensor *y,
int limits,
int maxlen)
: ctx_(ctx), x_(x), y_(y), limits_(limits), maxlen_(maxlen) {}
template <typename Ty>
void apply() const {
auto *y_data = y_->mutable_data<Ty>(ctx_.GetPlace());
platform::ForRange<DeviceContext> for_range(ctx_, limits_);
for_range(SequenceMaskForRangeFunctor<Tx, Ty>(x_, y_data, maxlen_));
}
private:
const DeviceContext &ctx_;
const Tx *x_;
phi::DenseTensor *y_;
int limits_;
int maxlen_;
};
template <typename Tx, typename DeviceContext>
class SequenceMaskKernel : public framework::OpKernel<Tx> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<phi::DenseTensor>("X");
auto *y = ctx.Output<phi::DenseTensor>("Y");
int maxlen = ctx.Attr<int>("maxlen");
if (ctx.HasInput("MaxLenTensor")) {
auto max_len_tensor = ctx.Input<phi::DenseTensor>("MaxLenTensor");
PADDLE_ENFORCE_NOT_NULL(max_len_tensor,
platform::errors::InvalidArgument(
"Input(MaxLenTensor) should not be NULL."
"But received Input(MaxLenTensor) is NULL"));
if (platform::is_gpu_place(max_len_tensor->place())) {
phi::DenseTensor temp;
paddle::framework::TensorCopySync(
*max_len_tensor, platform::CPUPlace(), &temp);
maxlen = *temp.data<int32_t>();
} else {
maxlen = *max_len_tensor->data<int32_t>();
}
auto y_dim = phi::vectorize<int>(x->dims());
y_dim.push_back(maxlen);
y->Resize(phi::make_ddim(y_dim));
PADDLE_ENFORCE_GT(
maxlen,
0,
platform::errors::InvalidArgument(
"Input(MaxLenTensor) value should be greater than 0. But "
"received Input(MaxLenTensor) value = %d.",
maxlen));
}
auto *x_data = x->data<Tx>();
auto x_numel = x->numel();
if (maxlen < 0) {
if (x_numel == 0) {
maxlen = 0;
} else {
#if defined(__NVCC__) || defined(__HIPCC__)
VLOG(10)
<< "SequenceMaskOp on GPU may be slow when maxlen is not provided.";
maxlen = static_cast<int>(
thrust::reduce(thrust::device_pointer_cast(x_data),
thrust::device_pointer_cast(x_data) + x_numel,
static_cast<Tx>(0),
thrust::maximum<Tx>()));
#else
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
#endif
}
auto y_dim = phi::vectorize<int>(x->dims());
y_dim.push_back(maxlen);
y->Resize(phi::make_ddim(y_dim));
}
auto out_dtype = static_cast<framework::proto::VarType::Type>(
ctx.Attr<int>("out_dtype"));
auto &dev_ctx = ctx.template device_context<DeviceContext>();
framework::VisitDataType(out_dtype,
SequenceMaskFunctor<DeviceContext, Tx>(
dev_ctx, x_data, y, x_numel * maxlen, maxlen));
}
};
} // namespace operators
} // namespace paddle
......@@ -28,7 +28,6 @@ register_unity_group(
sequence_enumerate_op.cu
sequence_erase_op.cu
sequence_expand_op.cu
sequence_mask_op.cu
sequence_pad_op.cu
sequence_expand_as_op.cu
sequence_reshape_op.cu
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sequence_mask_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/sequence_mask_kernel_impl.h"
PD_REGISTER_KERNEL(sequence_mask,
CPU,
ALL_LAYOUT,
phi::SequenceMaskKernel,
float,
double,
int,
int64_t) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
namespace funcs {
template <typename Tx, typename Ty>
struct SequenceMaskForRangeFunctor {
HOSTDEVICE SequenceMaskForRangeFunctor(const Tx *x, Ty *y, int maxlen)
: x_(x), y_(y), maxlen_(maxlen) {}
HOSTDEVICE void operator()(int y_idx) const {
int x_idx = y_idx / maxlen_;
int j = y_idx % maxlen_;
y_[y_idx] = static_cast<Ty>(j < x_[x_idx] ? 1 : 0);
}
private:
const Tx *x_;
Ty *y_;
int maxlen_;
};
template <typename DeviceContext, typename Tx>
struct SequenceMaskFunctor {
SequenceMaskFunctor(const DeviceContext &ctx,
const Tx *x,
phi::DenseTensor *y,
int limits,
int maxlen)
: ctx_(ctx), x_(x), y_(y), limits_(limits), maxlen_(maxlen) {}
template <typename Ty>
void apply() const {
ctx_.template Alloc<Ty>(y_);
auto *y_data = y_->data<Ty>();
phi::funcs::ForRange<DeviceContext> for_range(ctx_, limits_);
for_range(SequenceMaskForRangeFunctor<Tx, Ty>(x_, y_data, maxlen_));
}
private:
const DeviceContext &ctx_;
const Tx *x_;
phi::DenseTensor *y_;
int limits_;
int maxlen_;
};
} // namespace funcs
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sequence_mask_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/sequence_mask_kernel_impl.h"
PD_REGISTER_KERNEL(sequence_mask,
GPU,
ALL_LAYOUT,
phi::SequenceMaskKernel,
float,
double,
int,
int64_t) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(__NVCC__) || defined(__HIPCC__)
#include <thrust/device_ptr.h>
#include <thrust/functional.h>
#include <thrust/reduce.h>
#else
#include <algorithm>
#endif
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/sequence_mask.h"
namespace phi {
template <typename T, typename Context>
void SequenceMaskKernel(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& max_len_tensor,
int maxlen,
int out_dtype,
DenseTensor* y) {
if (max_len_tensor) {
bool is_gpu_place = ctx.GetPlace().GetType() == phi::AllocationType::GPU;
if (is_gpu_place) {
phi::DenseTensor temp;
phi::Copy(ctx, *max_len_tensor.get_ptr(), phi::CPUPlace(), false, &temp);
maxlen = *temp.data<int32_t>();
} else {
maxlen = *max_len_tensor.get_ptr()->data<int32_t>();
}
auto y_dim = phi::vectorize<int>(x.dims());
y_dim.push_back(maxlen);
y->Resize(phi::make_ddim(y_dim));
PADDLE_ENFORCE_GT(
maxlen,
0,
phi::errors::InvalidArgument(
"Input(MaxLenTensor) value should be greater than 0. But "
"received Input(MaxLenTensor) value = %d.",
maxlen));
}
auto* x_data = x.data<T>();
auto x_numel = x.numel();
if (maxlen < 0) {
if (x_numel == 0) {
maxlen = 0;
} else {
#if defined(__NVCC__) || defined(__HIPCC__)
VLOG(10)
<< "SequenceMaskOp on GPU may be slow when maxlen is not provided.";
maxlen = static_cast<int>(
thrust::reduce(thrust::device_pointer_cast(x_data),
thrust::device_pointer_cast(x_data) + x_numel,
static_cast<T>(0),
thrust::maximum<T>()));
#else
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
#endif
}
auto y_dim = phi::vectorize<int>(x.dims());
y_dim.push_back(maxlen);
y->Resize(phi::make_ddim(y_dim));
}
phi::VisitDataType(phi::TransToPhiDataType(out_dtype),
phi::funcs::SequenceMaskFunctor<Context, T>(
ctx, x_data, y, x_numel * maxlen, maxlen));
}
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SequenceMaskKernel(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& max_len_tensor,
int maxlen,
int out_dtype,
DenseTensor* y);
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SequenceMaskOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"sequence_mask", {"X", "MaxLenTensor"}, {"maxlen", "out_dtype"}, {"Y"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(sequence_mask, phi::SequenceMaskOpArgumentMapping);
......@@ -28,9 +28,18 @@ sys.path.append("../../python/paddle/fluid/tests/unittests")
from eager_op_test import OpTest
def sequence_mask_wraper(x, maxlen_tensor=None, maxlen=-1, mask_dtype='int64'):
if maxlen_tensor is not None:
maxlen = maxlen_tensor
return paddle.nn.functional.sequence_mask(
x, maxlen=maxlen, dtype=mask_dtype
)
class SequenceMaskTestBase(OpTest):
def initDefaultParameters(self):
self.op_type = 'sequence_mask'
self.python_api = sequence_mask_wraper
self.maxlen = 10
self.mask_dtype = 'int64'
self.x = [[0, 3, 4], [5, 7, 9]]
......@@ -100,6 +109,7 @@ class SequenceMaskTest6(SequenceMaskTestBase):
class SequenceMaskTestBase_tensor_attr(OpTest):
def initDefaultParameters(self):
self.op_type = 'sequence_mask'
self.python_api = sequence_mask_wraper
self.maxlen = 10
self.maxlen_tensor = np.ones((1), 'int32') * 10
self.mask_dtype = 'int64'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册