From 4d2adab772e3c0789e9696533da61ee3583363d1 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Tue, 15 Aug 2017 23:54:16 +0000 Subject: [PATCH] gather op added with python unittest --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/pybind.cc | 1 + paddle/operators/CMakeLists.txt | 3 +- paddle/operators/gather_op.cc | 43 +++++++++++-------- .../paddle/v2/framework/tests/CMakeLists.txt | 1 + 5 files changed, 29 insertions(+), 20 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 9e306c865..30313780a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -47,6 +47,7 @@ cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python backward sgd_op + gather_op add_op mul_op rowwise_add_op diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index fe0c87bc5..90311e0dc 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -42,6 +42,7 @@ USE_OP(fill_zeros_like); USE_OP_ITSELF(recurrent_op); USE_OP(gaussian_random); USE_OP(uniform_random); +USE_CPU_ONLY_OP(gather); namespace paddle { namespace framework { diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 5ac898a8d..6849e39cb 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -43,7 +43,8 @@ endfunction() add_subdirectory(math) cc_test(gather_test SRCS gather_test.cc DEPS tensor) -cc_library(gather_op SRCS gather_op.cc DEPS op_registry) +op_library(gather_op SRCS gather_op.cc gather_op.cu) +# DEPS op_registry) # cc_test(gather_op_test SRCS gather_op_test.cc DEPS gather_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 05ba52ce0..2e08ba8dc 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -19,17 +19,33 @@ namespace paddle { namespace operators { class GatherOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(ctx.InputSize() == 2, ""); - PADDLE_ENFORCE(ctx.OutputSize() == 1, ""); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), - "Inputs of GatherOp must all be set"); - int batch_size = ctx.Input(1)->dims()[0]; + // PADDLE_ENFORCE(ctx.InputSize() == 2, ""); + // PADDLE_ENFORCE(ctx.OutputSize() == 1, ""); + // PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), + // "Inputs of GatherOp must all be set"); + int batch_size = ctx.Input("Index")->dims()[0]; PADDLE_ENFORCE(batch_size > 0); paddle::framework::DDim output_dims(ctx.Input(0)->dims()); output_dims[0] = batch_size; - ctx.Output(0)->Resize(output_dims); + ctx.Output("Y")->Resize(output_dims); + } +}; + +class GatherGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto X_grad = ctx.Output(framework::GradVarName("X")); + auto X = ctx.Input("X"); + + X_grad->Resize(X->dims()); } }; @@ -47,25 +63,14 @@ Y = X[Index] )DOC"); } }; - -class GatherGradOp : public framework::OperatorWithKernel { - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto X_grad = ctx.Output(framework::GradVarName("X")); - auto X = ctx.Input("X"); - - X_grad->Resize(X->dims()); - } -}; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker); +REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, + ops::GatherGradOp); REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel); -REGISTER_GRADIENT_OP(gather, gather_grad, ops::GatherGradOp); REGISTER_OP_CPU_KERNEL( gather_grad, ops::GatherGradientOpKernel); diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 96fad9b42..1032743a1 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -13,6 +13,7 @@ py_test(test_add_two_op SRCS test_add_two_op.py) py_test(test_sigmoid_op SRCS test_sigmoid_op.py) py_test(test_softmax_op SRCS test_softmax_op.py) py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py) +py_test(test_gather_op SRCS test_gather_op.py) py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py) py_test(gradient_checker SRCS gradient_checker.py) -- GitLab