提交 8f6c8780 编写于 作者: D dangqingqing

Replace functor by function.

上级 70285cce
...@@ -21,19 +21,18 @@ namespace operators { ...@@ -21,19 +21,18 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
struct clipping_log { __host__ __device__ T clipping_log(const T x) {
__host__ __device__ T operator()(const T x) { PADDLE_ASSERT(std::is_floating_point<T>::value);
PADDLE_ASSERT(std::is_floating_point<T>::value); const T kApproInf = 1e20;
const T kApproInf = 1e20; T v = log(x);
if (x == INFINITY) { if (v == INFINITY) {
return kApproInf; return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
} }
}; if (v == -INFINITY) {
return -kApproInf;
}
return v;
}
template <typename T> template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
...@@ -43,7 +42,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, ...@@ -43,7 +42,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) { i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D); PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
Y[i] = -clipping_log<T>()(X[i * D + label[i]]); Y[i] = -clipping_log(X[i * D + label[i]]);
} }
} }
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
T tolerable_value(const T x) { inline T tolerable_value(const T x) {
static_assert(std::is_floating_point<T>::value, static_assert(std::is_floating_point<T>::value,
"tolerable_value works only on float, " "tolerable_value works only on float, "
"double and double double."); "double and double double.");
......
...@@ -65,7 +65,7 @@ class OpTestMeta(type): ...@@ -65,7 +65,7 @@ class OpTestMeta(type):
expect = self.outputs[out_name] expect = self.outputs[out_name]
self.assertTrue( self.assertTrue(
numpy.allclose( numpy.allclose(
actual, expect, atol=1e-04), actual, expect, atol=1e-05),
"output name: " + out_name + "has diff") "output name: " + out_name + "has diff")
obj.test_all = test_all obj.test_all = test_all
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册