提交 8f6c8780 编写于 作者: D dangqingqing

Replace functor by function.

上级 70285cce
......@@ -21,19 +21,18 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct clipping_log {
__host__ __device__ T operator()(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
__host__ __device__ T clipping_log(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
T v = log(x);
if (v == INFINITY) {
return kApproInf;
}
};
if (v == -INFINITY) {
return -kApproInf;
}
return v;
}
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
......@@ -43,7 +42,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
Y[i] = -clipping_log<T>()(X[i * D + label[i]]);
Y[i] = -clipping_log(X[i * D + label[i]]);
}
}
......
......@@ -21,7 +21,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
T tolerable_value(const T x) {
inline T tolerable_value(const T x) {
static_assert(std::is_floating_point<T>::value,
"tolerable_value works only on float, "
"double and double double.");
......
......@@ -65,7 +65,7 @@ class OpTestMeta(type):
expect = self.outputs[out_name]
self.assertTrue(
numpy.allclose(
actual, expect, atol=1e-04),
actual, expect, atol=1e-05),
"output name: " + out_name + "has diff")
obj.test_all = test_all
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册