提交 97649bf9 编写于 作者: Z zchen0211

fix codes in scatter

上级 305a94e6
...@@ -24,8 +24,18 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -24,8 +24,18 @@ class ScatterOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
framework::DDim output_dims(ctx.Input<Tensor>("Ref")->dims()); PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Index")->dims().size(), 1,
ctx.Output<Tensor>("Out")->Resize(output_dims); "Update Index should be 1-D.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Ref")->dims().size(),
ctx.Input<Tensor>("Updates")->dims().size(),
"Reference and Updates should have the same shape size");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Updates")->dims()[0],
ctx.Input<Tensor>("Index")->dims()[0],
"Updates and Index should have same batch-size.");
framework::DDim data_dim(ctx.Input<Tensor>("Updates")->dims());
for (int i = 1; i < data_dim.size(); ++i)
PADDLE_ENFORCE_EQ(data_dim[i], ctx.Input<Tensor>("Updates")->dims()[i]);
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("Ref")->dims());
} }
}; };
...@@ -35,13 +45,13 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -35,13 +45,13 @@ class ScatterGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto Updates_grad = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto Updates = ctx.Input<Tensor>("Updates"); auto *Updates = ctx.Input<Tensor>("Updates");
auto Ref_grad = ctx.Output<Tensor>(framework::GradVarName("Ref")); auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto Ref = ctx.Input<Tensor>("Ref"); auto *Ref = ctx.Input<Tensor>("Ref");
Ref_grad->Resize(Ref->dims()); dRef->Resize(Ref->dims());
Updates_grad->Resize(Updates->dims()); dUpdates->Resize(Updates->dims());
} }
}; };
......
...@@ -46,13 +46,13 @@ class ScatterGradientOpKernel : public framework::OpKernel { ...@@ -46,13 +46,13 @@ class ScatterGradientOpKernel : public framework::OpKernel {
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref")); auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Index = ctx.Input<Tensor>("Index"); auto *Index = ctx.Input<Tensor>("Index");
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dRef = dO // In place gradient: dRef = dO
dRef->ShareDataWith<T>(*dO); dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index] // Gradient by Gather: dUpdates += dO[Index]
Gather<T>(ctx.GetPlace(), dO, Index, dUpdates); Gather<T>(ctx.GetPlace(), dOut, Index, dUpdates);
} }
}; };
......
...@@ -82,6 +82,11 @@ def get_numeric_gradient(op, ...@@ -82,6 +82,11 @@ def get_numeric_gradient(op,
def product(dim): def product(dim):
return reduce(lambda a, b: a * b, dim, 1) return reduce(lambda a, b: a * b, dim, 1)
def copy_tensor():
for var_name in input_values:
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
# get the input tensor that we want to get it's numeric gradient. # get the input tensor that we want to get it's numeric gradient.
tensor_to_check = local_scope.find_var(input_to_check).get_tensor() tensor_to_check = local_scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims()) tensor_size = product(tensor_to_check.get_dims())
...@@ -92,9 +97,7 @@ def get_numeric_gradient(op, ...@@ -92,9 +97,7 @@ def get_numeric_gradient(op,
# we use a for loop to compute the gradient of every element. # we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size): for i in xrange(tensor_size):
if in_place: if in_place:
for var_name in input_values: copy_tensor()
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
# get one input element throw it's index i. # get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i) origin = tensor_to_check.get_float_element(i)
...@@ -105,9 +108,7 @@ def get_numeric_gradient(op, ...@@ -105,9 +108,7 @@ def get_numeric_gradient(op,
# plus delta to this element, run op and get the sum of the result tensor. # plus delta to this element, run op and get the sum of the result tensor.
if in_place: if in_place:
for var_name in input_values: copy_tensor()
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
x_neg = origin - delta x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg) tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output() y_neg = get_output()
......
...@@ -30,7 +30,6 @@ class TestScatterGradOp(GradientChecker): ...@@ -30,7 +30,6 @@ class TestScatterGradOp(GradientChecker):
output_np = numpy.copy(ref_np) output_np = numpy.copy(ref_np)
output_np[index_np] += updates_np output_np[index_np] += updates_np
inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np} inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
# check gradient
self.check_grad( self.check_grad(
op, inputs, set(["Updates", "Ref"]), "Out", in_place=True) op, inputs, set(["Updates", "Ref"]), "Out", in_place=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册