From 116687a8ee8dab5938f8783428b4b5f416a443f5 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 29 Nov 2017 09:07:15 +0000 Subject: [PATCH] clean up code in ctc_edit_distance_op --- paddle/operators/ctc_edit_distance_op.cc | 10 +++- paddle/operators/ctc_edit_distance_op.cu | 59 ++++++++++--------- paddle/operators/ctc_edit_distance_op.h | 16 ++--- .../fluid/tests/test_ctc_edit_distance_op.py | 8 +-- 4 files changed, 49 insertions(+), 44 deletions(-) diff --git a/paddle/operators/ctc_edit_distance_op.cc b/paddle/operators/ctc_edit_distance_op.cc index fae5cfc117..d2f4ce67c2 100644 --- a/paddle/operators/ctc_edit_distance_op.cc +++ b/paddle/operators/ctc_edit_distance_op.cc @@ -27,6 +27,13 @@ class CTCEditDistanceOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); ctx->SetOutputDim("Out", {1}); } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(framework::DataType::FP32, + ctx.device_context()); + } }; class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { @@ -70,5 +77,4 @@ REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp, ops::CTCEditDistanceOpMaker); REGISTER_OP_CPU_KERNEL( ctc_edit_distance, - ops::CTCEditDistanceKernel, - ops::CTCEditDistanceKernel); + ops::CTCEditDistanceKernel); diff --git a/paddle/operators/ctc_edit_distance_op.cu b/paddle/operators/ctc_edit_distance_op.cu index 872268296e..22871acc4e 100644 --- a/paddle/operators/ctc_edit_distance_op.cu +++ b/paddle/operators/ctc_edit_distance_op.cu @@ -39,7 +39,7 @@ __global__ void FillFirstColumn(T* dist, const int M, const int N) { } template -__global__ void Levenshtein(T* dist, const T* x1, const T* x2, const int M, +__global__ void Levenshtein(T* dist, const int* x1, const int* x2, const int M, const int N, const int start) { int idx = blockDim.x * blockIdx.x + threadIdx.x; int offset = N; @@ -55,6 +55,15 @@ __global__ void Levenshtein(T* dist, const T* x1, const T* x2, const int M, } } +template +__global__ void SetOutput(T* out, const T* dist, const int M, const int N, + bool normalized) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx == 0) { + out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N]; + } +} + template class CTCEditDistanceGPUKernel : public framework::OpKernel { public: @@ -64,7 +73,8 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel { auto* x1_t = ctx.Input("X1"); auto* x2_t = ctx.Input("X2"); - out_t->mutable_data(ctx.GetPlace()); + out_t->mutable_data(ctx.GetPlace()); + auto out = out_t->data(); auto normalized = ctx.Attr("normalized"); auto stream = reinterpret_cast( @@ -73,49 +83,41 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel { auto m = x1_t->numel(); auto n = x2_t->numel(); - T distance = 0; - if (m == 0) { - distance = n; - } else if (n == 0) { - distance = m; + T distance = 0.0; + if (m == 0 || n == 0) { + distance = std::max(m, n); + if (normalized) { + distance = distance / n; + } + memory::Copy(boost::get(ctx.GetPlace()), out, platform::CPUPlace(), + &distance, sizeof(T), stream); } else { framework::Tensor dist_t; dist_t.Resize({m + 1, n + 1}); dist_t.mutable_data(ctx.GetPlace()); auto dist = dist_t.data(); - auto x1 = x1_t->data(); - auto x2 = x2_t->data(); + auto x1 = x1_t->data(); + auto x2 = x2_t->data(); FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); - // compute the elements of distance matrix in the anti-diagonal diretion - for (size_t slice = 2; slice < m + n + 1; ++slice) { + // Compute the elements of distance matrix in the anti-diagonal diretion + for (int64_t slice = 2; slice < m + n + 1; ++slice) { int z_m = slice < m + 1 ? 0 : slice - m; int z_n = slice < n + 1 ? 0 : slice - n; - // number of elments in the same anti-diagonal line - int size = slice - (z_m + z_n) + 1; - int start = slice < n + 1 ? slice : z_n * (n + 1) - 1; + int size = slice - (z_m + z_n) + 1; // number of elments in the same + // anti-diagonal line to update + int start = slice < n + 1 ? slice : z_n * (n + 1) - 1; // start index + Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m, n, start); } - - Place gpu_place = boost::get(ctx.GetPlace()); - memory::Copy(platform::CPUPlace(), &distance, gpu_place, - dist + m * (n + 1) + n, sizeof(T), stream); - } - - if (normalized) { - distance = distance / n; + SetOutput<<<1, 1, 0, stream>>>(out, dist, m, n, normalized); } - auto out = out_t->data(); - Place gpu_place = boost::get(ctx.GetPlace()); - float dist_f = distance; - memory::Copy(gpu_place, out, platform::CPUPlace(), &dist_f, sizeof(float), - stream); } }; @@ -126,5 +128,4 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( ctc_edit_distance, - ops::CTCEditDistanceGPUKernel, - ops::CTCEditDistanceGPUKernel); + ops::CTCEditDistanceGPUKernel); diff --git a/paddle/operators/ctc_edit_distance_op.h b/paddle/operators/ctc_edit_distance_op.h index a52960f1ef..08f29cf24a 100644 --- a/paddle/operators/ctc_edit_distance_op.h +++ b/paddle/operators/ctc_edit_distance_op.h @@ -35,7 +35,7 @@ class CTCEditDistanceKernel : public framework::OpKernel { auto m = x1_t->numel(); auto n = x2_t->numel(); - float distance = 0.0; + T distance = 0.0; if (m == 0) { distance = n; } else if (n == 0) { @@ -45,16 +45,16 @@ class CTCEditDistanceKernel : public framework::OpKernel { dist_t.Resize({m + 1, n + 1}); dist_t.mutable_data(ctx.GetPlace()); auto dist = dist_t.data(); - auto x1 = x1_t->data(); - auto x2 = x2_t->data(); - for (size_t i = 0; i < m + 1; ++i) { + auto x1 = x1_t->data(); + auto x2 = x2_t->data(); + for (int64_t i = 0; i < m + 1; ++i) { dist[i * (n + 1)] = i; } - for (size_t j = 0; j < n + 1; ++j) { + for (int64_t j = 0; j < n + 1; ++j) { dist[j] = j; } - for (size_t i = 1; i < m + 1; ++i) { - for (size_t j = 1; j < n + 1; ++j) { + for (int64_t i = 1; i < m + 1; ++i) { + for (int64_t j = 1; j < n + 1; ++j) { int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; int dels = dist[(i - 1) * (n + 1) + j] + 1; int ins = dist[i * (n + 1) + (j - 1)] + 1; @@ -68,7 +68,7 @@ class CTCEditDistanceKernel : public framework::OpKernel { if (normalized) { distance = distance / n; } - auto out = out_t->data(); + auto out = out_t->data(); out[0] = distance; } }; diff --git a/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py index 6694a6ee29..62c233b34f 100644 --- a/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py +++ b/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py @@ -37,11 +37,9 @@ def Levenshtein(hyp, ref): class TestCTCEditDistanceOp(OpTest): def setUp(self): self.op_type = "ctc_edit_distance" - normalized = False - #x1 = np.array([0, 12, 3, 5]).astype("int64") - #x2 = np.array([0, 12, 4, 7, 8]).astype("int64") - x1 = np.array([0, 12, 5]).astype("int64") - x2 = np.array([0, 12, 4]).astype("int64") + normalized = True + x1 = np.array([0, 12, 3, 5]).astype("int32") + x2 = np.array([0, 12, 4, 7, 8]).astype("int32") distance = Levenshtein(hyp=x1, ref=x2) if normalized is True: -- GitLab