From 1e676f684d58cfce90f194f85c422306543543da Mon Sep 17 00:00:00 2001 From: liaogang Date: Tue, 1 Aug 2017 16:10:52 +0800 Subject: [PATCH] Add mean op unit test in python --- paddle/operators/mean_op.cu | 5 +++-- paddle/operators/mean_op.h | 4 ++-- paddle/pybind/CMakeLists.txt | 11 +++++++++-- paddle/pybind/pybind.cc | 1 + python/paddle/v2/framework/tests/CMakeLists.txt | 1 + python/paddle/v2/framework/tests/test_mean_op.py | 16 ++++++++++++++++ 6 files changed, 32 insertions(+), 6 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_mean_op.py diff --git a/paddle/operators/mean_op.cu b/paddle/operators/mean_op.cu index 4dbb566b1d7..740157cbc57 100644 --- a/paddle/operators/mean_op.cu +++ b/paddle/operators/mean_op.cu @@ -1,4 +1,5 @@ -#include "paddle/framework/op_registry.h" +#define EIGEN_USE_GPU + #include "paddle/operators/mean_op.h" -REGISTER_OP_GPU_KERNEL(mean, ops::AddKernel); +REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel); diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index 21fa5796438..483b3eb6015 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -26,8 +26,8 @@ public: auto output = context.Output(0)->GetMutable(); output->mutable_data(context.GetPlace()); - EigenVector::Flatten(*output).device( - *(context.GetEigenDevice())) = + + EigenScalar::From(*output).device(*(context.GetEigenDevice())) = EigenVector::Flatten(input).mean(); } }; diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 7d0e68a8f30..845589dcb19 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,2 +1,9 @@ -cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python - add_op fc_op sgd_op cross_entropy_op recurrent_network_op) +cc_library(paddle_pybind SHARED + SRCS pybind.cc + DEPS pybind python + fc_op + sgd_op + add_op + mean_op + cross_entropy_op + recurrent_network_op) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 08a8bd0d8b6..4fa481bedf5 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -33,6 +33,7 @@ USE_OP(onehot_cross_entropy); USE_OP_WITHOUT_KERNEL(fc); USE_OP(sgd); USE_OP(mul); +USE_OP(mean); USE_OP(sigmoid); USE_OP(softmax); USE_OP(rowwise_add); diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index cdaaa606749..540636a0e81 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -10,6 +10,7 @@ add_python_test(test_framework test_sgd_op.py test_cross_entropy_op.py test_mul_op.py + test_mean_op.py test_sigmoid_op.py test_softmax_op.py test_rowwise_add_op.py diff --git a/python/paddle/v2/framework/tests/test_mean_op.py b/python/paddle/v2/framework/tests/test_mean_op.py new file mode 100644 index 00000000000..78fff1eeff9 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_mean_op.py @@ -0,0 +1,16 @@ +import unittest +from op_test_util import OpTestMeta +import numpy as np + + +class TestMeanOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "mean" + self.X = np.random.random((32, 784)).astype("float32") + self.Out = np.mean(self.X) + + +if __name__ == '__main__': + unittest.main() -- GitLab