未验证 提交 6072aecb 编写于 作者: J Jack Zhou 提交者: GitHub

Add viterbi decode (#35778)

* add viterbi decode cpu kernel

* add viterbi decoder api in paddle.text

* add a data buffer once to avoid create many small pieces of data buffer frequently

* fix viterbi max_seq_length bug

* fix seq_len=1 bug

* fix device context

* move split out of for loop

* remove INVERSE_SUB

* remove 2 GET_CAST_MASK

* remove 1 loop

* remove Functor

* add to_static deploy code

* use MAX_FUNC instead of ELE_MAX

* add MaxFunctor

* impl max_func

* remove MaxFunctor

* remove cast op

* use REGISTER_OP_WITHOUT_GRADIENT

* add viterbi cuda kernel

* add FIX_BLOCKDIM_CASE macro

* add MKL add, mul; add get data mask

* add arange mkl impl

* add CPU Argmax

* add cpu gather

* use EXECUTE_MKL_ELEMENT_BINARY_OP instead of some ADD, MUL

* use SameDimsBinaryOP instead of EXECUTE_MKL_ELEMENT_BINARY_OP

* use SAME_DIMS_ELEMENT_BINARY_OP

* add SimpleBroadcastBinaryOP

* use int instead of int64_t to accelerate

* optimize SimpleBroadcastBinaryOP

* optimize SimpleBroadcastBinaryOP

* optimize performance in both single thread and multithread situation

* remove useless line

* remove useless code

* add CREATE_TENSOR_BUFFER macro

* add INIT_REQUIRED_TENSOR macro

* add comment

* fix windows ci

* add viterbi unittest

* remove cuda add functor

* remove cuda equal

* remove a template function

* fix windows ci

* fix windows dtype

* remove some template instance

* remove useless header file

* remove some blockdim

* remove transpose impl

* accelerate cpu performance on single thread situation

* viterbi_decode->crf_decode

* rename crf params name

* add viterbi api test

* remove useless import

* add enable_static

* use viterbi decoder

* fix viterbi len=1

* fix  viterbi unittest

* remove useless comments

* reconstruct viterbi decode

* remove ADD,SUB,MUL structure

* fix coverage

* remove CREATE_TENSOR

* add name args

* crf.py->ops.py; with_start_stop_tag->include_start_end_tag

* update crf_decode en docs

* fix viterbi decode en docs

* fix some review comments

* add FIXED_BLOCK_DIM_CASE in cuda

* push_back->emplace_back

* crf_decode->viterbi_decode; include_start_end_tag->include_bos_eos_tag

* paddle.text.ops.viterbi_decode->paddle.text.viterbi_decode

* fix viterbi_decode en docs
上级 66f4b292
......@@ -240,7 +240,7 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
x_dims, y_dims, x_dims_array[i], y_dims_array[i], i));
if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) ||
(x_dims_array[i] == 1 && y_dims_array[i] == 1)) {
out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]);
out_dims_array[i] = (std::max)(x_dims_array[i], y_dims_array[i]);
} else {
out_dims_array[i] = -1;
}
......@@ -1779,7 +1779,7 @@ void CommonElementwiseBroadcastForward(
const framework::Tensor *y, framework::Tensor *z,
const framework::DDim &x_dims, const framework::DDim &y_dims, Functor func,
int axis, const bool is_xsize_larger = true) {
int max_dim = std::max(x_dims.size(), y_dims.size());
int max_dim = (std::max)(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis, 0,
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/viterbi_decode_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class ViterbiDecodeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ViterbiDecode");
OP_INOUT_CHECK(ctx->HasInput("Transition"), "Input", "Transition",
"ViterbiDecode");
OP_INOUT_CHECK(ctx->HasInput("Length"), "Input", "Length", "ViterbiDecode");
OP_INOUT_CHECK(ctx->HasOutput("Scores"), "Output", "Scores",
"ViterbiDecode");
OP_INOUT_CHECK(ctx->HasOutput("Path"), "Output", "Path", "ViterbiDecode");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input in ViterbiDecode must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
auto length_dims = ctx->GetInputDim("Length");
PADDLE_ENFORCE_EQ(length_dims.size(), 1,
platform::errors::InvalidArgument(
"The rank of Length in ViterbiDecode must be 1. But "
"received Length's rank is %d.",
length_dims.size()));
auto transition_dims = ctx->GetInputDim("Transition");
PADDLE_ENFORCE_EQ(
transition_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Transition in ViterbiDecode must be 2. But "
"received Transition's rank is %d.",
transition_dims.size()));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
in_dims[0], length_dims[0],
platform::errors::InvalidArgument(
"The batch size of Input and Length should be equal."));
PADDLE_ENFORCE_EQ(in_dims[2], transition_dims[0],
platform::errors::InvalidArgument(
"The number of tags of Input (%d) and Transition "
"(%d) should be equal.",
transition_dims[0], in_dims[2]));
}
ctx->SetOutputDim("Scores", length_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"Input",
"The unary emission tensor. The shape of Input must be (batch_size,"
"sequence_length, num_tags). ");
AddInput("Transition",
"The transition matrix. The shape of Transition must be ( "
"num_tags, num_tags). ");
AddInput("Length",
"The input length tensor storing real length of each sequence for "
"correctness. The shape of Length MUST be (batch_size).");
AddOutput("Scores",
"The scores tensor containing the score for the Viterbi "
"sequence. The shape of Scores MUST be (batch_size).");
AddOutput("Path",
"The paths tensor containing the highest scoring tag indices. "
"The shape of Scores MUST be (batch_size, sequence_length).");
AddAttr<bool>("include_bos_eos_tag",
"If set to True, the last row and the last column of "
"transitions will be considered as start tag.")
.SetDefault(true);
AddComment(R"DOC(
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace platform = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(viterbi_decode, ops::ViterbiDecodeOp,
ops::ViterbiDecodeOpMaker);
REGISTER_OP_CPU_KERNEL(
viterbi_decode, ops::ViterbiDecodeKernel<platform::CPUDeviceContext, float>,
ops::ViterbiDecodeKernel<platform::CPUDeviceContext, double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/viterbi_decode_op.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
namespace paddle {
namespace operators {
#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \
case (1 << (log2_block_dim)): { \
constexpr auto kBlockDim = (1 << (log2_block_dim)); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM_CASE(...) \
FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);
int64_t ComputeBlockSize(int64_t col) {
if (col > 512)
return 1024;
else if (col > 256)
return 512;
else if (col > 128)
return 256;
else if (col > 64)
return 128;
else if (col > 32)
return 64;
else if (col > 16)
return 32;
else if (col > 8)
return 16;
else
return 8;
}
template <template <typename T> typename BinaryFunctor, typename T>
struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* output) {
std::vector<const Tensor*> ins{&lhs, &rhs};
std::vector<Tensor*> outs{output};
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, BinaryFunctor<T>());
}
};
template <template <typename T> typename CompareFunctor, typename T>
struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> {
void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* mask) {
std::vector<const Tensor*> ins = {&lhs, &rhs};
std::vector<Tensor*> outs = {mask};
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, int64_t, T>(
dev_ctx, ins, &outs, CompareFunctor<int64_t>());
}
};
template <typename T, typename IndType, size_t BlockDim>
__global__ void ArgmaxCUDAKernel(const int64_t height, // n * h
const int64_t width, // c
const int64_t post_size, // h
const T* in, IndType* out_idx, T* out) {
typedef cub::BlockReduce<cub::KeyValuePair<int, T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
cub::ArgMax reducer;
T init = (std::numeric_limits<T>::lowest)(); // for windows compile
for (int idx = blockIdx.x; idx < height; idx += gridDim.x) {
cub::KeyValuePair<int, T> kv_pair = {-1, init};
int h = idx / post_size;
int w = idx % post_size;
for (int k = threadIdx.x; k < width; k += blockDim.x) {
kv_pair =
reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair);
}
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer);
if (threadIdx.x == 0) {
// return max, argmax
if (out_idx != nullptr) out_idx[idx] = static_cast<IndType>(kv_pair.key);
if (out != nullptr) out[idx] = kv_pair.value;
}
__syncthreads();
}
}
__global__ void ARangeKernel(int64_t* data, int num, int64_t scale) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int start = idx; idx < num; idx += gridDim.x) {
data[idx] = idx * scale;
}
}
template <>
struct ARange<platform::CUDADeviceContext> {
void operator()(const platform::CUDADeviceContext& dev_ctx, int64_t* data,
int num, int64_t scale) {
int64_t kBlockDim = ComputeBlockSize(num);
// kBlockDim > num at most of time, so we can set grid = 1
ARangeKernel<<<1, kBlockDim, 0, dev_ctx.stream()>>>(data, num, scale);
}
};
template <typename T, typename IndType>
struct Argmax<platform::CUDADeviceContext, T, IndType> {
void operator()(const framework::ExecutionContext& ctx, const Tensor& input,
Tensor* out_idx, Tensor* out, int axis) {
framework::DDim input_dims = input.dims();
int64_t numel = input.numel();
int64_t groups = numel / input_dims[axis];
int64_t pre = 1;
int64_t post = 1;
int64_t n = input_dims[axis];
for (int i = 0; i < axis; i++) {
pre *= input_dims[i];
}
for (int i = axis + 1; i < input_dims.size(); i++) {
post *= input_dims[i];
}
const auto& dev_ctx = ctx.cuda_device_context();
auto cu_stream = dev_ctx.stream();
int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize().x;
int64_t height = pre * post;
int64_t width = n;
int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;
const T* in_data = input.data<T>();
IndType* out_idx_data = out_idx->data<IndType>();
T* out_data = out->data<T>();
switch (ComputeBlockSize(width)) {
FIXED_BLOCK_DIM_CASE(
ArgmaxCUDAKernel<T, IndType,
kBlockDim><<<grid_size, kBlockDim, 0, cu_stream>>>(
height, width, post, in_data, out_idx_data, out_data));
}
}
};
template <typename T>
struct GetMaxValue<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const Tensor& input, T* max_value) {
Tensor out_data;
out_data.Resize(framework::make_ddim({1}));
out_data.mutable_data<T>(platform::CUDAPlace());
switch (ComputeBlockSize(input.numel())) {
FIXED_BLOCK_DIM_CASE(
ArgmaxCUDAKernel<T, T,
kBlockDim><<<1, kBlockDim, 0, dev_ctx.stream()>>>(
1, input.numel(), 1, input.data<int64_t>(), nullptr,
out_data.data<int64_t>()));
}
Tensor max_value_tensor;
framework::TensorCopy(out_data, platform::CPUPlace(), &max_value_tensor);
*max_value = max_value_tensor.data<T>()[0];
}
};
template <typename T, typename IndexT>
struct Gather<platform::CUDADeviceContext, T, IndexT> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
GPUGather<T, IndexT>(ctx, src, index, output);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace platform = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
viterbi_decode,
ops::ViterbiDecodeKernel<platform::CUDADeviceContext, float>,
ops::ViterbiDecodeKernel<platform::CUDADeviceContext, double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/unique_op.h"
#ifdef PADDLE_WITH_MKLML
#include <omp.h>
#endif
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T, typename IndType>
struct Argmax {
void operator()(const framework::ExecutionContext& ctx, const Tensor& input,
Tensor* out_idx, Tensor* out, int axis) {
framework::DDim input_dims = input.dims();
int64_t pre = 1;
int64_t post = 1;
int64_t n = input_dims[axis];
for (int i = 0; i < axis; i++) {
pre *= input_dims[i];
}
for (int i = axis + 1; i < input_dims.size(); i++) {
post *= input_dims[i];
}
int64_t height = pre * post;
int64_t width = n;
const T* in_data = input.data<T>();
IndType* out_idx_data = out_idx->data<IndType>();
T* out_data = out->data<T>();
// Reduce
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < height; ++i) {
int64_t h = i / post;
int64_t w = i % post;
IndType max_idx = -1;
T max_value = (std::numeric_limits<T>::lowest)(); // for windows compile
for (int64_t j = 0; j < width; ++j) {
if (in_data[h * width * post + j * post + w] > max_value) {
max_value = in_data[h * width * post + j * post + w];
max_idx = j;
}
}
out_data[i] = max_value;
out_idx_data[i] = max_idx;
}
}
};
template <typename DeviceContext>
struct ARange {
void operator()(const DeviceContext& dev_ctx, int64_t* data, int end,
int64_t scale) {
for (int i = 0; i < end; ++i) {
data[i] = i * scale;
}
}
};
template <typename DeviceContext, typename T>
struct GetMaxValue {
void operator()(const DeviceContext& dev_ctx, const Tensor& input,
T* max_value) {
auto input_ptr = input.data<T>();
auto num = input.numel();
*max_value = *std::max_element(input_ptr, input_ptr + num);
}
};
template <typename DeviceContext, typename T, typename IndexT = int>
struct Gather {
void operator()(const DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
CPUGather<T, IndexT>(ctx, src, index, output);
}
};
template <typename T, typename Functor, typename OutT = T>
void SameDimsBinaryOP(const Tensor& lhs, const Tensor& rhs, Tensor* out) {
const T* lhs_ptr = lhs.data<T>();
const T* rhs_ptr = rhs.data<T>();
OutT* out_ptr = out->data<OutT>();
Functor functor;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < out->numel(); ++i) {
out_ptr[i] = functor(lhs_ptr[i], rhs_ptr[i]);
}
}
template <typename DeviceContext, template <typename T> typename CompareFunctor,
typename T>
struct GetMask {
void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* mask) {
SameDimsBinaryOP<int64_t, CompareFunctor<int64_t>, T>(lhs, rhs, mask);
}
};
template <bool is_multi_threads>
struct GetInputIndex {
void operator()(const std::vector<int>& lhs_dims,
const std::vector<int>& rhs_dims,
const std::vector<int>& output_dims,
const std::vector<int>& lhs_strides,
const std::vector<int>& rhs_strides,
const std::vector<int>& output_strides, int output_idx,
int* index_array, int* lhs_idx, int* rhs_idx) {
int out_dims_size = output_strides.size();
for (int j = 0; j < out_dims_size; ++j) {
int curr_idx = output_idx / output_strides[j];
output_idx %= output_strides[j];
*lhs_idx += (lhs_dims[j] > 1) ? curr_idx * lhs_strides[j] : 0;
*rhs_idx += (rhs_dims[j] > 1) ? curr_idx * rhs_strides[j] : 0;
}
}
};
template <>
struct GetInputIndex<false> {
void operator()(const std::vector<int>& lhs_dims,
const std::vector<int>& rhs_dims,
const std::vector<int>& output_dims,
const std::vector<int>& lhs_strides,
const std::vector<int>& rhs_strides,
const std::vector<int>& output_strides, int output_idx,
int* index_array, int* lhs_idx, int* rhs_idx) {
int out_dims_size = output_strides.size();
*lhs_idx = GetElementwiseIndex(lhs_dims.data(), out_dims_size, index_array);
*rhs_idx = GetElementwiseIndex(rhs_dims.data(), out_dims_size, index_array);
UpdateElementwiseIndexArray(output_dims.data(), out_dims_size, index_array);
}
};
template <typename T, typename Functor, bool is_multi_threads = false>
void SimpleBroadcastBinaryOP(const Tensor& lhs, const Tensor& rhs,
Tensor* out) {
const T* lhs_ptr = lhs.data<T>();
const T* rhs_ptr = rhs.data<T>();
T* out_ptr = out->data<T>();
int out_size = static_cast<int>(out->dims().size());
std::vector<int> out_dims(out_size);
std::vector<int> lhs_dims(out_size);
std::vector<int> rhs_dims(out_size);
std::copy(lhs.dims().Get(), lhs.dims().Get() + out_size, lhs_dims.data());
std::copy(rhs.dims().Get(), rhs.dims().Get() + out_size, rhs_dims.data());
std::copy(out->dims().Get(), out->dims().Get() + out_size, out_dims.data());
std::vector<int> output_strides(out_size, 1);
std::vector<int> lhs_strides(out_size, 1);
std::vector<int> rhs_strides(out_size, 1);
std::vector<int> index_array(out_size, 0);
// calculate strides
for (int i = out_size - 2; i >= 0; --i) {
output_strides[i] = output_strides[i + 1] * out_dims[i + 1];
lhs_strides[i] = lhs_strides[i + 1] * lhs_dims[i + 1];
rhs_strides[i] = rhs_strides[i + 1] * rhs_dims[i + 1];
}
Functor functor;
GetInputIndex<is_multi_threads> get_input_index;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < out->numel(); ++i) {
int lhs_idx = 0;
int rhs_idx = 0;
get_input_index(lhs_dims, rhs_dims, out_dims, lhs_strides, rhs_strides,
output_strides, i, index_array.data(), &lhs_idx, &rhs_idx);
out_ptr[i] = functor(lhs_ptr[lhs_idx], rhs_ptr[rhs_idx]);
}
}
template <typename DeviceContext, template <typename T> typename BinaryFunctor,
typename T>
struct BinaryOperation {
void operator()(const DeviceContext& dev_ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* output) {
if (lhs.dims() == rhs.dims()) {
SameDimsBinaryOP<T, BinaryFunctor<T>>(lhs, rhs, output);
} else {
bool is_multi_threads = false;
#ifdef PADDLE_WITH_MKLML
if (omp_get_max_threads() > 1) {
is_multi_threads = true;
}
#endif
if (is_multi_threads) {
SimpleBroadcastBinaryOP<T, BinaryFunctor<T>, true>(lhs, rhs, output);
} else {
SimpleBroadcastBinaryOP<T, BinaryFunctor<T>, false>(lhs, rhs, output);
}
}
}
};
class TensorBuffer {
public:
explicit TensorBuffer(const LoDTensor& in) : buffer_(in), offset_(0) {
buffer_.Resize({buffer_.numel()});
}
Tensor GetBufferBlock(std::initializer_list<int64_t> shape) {
int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
Tensor block = buffer_.Slice(offset_, offset_ + size);
offset_ += size;
block.Resize(shape);
return block;
}
private:
LoDTensor buffer_; // need to resize 1-D Tensor
int offset_;
};
template <typename DeviceContext, typename T>
class ViterbiDecodeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
bool include_bos_eos_tag = ctx.Attr<bool>("include_bos_eos_tag");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto curr_place = ctx.GetPlace();
auto* input = ctx.Input<Tensor>("Input");
auto batch_size = static_cast<int>(input->dims()[0]);
auto seq_len = static_cast<int>(input->dims()[1]);
auto n_labels = static_cast<int>(input->dims()[2]);
math::SetConstant<DeviceContext, T> float_functor;
math::SetConstant<DeviceContext, int64_t> int_functor;
std::vector<Tensor> historys;
// We create tensor buffer in order to avoid allocating memory frequently
// 10 means allocate 10*batch_size bytes memory, such as int_mask, zero...
int buffer_size = batch_size * (n_labels + 1) * seq_len + 10 * batch_size;
LoDTensor int_buffer;
int_buffer.Resize(framework::make_ddim({buffer_size}));
int_buffer.mutable_data<int64_t>(ctx.GetPlace());
TensorBuffer int_tensor_buffer(int_buffer);
// create float tensor buffer
// 10 means allocate 10*batch_size*n_labels bytes, such as alpha, alpha_max
buffer_size = batch_size * (seq_len + 10) * n_labels +
(batch_size + 2) * n_labels * n_labels;
LoDTensor float_buffer;
float_buffer.Resize(framework::make_ddim({buffer_size}));
float_buffer.mutable_data<T>(ctx.GetPlace());
TensorBuffer float_tensor_buffer(float_buffer);
auto* length = ctx.Input<Tensor>("Length");
Tensor left_length = int_tensor_buffer.GetBufferBlock({batch_size, 1});
framework::TensorCopy(*length, curr_place, dev_ctx, &left_length);
int64_t max_seq_len = 0;
GetMaxValue<DeviceContext, int64_t> get_max_value;
get_max_value(dev_ctx, left_length, &max_seq_len);
auto* scores = ctx.Output<Tensor>("Scores");
scores->mutable_data<T>(curr_place);
auto* path = ctx.Output<Tensor>("Path");
path->Resize({batch_size, max_seq_len});
path->mutable_data<int64_t>(curr_place);
Tensor tpath = int_tensor_buffer.GetBufferBlock({max_seq_len, batch_size});
auto batch_path = Unbind(tpath);
for (auto it = batch_path.begin(); it != batch_path.end(); ++it) {
it->Resize({batch_size});
}
// create and init required tensor
Tensor input_exp =
float_tensor_buffer.GetBufferBlock({seq_len, batch_size, n_labels});
TransCompute<DeviceContext, T>(3, dev_ctx, *input, &input_exp, {1, 0, 2});
auto* transition = ctx.Input<Tensor>("Transition");
Tensor trans_exp = float_tensor_buffer.GetBufferBlock({n_labels, n_labels});
framework::TensorCopy(*transition, curr_place, dev_ctx, &trans_exp);
trans_exp.Resize({1, n_labels, n_labels});
Tensor alpha = float_tensor_buffer.GetBufferBlock({batch_size, n_labels});
Tensor zero = int_tensor_buffer.GetBufferBlock({batch_size, 1});
int_functor(dev_ctx, &zero, 0);
Tensor one = int_tensor_buffer.GetBufferBlock({batch_size, 1});
int_functor(dev_ctx, &one, 1);
Tensor float_one = float_tensor_buffer.GetBufferBlock({batch_size, 1});
float_functor(dev_ctx, &float_one, static_cast<T>(1.0));
Tensor alpha_trn_sum =
float_tensor_buffer.GetBufferBlock({batch_size, n_labels, n_labels});
Tensor alpha_max =
float_tensor_buffer.GetBufferBlock({batch_size, n_labels});
Tensor alpha_argmax =
int_tensor_buffer.GetBufferBlock({seq_len, batch_size, n_labels});
auto alpha_argmax_unbind = Unbind(alpha_argmax);
Tensor alpha_nxt =
float_tensor_buffer.GetBufferBlock({batch_size, n_labels});
Tensor int_mask = int_tensor_buffer.GetBufferBlock({batch_size});
Tensor zero_len_mask = int_tensor_buffer.GetBufferBlock({batch_size});
Tensor float_mask = float_tensor_buffer.GetBufferBlock({batch_size, 1});
Tensor stop_trans = float_tensor_buffer.GetBufferBlock({1, 1, n_labels});
Tensor start_trans = float_tensor_buffer.GetBufferBlock({1, 1, n_labels});
Tensor rest_trans =
float_tensor_buffer.GetBufferBlock({1, n_labels - 2, n_labels});
Tensor last_ids = int_tensor_buffer.GetBufferBlock({batch_size});
Tensor last_ids_tmp = int_tensor_buffer.GetBufferBlock({batch_size});
Tensor batch_offset = int_tensor_buffer.GetBufferBlock({batch_size});
Tensor gather_idx = int_tensor_buffer.GetBufferBlock({batch_size});
std::vector<const Tensor*> shape{&rest_trans, &stop_trans, &start_trans};
std::vector<Tensor*> outputs{&rest_trans, &stop_trans, &start_trans};
math::SplitFunctor<DeviceContext, T> split_functor;
split_functor(dev_ctx, trans_exp, shape, 1, &outputs);
stop_trans.Resize({1, n_labels});
start_trans.Resize({1, n_labels});
auto logit0 = input_exp.Slice(0, 1);
logit0.Resize({batch_size, n_labels});
BinaryOperation<DeviceContext, AddFunctor, T> AddFloat;
BinaryOperation<DeviceContext, AddFunctor, int64_t> AddInt;
BinaryOperation<DeviceContext, MulFunctor, T> MulFloat;
BinaryOperation<DeviceContext, MulFunctor, int64_t> MulInt;
BinaryOperation<DeviceContext, SubFunctor, T> SubFloat;
BinaryOperation<DeviceContext, SubFunctor, int64_t> SubInt;
if (include_bos_eos_tag) {
AddFloat(dev_ctx, logit0, start_trans, &alpha);
GetMask<DeviceContext, EqualFunctor, T>()(ctx, left_length, one,
&float_mask);
MulFloat(dev_ctx, stop_trans, float_mask, &alpha_nxt);
AddFloat(dev_ctx, alpha, alpha_nxt, &alpha);
} else {
alpha = logit0;
}
SubInt(dev_ctx, left_length, one, &left_length);
Argmax<DeviceContext, T, int64_t> argmax;
for (int64_t i = 1; i < max_seq_len; ++i) {
Tensor logit = input_exp.Slice(i, i + 1);
logit.Resize({batch_size, n_labels});
Tensor& alpha_exp = alpha.Resize({batch_size, n_labels, 1});
AddFloat(dev_ctx, alpha_exp, trans_exp, &alpha_trn_sum);
auto alpha_argmax_temp = alpha_argmax_unbind[i - 1];
alpha_argmax_temp.Resize({batch_size, n_labels});
argmax(ctx, alpha_trn_sum, &alpha_argmax_temp, &alpha_max, 1);
historys.emplace_back(alpha_argmax_temp);
AddFloat(dev_ctx, alpha_max, logit, &alpha_nxt);
alpha.Resize({batch_size, n_labels});
// mask = paddle.cast((left_length > 0), dtype='float32')
// alpha = mask * alpha_nxt + (1 - mask) * alpha
GetMask<DeviceContext, GreaterThanFunctor, T>()(ctx, left_length, zero,
&float_mask);
// alpha_nxt = mask * alpha_nxt
MulFloat(dev_ctx, alpha_nxt, float_mask, &alpha_nxt);
// inv_mask = 1 - mask
SubFloat(dev_ctx, float_one, float_mask, &float_mask);
// alpha = (1 - mask) * alpha
MulFloat(dev_ctx, alpha, float_mask, &alpha);
// alpha += alpha_nxt
AddFloat(dev_ctx, alpha, alpha_nxt, &alpha);
if (include_bos_eos_tag) {
GetMask<DeviceContext, EqualFunctor, T>()(ctx, left_length, one,
&float_mask);
// alpha += mask * trans_exp[:, self.stop_idx]
MulFloat(dev_ctx, stop_trans, float_mask, &alpha_nxt);
AddFloat(dev_ctx, alpha, alpha_nxt, &alpha);
}
SubInt(dev_ctx, left_length, one, &left_length);
}
argmax(ctx, alpha, &last_ids, scores, 1);
left_length.Resize({batch_size});
GetMask<DeviceContext, GreaterEqualFunctor, int64_t>()(ctx, left_length,
zero, &int_mask);
// last_ids_update = last_ids * tag_mask
int last_ids_index = 1;
int actual_len = (std::min)(seq_len, static_cast<int>(max_seq_len));
MulInt(dev_ctx, last_ids, int_mask,
&batch_path[actual_len - last_ids_index]);
// The algorithm below can refer to
// https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/layers/crf.py#L438
ARange<DeviceContext> arange;
arange(dev_ctx, batch_offset.data<int64_t>(), batch_size, n_labels);
Gather<DeviceContext, int64_t, int64_t> gather;
for (auto hist = historys.rbegin(); hist != historys.rend(); ++hist) {
++last_ids_index;
AddInt(dev_ctx, left_length, one, &left_length);
AddInt(dev_ctx, batch_offset, last_ids, &gather_idx);
Tensor& last_ids_update = batch_path[actual_len - last_ids_index];
hist->Resize({batch_size * n_labels});
gather(dev_ctx, *hist, gather_idx, &last_ids_update);
GetMask<DeviceContext, GreaterThanFunctor, int64_t>()(ctx, left_length,
zero, &int_mask);
MulInt(dev_ctx, last_ids_update, int_mask, &last_ids_update);
GetMask<DeviceContext, EqualFunctor, int64_t>()(ctx, left_length, zero,
&zero_len_mask);
MulInt(dev_ctx, last_ids, zero_len_mask, &last_ids_tmp);
SubInt(dev_ctx, one, zero_len_mask, &zero_len_mask);
MulInt(dev_ctx, last_ids_update, zero_len_mask, &last_ids_update);
AddInt(dev_ctx, last_ids_update, last_ids_tmp, &last_ids_update);
GetMask<DeviceContext, LessThanFunctor, int64_t>()(ctx, left_length, zero,
&int_mask);
MulInt(dev_ctx, last_ids, int_mask, &last_ids);
AddInt(dev_ctx, last_ids_update, last_ids, &last_ids);
}
TransCompute<DeviceContext, int64_t>(2, dev_ctx, tpath, path, {1, 0});
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import core
import unittest
import paddle
paddle.enable_static()
class Decoder(object):
def __init__(self, transitions, use_tag=True):
self.transitions = transitions
self.use_tag = use_tag
self.start_idx, self.stop_idx = -1, -2
def __call__(self, inputs, length):
bs, seq_len, n_label = inputs.shape
inputs_t = np.transpose(inputs, (1, 0, 2))
trans_exp = np.expand_dims(self.transitions, axis=0)
historys = []
left_length = np.array(length)
max_seq_len = np.amax(left_length)
left_length = np.expand_dims(left_length, 1)
alpha = np.full((bs, n_label), -1e4, dtype='float32') if self.use_tag \
else np.zeros((bs, n_label), dtype='float32')
alpha[:, -1] = 0
for i, logit in enumerate(inputs_t[:max_seq_len]):
if i == 0 and not self.use_tag:
alpha = logit
left_length = left_length - 1
continue
alpha_exp = np.expand_dims(alpha, 2)
alpha_trn_sum = alpha_exp + trans_exp
max_res = np.amax(alpha_trn_sum, 1), np.argmax(alpha_trn_sum, 1)
historys = historys + [max_res[1]] if i >= 1 else []
alpha_nxt = max_res[0] + logit
mask = (left_length > 0)
alpha = mask * alpha_nxt + (1 - mask) * alpha
if self.use_tag:
alpha += (left_length == 1) * trans_exp[:, self.stop_idx]
left_length = left_length - 1
scores, last_ids = np.amax(alpha, 1), np.argmax(alpha, 1)
left_length = left_length[:, 0]
last_ids_update = last_ids * (left_length >= 0)
batch_path = [last_ids_update]
batch_offset = np.arange(bs) * n_label
for hist in reversed(historys):
left_length = left_length + 1
gather_idx = batch_offset + last_ids
last_ids_update = np.take(hist, gather_idx) * (left_length > 0)
mask = (left_length == 0)
last_ids_update = last_ids_update * (1 - mask) + last_ids * mask
batch_path.insert(0, last_ids_update)
last_ids = last_ids_update + (left_length < 0) * last_ids
batch_path = np.stack(batch_path, 1)
return scores, batch_path
class TestViterbiOp(OpTest):
def set_attr(self):
self.dtype = "float32" if core.is_compiled_with_rocm() else "float64"
self.use_tag = True
self.bz, self.len, self.ntags = 4, 8, 10
def setUp(self):
self.op_type = "viterbi_decode"
self.set_attr()
bz, length, ntags = self.bz, self.len, self.ntags
self.input = np.random.randn(bz, length, ntags).astype(self.dtype)
self.trans = np.random.randn(ntags, ntags).astype(self.dtype)
self.length = np.random.randint(1, length + 1, [bz]).astype('int64')
decoder = Decoder(self.trans, self.use_tag)
scores, path = decoder(self.input, self.length)
self.inputs = {
'Input': self.input,
'Transition': self.trans,
'Length': self.length
}
self.attrs = {'include_bos_eos_tag': self.use_tag, }
self.outputs = {'Scores': scores, 'Path': path}
def test_output(self):
self.check_output()
class TestViterbiAPI(unittest.TestCase):
def set_attr(self):
self.use_tag = True
self.bz, self.len, self.ntags = 4, 8, 10
self.places = [fluid.CPUPlace(), fluid.CUDAPlace(0)] \
if core.is_compiled_with_cuda() else [fluid.CPUPlace()]
def setUp(self):
self.set_attr()
bz, length, ntags = self.bz, self.len, self.ntags
self.input = np.random.randn(bz, length, ntags).astype('float32')
self.transitions = np.random.randn(ntags, ntags).astype('float32')
self.length = np.random.randint(1, length + 1, [bz]).astype('int64')
decoder = Decoder(self.transitions, self.use_tag)
self.scores, self.path = decoder(self.input, self.length)
def check_static_result(self, place):
bz, length, ntags = self.bz, self.len, self.ntags
with fluid.program_guard(fluid.Program(), fluid.Program()):
Input = fluid.data(
name="Input", shape=[bz, length, ntags], dtype="float32")
Transition = fluid.data(
name="Transition", shape=[ntags, ntags], dtype="float32")
Length = fluid.data(name="Length", shape=[bz], dtype="int64")
decoder = paddle.text.ViterbiDecoder(Transition, self.use_tag)
score, path = decoder(Input, Length)
exe = fluid.Executor(place)
feed_list = {
"Input": self.input,
"Transition": self.transitions,
"Length": self.length
}
fetches = exe.run(feed=feed_list, fetch_list=[score, path])
np.testing.assert_allclose(fetches[0], self.scores, rtol=1e-5)
np.testing.assert_allclose(fetches[1], self.path)
def test_static_net(self):
for place in self.places:
self.check_static_result(place)
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .viterbi_decode import ViterbiDecoder, viterbi_decode
from .datasets import Conll05st # noqa: F401
from .datasets import Imdb # noqa: F401
from .datasets import Imikolov # noqa: F401
......@@ -20,7 +21,6 @@ from .datasets import UCIHousing # noqa: F401
from .datasets import WMT14 # noqa: F401
from .datasets import WMT16 # noqa: F401
__all__ = [ #noqa
'Conll05st',
'Imdb',
......@@ -28,5 +28,7 @@ __all__ = [ #noqa
'Movielens',
'UCIHousing',
'WMT14',
'WMT16'
'WMT16',
'ViterbiDecoder',
'viterbi_decode'
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..nn import Layer
from ..fluid.framework import core, in_dygraph_mode
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type
__all__ = ['viterbi_decode', 'ViterbiDecoder']
def viterbi_decode(potentials,
transition_params,
lengths,
include_bos_eos_tag=True,
name=None):
"""
Decode the highest scoring sequence of tags computed by transitions and potentials and get the viterbi path.
Args:
potentials (Tensor): The input tensor of unary emission. This is a 3-D
tensor with shape of [batch_size, sequence_length, num_tags]. The data type is float32 or float64.
transition_params (Tensor): The input tensor of transition matrix. This is a 2-D
tensor with shape of [num_tags, num_tags]. The data type is float32 or float64.
lengths (Tensor): The input tensor of length of each sequence. This is a 1-D tensor with shape of [batch_size]. The data type is int64.
include_bos_eos_tag (`bool`, optional): If set to True, the last row and the last column of transitions will be considered
as start tag, the second to last row and the second to last column of transitions will be considered as stop tag. Defaults to ``True``.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
scores(Tensor): The output tensor containing the score for the Viterbi sequence. The shape is [batch_size]
and the data type is float32 or float64.
paths(Tensor): The output tensor containing the highest scoring tag indices. The shape is [batch_size, sequence_length]
and the data type is int64.
Example:
.. code-block:: python
import paddle
paddle.seed(102)
batch_size, seq_len, num_tags = 2, 4, 3
emission = paddle.rand((batch_size, seq_len, num_tags), dtype='float32')
length = paddle.randint(1, seq_len + 1, [batch_size])
tags = paddle.randint(0, num_tags, [batch_size, seq_len])
transition = paddle.rand((num_tags, num_tags), dtype='float32')
scores, path = paddle.text.viterbi_decode(emission, transition, length, False) # scores: [3.37089300, 1.56825531], path: [[1, 0, 0], [1, 1, 0]]
"""
if in_dygraph_mode():
return core.ops.viterbi_decode(potentials, transition_params, lengths,
'include_bos_eos_tag',
include_bos_eos_tag)
check_variable_and_dtype(potentials, 'input', ['float32', 'float64'],
'viterbi_decode')
check_variable_and_dtype(transition_params, 'transitions',
['float32', 'float64'], 'viterbi_decode')
check_variable_and_dtype(lengths, 'length', 'int64', 'viterbi_decode')
check_type(include_bos_eos_tag, 'include_tag', bool, 'viterbi_decode')
helper = LayerHelper('viterbi_decode', **locals())
attrs = {'include_bos_eos_tag': include_bos_eos_tag}
scores = helper.create_variable_for_type_inference(potentials.dtype)
path = helper.create_variable_for_type_inference('int64')
helper.append_op(
type='viterbi_decode',
inputs={
'Input': potentials,
'Transition': transition_params,
'Length': lengths
},
outputs={'Scores': scores,
'Path': path},
attrs=attrs)
return scores, path
class ViterbiDecoder(Layer):
"""
Decode the highest scoring sequence of tags computed by transitions and potentials and get the viterbi path.
Args:
transitions (`Tensor`): The transition matrix. Its dtype is float32 and has a shape of `[num_tags, num_tags]`.
include_bos_eos_tag (`bool`, optional): If set to True, the last row and the last column of transitions will be considered
as start tag, the second to last row and the second to last column of transitions will be considered as stop tag. Defaults to ``True``.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Shape:
potentials (Tensor): The input tensor of unary emission. This is a 3-D tensor with shape of
[batch_size, sequence_length, num_tags]. The data type is float32 or float64.
lengths (Tensor): The input tensor of length of each sequence. This is a 1-D tensor with shape of
[batch_size]. The data type is int64.
Returns:
scores(Tensor): The output tensor containing the score for the Viterbi sequence. The shape is [batch_size]
and the data type is float32 or float64.
paths(Tensor): The output tensor containing the highest scoring tag indices. The shape is [batch_size, sequence_length]
and the data type is int64.
Example:
.. code-block:: python
import paddle
paddle.seed(102)
batch_size, seq_len, num_tags = 2, 4, 3
emission = paddle.rand((batch_size, seq_len, num_tags), dtype='float32')
length = paddle.randint(1, seq_len + 1, [batch_size])
tags = paddle.randint(0, num_tags, [batch_size, seq_len])
transition = paddle.rand((num_tags, num_tags), dtype='float32')
decoder = paddle.text.ViterbiDecoder(transition, include_bos_eos_tag=False)
scores, path = decoder(emission, length) # scores: [3.37089300, 1.56825531], path: [[1, 0, 0], [1, 1, 0]]
"""
def __init__(self, transitions, include_bos_eos_tag=True, name=None):
super(ViterbiDecoder, self).__init__()
self.transitions = transitions
self.include_bos_eos_tag = include_bos_eos_tag
self.name = name
def forward(self, potentials, lengths):
return viterbi_decode(potentials, self.transitions, lengths,
self.include_bos_eos_tag, self.name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册