提交 7c6f2350 编写于 作者: Z zhoukunsheng 提交者: Tao Luo

support Tensor input for edit_distance op (#18162)

上级 85f5e9e2
......@@ -139,7 +139,7 @@ paddle.fluid.layers.sequence_slice (ArgSpec(args=['input', 'offset', 'length', '
paddle.fluid.layers.dropout (ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer')), ('document', '558d13133596209190df9a624264f28f'))
paddle.fluid.layers.split (ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '78cf3a7323d1a7697658242e13f63759'))
paddle.fluid.layers.ctc_greedy_decoder (ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '2bc3a59efa9d52b628a6255422d9f0e8'))
paddle.fluid.layers.edit_distance (ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None)), ('document', 'f2c252aa2f83f8e503ffaf79668eaa28'))
paddle.fluid.layers.edit_distance (ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens', 'input_length', 'label_length'], varargs=None, keywords=None, defaults=(True, None, None, None)), ('document', '77cbfb28cd2fc589f589c7013c5086cd'))
paddle.fluid.layers.l2_normalize (ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)), ('document', 'c1df110ea65998984f564c5c10abc54a'))
paddle.fluid.layers.matmul (ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None)), ('document', 'fa2081f6e731bb9de7cd535ca07f523a'))
paddle.fluid.layers.topk (ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'e50940f3ce5a08cc477b72f517491bf3'))
......
......@@ -29,12 +29,30 @@ class EditDistanceOp : public framework::OperatorWithKernel {
"Output(SequenceNum) shouldn't be null.");
auto hyp_dims = ctx->GetInputDim("Hyps");
auto ref_dims = ctx->GetInputDim("Refs");
PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1,
"Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1.");
PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1,
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1.");
if (ctx->HasInput("HypsLength") && ctx->HasInput("RefsLength")) {
auto hyp_length_dims = ctx->GetInputDim("HypsLength");
auto ref_length_dims = ctx->GetInputDim("RefsLength");
PADDLE_ENFORCE(hyp_dims.size() == 2 && ref_dims.size() == 2 &&
hyp_dims[0] == ref_dims[0],
"Input(Hyps) and Input(Refs) must be 2-D Tensors with "
"identical first dimension");
PADDLE_ENFORCE(hyp_length_dims[0] == ref_length_dims[0] &&
hyp_length_dims[0] == hyp_dims[0],
"Input(HypsLength), Input(RefsLength) and Input(Hyps) "
"should have identical first dimension");
} else {
PADDLE_ENFORCE(
hyp_dims.size() == 2 && hyp_dims[1] == 1,
"Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1.");
PADDLE_ENFORCE(
ref_dims.size() == 2 && ref_dims[1] == 1,
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1.");
}
ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
ctx->SetOutputDim("SequenceNum", {1});
}
......@@ -51,11 +69,21 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Hyps",
"(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
"dimension being 1. "
"The indices for hypothesis strings.");
AddInput("Refs",
"(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
"dimension being 1. "
"The indices for reference strings.");
AddInput("HypsLength",
"1-D Tensor<int64_t>. "
"Sequence length for hyps when hyps is a tensor")
.AsDispensable();
AddInput("RefsLength",
"1-D Tensor<int64_t>. "
"Sequence length for refs when refs is a tensor")
.AsDispensable();
AddOutput("SequenceNum", "The sequence count of current batch");
AddAttr<bool>("normalized",
"(bool, default false) Indicated whether to normalize "
......@@ -78,12 +106,11 @@ insertion:
"kitten" -> "sitten" -> "sittin" -> "sitting"
Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total
number denoted by `batch_size`, and the separation is specified by the LoD information.
Input(Hyps) is a 2-D Tensor or a 2-D LoDTensor consisting of all the hypothesis strings.
And the `batch_size` reference strings are arranged in order in the same way in the
LoDTensor Input(Refs).
Input(Refs).
Output(Out) contains the `batch_size` results and each stands for the edit stance
Output(Out) contains the `batch_size` results and each stands for the edit distance
for a pair of strings respectively. If Attr(normalized) is true, the edit distance
will be divided by the length of reference string.
)DOC");
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/edit_distance_op.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -76,20 +77,43 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum");
sequence_num->mutable_data<int64_t>(ctx.GetPlace());
auto batch_size = x1_t->dims()[0];
auto normalized = ctx.Attr<bool>("normalized");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
auto hyp_lod = x1_t->lod()[0];
auto ref_lod = x2_t->lod()[0];
PADDLE_ENFORCE(
hyp_lod.size() == ref_lod.size(),
"Input(Hyps) and Input(Refs) must have the same batch size.");
for (size_t i = 1; i < ref_lod.size(); ++i) {
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
"Reference string %d is empty.", i);
framework::Vector<size_t> hyp_lod(batch_size + 1);
framework::Vector<size_t> ref_lod(batch_size + 1);
bool use_length = ctx.HasInput("HypsLength");
if (use_length) {
// build lod when using padding
auto* hyp_length = ctx.Input<framework::Tensor>("HypsLength");
auto* ref_length = ctx.Input<framework::Tensor>("RefsLength");
framework::Tensor hyp_length_cpu;
framework::Tensor ref_length_cpu;
framework::TensorCopy(*hyp_length, platform::CPUPlace(), &hyp_length_cpu);
framework::TensorCopy(*ref_length, platform::CPUPlace(), &ref_length_cpu);
for (auto i = 0; i < batch_size; i++) {
hyp_lod[i + 1] = hyp_lod[i] + hyp_length_cpu.data<int64_t>()[i];
ref_lod[i + 1] = ref_lod[i] + ref_length_cpu.data<int64_t>()[i];
}
} else {
hyp_lod = x1_t->lod()[0];
ref_lod = x2_t->lod()[0];
}
if (normalized) {
for (size_t i = 1; i < ref_lod.size(); ++i) {
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
"Reference string %d is empty.", i);
}
}
const size_t num_strs = hyp_lod.size() - 1;
......@@ -108,10 +132,6 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
if (m == 0 || n == 0) {
distance = std::max(m, n);
if (normalized) {
PADDLE_ENFORCE(n > 0,
"The reference string (#%d) cannot be empty "
"when Attr(normalized) is enabled.",
n);
distance = distance / n;
}
memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num,
......@@ -121,14 +141,17 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
dist_t.Resize({m + 1, n + 1});
dist_t.mutable_data<T>(ctx.GetPlace());
auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int64_t>() + hyp_lod[num];
auto x2 = x2_t->data<int64_t>() + ref_lod[num];
auto hyp_offset = use_length ? num * x1_t->dims()[1] : hyp_lod[num];
auto ref_offset = use_length ? num * x2_t->dims()[1] : ref_lod[num];
auto x1 = x1_t->data<int64_t>() + hyp_offset;
auto x2 = x2_t->data<int64_t>() + ref_offset;
FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n);
FillFirstRow<T><<<1 + n / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n);
// Compute the elements of distance matrix in the anti-diagonal diretion
for (int64_t slice = 2; slice < m + n + 1; ++slice) {
int z_m = slice < m + 1 ? 0 : slice - m;
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -29,17 +30,37 @@ class EditDistanceKernel : public framework::OpKernel<T> {
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum");
int64_t* seq_num_data = sequence_num->mutable_data<int64_t>(ctx.GetPlace());
auto batch_size = x1_t->dims()[0];
auto normalized = ctx.Attr<bool>("normalized");
auto hyp_lod = x1_t->lod()[0];
auto ref_lod = x2_t->lod()[0];
PADDLE_ENFORCE(
hyp_lod.size() == ref_lod.size(),
"Input(Hyps) and Input(Refs) must have the same batch size.");
for (size_t i = 1; i < ref_lod.size(); ++i) {
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
"Reference string %d is empty.", i);
framework::Vector<size_t> hyp_lod(batch_size + 1);
framework::Vector<size_t> ref_lod(batch_size + 1);
bool use_length = ctx.HasInput("HypsLength");
if (use_length) {
// build lod when using padding
auto hyp_length_ptr =
ctx.Input<framework::Tensor>("HypsLength")->data<int64_t>();
auto ref_length_ptr =
ctx.Input<framework::Tensor>("RefsLength")->data<int64_t>();
for (auto i = 0; i < batch_size; i++) {
hyp_lod[i + 1] = hyp_lod[i] + hyp_length_ptr[i];
ref_lod[i + 1] = ref_lod[i] + ref_length_ptr[i];
}
} else {
hyp_lod = x1_t->lod()[0];
ref_lod = x2_t->lod()[0];
}
if (normalized) {
for (size_t i = 1; i < ref_lod.size(); ++i) {
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
"Reference string %d is empty.", i);
}
}
auto num_strs = hyp_lod.size() - 1;
*seq_num_data = static_cast<int64_t>(num_strs);
......@@ -62,8 +83,10 @@ class EditDistanceKernel : public framework::OpKernel<T> {
dist_t.Resize({m + 1, n + 1});
dist_t.mutable_data<T>(ctx.GetPlace());
auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int64_t>() + hyp_lod[num];
auto x2 = x2_t->data<int64_t>() + ref_lod[num];
auto hyp_offset = use_length ? num * x1_t->dims()[1] : hyp_lod[num];
auto ref_offset = use_length ? num * x2_t->dims()[1] : ref_lod[num];
auto x1 = x1_t->data<int64_t>() + hyp_offset;
auto x2 = x2_t->data<int64_t>() + ref_offset;
for (int64_t i = 0; i < m + 1; ++i) {
dist[i * (n + 1)] = i;
}
......
......@@ -5353,7 +5353,12 @@ def topk(input, k, name=None):
return values, indices
def edit_distance(input, label, normalized=True, ignored_tokens=None):
def edit_distance(input,
label,
normalized=True,
ignored_tokens=None,
input_length=None,
label_length=None):
"""
Edit distance operator computes the edit distances between a batch of
hypothesis strings and their references. Edit distance, also called
......@@ -5367,52 +5372,49 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None):
"kitten" -> "sitten" -> "sittin" -> "sitting"
The input is a LoDTensor consisting of all the hypothesis strings with
The input is a LoDTensor/Tensor consisting of all the hypothesis strings with
the total number denoted by `batch_size`, and the separation is specified
by the LoD information. And the `batch_size` reference strings are arranged
in order in the same way in the input LoDTensor.
by the LoD information or input_length. And the `batch_size` reference strings are arranged
in order in the same way as `input`.
The output contains the `batch_size` results and each stands for the edit
distance for a pair of strings respectively. If Attr(normalized) is true,
the edit distance will be divided by the length of reference string.
Args:
input(Variable): The indices for hypothesis strings.
label(Variable): The indices for reference strings.
input(Variable): The indices for hypothesis strings, it should have rank 2 and dtype int64.
label(Variable): The indices for reference strings, it should have rank 2 and dtype int64.
normalized(bool, default True): Indicated whether to normalize the edit distance by
the length of reference string.
ignored_tokens(list<int>, default None): Tokens that should be removed before
calculating edit distance.
name (str): The name of this layer. It is optional.
input_length(Variable): The length for each sequence in `input` if it's of Tensor type, it should have shape `[batch_size]` and dtype int64.
label_length(Variable): The length for each sequence in `label` if it's of Tensor type, it should have shape `[batch_size]` and dtype int64.
Returns:
Variable: sequence-to-sequence edit distance in shape [batch_size, 1].
edit_distance_out(Variable): edit distance result in shape [batch_size, 1]. \n
sequence_num(Variable): sequence number in shape [].
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[1], dtype='int64')
y = fluid.layers.data(name='y', shape=[1], dtype='int64')
cost, _ = fluid.layers.edit_distance(input=x, label=y)
cpu = fluid.core.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(fluid.default_startup_program())
# using LoDTensor
x_lod = fluid.layers.data(name='x_lod', shape=[1], dtype='int64', lod_level=1)
y_lod = fluid.layers.data(name='y_lod', shape=[1], dtype='int64', lod_level=1)
distance_lod, seq_num_lod = fluid.layers.edit_distance(input=x_lod, label=y_lod)
import numpy
x_ = numpy.random.randint(5, size=(2, 1)).astype('int64')
y_ = numpy.random.randint(5, size=(2, 1)).astype('int64')
print(x_)
print(y_)
x = fluid.create_lod_tensor(x_, [[2]], cpu)
y = fluid.create_lod_tensor(y_, [[2]], cpu)
# using Tensor
x_seq_len = 5
y_seq_len = 6
x_pad = fluid.layers.data(name='x_pad', shape=[x_seq_len], dtype='int64')
y_pad = fluid.layers.data(name='y_pad', shape=[y_seq_len], dtype='int64')
x_len = fluid.layers.data(name='x_len', shape=[], dtype='int64')
y_len = fluid.layers.data(name='y_len', shape=[], dtype='int64')
distance_pad, seq_num_pad = fluid.layers.edit_distance(input=x_pad, label=y_pad, input_length=x_len, label_length=y_len)
outs = exe.run(feed={'x':x, 'y':y}, fetch_list=[cost.name])
print(outs)
"""
helper = LayerHelper("edit_distance", **locals())
......@@ -5435,13 +5437,17 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None):
attrs={"tokens": ignored_tokens})
label = erased_label
this_inputs = {"Hyps": [input], "Refs": [label]}
if input_length and label_length:
this_inputs['HypsLength'] = [input_length]
this_inputs['RefsLength'] = [label_length]
# edit distance op
edit_distance_out = helper.create_variable_for_type_inference(dtype="int64")
sequence_num = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="edit_distance",
inputs={"Hyps": [input],
"Refs": [label]},
inputs=this_inputs,
outputs={"Out": [edit_distance_out],
"SequenceNum": [sequence_num]},
attrs={"normalized": normalized})
......
......@@ -89,27 +89,31 @@ class TestEditDistanceOpNormalizedCase0(OpTest):
def reset_config(self):
pass
def post_config(self):
pass
def setUp(self):
self.op_type = "edit_distance"
normalized = True
x1 = np.array([[10, 3, 6, 5, 8, 2]]).astype("int64")
x2 = np.array([[10, 4, 6, 7, 8]]).astype("int64")
x1 = np.transpose(x1)
x2 = np.transpose(x2)
self.x1 = np.array([[10, 3, 6, 5, 8, 2]]).astype("int64")
self.x2 = np.array([[10, 4, 6, 7, 8]]).astype("int64")
self.x1_lod = [3, 0, 3]
self.x2_lod = [2, 1, 2]
self.x1 = np.transpose(self.x1)
self.x2 = np.transpose(self.x2)
self.reset_config()
num_strs = len(self.x1_lod)
distance = np.zeros((num_strs, 1)).astype("float32")
sequence_num = np.array(3).astype("int64")
sequence_num = np.array(num_strs).astype("int64")
x1_offset = 0
x2_offset = 0
for i in range(0, num_strs):
distance[i] = Levenshtein(
hyp=x1[x1_offset:(x1_offset + self.x1_lod[i])],
ref=x2[x2_offset:(x2_offset + self.x2_lod[i])])
hyp=self.x1[x1_offset:(x1_offset + self.x1_lod[i])],
ref=self.x2[x2_offset:(x2_offset + self.x2_lod[i])])
x1_offset += self.x1_lod[i]
x2_offset += self.x2_lod[i]
if normalized is True:
......@@ -117,9 +121,14 @@ class TestEditDistanceOpNormalizedCase0(OpTest):
distance[i] = distance[i] / len_ref
self.attrs = {'normalized': normalized}
self.inputs = {'Hyps': (x1, [self.x1_lod]), 'Refs': (x2, [self.x2_lod])}
self.inputs = {
'Hyps': (self.x1, [self.x1_lod]),
'Refs': (self.x2, [self.x2_lod])
}
self.outputs = {'Out': distance, 'SequenceNum': sequence_num}
self.post_config()
def test_check_output(self):
self.check_output()
......@@ -136,5 +145,43 @@ class TestEditDistanceOpNormalizedCase2(TestEditDistanceOpNormalizedCase0):
self.x2_lod = [2, 2, 1]
class TestEditDistanceOpNormalizedTensor(OpTest):
def reset_config(self):
self.x1 = np.array([[10, 3, 0, 0], [6, 5, 8, 2]], dtype=np.int64)
self.x2 = np.array([[10, 4, 0], [6, 7, 8]], dtype=np.int64)
self.x1_lod = np.array([2, 4], dtype=np.int64)
self.x2_lod = np.array([2, 3], dtype=np.int64)
def setUp(self):
self.op_type = "edit_distance"
normalized = True
self.reset_config()
num_strs = len(self.x1_lod)
distance = np.zeros((num_strs, 1)).astype("float32")
sequence_num = np.array(num_strs).astype("int64")
for i in range(0, num_strs):
distance[i] = Levenshtein(
hyp=self.x1[i][0:self.x1_lod[i]],
ref=self.x2[i][0:self.x2_lod[i]])
if normalized is True:
len_ref = self.x2_lod[i]
distance[i] = distance[i] / len_ref
self.attrs = {'normalized': normalized}
self.inputs = {
'Hyps': self.x1,
'Refs': self.x2,
'HypsLength': self.x1_lod,
'RefsLength': self.x2_lod
}
self.outputs = {'Out': distance, 'SequenceNum': sequence_num}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册