提交 97649bf9 编写于 作者: Z zchen0211

fix codes in scatter

上级 305a94e6
......@@ -24,8 +24,18 @@ class ScatterOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
framework::DDim output_dims(ctx.Input<Tensor>("Ref")->dims());
ctx.Output<Tensor>("Out")->Resize(output_dims);
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Index")->dims().size(), 1,
"Update Index should be 1-D.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Ref")->dims().size(),
ctx.Input<Tensor>("Updates")->dims().size(),
"Reference and Updates should have the same shape size");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Updates")->dims()[0],
ctx.Input<Tensor>("Index")->dims()[0],
"Updates and Index should have same batch-size.");
framework::DDim data_dim(ctx.Input<Tensor>("Updates")->dims());
for (int i = 1; i < data_dim.size(); ++i)
PADDLE_ENFORCE_EQ(data_dim[i], ctx.Input<Tensor>("Updates")->dims()[i]);
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("Ref")->dims());
}
};
......@@ -35,13 +45,13 @@ class ScatterGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto Updates_grad = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto Updates = ctx.Input<Tensor>("Updates");
auto Ref_grad = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto Ref = ctx.Input<Tensor>("Ref");
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Updates = ctx.Input<Tensor>("Updates");
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *Ref = ctx.Input<Tensor>("Ref");
Ref_grad->Resize(Ref->dims());
Updates_grad->Resize(Updates->dims());
dRef->Resize(Ref->dims());
dUpdates->Resize(Updates->dims());
}
};
......
......@@ -46,13 +46,13 @@ class ScatterGradientOpKernel : public framework::OpKernel {
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Index = ctx.Input<Tensor>("Index");
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dRef = dO
dRef->ShareDataWith<T>(*dO);
dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index]
Gather<T>(ctx.GetPlace(), dO, Index, dUpdates);
Gather<T>(ctx.GetPlace(), dOut, Index, dUpdates);
}
};
......
......@@ -82,6 +82,11 @@ def get_numeric_gradient(op,
def product(dim):
return reduce(lambda a, b: a * b, dim, 1)
def copy_tensor():
for var_name in input_values:
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
# get the input tensor that we want to get it's numeric gradient.
tensor_to_check = local_scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims())
......@@ -92,9 +97,7 @@ def get_numeric_gradient(op,
# we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size):
if in_place:
for var_name in input_values:
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
copy_tensor()
# get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i)
......@@ -105,9 +108,7 @@ def get_numeric_gradient(op,
# plus delta to this element, run op and get the sum of the result tensor.
if in_place:
for var_name in input_values:
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
copy_tensor()
x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output()
......
......@@ -30,7 +30,6 @@ class TestScatterGradOp(GradientChecker):
output_np = numpy.copy(ref_np)
output_np[index_np] += updates_np
inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
# check gradient
self.check_grad(
op, inputs, set(["Updates", "Ref"]), "Out", in_place=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册