From c5e28dd1a0b4cb6f8ba74bc16760dc6cf32ad50e Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Thu, 24 Aug 2017 02:58:41 +0000 Subject: [PATCH] scatter check in --- paddle/operators/CMakeLists.txt | 1 + paddle/operators/scatter_op.cc | 76 +++++++++++++++++++ paddle/operators/scatter_op.cu | 20 +++++ paddle/operators/scatter_op.h | 60 +++++++++++++++ paddle/pybind/CMakeLists.txt | 1 + paddle/pybind/pybind.cc | 1 + .../paddle/v2/framework/tests/CMakeLists.txt | 1 + .../v2/framework/tests/test_gather_op.py | 3 - .../v2/framework/tests/test_scatter_op.py | 38 ++++++++++ 9 files changed, 198 insertions(+), 3 deletions(-) create mode 100644 paddle/operators/scatter_op.cc create mode 100644 paddle/operators/scatter_op.cu create mode 100644 paddle/operators/scatter_op.h create mode 100644 python/paddle/v2/framework/tests/test_scatter_op.py diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f466dbc79a2..f0fd12f1b52 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -47,6 +47,7 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor) op_library(gather_op SRCS gather_op.cc gather_op.cu) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) +op_library(scatter_op SRCS scatter_op.cc scatter_op.cu) cc_library(net_op SRCS net_op.cc DEPS op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc new file mode 100644 index 00000000000..cf01ef62799 --- /dev/null +++ b/paddle/operators/scatter_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/scatter_op.h" +#include "paddle/framework/ddim.h" + +namespace paddle { +namespace operators { + +class ScatterOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + framework::DDim output_dims(ctx.Input("Ref")->dims()); + ctx.Output("Out")->Resize(output_dims); + } +}; + +class ScatterGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto Updates_grad = ctx.Output(framework::GradVarName("Updates")); + auto Updates = ctx.Input("Updates"); + auto Ref_grad = ctx.Output(framework::GradVarName("Ref")); + auto Ref = ctx.Input("Ref"); + + Ref_grad->Resize(Ref->dims()); + Updates_grad->Resize(Updates->dims()); + } +}; + +class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ScatterOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Ref", "The source input of scatter op"); + AddInput("Index", + "The index input of scatter op where Ref will be updated"); + AddInput("Updates", "The updated value of updates op"); + AddOutput("Out", "The output of add op"); + AddComment(R"DOC( +Scatter Operator by selecting from the first axis, + +Out = Ref +Out[Index] = Ref[Index] + Updates +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad, + ops::ScatterGradOp); +REGISTER_OP_CPU_KERNEL(scatter, + ops::ScatterOpKernel); +REGISTER_OP_CPU_KERNEL( + scatter_grad, + ops::ScatterGradientOpKernel); diff --git a/paddle/operators/scatter_op.cu b/paddle/operators/scatter_op.cu new file mode 100644 index 00000000000..e6a6fa57d93 --- /dev/null +++ b/paddle/operators/scatter_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/scatter_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(scatter, + ops::ScatterOpKernel); diff --git a/paddle/operators/scatter_op.h b/paddle/operators/scatter_op.h new file mode 100644 index 00000000000..c2db3ae37cc --- /dev/null +++ b/paddle/operators/scatter_op.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "gather.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "scatter.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ScatterOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *Ref = ctx.Input("Ref"); + auto *Index = ctx.Input("Index"); + auto *Updates = ctx.Input("Updates"); + auto *Out = ctx.Output("Out"); + + // In place output: Out = Ref, Out[Index] += Updates + Out->ShareDataWith(*Ref); + // Apply ScatterUpdate: Out[index] += Updates[:] + ScatterUpdate(ctx.GetPlace(), Updates, Index, Out); + } +}; + +template +class ScatterGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *dRef = ctx.Output(framework::GradVarName("Ref")); + auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); + auto *Index = ctx.Input("Index"); + auto *dO = ctx.Input(framework::GradVarName("Out")); + + // In place gradient: dRef = dO + dRef->ShareDataWith(*dO); + dUpdates->mutable_data(ctx.GetPlace()); + // Gradient by Gather: dUpdates += dO[Index] + Gather(ctx.GetPlace(), dO, Index, dUpdates); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index abb9c248eee..37e186a408f 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -4,6 +4,7 @@ cc_library(paddle_pybind SHARED DEPS pybind python backward sgd_op gather_op + scatter_op add_op mul_op rowwise_add_op diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 8fa8be2cef5..3bc150ccb7a 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -47,6 +47,7 @@ USE_OP(scale); USE_OP_ITSELF(identity); USE_OP(minus); USE_CPU_ONLY_OP(gather); +USE_CPU_ONLY_OP(scatter); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index fb4686889a6..661ebd89648 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -14,6 +14,7 @@ py_test(test_sigmoid_op SRCS test_sigmoid_op.py) py_test(test_softmax_op SRCS test_softmax_op.py) py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py) py_test(test_gather_op SRCS test_gather_op.py) +py_test(test_scatter_op SRCS test_scatter_op.py) py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py) py_test(gradient_checker SRCS gradient_checker.py) diff --git a/python/paddle/v2/framework/tests/test_gather_op.py b/python/paddle/v2/framework/tests/test_gather_op.py index e8689830425..e3de3fd0a1d 100644 --- a/python/paddle/v2/framework/tests/test_gather_op.py +++ b/python/paddle/v2/framework/tests/test_gather_op.py @@ -21,12 +21,9 @@ class TestGatherOp(unittest.TestCase): class TestGatherGradOp(GradientChecker): def test_gather_grad(self): - print 'creating op' op = create_op("gather") - print 'creating op done' xnp = numpy.random.random((10, 20)).astype("float32") inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")} - print 'correct before check gradient' self.check_grad(op, inputs, set("X"), "Out") diff --git a/python/paddle/v2/framework/tests/test_scatter_op.py b/python/paddle/v2/framework/tests/test_scatter_op.py new file mode 100644 index 00000000000..e7696844d5d --- /dev/null +++ b/python/paddle/v2/framework/tests/test_scatter_op.py @@ -0,0 +1,38 @@ +import unittest +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op +import numpy +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator + + +class TestScatterOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "scatter" + ref_np = numpy.ones((3, 3)).astype("float32") + index_np = numpy.array([1, 2]).astype("int32") + updates_np = numpy.random.random((2, 3)).astype("float32") + output_np = numpy.copy(ref_np) + output_np[index_np] += updates_np + self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + +class TestScatterGradOp(GradientChecker): + def test_scatter_grad(self): + op = create_op("scatter") + # test data setup + ref_np = numpy.ones((3, 10)).astype("float32") + index_np = numpy.array([1, 2]).astype("int32") + updates_np = numpy.random.random((2, 10)).astype("float32") + output_np = numpy.copy(ref_np) + output_np[index_np] += updates_np + inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np} + # check gradient + self.check_grad(op, inputs, set(["Updates", "Ref"]), "Out") + + +if __name__ == "__main__": + unittest.main() -- GitLab