提交 0250e54c 编写于 作者: Y Yibing Liu

Enable batch input in edit_distance_op

上级 2e49faca
...@@ -22,10 +22,18 @@ class EditDistanceOp : public framework::OperatorWithKernel { ...@@ -22,10 +22,18 @@ class EditDistanceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Hyp"), "Input(Hyp) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Ref"), "Input(Ref) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
ctx->SetOutputDim("Out", {1}); auto hyp_dims = ctx->GetInputDim("Hyps");
auto ref_dims = ctx->GetInputDim("Refs");
PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1,
"Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1.");
PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1,
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1.");
ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
} }
protected: protected:
...@@ -40,24 +48,23 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -40,24 +48,23 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Hyp", AddInput("Hyps",
"(2-D tensor with shape [M x 1]) The indices for " "(2-D LoDTensor, 2nd dim. equal to 1) "
"hypothesis string"); "The indices for hypothesis strings.");
AddInput("Ref", AddInput("Refs",
"(2-D tensor with shape [N x 1]) The indices " "(2-D LoDTensor, 2nd dim. equal to 1) "
"for reference string."); "The indices for reference strings.");
AddAttr<bool>("normalized", AddAttr<bool>("normalized",
"(bool, default false) Indicated whether " "(bool, default false) Indicated whether to normalize "
"normalize the Output(Out) by the length of reference " "the edit distance by the length of reference string.")
"string (Ref).")
.SetDefault(false); .SetDefault(false);
AddOutput("Out", AddOutput("Out",
"(2-D tensor with shape [1 x 1]) " "(2-D Tensor with shape [`batch_size` x 1]) "
"The output distance of EditDistance operator."); "The output edit distances of EditDistance operator.");
AddComment(R"DOC( AddComment(R"DOC(
EditDistance operator computes the edit distance of two sequences, one named EditDistance operator computes the edit distances between a batch of hypothesis
hypothesis with length M and another named reference with length N. strings and their references.
Edit distance, also called Levenshtein distance, measures how dissimilar two strings Edit distance, also called Levenshtein distance, measures how dissimilar two strings
are by counting the minimum number of operations to transform one string into anthor. are by counting the minimum number of operations to transform one string into anthor.
...@@ -68,8 +75,14 @@ insertion: ...@@ -68,8 +75,14 @@ insertion:
"kitten" -> "sitten" -> "sittin" -> "sitting" "kitten" -> "sitten" -> "sittin" -> "sitting"
If Attr(normalized) is true, the edit distance will be divided by the length of Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total
reference string N. number denoted by `batch_size`, and the separation is specified by the LoD information.
And the `batch_size` reference strings are arranged in order in the same way in the
LoDTensor Input(Refs).
Output(Out) contains the `batch_size` results and each stands for the edit stance
for a pair of strings respectively. If Attr(normalized) is true, the edit distance
will be divided by the length of reference string.
)DOC"); )DOC");
} }
}; };
......
...@@ -70,53 +70,71 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> { ...@@ -70,53 +70,71 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out"); auto* out_t = ctx.Output<framework::Tensor>("Out");
auto* x1_t = ctx.Input<framework::Tensor>("Hyp"); auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
auto* x2_t = ctx.Input<framework::Tensor>("Ref"); auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
out_t->mutable_data<T>(ctx.GetPlace());
auto out = out_t->data<T>();
auto normalized = ctx.Attr<bool>("normalized"); auto normalized = ctx.Attr<bool>("normalized");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
.stream(); .stream();
auto m = x1_t->numel(); auto hyp_lod = x1_t->lod()[0];
auto n = x2_t->numel(); auto ref_lod = x2_t->lod()[0];
T distance = 0.0; PADDLE_ENFORCE(
if (m == 0 || n == 0) { hyp_lod.size() == ref_lod.size(),
distance = std::max(m, n); "Input(Hyps) and Input(Refs) must have the same batch size.");
if (normalized) { for (size_t i = 1; i < ref_lod.size(); ++i) {
distance = distance / n; PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
} "Reference string %d is empty.", i);
memory::Copy(boost::get<Place>(ctx.GetPlace()), out, platform::CPUPlace(), }
&distance, sizeof(T), stream);
} else { auto num_strs = hyp_lod.size() - 1;
framework::Tensor dist_t; out_t->Resize({static_cast<int64_t>(num_strs), 1});
dist_t.Resize({m + 1, n + 1}); out_t->mutable_data<T>(ctx.GetPlace());
dist_t.mutable_data<T>(ctx.GetPlace()); auto out = out_t->data<T>();
auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int>(); std::vector<T> distance(num_strs, 0.0);
auto x2 = x2_t->data<int>(); for (size_t num = 0; num < num_strs; num++) {
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]);
FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS, auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]);
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); if (m == 0 || n == 0) {
distance[num] = std::max(m, n);
FillFirstRow<T><<<1 + n / PADDLE_CUDA_NUM_THREADS, if (normalized) {
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); PADDLE_ENFORCE(n > 0,
// Compute the elements of distance matrix in the anti-diagonal diretion "The reference string (#%d) cannot be empty "
for (int64_t slice = 2; slice < m + n + 1; ++slice) { "when Attr(normalized) is enabled.",
int z_m = slice < m + 1 ? 0 : slice - m; n);
int z_n = slice < n + 1 ? 0 : slice - n; distance[num] = distance[num] / n;
int size = slice - (z_m + z_n) + 1; // number of elments in the same }
// anti-diagonal line to update memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num,
// the start index at which computes from platform::CPUPlace(), &distance[num], sizeof(T), stream);
int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; } else {
Levenshtein<T><<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, framework::Tensor dist_t;
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m, dist_t.Resize({m + 1, n + 1});
n, start); dist_t.mutable_data<T>(ctx.GetPlace());
auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int>() + hyp_lod[num];
auto x2 = x2_t->data<int>() + ref_lod[num];
FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n);
FillFirstRow<T><<<1 + n / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n);
// Compute the elements of distance matrix in the anti-diagonal diretion
for (int64_t slice = 2; slice < m + n + 1; ++slice) {
int z_m = slice < m + 1 ? 0 : slice - m;
int z_n = slice < n + 1 ? 0 : slice - n;
int size = slice - (z_m + z_n) + 1; // number of elments in the same
// anti-diagonal line to update
// the start index at which computes from
int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1;
Levenshtein<T><<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2,
m, n, start);
}
SetOutput<T><<<1, 1, 0, stream>>>(out + num, dist, m, n, normalized);
} }
SetOutput<T><<<1, 1, 0, stream>>>(out, dist, m, n, normalized);
} }
} }
}; };
......
...@@ -26,50 +26,69 @@ class EditDistanceKernel : public framework::OpKernel<T> { ...@@ -26,50 +26,69 @@ class EditDistanceKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out"); auto* out_t = ctx.Output<framework::Tensor>("Out");
auto* x1_t = ctx.Input<framework::Tensor>("Hyp"); auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
auto* x2_t = ctx.Input<framework::Tensor>("Ref"); auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
auto normalized = ctx.Attr<bool>("normalized");
auto hyp_lod = x1_t->lod()[0];
auto ref_lod = x2_t->lod()[0];
PADDLE_ENFORCE(
hyp_lod.size() == ref_lod.size(),
"Input(Hyps) and Input(Refs) must have the same batch size.");
for (size_t i = 1; i < ref_lod.size(); ++i) {
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
"Reference string %d is empty.", i);
}
auto num_strs = hyp_lod.size() - 1;
out_t->Resize({static_cast<int64_t>(num_strs), 1});
out_t->mutable_data<float>(ctx.GetPlace()); out_t->mutable_data<float>(ctx.GetPlace());
auto out = out_t->data<T>();
auto normalized = ctx.Attr<bool>("normalized"); std::vector<T> distance(num_strs, 0.0);
for (size_t num = 0; num < num_strs; ++num) {
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]);
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]);
auto m = x1_t->numel(); if (m == 0) {
auto n = x2_t->numel(); distance[num] = n;
T distance = 0.0; } else if (n == 0) {
if (m == 0) { distance[num] = m;
distance = n; } else {
} else if (n == 0) { framework::Tensor dist_t;
distance = m; dist_t.Resize({m + 1, n + 1});
} else { dist_t.mutable_data<T>(ctx.GetPlace());
framework::Tensor dist_t; auto dist = dist_t.data<T>();
dist_t.Resize({m + 1, n + 1}); auto x1 = x1_t->data<int>() + hyp_lod[num];
dist_t.mutable_data<T>(ctx.GetPlace()); auto x2 = x2_t->data<int>() + ref_lod[num];
auto dist = dist_t.data<T>(); for (int64_t i = 0; i < m + 1; ++i) {
auto x1 = x1_t->data<int>(); dist[i * (n + 1)] = i;
auto x2 = x2_t->data<int>(); }
for (int64_t i = 0; i < m + 1; ++i) { for (int64_t j = 0; j < n + 1; ++j) {
dist[i * (n + 1)] = i; dist[j] = j;
} }
for (int64_t j = 0; j < n + 1; ++j) { for (int64_t i = 1; i < m + 1; ++i) {
dist[j] = j; for (int64_t j = 1; j < n + 1; ++j) {
} int cost = x1[i - 1] == x2[j - 1] ? 0 : 1;
for (int64_t i = 1; i < m + 1; ++i) { int dels = dist[(i - 1) * (n + 1) + j] + 1;
for (int64_t j = 1; j < n + 1; ++j) { int ins = dist[i * (n + 1) + (j - 1)] + 1;
int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost;
int dels = dist[(i - 1) * (n + 1) + j] + 1; dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs));
int ins = dist[i * (n + 1) + (j - 1)] + 1; }
int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost;
dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs));
} }
distance[num] = dist[m * (n + 1) + n];
} }
distance = dist[m * (n + 1) + n];
}
if (normalized) { if (normalized) {
distance = distance / n; PADDLE_ENFORCE(n > 0,
"The reference string (#%d) cannot be empty "
"when Attr(normalized) is enabled.",
n);
distance[num] = distance[num] / n;
}
out[num] = distance[num];
} }
auto out = out_t->data<T>();
out[0] = distance;
} }
}; };
......
...@@ -18,7 +18,7 @@ def Levenshtein(hyp, ref): ...@@ -18,7 +18,7 @@ def Levenshtein(hyp, ref):
if n == 0: if n == 0:
return m return m
dist = np.zeros((m + 1, n + 1)) dist = np.zeros((m + 1, n + 1)).astype("float32")
for i in range(0, m + 1): for i in range(0, m + 1):
dist[i][0] = i dist[i][0] = i
for j in range(0, n + 1): for j in range(0, n + 1):
...@@ -35,17 +35,55 @@ def Levenshtein(hyp, ref): ...@@ -35,17 +35,55 @@ def Levenshtein(hyp, ref):
class TestCTCEditDistanceOp(OpTest): class TestCTCEditDistanceOp(OpTest):
def setUp(self):
self.op_type = "edit_distance"
normalized = False
x1 = np.array([[0, 12, 3, 5, 8, 2]]).astype("int32")
x2 = np.array([[0, 12, 4, 7, 8]]).astype("int32")
x1 = np.transpose(x1)
x2 = np.transpose(x2)
x1_lod = [0, 1, 5]
x2_lod = [0, 3, 4]
num_strs = len(x1_lod) - 1
distance = np.zeros((num_strs, 1)).astype("float32")
for i in range(0, num_strs):
distance[i] = Levenshtein(
hyp=x1[x1_lod[i]:x1_lod[i + 1]],
ref=x2[x2_lod[i]:x2_lod[i + 1]])
if normalized is True:
len_ref = x2_lod[i + 1] - x2_lod[i]
distance[i] = distance[i] / len_ref
self.attrs = {'normalized': normalized}
self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])}
self.outputs = {'Out': distance}
def test_check_output(self):
self.check_output()
class TestCTCEditDistanceOpNormalized(OpTest):
def setUp(self): def setUp(self):
self.op_type = "edit_distance" self.op_type = "edit_distance"
normalized = True normalized = True
x1 = np.array([0, 12, 3, 5]).astype("int32") x1 = np.array([[0, 10, 3, 6, 5, 8, 2]]).astype("int32")
x2 = np.array([0, 12, 4, 7, 8]).astype("int32") x2 = np.array([[0, 10, 4, 6, 7, 8]]).astype("int32")
x1 = np.transpose(x1)
x2 = np.transpose(x2)
x1_lod = [0, 1, 3, 6]
x2_lod = [0, 2, 3, 5]
distance = Levenshtein(hyp=x1, ref=x2) num_strs = len(x1_lod) - 1
if normalized is True: distance = np.zeros((num_strs, 1)).astype("float32")
distance = distance / len(x2) for i in range(0, num_strs):
distance[i] = Levenshtein(
hyp=x1[x1_lod[i]:x1_lod[i + 1]],
ref=x2[x2_lod[i]:x2_lod[i + 1]])
if normalized is True:
len_ref = x2_lod[i + 1] - x2_lod[i]
distance[i] = distance[i] / len_ref
self.attrs = {'normalized': normalized} self.attrs = {'normalized': normalized}
self.inputs = {'Hyp': x1, 'Ref': x2} self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])}
self.outputs = {'Out': distance} self.outputs = {'Out': distance}
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册