From b97e6d13fd552df98bda8156e7851d21399c6579 Mon Sep 17 00:00:00 2001 From: Linjie Chen <40840292+linjieccc@users.noreply.github.com> Date: Wed, 9 Mar 2022 22:38:14 +0800 Subject: [PATCH] [phi] move viterbi_decode to phi (#40186) * move viterbi to phi * move infershape to phi * update infershape * fix * resolve conflicts --- paddle/fluid/operators/viterbi_decode_op.cc | 53 +-- paddle/fluid/operators/viterbi_decode_op.cu | 206 -------- paddle/fluid/operators/viterbi_decode_op.h | 438 ------------------ paddle/phi/infermeta/ternary.cc | 47 ++ paddle/phi/infermeta/ternary.h | 8 + .../phi/kernels/cpu/viterbi_decode_kernel.cc | 319 +++++++++++++ .../kernels/funcs/viterbi_decode_functor.h | 140 ++++++ .../phi/kernels/gpu/viterbi_decode_kernel.cu | 402 ++++++++++++++++ paddle/phi/kernels/viterbi_decode_kernel.h | 30 ++ 9 files changed, 953 insertions(+), 690 deletions(-) delete mode 100644 paddle/fluid/operators/viterbi_decode_op.cu delete mode 100644 paddle/fluid/operators/viterbi_decode_op.h create mode 100644 paddle/phi/kernels/cpu/viterbi_decode_kernel.cc create mode 100644 paddle/phi/kernels/funcs/viterbi_decode_functor.h create mode 100644 paddle/phi/kernels/gpu/viterbi_decode_kernel.cu create mode 100644 paddle/phi/kernels/viterbi_decode_kernel.h diff --git a/paddle/fluid/operators/viterbi_decode_op.cc b/paddle/fluid/operators/viterbi_decode_op.cc index bf1cdeed65..602376d54e 100644 --- a/paddle/fluid/operators/viterbi_decode_op.cc +++ b/paddle/fluid/operators/viterbi_decode_op.cc @@ -9,8 +9,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/viterbi_decode_op.h" +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -19,47 +21,6 @@ class ViterbiDecodeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ViterbiDecode"); - OP_INOUT_CHECK(ctx->HasInput("Transition"), "Input", "Transition", - "ViterbiDecode"); - OP_INOUT_CHECK(ctx->HasInput("Length"), "Input", "Length", "ViterbiDecode"); - OP_INOUT_CHECK(ctx->HasOutput("Scores"), "Output", "Scores", - "ViterbiDecode"); - OP_INOUT_CHECK(ctx->HasOutput("Path"), "Output", "Path", "ViterbiDecode"); - auto in_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_EQ(in_dims.size(), 3, - platform::errors::InvalidArgument( - "The rank of Input in ViterbiDecode must be 3. But " - "received Input's rank is %d.", - in_dims.size())); - auto length_dims = ctx->GetInputDim("Length"); - PADDLE_ENFORCE_EQ(length_dims.size(), 1, - platform::errors::InvalidArgument( - "The rank of Length in ViterbiDecode must be 1. But " - "received Length's rank is %d.", - length_dims.size())); - auto transition_dims = ctx->GetInputDim("Transition"); - PADDLE_ENFORCE_EQ( - transition_dims.size(), 2, - platform::errors::InvalidArgument( - "The rank of Transition in ViterbiDecode must be 2. But " - "received Transition's rank is %d.", - transition_dims.size())); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - in_dims[0], length_dims[0], - platform::errors::InvalidArgument( - "The batch size of Input and Length should be equal.")); - PADDLE_ENFORCE_EQ(in_dims[2], transition_dims[0], - platform::errors::InvalidArgument( - "The number of tags of Input (%d) and Transition " - "(%d) should be equal.", - transition_dims[0], in_dims[2])); - } - ctx->SetOutputDim("Scores", length_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -102,8 +63,8 @@ class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; namespace platform = paddle::platform; +DECLARE_INFER_SHAPE_FUNCTOR(viterbi_decode, ViterbiDecodeInferShapeFunctor, + PD_INFER_META(phi::ViterbiDecodeInferMeta)); REGISTER_OP_WITHOUT_GRADIENT(viterbi_decode, ops::ViterbiDecodeOp, - ops::ViterbiDecodeOpMaker); -REGISTER_OP_CPU_KERNEL( - viterbi_decode, ops::ViterbiDecodeKernel, - ops::ViterbiDecodeKernel); + ops::ViterbiDecodeOpMaker, + ViterbiDecodeInferShapeFunctor); diff --git a/paddle/fluid/operators/viterbi_decode_op.cu b/paddle/fluid/operators/viterbi_decode_op.cu deleted file mode 100644 index 68628fb274..0000000000 --- a/paddle/fluid/operators/viterbi_decode_op.cu +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/elementwise/elementwise_functor.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/operators/viterbi_decode_op.h" -#include "paddle/phi/kernels/funcs/gather.cu.h" - -#ifdef __NVCC__ -#include "cub/cub.cuh" -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif - -namespace paddle { -namespace operators { - -#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ - case (1 << (log2_block_dim)): { \ - constexpr auto kBlockDim = (1 << (log2_block_dim)); \ - __VA_ARGS__; \ - } break - -#define FIXED_BLOCK_DIM_CASE(...) \ - FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); - -int64_t ComputeBlockSize(int64_t col) { - if (col > 512) - return 1024; - else if (col > 256) - return 512; - else if (col > 128) - return 256; - else if (col > 64) - return 128; - else if (col > 32) - return 64; - else if (col > 16) - return 32; - else if (col > 8) - return 16; - else - return 8; -} - -template