From 97649bf9b251707803b2665dedf1ef8f929d8c88 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Fri, 25 Aug 2017 22:08:24 +0000 Subject: [PATCH] fix codes in scatter --- paddle/operators/scatter_op.cc | 26 +++++++++++++------ paddle/operators/scatter_op.h | 6 ++--- .../v2/framework/tests/gradient_checker.py | 13 +++++----- .../v2/framework/tests/test_scatter_op.py | 1 - 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index cf01ef62799..f901edefa22 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -24,8 +24,18 @@ class ScatterOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - framework::DDim output_dims(ctx.Input("Ref")->dims()); - ctx.Output("Out")->Resize(output_dims); + PADDLE_ENFORCE_EQ(ctx.Input("Index")->dims().size(), 1, + "Update Index should be 1-D."); + PADDLE_ENFORCE_EQ(ctx.Input("Ref")->dims().size(), + ctx.Input("Updates")->dims().size(), + "Reference and Updates should have the same shape size"); + PADDLE_ENFORCE_EQ(ctx.Input("Updates")->dims()[0], + ctx.Input("Index")->dims()[0], + "Updates and Index should have same batch-size."); + framework::DDim data_dim(ctx.Input("Updates")->dims()); + for (int i = 1; i < data_dim.size(); ++i) + PADDLE_ENFORCE_EQ(data_dim[i], ctx.Input("Updates")->dims()[i]); + ctx.Output("Out")->Resize(ctx.Input("Ref")->dims()); } }; @@ -35,13 +45,13 @@ class ScatterGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto Updates_grad = ctx.Output(framework::GradVarName("Updates")); - auto Updates = ctx.Input("Updates"); - auto Ref_grad = ctx.Output(framework::GradVarName("Ref")); - auto Ref = ctx.Input("Ref"); + auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); + auto *Updates = ctx.Input("Updates"); + auto *dRef = ctx.Output(framework::GradVarName("Ref")); + auto *Ref = ctx.Input("Ref"); - Ref_grad->Resize(Ref->dims()); - Updates_grad->Resize(Updates->dims()); + dRef->Resize(Ref->dims()); + dUpdates->Resize(Updates->dims()); } }; diff --git a/paddle/operators/scatter_op.h b/paddle/operators/scatter_op.h index c2db3ae37cc..e9595638a86 100644 --- a/paddle/operators/scatter_op.h +++ b/paddle/operators/scatter_op.h @@ -46,13 +46,13 @@ class ScatterGradientOpKernel : public framework::OpKernel { auto *dRef = ctx.Output(framework::GradVarName("Ref")); auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); auto *Index = ctx.Input("Index"); - auto *dO = ctx.Input(framework::GradVarName("Out")); + auto *dOut = ctx.Input(framework::GradVarName("Out")); // In place gradient: dRef = dO - dRef->ShareDataWith(*dO); + dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates += dO[Index] - Gather(ctx.GetPlace(), dO, Index, dUpdates); + Gather(ctx.GetPlace(), dOut, Index, dUpdates); } }; diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index ac37671c77b..abe0b5391ab 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -82,6 +82,11 @@ def get_numeric_gradient(op, def product(dim): return reduce(lambda a, b: a * b, dim, 1) + def copy_tensor(): + for var_name in input_values: + tensor_ = local_scope.find_var(var_name).get_tensor() + tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace()) + # get the input tensor that we want to get it's numeric gradient. tensor_to_check = local_scope.find_var(input_to_check).get_tensor() tensor_size = product(tensor_to_check.get_dims()) @@ -92,9 +97,7 @@ def get_numeric_gradient(op, # we use a for loop to compute the gradient of every element. for i in xrange(tensor_size): if in_place: - for var_name in input_values: - tensor_ = local_scope.find_var(var_name).get_tensor() - tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace()) + copy_tensor() # get one input element throw it's index i. origin = tensor_to_check.get_float_element(i) @@ -105,9 +108,7 @@ def get_numeric_gradient(op, # plus delta to this element, run op and get the sum of the result tensor. if in_place: - for var_name in input_values: - tensor_ = local_scope.find_var(var_name).get_tensor() - tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace()) + copy_tensor() x_neg = origin - delta tensor_to_check.set_float_element(i, x_neg) y_neg = get_output() diff --git a/python/paddle/v2/framework/tests/test_scatter_op.py b/python/paddle/v2/framework/tests/test_scatter_op.py index 861fe6cf89e..c1f94448893 100644 --- a/python/paddle/v2/framework/tests/test_scatter_op.py +++ b/python/paddle/v2/framework/tests/test_scatter_op.py @@ -30,7 +30,6 @@ class TestScatterGradOp(GradientChecker): output_np = numpy.copy(ref_np) output_np[index_np] += updates_np inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np} - # check gradient self.check_grad( op, inputs, set(["Updates", "Ref"]), "Out", in_place=True) -- GitLab