diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 312978a010b30ceb394d25269d9f4daa804702e0..2df7dd06f2cc89ba58f278952fea2b324d747610 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -240,7 +240,7 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, x_dims, y_dims, x_dims_array[i], y_dims_array[i], i)); if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) || (x_dims_array[i] == 1 && y_dims_array[i] == 1)) { - out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]); + out_dims_array[i] = (std::max)(x_dims_array[i], y_dims_array[i]); } else { out_dims_array[i] = -1; } @@ -1779,7 +1779,7 @@ void CommonElementwiseBroadcastForward( const framework::Tensor *y, framework::Tensor *z, const framework::DDim &x_dims, const framework::DDim &y_dims, Functor func, int axis, const bool is_xsize_larger = true) { - int max_dim = std::max(x_dims.size(), y_dims.size()); + int max_dim = (std::max)(x_dims.size(), y_dims.size()); axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); PADDLE_ENFORCE_GE( axis, 0, diff --git a/paddle/fluid/operators/viterbi_decode_op.cc b/paddle/fluid/operators/viterbi_decode_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf1cdeed65a8427c19410347209faa099673cb7c --- /dev/null +++ b/paddle/fluid/operators/viterbi_decode_op.cc @@ -0,0 +1,109 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/viterbi_decode_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class ViterbiDecodeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ViterbiDecode"); + OP_INOUT_CHECK(ctx->HasInput("Transition"), "Input", "Transition", + "ViterbiDecode"); + OP_INOUT_CHECK(ctx->HasInput("Length"), "Input", "Length", "ViterbiDecode"); + OP_INOUT_CHECK(ctx->HasOutput("Scores"), "Output", "Scores", + "ViterbiDecode"); + OP_INOUT_CHECK(ctx->HasOutput("Path"), "Output", "Path", "ViterbiDecode"); + auto in_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(in_dims.size(), 3, + platform::errors::InvalidArgument( + "The rank of Input in ViterbiDecode must be 3. But " + "received Input's rank is %d.", + in_dims.size())); + auto length_dims = ctx->GetInputDim("Length"); + PADDLE_ENFORCE_EQ(length_dims.size(), 1, + platform::errors::InvalidArgument( + "The rank of Length in ViterbiDecode must be 1. But " + "received Length's rank is %d.", + length_dims.size())); + auto transition_dims = ctx->GetInputDim("Transition"); + PADDLE_ENFORCE_EQ( + transition_dims.size(), 2, + platform::errors::InvalidArgument( + "The rank of Transition in ViterbiDecode must be 2. But " + "received Transition's rank is %d.", + transition_dims.size())); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + in_dims[0], length_dims[0], + platform::errors::InvalidArgument( + "The batch size of Input and Length should be equal.")); + PADDLE_ENFORCE_EQ(in_dims[2], transition_dims[0], + platform::errors::InvalidArgument( + "The number of tags of Input (%d) and Transition " + "(%d) should be equal.", + transition_dims[0], in_dims[2])); + } + ctx->SetOutputDim("Scores", length_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); + } +}; + +class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "Input", + "The unary emission tensor. The shape of Input must be (batch_size," + "sequence_length, num_tags). "); + AddInput("Transition", + "The transition matrix. The shape of Transition must be ( " + "num_tags, num_tags). "); + AddInput("Length", + "The input length tensor storing real length of each sequence for " + "correctness. The shape of Length MUST be (batch_size)."); + AddOutput("Scores", + "The scores tensor containing the score for the Viterbi " + "sequence. The shape of Scores MUST be (batch_size)."); + AddOutput("Path", + "The paths tensor containing the highest scoring tag indices. " + "The shape of Scores MUST be (batch_size, sequence_length)."); + AddAttr("include_bos_eos_tag", + "If set to True, the last row and the last column of " + "transitions will be considered as start tag.") + .SetDefault(true); + AddComment(R"DOC( + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace platform = paddle::platform; +REGISTER_OP_WITHOUT_GRADIENT(viterbi_decode, ops::ViterbiDecodeOp, + ops::ViterbiDecodeOpMaker); +REGISTER_OP_CPU_KERNEL( + viterbi_decode, ops::ViterbiDecodeKernel, + ops::ViterbiDecodeKernel); diff --git a/paddle/fluid/operators/viterbi_decode_op.cu b/paddle/fluid/operators/viterbi_decode_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..086ff05b08461245fc4cb1592dd107027db1f86a --- /dev/null +++ b/paddle/fluid/operators/viterbi_decode_op.cu @@ -0,0 +1,200 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/gather.cu.h" +#include "paddle/fluid/operators/viterbi_decode_op.h" + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +namespace paddle { +namespace operators { + +#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ + case (1 << (log2_block_dim)): { \ + constexpr auto kBlockDim = (1 << (log2_block_dim)); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM_CASE(...) \ + FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); + +int64_t ComputeBlockSize(int64_t col) { + if (col > 512) + return 1024; + else if (col > 256) + return 512; + else if (col > 128) + return 256; + else if (col > 64) + return 128; + else if (col > 32) + return 64; + else if (col > 16) + return 32; + else if (col > 8) + return 16; + else + return 8; +} + +template