From 7f8c3f82145dd02cf7d136f27de42a6f0a56024b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 17 Aug 2017 18:02:20 +0800 Subject: [PATCH] Add MeanOp's Gradient Test And Fix Mean Op Gradient --- paddle/operators/mean_op.h | 3 ++- python/paddle/v2/framework/tests/test_mean_op.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index fcb703e63bd..9848af280b6 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -55,9 +55,10 @@ class MeanGradKernel : public framework::OpKernel { IG->mutable_data(context.GetPlace()); T ig_size = (T)framework::product(IG->dims()); + Eigen::DSizes bcast(ig_size); EigenVector::Flatten(*IG).device(context.GetEigenDevice()) = - EigenScalar::From(*OG) / ig_size; + (EigenVector::From(*OG) / ig_size).broadcast(bcast); } }; diff --git a/python/paddle/v2/framework/tests/test_mean_op.py b/python/paddle/v2/framework/tests/test_mean_op.py index b5d52b90567..f32b3160d65 100644 --- a/python/paddle/v2/framework/tests/test_mean_op.py +++ b/python/paddle/v2/framework/tests/test_mean_op.py @@ -1,5 +1,6 @@ import unittest from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op import numpy as np @@ -12,5 +13,12 @@ class TestMeanOp(unittest.TestCase): self.outputs = {'Out': np.mean(self.inputs['X'])} +class MeanGradOpTest(GradientChecker): + def test_normal(self): + op = create_op("mean") + inputs = {"X": np.random.random((10, 10)).astype("float32")} + self.check_grad(op, inputs, set("X"), "Out") + + if __name__ == '__main__': unittest.main() -- GitLab