“41c308f40a813403b94f3a3446aba2e5cc6bd6b4”上不存在“arch/arm/mach-exynos/Kconfig”
提交 7c6f2350 编写于 作者: Z zhoukunsheng 提交者: Tao Luo

support Tensor input for edit_distance op (#18162)

上级 85f5e9e2
...@@ -139,7 +139,7 @@ paddle.fluid.layers.sequence_slice (ArgSpec(args=['input', 'offset', 'length', ' ...@@ -139,7 +139,7 @@ paddle.fluid.layers.sequence_slice (ArgSpec(args=['input', 'offset', 'length', '
paddle.fluid.layers.dropout (ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer')), ('document', '558d13133596209190df9a624264f28f')) paddle.fluid.layers.dropout (ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer')), ('document', '558d13133596209190df9a624264f28f'))
paddle.fluid.layers.split (ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '78cf3a7323d1a7697658242e13f63759')) paddle.fluid.layers.split (ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '78cf3a7323d1a7697658242e13f63759'))
paddle.fluid.layers.ctc_greedy_decoder (ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '2bc3a59efa9d52b628a6255422d9f0e8')) paddle.fluid.layers.ctc_greedy_decoder (ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '2bc3a59efa9d52b628a6255422d9f0e8'))
paddle.fluid.layers.edit_distance (ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None)), ('document', 'f2c252aa2f83f8e503ffaf79668eaa28')) paddle.fluid.layers.edit_distance (ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens', 'input_length', 'label_length'], varargs=None, keywords=None, defaults=(True, None, None, None)), ('document', '77cbfb28cd2fc589f589c7013c5086cd'))
paddle.fluid.layers.l2_normalize (ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)), ('document', 'c1df110ea65998984f564c5c10abc54a')) paddle.fluid.layers.l2_normalize (ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)), ('document', 'c1df110ea65998984f564c5c10abc54a'))
paddle.fluid.layers.matmul (ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None)), ('document', 'fa2081f6e731bb9de7cd535ca07f523a')) paddle.fluid.layers.matmul (ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None)), ('document', 'fa2081f6e731bb9de7cd535ca07f523a'))
paddle.fluid.layers.topk (ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'e50940f3ce5a08cc477b72f517491bf3')) paddle.fluid.layers.topk (ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'e50940f3ce5a08cc477b72f517491bf3'))
......
...@@ -29,12 +29,30 @@ class EditDistanceOp : public framework::OperatorWithKernel { ...@@ -29,12 +29,30 @@ class EditDistanceOp : public framework::OperatorWithKernel {
"Output(SequenceNum) shouldn't be null."); "Output(SequenceNum) shouldn't be null.");
auto hyp_dims = ctx->GetInputDim("Hyps"); auto hyp_dims = ctx->GetInputDim("Hyps");
auto ref_dims = ctx->GetInputDim("Refs"); auto ref_dims = ctx->GetInputDim("Refs");
PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1,
"Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension " if (ctx->HasInput("HypsLength") && ctx->HasInput("RefsLength")) {
"equal to 1."); auto hyp_length_dims = ctx->GetInputDim("HypsLength");
PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1, auto ref_length_dims = ctx->GetInputDim("RefsLength");
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1."); PADDLE_ENFORCE(hyp_dims.size() == 2 && ref_dims.size() == 2 &&
hyp_dims[0] == ref_dims[0],
"Input(Hyps) and Input(Refs) must be 2-D Tensors with "
"identical first dimension");
PADDLE_ENFORCE(hyp_length_dims[0] == ref_length_dims[0] &&
hyp_length_dims[0] == hyp_dims[0],
"Input(HypsLength), Input(RefsLength) and Input(Hyps) "
"should have identical first dimension");
} else {
PADDLE_ENFORCE(
hyp_dims.size() == 2 && hyp_dims[1] == 1,
"Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1.");
PADDLE_ENFORCE(
ref_dims.size() == 2 && ref_dims[1] == 1,
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1.");
}
ctx->SetOutputDim("Out", ctx->GetInputDim("Refs")); ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
ctx->SetOutputDim("SequenceNum", {1}); ctx->SetOutputDim("SequenceNum", {1});
} }
...@@ -51,11 +69,21 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,11 +69,21 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Hyps", AddInput("Hyps",
"(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) " "2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
"dimension being 1. "
"The indices for hypothesis strings."); "The indices for hypothesis strings.");
AddInput("Refs", AddInput("Refs",
"(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) " "2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
"dimension being 1. "
"The indices for reference strings."); "The indices for reference strings.");
AddInput("HypsLength",
"1-D Tensor<int64_t>. "
"Sequence length for hyps when hyps is a tensor")
.AsDispensable();
AddInput("RefsLength",
"1-D Tensor<int64_t>. "
"Sequence length for refs when refs is a tensor")
.AsDispensable();
AddOutput("SequenceNum", "The sequence count of current batch"); AddOutput("SequenceNum", "The sequence count of current batch");
AddAttr<bool>("normalized", AddAttr<bool>("normalized",
"(bool, default false) Indicated whether to normalize " "(bool, default false) Indicated whether to normalize "
...@@ -78,12 +106,11 @@ insertion: ...@@ -78,12 +106,11 @@ insertion:
"kitten" -> "sitten" -> "sittin" -> "sitting" "kitten" -> "sitten" -> "sittin" -> "sitting"
Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total Input(Hyps) is a 2-D Tensor or a 2-D LoDTensor consisting of all the hypothesis strings.
number denoted by `batch_size`, and the separation is specified by the LoD information.
And the `batch_size` reference strings are arranged in order in the same way in the And the `batch_size` reference strings are arranged in order in the same way in the
LoDTensor Input(Refs). Input(Refs).
Output(Out) contains the `batch_size` results and each stands for the edit stance Output(Out) contains the `batch_size` results and each stands for the edit distance
for a pair of strings respectively. If Attr(normalized) is true, the edit distance for a pair of strings respectively. If Attr(normalized) is true, the edit distance
will be divided by the length of reference string. will be divided by the length of reference string.
)DOC"); )DOC");
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/edit_distance_op.h" #include "paddle/fluid/operators/edit_distance_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -76,20 +77,43 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> { ...@@ -76,20 +77,43 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs"); auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum"); auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum");
sequence_num->mutable_data<int64_t>(ctx.GetPlace()); sequence_num->mutable_data<int64_t>(ctx.GetPlace());
auto batch_size = x1_t->dims()[0];
auto normalized = ctx.Attr<bool>("normalized"); auto normalized = ctx.Attr<bool>("normalized");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
.stream(); .stream();
auto hyp_lod = x1_t->lod()[0]; framework::Vector<size_t> hyp_lod(batch_size + 1);
auto ref_lod = x2_t->lod()[0]; framework::Vector<size_t> ref_lod(batch_size + 1);
PADDLE_ENFORCE(
hyp_lod.size() == ref_lod.size(), bool use_length = ctx.HasInput("HypsLength");
"Input(Hyps) and Input(Refs) must have the same batch size.");
for (size_t i = 1; i < ref_lod.size(); ++i) { if (use_length) {
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], // build lod when using padding
"Reference string %d is empty.", i); auto* hyp_length = ctx.Input<framework::Tensor>("HypsLength");
auto* ref_length = ctx.Input<framework::Tensor>("RefsLength");
framework::Tensor hyp_length_cpu;
framework::Tensor ref_length_cpu;
framework::TensorCopy(*hyp_length, platform::CPUPlace(), &hyp_length_cpu);
framework::TensorCopy(*ref_length, platform::CPUPlace(), &ref_length_cpu);
for (auto i = 0; i < batch_size; i++) {
hyp_lod[i + 1] = hyp_lod[i] + hyp_length_cpu.data<int64_t>()[i];
ref_lod[i + 1] = ref_lod[i] + ref_length_cpu.data<int64_t>()[i];
}
} else {
hyp_lod = x1_t->lod()[0];
ref_lod = x2_t->lod()[0];
}
if (normalized) {
for (size_t i = 1; i < ref_lod.size(); ++i) {
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
"Reference string %d is empty.", i);
}
} }
const size_t num_strs = hyp_lod.size() - 1; const size_t num_strs = hyp_lod.size() - 1;
...@@ -108,10 +132,6 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> { ...@@ -108,10 +132,6 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
if (m == 0 || n == 0) { if (m == 0 || n == 0) {
distance = std::max(m, n); distance = std::max(m, n);
if (normalized) { if (normalized) {
PADDLE_ENFORCE(n > 0,
"The reference string (#%d) cannot be empty "
"when Attr(normalized) is enabled.",
n);
distance = distance / n; distance = distance / n;
} }
memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num, memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num,
...@@ -121,14 +141,17 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> { ...@@ -121,14 +141,17 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
dist_t.Resize({m + 1, n + 1}); dist_t.Resize({m + 1, n + 1});
dist_t.mutable_data<T>(ctx.GetPlace()); dist_t.mutable_data<T>(ctx.GetPlace());
auto dist = dist_t.data<T>(); auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int64_t>() + hyp_lod[num]; auto hyp_offset = use_length ? num * x1_t->dims()[1] : hyp_lod[num];
auto x2 = x2_t->data<int64_t>() + ref_lod[num]; auto ref_offset = use_length ? num * x2_t->dims()[1] : ref_lod[num];
auto x1 = x1_t->data<int64_t>() + hyp_offset;
auto x2 = x2_t->data<int64_t>() + ref_offset;
FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS, FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n);
FillFirstRow<T><<<1 + n / PADDLE_CUDA_NUM_THREADS, FillFirstRow<T><<<1 + n / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n);
// Compute the elements of distance matrix in the anti-diagonal diretion // Compute the elements of distance matrix in the anti-diagonal diretion
for (int64_t slice = 2; slice < m + n + 1; ++slice) { for (int64_t slice = 2; slice < m + n + 1; ++slice) {
int z_m = slice < m + 1 ? 0 : slice - m; int z_m = slice < m + 1 ? 0 : slice - m;
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,17 +30,37 @@ class EditDistanceKernel : public framework::OpKernel<T> { ...@@ -29,17 +30,37 @@ class EditDistanceKernel : public framework::OpKernel<T> {
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs"); auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum"); auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum");
int64_t* seq_num_data = sequence_num->mutable_data<int64_t>(ctx.GetPlace()); int64_t* seq_num_data = sequence_num->mutable_data<int64_t>(ctx.GetPlace());
auto batch_size = x1_t->dims()[0];
auto normalized = ctx.Attr<bool>("normalized"); auto normalized = ctx.Attr<bool>("normalized");
auto hyp_lod = x1_t->lod()[0]; framework::Vector<size_t> hyp_lod(batch_size + 1);
auto ref_lod = x2_t->lod()[0]; framework::Vector<size_t> ref_lod(batch_size + 1);
PADDLE_ENFORCE(
hyp_lod.size() == ref_lod.size(), bool use_length = ctx.HasInput("HypsLength");
"Input(Hyps) and Input(Refs) must have the same batch size.");
for (size_t i = 1; i < ref_lod.size(); ++i) { if (use_length) {
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], // build lod when using padding
"Reference string %d is empty.", i); auto hyp_length_ptr =
ctx.Input<framework::Tensor>("HypsLength")->data<int64_t>();
auto ref_length_ptr =
ctx.Input<framework::Tensor>("RefsLength")->data<int64_t>();
for (auto i = 0; i < batch_size; i++) {
hyp_lod[i + 1] = hyp_lod[i] + hyp_length_ptr[i];
ref_lod[i + 1] = ref_lod[i] + ref_length_ptr[i];
}
} else {
hyp_lod = x1_t->lod()[0];
ref_lod = x2_t->lod()[0];
}
if (normalized) {
for (size_t i = 1; i < ref_lod.size(); ++i) {
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
"Reference string %d is empty.", i);
}
} }
auto num_strs = hyp_lod.size() - 1; auto num_strs = hyp_lod.size() - 1;
*seq_num_data = static_cast<int64_t>(num_strs); *seq_num_data = static_cast<int64_t>(num_strs);
...@@ -62,8 +83,10 @@ class EditDistanceKernel : public framework::OpKernel<T> { ...@@ -62,8 +83,10 @@ class EditDistanceKernel : public framework::OpKernel<T> {
dist_t.Resize({m + 1, n + 1}); dist_t.Resize({m + 1, n + 1});
dist_t.mutable_data<T>(ctx.GetPlace()); dist_t.mutable_data<T>(ctx.GetPlace());
auto dist = dist_t.data<T>(); auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int64_t>() + hyp_lod[num]; auto hyp_offset = use_length ? num * x1_t->dims()[1] : hyp_lod[num];
auto x2 = x2_t->data<int64_t>() + ref_lod[num]; auto ref_offset = use_length ? num * x2_t->dims()[1] : ref_lod[num];
auto x1 = x1_t->data<int64_t>() + hyp_offset;
auto x2 = x2_t->data<int64_t>() + ref_offset;
for (int64_t i = 0; i < m + 1; ++i) { for (int64_t i = 0; i < m + 1; ++i) {
dist[i * (n + 1)] = i; dist[i * (n + 1)] = i;
} }
......
...@@ -5353,7 +5353,12 @@ def topk(input, k, name=None): ...@@ -5353,7 +5353,12 @@ def topk(input, k, name=None):
return values, indices return values, indices
def edit_distance(input, label, normalized=True, ignored_tokens=None): def edit_distance(input,
label,
normalized=True,
ignored_tokens=None,
input_length=None,
label_length=None):
""" """
Edit distance operator computes the edit distances between a batch of Edit distance operator computes the edit distances between a batch of
hypothesis strings and their references. Edit distance, also called hypothesis strings and their references. Edit distance, also called
...@@ -5367,52 +5372,49 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None): ...@@ -5367,52 +5372,49 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None):
"kitten" -> "sitten" -> "sittin" -> "sitting" "kitten" -> "sitten" -> "sittin" -> "sitting"
The input is a LoDTensor consisting of all the hypothesis strings with The input is a LoDTensor/Tensor consisting of all the hypothesis strings with
the total number denoted by `batch_size`, and the separation is specified the total number denoted by `batch_size`, and the separation is specified
by the LoD information. And the `batch_size` reference strings are arranged by the LoD information or input_length. And the `batch_size` reference strings are arranged
in order in the same way in the input LoDTensor. in order in the same way as `input`.
The output contains the `batch_size` results and each stands for the edit The output contains the `batch_size` results and each stands for the edit
distance for a pair of strings respectively. If Attr(normalized) is true, distance for a pair of strings respectively. If Attr(normalized) is true,
the edit distance will be divided by the length of reference string. the edit distance will be divided by the length of reference string.
Args: Args:
input(Variable): The indices for hypothesis strings. input(Variable): The indices for hypothesis strings, it should have rank 2 and dtype int64.
label(Variable): The indices for reference strings. label(Variable): The indices for reference strings, it should have rank 2 and dtype int64.
normalized(bool, default True): Indicated whether to normalize the edit distance by normalized(bool, default True): Indicated whether to normalize the edit distance by
the length of reference string. the length of reference string.
ignored_tokens(list<int>, default None): Tokens that should be removed before ignored_tokens(list<int>, default None): Tokens that should be removed before
calculating edit distance. calculating edit distance.
name (str): The name of this layer. It is optional. input_length(Variable): The length for each sequence in `input` if it's of Tensor type, it should have shape `[batch_size]` and dtype int64.
label_length(Variable): The length for each sequence in `label` if it's of Tensor type, it should have shape `[batch_size]` and dtype int64.
Returns: Returns:
Variable: sequence-to-sequence edit distance in shape [batch_size, 1]. edit_distance_out(Variable): edit distance result in shape [batch_size, 1]. \n
sequence_num(Variable): sequence number in shape [].
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[1], dtype='int64')
y = fluid.layers.data(name='y', shape=[1], dtype='int64')
cost, _ = fluid.layers.edit_distance(input=x, label=y)
cpu = fluid.core.CPUPlace() # using LoDTensor
exe = fluid.Executor(cpu) x_lod = fluid.layers.data(name='x_lod', shape=[1], dtype='int64', lod_level=1)
exe.run(fluid.default_startup_program()) y_lod = fluid.layers.data(name='y_lod', shape=[1], dtype='int64', lod_level=1)
distance_lod, seq_num_lod = fluid.layers.edit_distance(input=x_lod, label=y_lod)
import numpy # using Tensor
x_ = numpy.random.randint(5, size=(2, 1)).astype('int64') x_seq_len = 5
y_ = numpy.random.randint(5, size=(2, 1)).astype('int64') y_seq_len = 6
x_pad = fluid.layers.data(name='x_pad', shape=[x_seq_len], dtype='int64')
print(x_) y_pad = fluid.layers.data(name='y_pad', shape=[y_seq_len], dtype='int64')
print(y_) x_len = fluid.layers.data(name='x_len', shape=[], dtype='int64')
y_len = fluid.layers.data(name='y_len', shape=[], dtype='int64')
x = fluid.create_lod_tensor(x_, [[2]], cpu) distance_pad, seq_num_pad = fluid.layers.edit_distance(input=x_pad, label=y_pad, input_length=x_len, label_length=y_len)
y = fluid.create_lod_tensor(y_, [[2]], cpu)
outs = exe.run(feed={'x':x, 'y':y}, fetch_list=[cost.name])
print(outs)
""" """
helper = LayerHelper("edit_distance", **locals()) helper = LayerHelper("edit_distance", **locals())
...@@ -5435,13 +5437,17 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None): ...@@ -5435,13 +5437,17 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None):
attrs={"tokens": ignored_tokens}) attrs={"tokens": ignored_tokens})
label = erased_label label = erased_label
this_inputs = {"Hyps": [input], "Refs": [label]}
if input_length and label_length:
this_inputs['HypsLength'] = [input_length]
this_inputs['RefsLength'] = [label_length]
# edit distance op # edit distance op
edit_distance_out = helper.create_variable_for_type_inference(dtype="int64") edit_distance_out = helper.create_variable_for_type_inference(dtype="int64")
sequence_num = helper.create_variable_for_type_inference(dtype="int64") sequence_num = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op( helper.append_op(
type="edit_distance", type="edit_distance",
inputs={"Hyps": [input], inputs=this_inputs,
"Refs": [label]},
outputs={"Out": [edit_distance_out], outputs={"Out": [edit_distance_out],
"SequenceNum": [sequence_num]}, "SequenceNum": [sequence_num]},
attrs={"normalized": normalized}) attrs={"normalized": normalized})
......
...@@ -89,27 +89,31 @@ class TestEditDistanceOpNormalizedCase0(OpTest): ...@@ -89,27 +89,31 @@ class TestEditDistanceOpNormalizedCase0(OpTest):
def reset_config(self): def reset_config(self):
pass pass
def post_config(self):
pass
def setUp(self): def setUp(self):
self.op_type = "edit_distance" self.op_type = "edit_distance"
normalized = True normalized = True
x1 = np.array([[10, 3, 6, 5, 8, 2]]).astype("int64") self.x1 = np.array([[10, 3, 6, 5, 8, 2]]).astype("int64")
x2 = np.array([[10, 4, 6, 7, 8]]).astype("int64") self.x2 = np.array([[10, 4, 6, 7, 8]]).astype("int64")
x1 = np.transpose(x1)
x2 = np.transpose(x2)
self.x1_lod = [3, 0, 3] self.x1_lod = [3, 0, 3]
self.x2_lod = [2, 1, 2] self.x2_lod = [2, 1, 2]
self.x1 = np.transpose(self.x1)
self.x2 = np.transpose(self.x2)
self.reset_config() self.reset_config()
num_strs = len(self.x1_lod) num_strs = len(self.x1_lod)
distance = np.zeros((num_strs, 1)).astype("float32") distance = np.zeros((num_strs, 1)).astype("float32")
sequence_num = np.array(3).astype("int64") sequence_num = np.array(num_strs).astype("int64")
x1_offset = 0 x1_offset = 0
x2_offset = 0 x2_offset = 0
for i in range(0, num_strs): for i in range(0, num_strs):
distance[i] = Levenshtein( distance[i] = Levenshtein(
hyp=x1[x1_offset:(x1_offset + self.x1_lod[i])], hyp=self.x1[x1_offset:(x1_offset + self.x1_lod[i])],
ref=x2[x2_offset:(x2_offset + self.x2_lod[i])]) ref=self.x2[x2_offset:(x2_offset + self.x2_lod[i])])
x1_offset += self.x1_lod[i] x1_offset += self.x1_lod[i]
x2_offset += self.x2_lod[i] x2_offset += self.x2_lod[i]
if normalized is True: if normalized is True:
...@@ -117,9 +121,14 @@ class TestEditDistanceOpNormalizedCase0(OpTest): ...@@ -117,9 +121,14 @@ class TestEditDistanceOpNormalizedCase0(OpTest):
distance[i] = distance[i] / len_ref distance[i] = distance[i] / len_ref
self.attrs = {'normalized': normalized} self.attrs = {'normalized': normalized}
self.inputs = {'Hyps': (x1, [self.x1_lod]), 'Refs': (x2, [self.x2_lod])} self.inputs = {
'Hyps': (self.x1, [self.x1_lod]),
'Refs': (self.x2, [self.x2_lod])
}
self.outputs = {'Out': distance, 'SequenceNum': sequence_num} self.outputs = {'Out': distance, 'SequenceNum': sequence_num}
self.post_config()
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -136,5 +145,43 @@ class TestEditDistanceOpNormalizedCase2(TestEditDistanceOpNormalizedCase0): ...@@ -136,5 +145,43 @@ class TestEditDistanceOpNormalizedCase2(TestEditDistanceOpNormalizedCase0):
self.x2_lod = [2, 2, 1] self.x2_lod = [2, 2, 1]
class TestEditDistanceOpNormalizedTensor(OpTest):
def reset_config(self):
self.x1 = np.array([[10, 3, 0, 0], [6, 5, 8, 2]], dtype=np.int64)
self.x2 = np.array([[10, 4, 0], [6, 7, 8]], dtype=np.int64)
self.x1_lod = np.array([2, 4], dtype=np.int64)
self.x2_lod = np.array([2, 3], dtype=np.int64)
def setUp(self):
self.op_type = "edit_distance"
normalized = True
self.reset_config()
num_strs = len(self.x1_lod)
distance = np.zeros((num_strs, 1)).astype("float32")
sequence_num = np.array(num_strs).astype("int64")
for i in range(0, num_strs):
distance[i] = Levenshtein(
hyp=self.x1[i][0:self.x1_lod[i]],
ref=self.x2[i][0:self.x2_lod[i]])
if normalized is True:
len_ref = self.x2_lod[i]
distance[i] = distance[i] / len_ref
self.attrs = {'normalized': normalized}
self.inputs = {
'Hyps': self.x1,
'Refs': self.x2,
'HypsLength': self.x1_lod,
'RefsLength': self.x2_lod
}
self.outputs = {'Out': distance, 'SequenceNum': sequence_num}
def test_check_output(self):
self.check_output()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册