From 8f6c8780a52b3e0a6df85f6d9e3e98366a381692 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Sat, 19 Aug 2017 17:08:04 +0800 Subject: [PATCH] Replace functor by function. --- paddle/operators/cross_entropy_op.cu | 25 +++++++++---------- paddle/operators/cross_entropy_op.h | 2 +- .../paddle/v2/framework/tests/op_test_util.py | 2 +- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 5f5d269267..d999bfce58 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -21,19 +21,18 @@ namespace operators { using Tensor = framework::Tensor; template -struct clipping_log { - __host__ __device__ T operator()(const T x) { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - if (x == INFINITY) { - return kApproInf; - } - if (x == -INFINITY) { - return -kApproInf; - } - return x; +__host__ __device__ T clipping_log(const T x) { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + T v = log(x); + if (v == INFINITY) { + return kApproInf; } -}; + if (v == -INFINITY) { + return -kApproInf; + } + return v; +} template __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, @@ -43,7 +42,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -clipping_log()(X[i * D + label[i]]); + Y[i] = -clipping_log(X[i * D + label[i]]); } } diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index e95f5e1167..eb4d1348de 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -21,7 +21,7 @@ namespace operators { using Tensor = framework::Tensor; template -T tolerable_value(const T x) { +inline T tolerable_value(const T x) { static_assert(std::is_floating_point::value, "tolerable_value works only on float, " "double and double double."); diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index ae23108dfa..3bc05a0fec 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -65,7 +65,7 @@ class OpTestMeta(type): expect = self.outputs[out_name] self.assertTrue( numpy.allclose( - actual, expect, atol=1e-04), + actual, expect, atol=1e-05), "output name: " + out_name + "has diff") obj.test_all = test_all -- GitLab