提交 aa3de357 编写于 作者: Y Yu Yang 提交者: GitHub

Polish unit test for xe, generate probablities (#5096)

* Cross Entropy Wrong

* Fix XE

* Polish gradient check for xe

* Fix compile
上级 efc2464f
...@@ -162,6 +162,8 @@ or not. But the output only shares the LoD with input `X`. ...@@ -162,6 +162,8 @@ or not. But the output only shares the LoD with input `X`.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
cross_entropy_grad, ops::CrossEntropyGradientOp); cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>); REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>,
ops::CrossEntropyOpKernel<double>);
REGISTER_OP_CPU_KERNEL(cross_entropy_grad, REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpKernel<float>); ops::CrossEntropyGradientOpKernel<float>,
ops::CrossEntropyGradientOpKernel<double>);
...@@ -108,6 +108,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> { ...@@ -108,6 +108,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>); REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
ops::CrossEntropyOpCUDAKernel<double>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad, REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpCUDAKernel<float>); ops::CrossEntropyGradientOpCUDAKernel<float>,
ops::CrossEntropyGradientOpCUDAKernel<double>);
...@@ -54,6 +54,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> { ...@@ -54,6 +54,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> {
}; };
template class CrossEntropyFunctor<platform::CPUPlace, float>; template class CrossEntropyFunctor<platform::CPUPlace, float>;
template class CrossEntropyFunctor<platform::CPUPlace, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -39,11 +39,36 @@ __device__ __forceinline__ T sum_single_warp(T val) { ...@@ -39,11 +39,36 @@ __device__ __forceinline__ T sum_single_warp(T val) {
return val; return val;
} }
// CUDA do not support dynamic arrary in template
// https://stackoverflow.com/questions/20497209
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ T* GetPointer() { return NULL; }
};
template <>
struct SharedMemory<float> {
__device__ float* GetPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ double* GetPointer() {
extern __shared__ double s_double[];
return s_double;
}
};
template <typename T> template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) { const int class_num) {
int tid = threadIdx.x; int tid = threadIdx.x;
extern __shared__ T d_sum[]; SharedMemory<T> d_sum_shared;
T* d_sum = d_sum_shared.GetPointer();
d_sum[tid] = 0; d_sum[tid] = 0;
int cur_idx = tid; int cur_idx = tid;
...@@ -102,6 +127,7 @@ class CrossEntropyFunctor<platform::GPUPlace, T> { ...@@ -102,6 +127,7 @@ class CrossEntropyFunctor<platform::GPUPlace, T> {
}; };
template class CrossEntropyFunctor<platform::GPUPlace, float>; template class CrossEntropyFunctor<platform::GPUPlace, float>;
template class CrossEntropyFunctor<platform::GPUPlace, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -8,6 +8,15 @@ from paddle.v2.framework.executor import Executor ...@@ -8,6 +8,15 @@ from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import Program, OpProtoHolder from paddle.v2.framework.framework import Program, OpProtoHolder
def randomize_probability(batch_size, class_num, dtype='float32'):
prob = np.random.uniform(
0.1, 1.0, size=(batch_size, class_num)).astype(dtype)
prob_sum = prob.sum(axis=1)
for i in xrange(len(prob)):
prob[i] /= prob_sum[i]
return prob
def grad_var_name(var_name): def grad_var_name(var_name):
return var_name + "@GRAD" return var_name + "@GRAD"
...@@ -233,7 +242,7 @@ def append_input_output(block, op_proto, np_list, is_input): ...@@ -233,7 +242,7 @@ def append_input_output(block, op_proto, np_list, is_input):
if (var_name not in np_list) and var_proto.dispensable: if (var_name not in np_list) and var_proto.dispensable:
continue continue
assert (var_name in np_list) or (var_proto.dispensable), \ assert (var_name in np_list) or (var_proto.dispensable), \
"Missing {} as input".format(var_name) "Missing {} as input".format(var_name)
if var_proto.duplicable: if var_proto.duplicable:
assert isinstance(np_list[var_name], list), \ assert isinstance(np_list[var_name], list), \
"Duplicable {} should be set as list".format(var_name) "Duplicable {} should be set as list".format(var_name)
...@@ -379,9 +388,9 @@ class OpTest(unittest.TestCase): ...@@ -379,9 +388,9 @@ class OpTest(unittest.TestCase):
def err_msg(): def err_msg():
offset = np.argmax(diff_mat > max_relative_error) offset = np.argmax(diff_mat > max_relative_error)
return ("%s Variable %s max gradient diff %f over limit %f, " return ("%s Variable %s max gradient diff %f over limit %f, "
"the first error element is %d") % ( "the first error element is %d, %f, %f") % (
msg_prefix, name, max_diff, max_relative_error, msg_prefix, name, max_diff, max_relative_error,
offset) offset, a.flatten()[offset], b.flatten()[offset])
self.assertLessEqual(max_diff, max_relative_error, err_msg()) self.assertLessEqual(max_diff, max_relative_error, err_msg())
...@@ -389,6 +398,7 @@ class OpTest(unittest.TestCase): ...@@ -389,6 +398,7 @@ class OpTest(unittest.TestCase):
inputs_to_check, inputs_to_check,
output_names, output_names,
no_grad_set=None, no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False, in_place=False,
max_relative_error=0.005, max_relative_error=0.005,
user_defined_grads=None): user_defined_grads=None):
...@@ -411,6 +421,7 @@ class OpTest(unittest.TestCase): ...@@ -411,6 +421,7 @@ class OpTest(unittest.TestCase):
self.inputs, self.inputs,
input_to_check, input_to_check,
output_names, output_names,
delta=numeric_grad_delta,
in_place=in_place) for input_to_check in inputs_to_check in_place=in_place) for input_to_check in inputs_to_check
] ]
grad_names = [ grad_names = [
......
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest, randomize_probability
class TestCrossEntropyOp1(OpTest): class TestCrossEntropyOp1(OpTest):
...@@ -12,12 +12,12 @@ class TestCrossEntropyOp1(OpTest): ...@@ -12,12 +12,12 @@ class TestCrossEntropyOp1(OpTest):
batch_size = 30 batch_size = 30
class_num = 10 class_num = 10
X = np.random.uniform(0.1, 1.0, X = randomize_probability(batch_size, class_num, dtype='float64')
[batch_size, class_num]).astype("float32")
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32") label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32")
cross_entropy = np.asmatrix( cross_entropy = np.asmatrix(
[[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])], [[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])],
dtype="float32") dtype="float64")
self.inputs = {"X": X, "Label": label} self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
...@@ -27,7 +27,7 @@ class TestCrossEntropyOp1(OpTest): ...@@ -27,7 +27,7 @@ class TestCrossEntropyOp1(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Y") self.check_grad(["X"], "Y", numeric_grad_delta=0.001)
class TestCrossEntropyOp2(OpTest): class TestCrossEntropyOp2(OpTest):
...@@ -39,8 +39,7 @@ class TestCrossEntropyOp2(OpTest): ...@@ -39,8 +39,7 @@ class TestCrossEntropyOp2(OpTest):
batch_size = 5 batch_size = 5
class_num = 37 class_num = 37
X = np.random.uniform(0.1, 1.0, X = randomize_probability(batch_size, class_num)
[batch_size, class_num]).astype("float32")
label = np.random.uniform(0.1, 1.0, label = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
label /= label.sum(axis=1, keepdims=True) label /= label.sum(axis=1, keepdims=True)
...@@ -55,7 +54,8 @@ class TestCrossEntropyOp2(OpTest): ...@@ -55,7 +54,8 @@ class TestCrossEntropyOp2(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.05) self.check_grad(
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
class TestCrossEntropyOp3(OpTest): class TestCrossEntropyOp3(OpTest):
...@@ -67,8 +67,7 @@ class TestCrossEntropyOp3(OpTest): ...@@ -67,8 +67,7 @@ class TestCrossEntropyOp3(OpTest):
batch_size = 5 batch_size = 5
class_num = 17 class_num = 17
X = np.random.uniform(0.1, 1.0, X = randomize_probability(batch_size, class_num)
[batch_size, class_num]).astype("float32")
label_index = np.random.randint( label_index = np.random.randint(
0, class_num, (batch_size), dtype="int32") 0, class_num, (batch_size), dtype="int32")
label = np.zeros(X.shape) label = np.zeros(X.shape)
...@@ -88,7 +87,8 @@ class TestCrossEntropyOp3(OpTest): ...@@ -88,7 +87,8 @@ class TestCrossEntropyOp3(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.05) self.check_grad(
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册