From 1e676f684d58cfce90f194f85c422306543543da Mon Sep 17 00:00:00 2001
From: liaogang <liaogang@baidu.com>
Date: Tue, 1 Aug 2017 16:10:52 +0800
Subject: [PATCH] Add mean op unit test in python

---
 paddle/operators/mean_op.cu                      |  5 +++--
 paddle/operators/mean_op.h                       |  4 ++--
 paddle/pybind/CMakeLists.txt                     | 11 +++++++++--
 paddle/pybind/pybind.cc                          |  1 +
 python/paddle/v2/framework/tests/CMakeLists.txt  |  1 +
 python/paddle/v2/framework/tests/test_mean_op.py | 16 ++++++++++++++++
 6 files changed, 32 insertions(+), 6 deletions(-)
 create mode 100644 python/paddle/v2/framework/tests/test_mean_op.py

diff --git a/paddle/operators/mean_op.cu b/paddle/operators/mean_op.cu
index 4dbb566b1..740157cbc 100644
--- a/paddle/operators/mean_op.cu
+++ b/paddle/operators/mean_op.cu
@@ -1,4 +1,5 @@
-#include "paddle/framework/op_registry.h"
+#define EIGEN_USE_GPU
+
 #include "paddle/operators/mean_op.h"
 
-REGISTER_OP_GPU_KERNEL(mean, ops::AddKernel<ops::GPUPlace, float>);
+REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<ops::GPUPlace, float>);
diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h
index 21fa57964..483b3eb60 100644
--- a/paddle/operators/mean_op.h
+++ b/paddle/operators/mean_op.h
@@ -26,8 +26,8 @@ public:
     auto output = context.Output(0)->GetMutable<Tensor>();
 
     output->mutable_data<T>(context.GetPlace());
-    EigenVector<T>::Flatten(*output).device(
-        *(context.GetEigenDevice<Place>())) =
+
+    EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
         EigenVector<T>::Flatten(input).mean();
   }
 };
diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt
index 7d0e68a8f..845589dcb 100644
--- a/paddle/pybind/CMakeLists.txt
+++ b/paddle/pybind/CMakeLists.txt
@@ -1,2 +1,9 @@
-cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
-        add_op fc_op sgd_op cross_entropy_op recurrent_network_op)
+cc_library(paddle_pybind SHARED
+    SRCS pybind.cc
+    DEPS pybind python
+	fc_op
+	sgd_op
+	add_op
+	mean_op
+	cross_entropy_op
+	recurrent_network_op)
diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc
index 08a8bd0d8..4fa481bed 100644
--- a/paddle/pybind/pybind.cc
+++ b/paddle/pybind/pybind.cc
@@ -33,6 +33,7 @@ USE_OP(onehot_cross_entropy);
 USE_OP_WITHOUT_KERNEL(fc);
 USE_OP(sgd);
 USE_OP(mul);
+USE_OP(mean);
 USE_OP(sigmoid);
 USE_OP(softmax);
 USE_OP(rowwise_add);
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index cdaaa6067..540636a0e 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -10,6 +10,7 @@ add_python_test(test_framework
     test_sgd_op.py
     test_cross_entropy_op.py
     test_mul_op.py
+    test_mean_op.py
     test_sigmoid_op.py
     test_softmax_op.py
     test_rowwise_add_op.py
diff --git a/python/paddle/v2/framework/tests/test_mean_op.py b/python/paddle/v2/framework/tests/test_mean_op.py
new file mode 100644
index 000000000..78fff1eef
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_mean_op.py
@@ -0,0 +1,16 @@
+import unittest
+from op_test_util import OpTestMeta
+import numpy as np
+
+
+class TestMeanOp(unittest.TestCase):
+    __metaclass__ = OpTestMeta
+
+    def setUp(self):
+        self.type = "mean"
+        self.X = np.random.random((32, 784)).astype("float32")
+        self.Out = np.mean(self.X)
+
+
+if __name__ == '__main__':
+    unittest.main()
-- 
GitLab