From 0328ffd3ab7d58da388a784bf3035844323dd78a Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 26 Oct 2018 17:21:22 +0800 Subject: [PATCH] add fake init op --- paddle/fluid/operators/fake_init_op.cc | 84 ++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 paddle/fluid/operators/fake_init_op.cc diff --git a/paddle/fluid/operators/fake_init_op.cc b/paddle/fluid/operators/fake_init_op.cc new file mode 100644 index 0000000000..2b3a541156 --- /dev/null +++ b/paddle/fluid/operators/fake_init_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +class FakeInitInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FakeInitOp should not be null."); + auto &shape = ctx->Attrs().Get>("shape"); + ctx->SetOutputDim("Out", framework::make_ddim(shape)); + } +}; + +class FakeInitOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + framework::Tensor *tensor = nullptr; + + auto &out_var = *scope.FindVar(Output("Out")); + + if (out_var.IsType()) { + tensor = out_var.GetMutable(); + tensor->Resize(framework::make_ddim(Attr>("shape"))); + } else if (out_var.IsType()) { + tensor = out_var.GetMutable()->mutable_value(); + tensor->Resize(framework::make_ddim(Attr>("shape"))); + } else { + PADDLE_THROW( + "fake init op's output only" + "supports SelectedRows and LoDTensor"); + } + } +}; + +class FakeInitOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override {} +}; + +class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddAttr>("shape", "(vector) The shape of the output"); + AddOutput("Out", + "(Tensor) Tensor of specified shape will be filled " + "with the specified value"); + AddComment(R"DOC( +FakeInitBatchSizeLike Operator. + +Init an op but not alloc tensor for it, it is used for distributed lookup table. + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fake_init, ops::FakeInitOp, ops::FakeInitInferShape, + ops::FakeInitOpMaker, paddle::framework::EmptyGradOpMaker, + ops::FakeInitOpVarTypeInference); -- GitLab