diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 68304c9fc8b8fa13cb1f99b82517abc87c71496c..2cdf323c53a8ba729ec74c1eacb9fa3ef272f44a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -45,6 +45,7 @@ cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python backward sgd_op + gather_op add_op mul_op rowwise_add_op diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index de119e9e062bffa5e95929671c69001047772cf3..4539a1903eb430eb0d76a787adb32984342a468d 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -42,6 +42,7 @@ USE_OP(fill_zeros_like); USE_OP_ITSELF(recurrent_op); USE_OP(gaussian_random); USE_OP(uniform_random); +USE_CPU_ONLY_OP(gather); namespace paddle { namespace framework { diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a7c89787e43df6173791bc54b3faffc034867f7d..ba1362e8bf38ef4735ffeea29bea12f6eff99982 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -43,6 +43,7 @@ endfunction() add_subdirectory(math) cc_test(gather_test SRCS gather_test.cc DEPS tensor) +op_library(gather_op SRCS gather_op.cc gather_op.cu) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index d6e6990394e46ba06c4bacfe33ca522f3ff1413a..92fb51ec17709bc6f8abb2f516a9240fb5dc3a77 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/framework/ddim.h" +#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/place.h" @@ -25,13 +26,13 @@ namespace operators { // Implementation of CPU copy template -void CPUGather(const T* params, const int* indices, const int slice_size, +void CPUGather(const T* src, const int* indices, const int slice_size, const int index_size, T* output) { const size_t slice_bytes = slice_size * sizeof(T); for (int i = 0; i < index_size; ++i) { int index_ = indices[i]; - memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes); + memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes); } } @@ -55,7 +56,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, int index_size = index->dims()[0]; auto src_dims = src->dims(); - paddle::framework::DDim output_dims(src_dims); + framework::DDim output_dims(src_dims); output_dims[0] = index_size; // slice size diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..123bed296c462c30bddd3bfbd530098fdbfe4856 --- /dev/null +++ b/paddle/operators/gather_op.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/gather_op.h" +#include "paddle/framework/ddim.h" + +namespace paddle { +namespace operators { + +class GatherOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + int batch_size = ctx.Input("Index")->dims()[0]; + PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); + framework::DDim output_dims(ctx.Input("X")->dims()); + output_dims[0] = batch_size; + ctx.Output("Out")->Resize(output_dims); + } +}; + +class GatherGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto X_grad = ctx.Output(framework::GradVarName("X")); + auto X = ctx.Input("X"); + + X_grad->Resize(X->dims()); + } +}; + +class GatherOpMaker : public framework::OpProtoAndCheckerMaker { + public: + GatherOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The source input of gather op"); + AddInput("Index", "The index input of gather op"); + AddOutput("Out", "The output of add op"); + AddComment(R"DOC( +Gather Operator by selecting from the first axis, + +Out = X[Index] +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, + ops::GatherGradOp); +REGISTER_OP_CPU_KERNEL(gather, + ops::GatherOpKernel); +REGISTER_OP_CPU_KERNEL( + gather_grad, + ops::GatherGradientOpKernel); diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..3f04a7b3f8142106917975cd1e0413fa1633a298 --- /dev/null +++ b/paddle/operators/gather_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/gather_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(gather, + ops::GatherOpKernel); diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h new file mode 100644 index 0000000000000000000000000000000000000000..381854f301870beadb72d9e9b4eb17ff199960fb --- /dev/null +++ b/paddle/operators/gather_op.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "gather.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "scatter.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GatherOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *X = ctx.Input("X"); + auto *Index = ctx.Input("Index"); + auto *Y = ctx.Output("Out"); + + Y->mutable_data(ctx.GetPlace()); + Gather(ctx.GetPlace(), X, Index, Y); + } +}; + +template +class GatherGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *Index = ctx.Input("Index"); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dO = ctx.Input(framework::GradVarName("Out")); + + dX->mutable_data(ctx.GetPlace()); + ScatterUpdate(ctx.GetPlace(), dO, Index, dX); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index b07a65f4d1fed12d82c638ee59f9de72379cfcbe..8a2b7c54d3ef481712bef1e1a39fb336b23eb1b2 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -13,6 +13,7 @@ py_test(test_add_two_op SRCS test_add_two_op.py) py_test(test_sigmoid_op SRCS test_sigmoid_op.py) py_test(test_softmax_op SRCS test_softmax_op.py) py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py) +py_test(test_gather_op SRCS test_gather_op.py) py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py) py_test(gradient_checker SRCS gradient_checker.py) diff --git a/python/paddle/v2/framework/tests/test_gather_op.py b/python/paddle/v2/framework/tests/test_gather_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e86898304252d08be718e40fed46c5e921596af7 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gather_op.py @@ -0,0 +1,34 @@ +import unittest +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op +import numpy +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator + + +class TestGatherOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "gather" + xnp = numpy.random.random((10, 20)).astype("float32") + self.inputs = { + 'X': xnp, + 'Index': numpy.array([1, 3, 5]).astype("int32") + } + self.outputs = {'Out': self.inputs['X'][self.inputs['Index']]} + + +class TestGatherGradOp(GradientChecker): + def test_gather_grad(self): + print 'creating op' + op = create_op("gather") + print 'creating op done' + xnp = numpy.random.random((10, 20)).astype("float32") + inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")} + print 'correct before check gradient' + self.check_grad(op, inputs, set("X"), "Out") + + +if __name__ == "__main__": + unittest.main()