提交 2cd3fa3e 编写于 作者: S ShenLiang 提交者: Yi Liu

add scatter_nd op and scatter_nd_add op (#19571)

* add scatter_nd op, test=document_preview test=develop

* fixed the document, test=document_preview test=develop

* modify the notes, test=document_preview test=develop

* remove the ShareDataWith, test=develop
上级 364c4442
......@@ -196,6 +196,8 @@ paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale',
paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f985c9b66e3aec96fa753a8eb44c991c'))
paddle.fluid.layers.gather_nd (ArgSpec(args=['input', 'index', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '3cc24f9cf135770aa6263dba25b457f9'))
paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab'))
paddle.fluid.layers.scatter_nd_add (ArgSpec(args=['ref', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'c2fa5ee7484b52b95a28abf1d8827cd0'))
paddle.fluid.layers.scatter_nd (ArgSpec(args=['index', 'updates', 'shape', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '14b5449ce42f8ff4ac4ce79b41c86cc5'))
paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'abe3f714120117a5a3d3e639853932bf'))
paddle.fluid.layers.random_crop (ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)), ('document', '042af0b8abea96b40c22f6e70d99e042'))
paddle.fluid.layers.mean_iou (ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None), ('document', 'e714b4aa7993dfe9c1a38886875dbaac'))
......
......@@ -38,8 +38,9 @@ class GatherNdOp : public framework::OperatorWithKernel {
auto index_dims = ctx->GetInputDim("Index");
auto index_dims_size = index_dims.size();
PADDLE_ENFORCE_LE(index_dims[index_dims_size - 1], x_dims_size,
"Input(Index).shape[-1] <= Input(X).rank");
PADDLE_ENFORCE_LE(
index_dims[index_dims_size - 1], x_dims_size,
"Input(Index).shape[-1] should be no greater than Input(X).rank");
PADDLE_ENFORCE_GE(index_dims_size, 2UL,
"The rank of Input(Index) should be greater than 1");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/scatter_nd_add_op.h"
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
namespace paddle {
namespace operators {
class ScatterNdAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of ScatterNdAddOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
"Input(Index) of ScatterNdAddOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Updates"), true,
"Input(Updates) of ScatterNdAddOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of ScatterNdAddOp should not be null.");
auto ref_dims = ctx->GetInputDim("X");
auto ref_dims_size = ref_dims.size();
auto index_dims = ctx->GetInputDim("Index");
auto index_dims_size = index_dims.size();
auto updates_dims = ctx->GetInputDim("Updates");
auto updates_dims_size = updates_dims.size();
PADDLE_ENFORCE_LE(
index_dims[index_dims_size - 1], ref_dims_size,
"Input(Index).shape[-1] should be no greater than Input(X).rank");
PADDLE_ENFORCE_GE(index_dims_size, 2UL,
"The rank of Input(Index) should be greater than 1");
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
std::vector<int64_t> r_updates_dims;
for (int64_t i = 0; i < index_dims_size - 1; ++i) {
r_updates_dims.emplace_back(index_dims[i]);
}
for (int64_t i = index_dims[index_dims_size - 1]; i < ref_dims_size; ++i) {
r_updates_dims.emplace_back(ref_dims[i]);
}
PADDLE_ENFORCE_EQ(r_updates_dims.size(), updates_dims_size,
"Updates has wrong shape");
for (int64_t i = 0; i < updates_dims_size; ++i) {
PADDLE_ENFORCE_EQ(r_updates_dims[i], updates_dims[i],
"Updates has wrong shape");
}
ctx->SetOutputDim("Out", ref_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->type(),
ctx.Input<Tensor>("Updates")->type(),
"Ref and Updates must have same type");
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
class ScatterNdAddGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput(framework::GradVarName("Updates"))) {
ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->GetInputDim("Updates"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
class ScatterNdAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The source input of scatter_nd_add op");
AddInput("Index",
"The index input of scatter_nd_add op where X will be updated");
AddInput("Updates", "The updated value of scatter_nd_add op");
AddOutput("Out", "The output of scatter_nd_add op");
AddComment(R"DOC(
Scatter_nd_add Operator.
Output is obtained by applying sparse addition to a single value or slice in a Variable.
Given:
* Case 1:
ref = [0, 1, 2, 3, 4, 5]
index = [[1], [2], [3], [1]]
updates = [9, 10, 11, 12]
we get:
output = [0, 22, 12, 14, 4, 5]
* Case 2:
ref = [[65, 17], [-14, -25]]
index = [[], []]
updates = [[[-1, -2], [1, 2]],
[[3, 4], [-3, -4]]]
ref.shape = (2, 2)
index.shape = (2, 0)
updates.shape = (2, 2, 2)
we get:
output = [[67, 19], [-16, -27]]
)DOC");
}
};
class ScatterNdAddGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("scatter_nd_add_grad");
op->SetInput("Index", Input("Index"));
op->SetInput("Updates", Input("Updates"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Updates"), InputGrad("Updates"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ScatterNdAddGradNoNeedBufferVarsInference,
"Updates");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(scatter_nd_add, ops::ScatterNdAddOp, ops::ScatterNdAddOpMaker,
ops::ScatterNdAddGradDescMaker);
REGISTER_OPERATOR(scatter_nd_add_grad, ops::ScatterNdAddGradOp,
ops::ScatterNdAddGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(scatter_nd_add, ops::ScatterNdAddOpKernel<float>,
ops::ScatterNdAddOpKernel<double>,
ops::ScatterNdAddOpKernel<int64_t>,
ops::ScatterNdAddOpKernel<int>,
ops::ScatterNdAddOpKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(scatter_nd_add_grad,
ops::ScatterNdAddGradientOpKernel<float>,
ops::ScatterNdAddGradientOpKernel<double>,
ops::ScatterNdAddGradientOpKernel<int64_t>,
ops::ScatterNdAddGradientOpKernel<int>,
ops::ScatterNdAddGradientOpKernel<uint8_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/gather_op.h"
#include "paddle/fluid/operators/scatter.cu.h"
#include "paddle/fluid/operators/scatter_nd_add_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device.");
auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Index");
auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out");
framework::TensorCopySync(*X, ctx.GetPlace(), Out);
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
GPUScatterNdAdd<DeviceContext, T, int32_t>(ctx, *Updates, *Ids, Out);
} else {
GPUScatterNdAdd<DeviceContext, T, int64_t>(ctx, *Updates, *Ids, Out);
}
}
};
template <typename DeviceContext, typename T>
class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device.");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Index");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) {
// In place gradient: dX = dO
framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather
const auto &index_type = Ids->type();
if (index_type == framework::proto::VarType::INT32) {
GPUGatherNd<DeviceContext, T, int32_t>(ctx, *dOut, *Ids, dUpdates);
} else {
GPUGatherNd<DeviceContext, T, int64_t>(ctx, *dOut, *Ids, dUpdates);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(scatter_nd_add,
ops::ScatterNdAddOpCUDAKernel<CUDA, float>,
ops::ScatterNdAddOpCUDAKernel<CUDA, double>,
ops::ScatterNdAddOpCUDAKernel<CUDA, int64_t>,
ops::ScatterNdAddOpCUDAKernel<CUDA, int>,
ops::ScatterNdAddOpCUDAKernel<CUDA, plat::float16>);
REGISTER_OP_CUDA_KERNEL(scatter_nd_add_grad,
ops::ScatterNdAddGradOpCUDAKernel<CUDA, float>,
ops::ScatterNdAddGradOpCUDAKernel<CUDA, double>,
ops::ScatterNdAddGradOpCUDAKernel<CUDA, int64_t>,
ops::ScatterNdAddGradOpCUDAKernel<CUDA, int>,
ops::ScatterNdAddGradOpCUDAKernel<CUDA, plat::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/scatter.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class ScatterNdAddOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
"This kernel only runs on CPU.");
auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Index");
auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out");
// In place output: Out = X
framework::TensorCopySync(*X, ctx.GetPlace(), Out);
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
ScatterNdAdd<T, int32_t>(ctx, *Updates, *Ids, Out);
} else {
ScatterNdAdd<T, int64_t>(ctx, *Updates, *Ids, Out);
}
}
};
template <typename T>
class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
"This kernel only runs on CPU.");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Index");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) {
// In place gradient: dX = dO
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
const auto &index_type = Ids->type();
if (index_type == framework::proto::VarType::INT32) {
CPUGatherNd<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
} else {
CPUGatherNd<T, int64_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -28,7 +28,7 @@ from ..framework import Variable, OpProtoHolder, in_dygraph_mode
from ..dygraph import base
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
from .tensor import concat, assign, fill_constant
from .tensor import concat, assign, fill_constant, zeros
from . import utils
from .. import unique_name
from functools import reduce
......@@ -124,6 +124,8 @@ __all__ = [
'gather',
'gather_nd',
'scatter',
'scatter_nd_add',
'scatter_nd',
'sequence_scatter',
'random_crop',
'mean_iou',
......@@ -8686,6 +8688,127 @@ def scatter(input, index, updates, name=None, overwrite=True):
return out
def scatter_nd_add(ref, index, updates, name=None):
"""
**Scatter_nd_add Layer**
Output is obtained by applying sparse addition to a single value
or slice in a Variable. :attr:`ref` is a Tensor with rank :math:`R`
and :attr:`index` is a Tensor with rank :math:`K` . Thus, :attr:`index`
has shape :math:`[i_0, i_1, ..., i_{K-2}, Q]` where :math:`Q \leq R` . :attr:`updates`
is a Tensor with rank :math:`K - 1 + R - Q` and its
shape is :math:`index.shape[:-1] + ref.shape[index.shape[-1]:]` .
According to the :math:`[i_0, i_1, ..., i_{K-2}]` of :attr:`index` ,
add the corresponding :attr:`updates` slice to the :attr:`ref` slice
which is obtained by the last one dimension of :attr:`index` .
.. code-block:: text
Given:
* Case 1:
ref = [0, 1, 2, 3, 4, 5]
index = [[1], [2], [3], [1]]
updates = [9, 10, 11, 12]
we get:
output = [0, 22, 12, 14, 4, 5]
* Case 2:
ref = [[65, 17], [-14, -25]]
index = [[], []]
updates = [[[-1, -2], [1, 2]],
[[3, 4], [-3, -4]]]
ref.shape = (2, 2)
index.shape = (2, 0)
updates.shape = (2, 2, 2)
we get:
output = [[67, 19], [-16, -27]]
Args:
ref (Variable): The ref input.
index (Variable): The index input with rank > 1 and index.shape[-1] <= ref.rank.
Its dtype should be int32 or int64 as it is used as indexes.
updates (Variable): The updated value of scatter_nd_add op, and it must have the same type
as ref. It must have the shape index.shape[:-1] + ref.shape[index.shape[-1]:]
name (str|None): The output variable name. Default None.
Returns:
output (Variable): The output is a tensor with the same shape and type as ref.
Examples:
.. code-block:: python
import paddle.fluid as fluid
ref = fluid.layers.data(name='ref', shape=[3, 5, 9, 10], dtype='float32', append_batch_size=False)
index = fluid.layers.data(name='index', shape=[3, 2], dtype='int32', append_batch_size=False)
updates = fluid.layers.data(name='update', shape=[3, 9, 10], dtype='float32', append_batch_size=False)
output = fluid.layers.scatter_nd_add(ref, index, updates)
"""
if ref.dtype != updates.dtype:
raise ValueError("ref and updates must have same data type.")
helper = LayerHelper('scatter_nd_add', **locals())
dtype = helper.input_dtype()
if name is None:
output = helper.create_variable_for_type_inference(dtype)
else:
output = helper.create_variable(
name=name, dtype=dtype, persistable=False)
helper.append_op(
type="scatter_nd_add",
inputs={"X": ref,
"Index": index,
"Updates": updates},
outputs={"Out": output})
return output
def scatter_nd(index, updates, shape, name=None):
"""
**Scatter_nd Layer**
Output is obtained by scattering the :attr:`updates` in a new tensor according
to :attr:`index` . This op is similar to :code:`scatter_nd_add`, except the
tensor of :attr:`shape` is zero-initialized. Correspondingly, :code:`scatter_nd(index, updates, shape)`
is equal to :code:`scatter_nd_add(fluid.layers.zeros(shape, updates.dtype), index, updates)` .
If :attr:`index` has repeated elements, then the corresponding updates are accumulated.
Because of the numerical approximation issues, the different order of repeated elements
in :attr:`index` may cause different results. The specific calculation method can be
seen :code:`scatter_nd_add` . This op is the inverse of the :code:`gather_nd` op.
Args:
index (Variable): The index input with rank > 1 and index.shape[-1] <= len(shape).
Its dtype should be int32 or int64 as it is used as indexes.
updates (Variable): The updated value of scatter_nd op.
It must have the shape index.shape[:-1] + shape[index.shape[-1]:]
shape(tuple|list): Shape of output tensor.
name (str|None): The output variable name. Default None.
Returns:
output (Variable): The output is a tensor with the same type as :attr:`updates` .
Examples:
.. code-block:: python
import paddle.fluid as fluid
index = fluid.layers.data(name='index', shape=[3, 2], dtype='int64', append_batch_size=False)
updates = fluid.layers.data(name='update', shape=[3, 9, 10], dtype='float32', append_batch_size=False)
shape = [3, 5, 9, 10]
output = fluid.layers.scatter_nd(index, updates, shape)
"""
return scatter_nd_add(zeros(shape, updates.dtype), index, updates, name)
def sequence_scatter(input, index, updates, name=None):
"""
**Sequence Scatter Layer**
......
......@@ -158,7 +158,7 @@ class TestGatherNdOpRaise(OpTest):
output = fluid.layers.gather_nd(x, index)
except Exception as e:
t = \
"Input(Index).shape[-1] <= Input(X).rank"
"Input(Index).shape[-1] should be no greater than Input(X).rank"
if t in str(e):
raise IndexError
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
def numpy_scatter_nd(ref, index, updates, fun):
ref_shape = ref.shape
index_shape = index.shape
end_size = index_shape[-1]
remain_numl = 1
for i in range(len(index_shape) - 1):
remain_numl *= index_shape[i]
slice_size = 1
for i in range(end_size, len(ref_shape)):
slice_size *= ref_shape[i]
flat_index = index.reshape([remain_numl] + list(index_shape[-1:]))
flat_updates = updates.reshape((remain_numl, slice_size))
flat_output = ref.reshape(list(ref_shape[:end_size]) + [slice_size])
for i_up, i_out in enumerate(flat_index):
i_out = tuple(i_out)
flat_output[i_out] = fun(flat_output[i_out], flat_updates[i_up])
return flat_output.reshape(ref.shape)
def numpy_scatter_nd_add(ref, index, updates):
return numpy_scatter_nd(ref, index, updates, lambda x, y: x + y)
def judge_update_shape(ref, index):
ref_shape = ref.shape
index_shape = index.shape
update_shape = []
for i in range(len(index_shape) - 1):
update_shape.append(index_shape[i])
for i in range(index_shape[-1], len(ref_shape), 1):
update_shape.append(ref_shape[i])
return update_shape
class TestScatterNdAddSimpleOp(OpTest):
"""
A simple example
"""
def setUp(self):
self.op_type = "scatter_nd_add"
ref_np = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]).astype("float32")
index_np = np.array([[1], [2], [3], [5], [1]]).astype("int32")
updates_np = np.array([9, 10, 11, 12, 13]).astype("float32")
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
#expect_np = [ 0. 23. 12. 14. 4. 17. 6. 7. 8.]
self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
self.outputs = {'Out': expect_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
class TestScatterNdAddWithEmptyIndex(OpTest):
"""
Index has empty element
"""
def setUp(self):
self.op_type = "scatter_nd_add"
ref_np = np.array([[65, 17], [-14, -25]]).astype("float32")
index_np = np.array([[], []]).astype("int32")
updates_np = np.array([[[-1, -2], [1, 2]],
[[3, 4], [-3, -4]]]).astype("float32")
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
#expect_np = [[67, 19], [-16, -27]]
self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
self.outputs = {'Out': expect_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
class TestScatterNdAddWithHighRankSame(OpTest):
"""
Both Index and X have high rank, and Rank(Index) = Rank(X)
"""
def setUp(self):
self.op_type = "scatter_nd_add"
shape = (10, 9, 8, 1, 15)
ref_np = np.random.rand(*shape).astype("float32")
index_np = np.vstack(
[np.random.randint(
0, s, size=150) for s in shape]).T.astype("int32")
update_shape = judge_update_shape(ref_np, index_np)
updates_np = np.random.rand(*update_shape).astype("float32")
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
self.outputs = {'Out': expect_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
class TestScatterNdAddWithHighRankDiff(OpTest):
"""
Both Index and X have high rank, and Rank(Index) < Rank(X)
"""
def setUp(self):
self.op_type = "scatter_nd_add"
shape = (10, 9, 8, 1, 15)
ref_np = np.random.rand(*shape).astype("double")
index = np.vstack([np.random.randint(0, s, size=500) for s in shape]).T
index_np = index.reshape([10, 5, 10, 5]).astype("int64")
update_shape = judge_update_shape(ref_np, index_np)
updates_np = np.random.rand(*update_shape).astype("double")
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
self.outputs = {'Out': expect_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
#Test Python API
class TestScatterNdOpAPI(OpTest):
"""
test scatter_nd_add api and scatter_nd api
"""
def testcase1(self):
ref1 = fluid.layers.data(
name='ref1',
shape=[10, 9, 8, 1, 3],
dtype='float32',
append_batch_size=False)
index1 = fluid.layers.data(
name='index1',
shape=[5, 5, 8, 5],
dtype='int32',
append_batch_size=False)
updates1 = fluid.layers.data(
name='update1',
shape=[5, 5, 8],
dtype='float32',
append_batch_size=False)
output1 = fluid.layers.scatter_nd_add(ref1, index1, updates1)
def testcase2(self):
ref2 = fluid.layers.data(
name='ref2',
shape=[10, 9, 8, 1, 3],
dtype='double',
append_batch_size=False)
index2 = fluid.layers.data(
name='index2',
shape=[5, 8, 5],
dtype='int32',
append_batch_size=False)
updates2 = fluid.layers.data(
name='update2',
shape=[5, 8],
dtype='double',
append_batch_size=False)
output2 = fluid.layers.scatter_nd_add(
ref2, index2, updates2, name="scatter_nd_add")
def testcase3(self):
shape3 = [10, 9, 8, 1, 3]
index3 = fluid.layers.data(
name='index3',
shape=[5, 5, 8, 5],
dtype='int32',
append_batch_size=False)
updates3 = fluid.layers.data(
name='update3',
shape=[5, 5, 8],
dtype='float32',
append_batch_size=False)
output3 = fluid.layers.scatter_nd(index3, updates3, shape3)
def testcase4(self):
shape4 = [10, 9, 8, 1, 3]
index4 = fluid.layers.data(
name='index4',
shape=[5, 5, 8, 5],
dtype='int32',
append_batch_size=False)
updates4 = fluid.layers.data(
name='update4',
shape=[5, 5, 8],
dtype='double',
append_batch_size=False)
output4 = fluid.layers.scatter_nd(
index4, updates4, shape4, name='scatter_nd')
#Test Raise Error
class TestScatterNdOpRaise(OpTest):
def test_check_raise(self):
def check_raise_is_test():
try:
ref5 = fluid.layers.data(
name='ref5', shape=[3, 4, 5], dtype='float32')
index5 = fluid.layers.data(
name='index5', shape=[2, 10], dtype='int32')
updates5 = fluid.layers.data(
name='updates5', shape=[2, 10], dtype='float32')
output5 = fluid.layers.scatter_nd_add(ref5, index5, updates5)
except Exception as e:
t = \
"Input(Index).shape[-1] should be no greater than Input(X).rank"
if t in str(e):
raise IndexError
self.assertRaises(IndexError, check_raise_is_test)
def test_check_raise2(self):
with self.assertRaises(ValueError):
ref6 = fluid.layers.data(
name='ref6',
shape=[10, 9, 8, 1, 3],
dtype='double',
append_batch_size=False)
index6 = fluid.layers.data(
name='index6',
shape=[5, 8, 5],
dtype='int32',
append_batch_size=False)
updates6 = fluid.layers.data(
name='update6',
shape=[5, 8],
dtype='float32',
append_batch_size=False)
output6 = fluid.layers.scatter_nd_add(ref6, index6, updates6)
def test_check_raise3(self):
def check_raise_is_test():
try:
shape = [3, 4, 5]
index7 = fluid.layers.data(
name='index7', shape=[2, 1], dtype='int32')
updates7 = fluid.layers.data(
name='updates7', shape=[2, 4, 5, 20], dtype='float32')
output7 = fluid.layers.scatter_nd(index7, updates7, shape)
except Exception as e:
t = \
"Updates has wrong shape"
if t in str(e):
raise ValueError
self.assertRaises(ValueError, check_raise_is_test)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册