From 623848afa1f0bb3a69c7e49c4fa0f763a252669d Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 5 Oct 2017 12:11:56 -0700 Subject: [PATCH] add feed operator --- paddle/framework/scope.cc | 16 ++++++++++ paddle/framework/scope.h | 2 ++ paddle/operators/activation_op.cu | 18 +++++------ paddle/operators/feed_op.cc | 52 +++++++++++++++++++++++++++++++ paddle/operators/feed_op.cu | 18 +++++++++++ paddle/operators/feed_op.h | 40 ++++++++++++++++++++++++ 6 files changed, 137 insertions(+), 9 deletions(-) create mode 100644 paddle/operators/feed_op.cc create mode 100644 paddle/operators/feed_op.cu create mode 100644 paddle/operators/feed_op.h diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index 080b4ac621c..b04120abf27 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/scope.h" +#include // for unique_ptr +#include // for call_once #include "paddle/string/printf.h" namespace paddle { @@ -62,5 +64,19 @@ void Scope::DropKids() { kids_.clear(); } +std::once_flag feed_variable_flag; + +template +std::unique_ptr make_unique(Args&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +framework::Scope* GetScope() { + static std::unique_ptr g_scope = + make_unique(); + std::call_once(feed_variable_flag, [&]() { g_scope->NewVar("feed_value"); }); + return g_scope.get(); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 7047f0d55e9..96f3ae875b2 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -73,5 +73,7 @@ class Scope { DISABLE_COPY_AND_ASSIGN(Scope); }; +framework::Scope* GetScope(); + } // namespace framework } // namespace paddle diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu index 93e9f1c694b..44a6aaf9cb6 100644 --- a/paddle/operators/activation_op.cu +++ b/paddle/operators/activation_op.cu @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/operators/activation_op.h" diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc new file mode 100644 index 00000000000..805c3600be6 --- /dev/null +++ b/paddle/operators/feed_op.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/feed_op.h" + +namespace paddle { +namespace operators { + +class FeedOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + typedef std::vector FeedInputs; + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); + int col = ctx->Attrs().Get("col"); + framework::Variable* g_feed_variable = + framework::GetScope()->FindVar("feed_value"); + FeedInputs tensors = g_feed_variable->Get(); + auto in_dim = tensors[col].dims(); + ctx->SetOutputDim("Y", in_dim); + // need to handle LodTensor later + } +}; + +class FeedOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("col", "The col in Global Feed Variable"); + AddOutput("Out", "The output of dropout op."); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(feed, ops::FeedOp, ops::FeedOpMaker); +REGISTER_OP_CPU_KERNEL(feed, ops::FeedKernel); diff --git a/paddle/operators/feed_op.cu b/paddle/operators/feed_op.cu new file mode 100644 index 00000000000..7b6a2ac91e7 --- /dev/null +++ b/paddle/operators/feed_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/feed_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(feed, ops::FeedKernel); diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h new file mode 100644 index 00000000000..57781e205f8 --- /dev/null +++ b/paddle/operators/feed_op.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class FeedKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + typedef std::vector FeedInputs; + Tensor* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + framework::Variable* g_feed_variable = + framework::GetScope()->FindVar("feed_value"); + int col = ctx.template Attr("col"); + FeedInputs tensors = g_feed_variable->Get(); + out->CopyFrom(tensors[col], ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle -- GitLab