// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #ifdef __NVCC__ #include #include #include #else #include #endif #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; template struct SequenceMaskForRangeFunctor { HOSTDEVICE SequenceMaskForRangeFunctor(const Tx *x, Ty *y, int maxlen) : x_(x), y_(y), maxlen_(maxlen) {} HOSTDEVICE void operator()(int y_idx) const { int x_idx = y_idx / maxlen_; int j = y_idx % maxlen_; y_[y_idx] = static_cast(j < x_[x_idx] ? 1 : 0); } private: const Tx *x_; Ty *y_; int maxlen_; }; template struct SequenceMaskFunctor { SequenceMaskFunctor(const DeviceContext &ctx, const Tx *x, Tensor *y, int limits, int maxlen) : ctx_(ctx), x_(x), y_(y), limits_(limits), maxlen_(maxlen) {} template void apply() const { auto *y_data = y_->mutable_data(ctx_.GetPlace()); platform::ForRange for_range(ctx_, limits_); for_range(SequenceMaskForRangeFunctor(x_, y_data, maxlen_)); } private: const DeviceContext &ctx_; const Tx *x_; Tensor *y_; int limits_; int maxlen_; }; template class SequenceMaskKernel : public framework::OpKernel { using Tensor = framework::LoDTensor; public: void Compute(const framework::ExecutionContext &ctx) const override { auto *x = ctx.Input("X"); auto *y = ctx.Output("Y"); int maxlen = ctx.Attr("maxlen"); if (ctx.HasInput("MaxLenTensor")) { auto max_len_tensor = ctx.Input("MaxLenTensor"); PADDLE_ENFORCE(max_len_tensor != NULL, "MaxLenTensor is NULL"); if (platform::is_gpu_place(max_len_tensor->place())) { framework::Tensor temp; TensorCopySync(*max_len_tensor, platform::CPUPlace(), &temp); maxlen = *temp.data(); } else { maxlen = *max_len_tensor->data(); } auto y_dim = framework::vectorize2int(x->dims()); y_dim.push_back(maxlen); y->Resize(framework::make_ddim(y_dim)); PADDLE_ENFORCE_GT(maxlen, 0, "MaxLenTensor value should be greater than 0"); } auto *x_data = x->data(); auto x_numel = x->numel(); if (maxlen < 0) { #ifdef __NVCC__ VLOG(10) << "SequenceMaskOp on GPU may be slow when maxlen is not provided."; maxlen = static_cast( thrust::reduce(thrust::device_pointer_cast(x_data), thrust::device_pointer_cast(x_data) + x_numel, static_cast(0), thrust::maximum())); #else maxlen = static_cast(*std::max_element(x_data, x_data + x_numel)); #endif auto y_dim = framework::vectorize2int(x->dims()); y_dim.push_back(maxlen); y->Resize(framework::make_ddim(y_dim)); } auto out_dtype = static_cast( ctx.Attr("out_dtype")); auto &dev_ctx = ctx.template device_context(); framework::VisitDataType(out_dtype, SequenceMaskFunctor( dev_ctx, x_data, y, x_numel * maxlen, maxlen)); } }; } // namespace operators } // namespace paddle