diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 6849e39cb7d7e13ea37d7e5a5dd0e84ae6edbe61..ba1362e8bf38ef4735ffeea29bea12f6eff99982 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -44,8 +44,6 @@ endfunction() add_subdirectory(math) cc_test(gather_test SRCS gather_test.cc DEPS tensor) op_library(gather_op SRCS gather_op.cc gather_op.cu) -# DEPS op_registry) -# cc_test(gather_op_test SRCS gather_op_test.cc DEPS gather_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index 3f299ea1a6d63e3dd68bd9e5b637af5ac12bd8f0..edac29f6db03a5cb975d27ca86fffe0707d1fd82 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -27,13 +27,13 @@ namespace operators { // Implementation of CPU copy template -void CPUGather(const T* params, const int* indices, const int slice_size, +void CPUGather(const T* src, const int* indices, const int slice_size, const int index_size, T* output) { const size_t slice_bytes = slice_size * sizeof(T); for (int i = 0; i < index_size; ++i) { int index_ = indices[i]; - memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes); + memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes); } } @@ -57,7 +57,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, int index_size = index->dims()[0]; auto src_dims = src->dims(); - paddle::framework::DDim output_dims(src_dims); + framework::DDim output_dims(src_dims); output_dims[0] = index_size; // slice size diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 499def05a7f62d1c5f236e1f7dc5ab8f9c7b5bc3..123bed296c462c30bddd3bfbd530098fdbfe4856 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -26,9 +26,9 @@ class GatherOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { int batch_size = ctx.Input("Index")->dims()[0]; PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); - paddle::framework::DDim output_dims(ctx.Input("X")->dims()); + framework::DDim output_dims(ctx.Input("X")->dims()); output_dims[0] = batch_size; - ctx.Output("Y")->Resize(output_dims); + ctx.Output("Out")->Resize(output_dims); } }; @@ -51,11 +51,11 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The source input of gather op"); AddInput("Index", "The index input of gather op"); - AddOutput("Y", "The output of add op"); + AddOutput("Out", "The output of add op"); AddComment(R"DOC( Gather Operator by selecting from the first axis, -Y = X[Index] +Out = X[Index] )DOC"); } }; diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h index 13e4c9b058ada12f59ca585c81840a07b9bef78e..381854f301870beadb72d9e9b4eb17ff199960fb 100644 --- a/paddle/operators/gather_op.h +++ b/paddle/operators/gather_op.h @@ -26,10 +26,10 @@ using Tensor = framework::Tensor; template class GatherOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto X = ctx.Input("X"); - auto Index = ctx.Input("Index"); - auto Y = ctx.Output("Y"); + void Compute(const framework::ExecutionContext &ctx) const override { + auto *X = ctx.Input("X"); + auto *Index = ctx.Input("Index"); + auto *Y = ctx.Output("Out"); Y->mutable_data(ctx.GetPlace()); Gather(ctx.GetPlace(), X, Index, Y); @@ -39,12 +39,13 @@ class GatherOpKernel : public framework::OpKernel { template class GatherGradientOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto Index = ctx.Input("Index"); - auto dX = ctx.Output(framework::GradVarName("X")); - auto dY = ctx.Input(framework::GradVarName("Y")); + void Compute(const framework::ExecutionContext &ctx) const override { + auto *Index = ctx.Input("Index"); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dO = ctx.Input(framework::GradVarName("Out")); - ScatterUpdate(ctx.GetPlace(), dY, Index, dX); + dX->mutable_data(ctx.GetPlace()); + ScatterUpdate(ctx.GetPlace(), dO, Index, dX); } }; diff --git a/python/paddle/v2/framework/tests/test_gather_op.py b/python/paddle/v2/framework/tests/test_gather_op.py index 049054d07b63c06686d65682e3e6373ac43b8518..e86898304252d08be718e40fed46c5e921596af7 100644 --- a/python/paddle/v2/framework/tests/test_gather_op.py +++ b/python/paddle/v2/framework/tests/test_gather_op.py @@ -1,11 +1,10 @@ import unittest - +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op import numpy import paddle.v2.framework.core as core from paddle.v2.framework.op import Operator -from op_test_util import OpTestMeta - class TestGatherOp(unittest.TestCase): __metaclass__ = OpTestMeta @@ -17,7 +16,18 @@ class TestGatherOp(unittest.TestCase): 'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32") } - self.outputs = {'Y': self.inputs['X'][self.inputs['Index']]} + self.outputs = {'Out': self.inputs['X'][self.inputs['Index']]} + + +class TestGatherGradOp(GradientChecker): + def test_gather_grad(self): + print 'creating op' + op = create_op("gather") + print 'creating op done' + xnp = numpy.random.random((10, 20)).astype("float32") + inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")} + print 'correct before check gradient' + self.check_grad(op, inputs, set("X"), "Out") if __name__ == "__main__":