From cef27dab47b430ce4034cfcfedf0c6bc95266f51 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Thu, 17 Aug 2017 19:14:27 -0700 Subject: [PATCH] "add fixl" --- paddle/operators/rowwise_add_op.cc | 1 + python/paddle/v2/framework/tests/test_rowwise_add_op.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 82e5df591d0..f07dd8f602c 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -63,6 +63,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { "Input(Out@GRAD) should not be null"); auto dims0 = ctx.Input("X")->dims(); auto dims1 = ctx.Input("b")->dims(); + PADDLE_ENFORCE_EQ(1, framework::product(dims1), "b dims should be 1") ctx.Output(framework::GradVarName("X"))->Resize(dims0); ctx.Output(framework::GradVarName("b"))->Resize(dims1); } diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index 8118d2d7418..29d72e85009 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -21,12 +21,10 @@ class RowwiseAddGradOpTest(GradientChecker): op = create_op("rowwise_add") inputs = { "X": np.random.uniform(0.1, 1, [10, 10]).astype("float32"), - "b": np.random.uniform(0.1, 1, [10, 1]).astype("float32") + "b": np.random.uniform(0.1, 1, [10]).astype("float32") } self.check_grad(op, inputs, set(["X", "b"]), "Out") -#TODO(dzh): rowwise_grad check - if __name__ == '__main__': unittest.main() -- GitLab