diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst index 986026e0b9364076e777e4ba66c990fbecfa83d8..550b0e5b82609750ccd318eee889313cb2d7925a 100644 --- a/doc/api/v2/fluid/layers.rst +++ b/doc/api/v2/fluid/layers.rst @@ -500,6 +500,16 @@ swish .. autofunction:: paddle.v2.fluid.layers.swish :noindex: +edit_distance +--------------- +.. autofunction:: paddle.v2.fluid.layers.edit_distance_error + :noindex: + +ctc_greedy_decoder +--------------- +.. autofunction:: paddle.v2.fluid.layers.ctc_greedy_decoder + :noindex: + l2_normalize ------------ .. autofunction:: paddle.v2.fluid.layers.l2_normalize diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 6745a8da17723d663913a29f28e5ea9eedc0372a..15f7cb6b560590f55e276fde4900d2e3c0045fb8 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -156,6 +156,7 @@ op_library(parallel_do_op DEPS executor) # Regist multiple Kernel to pybind if (WITH_GPU) op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS vol2col) +op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function) op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling) op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc conv_transpose_cudnn_op.cu.cc DEPS vol2col) diff --git a/paddle/operators/edit_distance_op.cc b/paddle/operators/edit_distance_op.cc index 62a1fcebe7b7222ffceafc3ca2bc74e3998225f6..7e7dfc79eba5c9a75366415e5f4b3183653a5cc6 100644 --- a/paddle/operators/edit_distance_op.cc +++ b/paddle/operators/edit_distance_op.cc @@ -25,6 +25,8 @@ class EditDistanceOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasOutput("SequenceNum"), + "Output(SequenceNum) shouldn't be null."); auto hyp_dims = ctx->GetInputDim("Hyps"); auto ref_dims = ctx->GetInputDim("Refs"); PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1, @@ -34,6 +36,7 @@ class EditDistanceOp : public framework::OperatorWithKernel { "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension " "equal to 1."); ctx->SetOutputDim("Out", ctx->GetInputDim("Refs")); + ctx->SetOutputDim("SequenceNum", {1}); } protected: @@ -54,6 +57,7 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Refs", "(2-D LoDTensor, 2nd dim. equal to 1) " "The indices for reference strings."); + AddOutput("SequenceNum", "The sequence count of current batch"); AddAttr("normalized", "(bool, default false) Indicated whether to normalize " "the edit distance by the length of reference string.") diff --git a/paddle/operators/edit_distance_op.cu b/paddle/operators/edit_distance_op.cu index 338fd79bcc125b86c7764645c2fd8953d4477d2a..c3e116af086627d576c7a788caebd45667d70017 100644 --- a/paddle/operators/edit_distance_op.cu +++ b/paddle/operators/edit_distance_op.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" #include "paddle/platform/cuda_helper.h" #include "paddle/platform/gpu_info.h" @@ -72,6 +73,8 @@ class EditDistanceGPUKernel : public framework::OpKernel { auto* x1_t = ctx.Input("Hyps"); auto* x2_t = ctx.Input("Refs"); + auto* sequence_num = ctx.Output("SequenceNum"); + sequence_num->mutable_data(ctx.GetPlace()); auto normalized = ctx.Attr("normalized"); auto stream = reinterpret_cast( @@ -88,7 +91,11 @@ class EditDistanceGPUKernel : public framework::OpKernel { "Reference string %d is empty.", i); } - auto num_strs = hyp_lod.size() - 1; + const size_t num_strs = hyp_lod.size() - 1; + math::SetConstant set_constant; + set_constant(ctx.template device_context(), + sequence_num, static_cast(num_strs)); + out_t->Resize({static_cast(num_strs), 1}); out_t->mutable_data(ctx.GetPlace()); auto out = out_t->data(); diff --git a/paddle/operators/edit_distance_op.h b/paddle/operators/edit_distance_op.h index 4c5a29813ce39e42111c0ee5f3c16d5cefac4651..974299e604d22422e9024382e85c843ad831575d 100644 --- a/paddle/operators/edit_distance_op.h +++ b/paddle/operators/edit_distance_op.h @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" - namespace paddle { namespace operators { @@ -28,6 +27,8 @@ class EditDistanceKernel : public framework::OpKernel { auto* x1_t = ctx.Input("Hyps"); auto* x2_t = ctx.Input("Refs"); + auto* sequence_num = ctx.Output("SequenceNum"); + int64_t* seq_num_data = sequence_num->mutable_data(ctx.GetPlace()); auto normalized = ctx.Attr("normalized"); @@ -41,6 +42,7 @@ class EditDistanceKernel : public framework::OpKernel { "Reference string %d is empty.", i); } auto num_strs = hyp_lod.size() - 1; + *seq_num_data = static_cast(num_strs); out_t->Resize({static_cast(num_strs), 1}); out_t->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py index 396d56fc8b236d95d38517f3513521aa969e47be..2686a5bdfcf0e2d26ce8f58cceff1967b06d835b 100644 --- a/python/paddle/v2/fluid/evaluator.py +++ b/python/paddle/v2/fluid/evaluator.py @@ -205,3 +205,63 @@ class ChunkEvaluator(Evaluator): [precision], dtype='float32'), np.array( [recall], dtype='float32'), np.array( [f1_score], dtype='float32') + + +class EditDistance(Evaluator): + """ + Accumulate edit distance sum and sequence number from mini-batches and + compute the average edit_distance of all batches. + + Args: + input: the sequences predicted by network. + label: the target sequences which must has same sequence count + with input. + ignored_tokens(list of int): Tokens that should be removed before + calculating edit distance. + + Example: + + exe = fluid.executor(place) + distance_evaluator = fluid.Evaluator.EditDistance(input, label) + for epoch in PASS_NUM: + distance_evaluator.reset(exe) + for data in batches: + loss, sum_distance = exe.run(fetch_list=[cost] + distance_evaluator.metrics) + avg_distance = distance_evaluator.eval(exe) + pass_distance = distance_evaluator.eval(exe) + + In the above example: + 'sum_distance' is the sum of the batch's edit distance. + 'avg_distance' is the average of edit distance from the firt batch to the current batch. + 'pass_distance' is the average of edit distance from all the pass. + + """ + + def __init__(self, input, label, ignored_tokens=None, **kwargs): + super(EditDistance, self).__init__("edit_distance", **kwargs) + main_program = self.helper.main_program + if main_program.current_block().idx != 0: + raise ValueError("You can only invoke Evaluator in root block") + + self.total_error = self.create_state( + dtype='float32', shape=[1], suffix='total_error') + self.seq_num = self.create_state( + dtype='int64', shape=[1], suffix='seq_num') + error, seq_num = layers.edit_distance( + input=input, label=label, ignored_tokens=ignored_tokens) + #error = layers.cast(x=error, dtype='float32') + sum_error = layers.reduce_sum(error) + layers.sums(input=[self.total_error, sum_error], out=self.total_error) + layers.sums(input=[self.seq_num, seq_num], out=self.seq_num) + self.metrics.append(sum_error) + + def eval(self, executor, eval_program=None): + if eval_program is None: + eval_program = Program() + block = eval_program.current_block() + with program_guard(main_program=eval_program): + total_error = _clone_var_(block, self.total_error) + seq_num = _clone_var_(block, self.seq_num) + seq_num = layers.cast(x=seq_num, dtype='float32') + out = layers.elementwise_div(x=total_error, y=seq_num) + return np.array(executor.run(eval_program, fetch_list=[out])[0]) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index b1db16a83ecc917528be1defa781f659342edd77..f0345512f5133573f3f946878af1939ad1d7fcd3 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -28,7 +28,8 @@ __all__ = [ 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', 'sequence_first_step', 'sequence_last_step', 'dropout', 'split', - 'l2_normalize', 'matmul', 'warpctc', 'sequence_reshape' + 'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', 'matmul', 'warpctc', + 'sequence_reshape' ] @@ -1866,6 +1867,146 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): return out +def edit_distance(input, + label, + normalized=False, + ignored_tokens=None, + name=None): + """ + EditDistance operator computes the edit distances between a batch of hypothesis strings and their references. Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion: + + "kitten" -> "sitten" -> "sittin" -> "sitting" + + Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total number denoted by `batch_size`, and the separation is specified by the LoD information. And the `batch_size` reference strings are arranged in order in the same way in the LoDTensor Input(Refs). + + Output(Out) contains the `batch_size` results and each stands for the edit stance for a pair of strings respectively. If Attr(normalized) is true, the edit distance will be divided by the length of reference string. + + Args: + + input(Variable): The indices for hypothesis strings. + + label(Variable): The indices for reference strings. + + normalized(bool): Indicated whether to normalize the edit distance by the length of reference string. + + ignored_tokens(list of int): Tokens that should be removed before calculating edit distance. + + Returns: + Variable: sequence-to-sequence edit distance in shape [batch_size, 1]. + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[8], dtype='float32') + y = fluid.layers.data(name='y', shape=[7], dtype='float32') + + cost = fluid.layers.edit_distance(input=x,label=y) + """ + helper = LayerHelper("edit_distance", **locals()) + + # remove some tokens from input and labels + if ignored_tokens is not None and len(ignored_tokens) > 0: + erased_input = helper.create_tmp_variable(dtype="int64") + erased_label = helper.create_tmp_variable(dtype="int64") + + helper.append_op( + type="sequence_erase", + inputs={"X": [input]}, + outputs={"Out": [erased_input]}, + attrs={"tokens": ignored_tokens}) + input = erased_input + + helper.append_op( + type="sequence_erase", + inputs={"X": [label]}, + outputs={"Out": [erase_label]}, + attrs={"tokens": ignored_tokens}) + label = erased_label + + # edit distance op + edit_distance_out = helper.create_tmp_variable(dtype="int64") + sequence_num = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="edit_distance", + inputs={"Hyps": [input], + "Refs": [label]}, + outputs={"Out": [edit_distance_out], + "SequenceNum": [sequence_num]}, + attrs={"normalized": normalized}) + + return edit_distance_out, sequence_num + + +def ctc_greedy_decoder(input, blank, name=None): + """ + This op is used to decode sequences by greedy policy by below steps: + 1. Get the indexes of max value for each row in input. a.k.a. numpy.argmax(input, axis=0). + 2. For each sequence in result of step1, merge repeated tokens between two blanks and delete all blanks. + + A simple example as below: + + .. code-block:: text + + Given: + + input.data = [[0.6, 0.1, 0.3, 0.1], + [0.3, 0.2, 0.4, 0.1], + [0.1, 0.5, 0.1, 0.3], + [0.5, 0.1, 0.3, 0.1], + + [0.5, 0.1, 0.3, 0.1], + [0.2, 0.2, 0.2, 0.4], + [0.2, 0.2, 0.1, 0.5], + [0.5, 0.1, 0.3, 0.1]] + + input.lod = [[0, 4, 8]] + + Then: + + output.data = [[2], + [1], + [3]] + + output.lod = [[0, 2, 3]] + + Args: + + input(Variable): (LoDTensor), the probabilities of variable-length sequences, which is a 2-D Tensor with LoD information. It's shape is [Lp, num_classes + 1], where Lp is the sum of all input sequences' length and num_classes is the true number of classes. (not including the blank label). + + blank(int): the blank label index of Connectionist Temporal Classification (CTC) loss, which is in thehalf-opened interval [0, num_classes + 1). + + Returns: + Variable: CTC greedy decode result. + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[8], dtype='float32') + + cost = fluid.layers.ctc_greedy_decoder(input=x, blank=0) + """ + helper = LayerHelper("ctc_greedy_decoder", **locals()) + # top 1 op + topk_out = helper.create_tmp_variable(dtype=input.dtype) + topk_indices = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="top_k", + inputs={"X": [input]}, + outputs={"Out": [topk_out], + "Indices": [topk_indices]}, + attrs={"k": 1}) + + # ctc align op + ctc_out = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="ctc_align", + inputs={"Input": [topk_indices]}, + outputs={"Output": [ctc_out]}, + attrs={"merge_repeated": True, + "blank": blank}) + return ctc_out + + def warpctc(input, label, blank=0, norm_by_times=False, **kwargs): """ An operator integrating the open source Warp-CTC library @@ -1890,7 +2031,7 @@ def warpctc(input, label, blank=0, norm_by_times=False, **kwargs): Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). norm_by_times: (bool, default: false), whether to normalize - the gradients by the number of time-step,which is also the + the gradients by the number of time-step, which is also the sequence's length. There is no need to normalize the gradients if warpctc layer was follewed by a mean_op. diff --git a/python/paddle/v2/fluid/tests/test_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_edit_distance_op.py index 11cb85a151d1a4e213bcc52592d2f860f69b457f..bebdc5cba36fc96d31162d0d7d43e52064ca8e2d 100644 --- a/python/paddle/v2/fluid/tests/test_edit_distance_op.py +++ b/python/paddle/v2/fluid/tests/test_edit_distance_op.py @@ -61,6 +61,7 @@ class TestEditDistanceOp(OpTest): num_strs = len(x1_lod) - 1 distance = np.zeros((num_strs, 1)).astype("float32") + sequence_num = np.array(2).astype("int64") for i in range(0, num_strs): distance[i] = Levenshtein( hyp=x1[x1_lod[i]:x1_lod[i + 1]], @@ -70,7 +71,7 @@ class TestEditDistanceOp(OpTest): distance[i] = distance[i] / len_ref self.attrs = {'normalized': normalized} self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} - self.outputs = {'Out': distance} + self.outputs = {'Out': distance, 'SequenceNum': sequence_num} def test_check_output(self): self.check_output() @@ -89,6 +90,7 @@ class TestEditDistanceOpNormalized(OpTest): num_strs = len(x1_lod) - 1 distance = np.zeros((num_strs, 1)).astype("float32") + sequence_num = np.array(3).astype("int64") for i in range(0, num_strs): distance[i] = Levenshtein( hyp=x1[x1_lod[i]:x1_lod[i + 1]], @@ -98,7 +100,7 @@ class TestEditDistanceOpNormalized(OpTest): distance[i] = distance[i] / len_ref self.attrs = {'normalized': normalized} self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} - self.outputs = {'Out': distance} + self.outputs = {'Out': distance, 'SequenceNum': sequence_num} def test_check_output(self): self.check_output()