From 52d2ebdaef66f980c8ecb4878d41da6b44467115 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 8 Aug 2017 16:40:21 +0800 Subject: [PATCH] "test gaussian random in python side" --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/pybind.cc | 1 + .../paddle/v2/framework/tests/CMakeLists.txt | 2 ++ .../tests/test_gaussian_random_op.py | 33 +++++++++++++++++++ .../v2/framework/tests/test_random_op.py | 29 ---------------- 5 files changed, 37 insertions(+), 29 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_gaussian_random_op.py delete mode 100644 python/paddle/v2/framework/tests/test_random_op.py diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 1db042c6f..9b96a5918 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -43,4 +43,5 @@ cc_library(paddle_pybind SHARED add_op mean_op cross_entropy_op + gaussian_random_op recurrent_op) diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index cbb86c419..85548e3e9 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -41,6 +41,7 @@ USE_OP(sigmoid); USE_OP(softmax); USE_OP(rowwise_add); USE_OP_WITHOUT_KERNEL(recurrent_op); +USE_OP(gaussian_random); namespace paddle { namespace framework { template diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 7eec37678..5a8998411 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -21,3 +21,5 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_op_creation_methods SRCS test_op_creation_methods.py) + +py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/framework/tests/test_gaussian_random_op.py new file mode 100644 index 000000000..020e69fe1 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gaussian_random_op.py @@ -0,0 +1,33 @@ +import unittest +import paddle.v2.framework.core as core +import paddle.v2.framework.op as Operator +import numpy + + +class GaussianRandomTest(unittest.TestCase): + def test_cpu(self): + self.test_gaussian_random(place=core.CPUPlace()) + + def test_gpu(self): + self.test_gaussian_random(place=core.GPUPlace(0)) + + def test_gaussian_random(self, place): + scope = core.Scope() + scope.new_var("Out").get_tensor() + op = Operator( + "gaussian_random", + Out="Out", + dims=[1000, 784], + mean=.0, + std=1., + seed=0) + op.infer_shape(scope) + context = core.DeviceContext.create(place) + op.run(scope, context) + tensor = numpy.array(scope.find_var("Out").get_tensor()) + self.assertAlmostEqual(numpy.mean(tensor), .0, places=3) + self.assertAlmostEqual(numpy.std(tensor), 1., places=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_random_op.py b/python/paddle/v2/framework/tests/test_random_op.py deleted file mode 100644 index d3474880d..000000000 --- a/python/paddle/v2/framework/tests/test_random_op.py +++ /dev/null @@ -1,29 +0,0 @@ -import unittest -import paddle.v2.framework.create_op_creation_methods as creation -import paddle.v2.framework.core as core -from op_test_util import OpTestMeta -import numpy - - -class TestRandomOp(unittest.TestCase): - def test_random(self): - scope = core.Scope(None) - # Out = scope.create_var("Out") - op = creation.op_creations.gaussian_random( - shape=[1000, 1000], mean=5.0, std=1.0, Out="Out") - for out in op.outputs(): - if scope.get_var(out) is None: - scope.create_var(out).get_tensor() - - tensor = scope.get_var("Out").get_tensor() - op.infer_shape(scope) - self.assertEqual([1000, 1000], tensor.shape()) - ctx = core.DeviceContext.cpu_context() - op.run(scope, ctx) - tensor_array = numpy.array(tensor) - self.assertAlmostEqual(numpy.mean(tensor_array), 5.0, places=3) - self.assertAlmostEqual(numpy.std(tensor_array), 1.0, places=3) - - -if __name__ == '__main__': - unittest.main() -- GitLab