From af1eb31afc92ae3ac59869a6a5b0e890e009c44b Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Fri, 11 Aug 2017 11:55:56 -0700 Subject: [PATCH] add as an operator --- paddle/operators/CMakeLists.txt | 2 ++ paddle/operators/gather_op.cc | 64 +++++++++++++++++++++++++++++++++ paddle/operators/gather_op.h | 52 +++++++++++++++++++++++++++ 3 files changed, 118 insertions(+) create mode 100644 paddle/operators/gather_op.cc create mode 100644 paddle/operators/gather_op.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a7c89787e4..5ac898a8d3 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -43,6 +43,8 @@ endfunction() add_subdirectory(math) cc_test(gather_test SRCS gather_test.cc DEPS tensor) +cc_library(gather_op SRCS gather_op.cc DEPS op_registry) +# cc_test(gather_op_test SRCS gather_op_test.cc DEPS gather_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc new file mode 100644 index 0000000000..1008a57a87 --- /dev/null +++ b/paddle/operators/gather_op.cc @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/gather_op.h" +#include "paddle/framework/ddim.h" + +namespace paddle { +namespace operators { + +class GatherOp : public framework::OperatorWithKernel { + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 2, ""); + PADDLE_ENFORCE(ctx.OutputSize() == 1, ""); + int batch_size = ctx.Input(1)->dims()[0]; + PADDLE_ENFORCE(batch_size > 0); + } +}; + +class GatherOpMaker : public framework::OpProtoAndCheckerMaker { + public: + GatherOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The source input of gather op"); + AddInput("Index", "The index input of gather op"); + AddOutput("Y", "The output of add op"); + AddComment(R"DOC( +Gather Operator by selecting from the first axis, + +Y = X[Index] +)DOC"); + } +}; + +class GatherGradOp : public framework::OperatorWithKernel { + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + // ctx.Output("X" + framework::kGradVarSuffix) + // ->Resize(ctx.Input("X")->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker); +REGISTER_OP_CPU_KERNEL(gather, + ops::GatherOpKernel); +REGISTER_GRADIENT_OP(gather, gather_grad, ops::GatherGradOp); +REGISTER_OP_CPU_KERNEL( + gather_grad, + ops::GatherGradientOpKernel); diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h new file mode 100644 index 0000000000..13e4c9b058 --- /dev/null +++ b/paddle/operators/gather_op.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "gather.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "scatter.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GatherOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto X = ctx.Input("X"); + auto Index = ctx.Input("Index"); + auto Y = ctx.Output("Y"); + + Y->mutable_data(ctx.GetPlace()); + Gather(ctx.GetPlace(), X, Index, Y); + } +}; + +template +class GatherGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto Index = ctx.Input("Index"); + auto dX = ctx.Output(framework::GradVarName("X")); + auto dY = ctx.Input(framework::GradVarName("Y")); + + ScatterUpdate(ctx.GetPlace(), dY, Index, dX); + } +}; + +} // namespace operators +} // namespace paddle -- GitLab