提交 2c1adb06 编写于 作者: Y Yibing Liu

Rename ctc_edit_distance_op to edit_distance_op

上级 36ec3e90
......@@ -12,12 +12,12 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/ctc_edit_distance_op.h"
#include "paddle/operators/edit_distance_op.h"
namespace paddle {
namespace operators {
class CTCEditDistanceOp : public framework::OperatorWithKernel {
class EditDistanceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -29,17 +29,16 @@ class CTCEditDistanceOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::DataType::FP32,
return framework::OpKernelType(framework::proto::DataType::FP32,
ctx.device_context());
}
};
class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CTCEditDistanceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X1",
"(2-D tensor with shape [M x 1]) The indices for "
......@@ -54,10 +53,10 @@ class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(false);
AddOutput("Out",
"(2-D tensor with shape [1 x 1]) "
"The output distance of CTCEditDistance operator.");
"The output distance of EditDistance operator.");
AddComment(R"DOC(
CTCEditDistance operator computes the edit distance of two sequences, one named
EditDistance operator computes the edit distance of two sequences, one named
hypothesis with length M and another named reference with length N.
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
......@@ -80,8 +79,7 @@ reference string N.
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp,
ops::CTCEditDistanceOpMaker);
REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
ctc_edit_distance,
ops::CTCEditDistanceKernel<paddle::platform::CPUPlace, float>);
edit_distance, ops::EditDistanceKernel<paddle::platform::CPUPlace, float>);
......@@ -65,7 +65,7 @@ __global__ void SetOutput(T* out, const T* dist, const int M, const int N,
}
template <typename Place, typename T>
class CTCEditDistanceGPUKernel : public framework::OpKernel<T> {
class EditDistanceGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out");
......@@ -110,8 +110,8 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel<T> {
int z_n = slice < n + 1 ? 0 : slice - n;
int size = slice - (z_m + z_n) + 1; // number of elments in the same
// anti-diagonal line to update
int start = slice < n + 1 ? slice : z_n * (n + 1) - 1; // start index
// the start index at which computes from
int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1;
Levenshtein<T><<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m,
n, start);
......@@ -126,6 +126,6 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
ctc_edit_distance,
ops::CTCEditDistanceGPUKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_CUDA_KERNEL(
edit_distance,
ops::EditDistanceGPUKernel<paddle::platform::CUDAPlace, float>);
......@@ -21,7 +21,7 @@ namespace paddle {
namespace operators {
template <typename Place, typename T>
class CTCEditDistanceKernel : public framework::OpKernel<T> {
class EditDistanceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out");
......
......@@ -36,7 +36,7 @@ def Levenshtein(hyp, ref):
class TestCTCEditDistanceOp(OpTest):
def setUp(self):
self.op_type = "ctc_edit_distance"
self.op_type = "edit_distance"
normalized = True
x1 = np.array([0, 12, 3, 5]).astype("int32")
x2 = np.array([0, 12, 4, 7, 8]).astype("int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册