From a2dc9614edfff8ab2a602e1ed605ffdc4155373a Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 26 Jul 2017 18:10:19 +0800 Subject: [PATCH] Add fill_zeros_like op --- paddle/operators/fill_zeros_like_op.cc | 58 ++++++++++++++++++++++++++ paddle/operators/fill_zeros_like_op.cu | 6 +++ paddle/operators/fill_zeros_like_op.h | 34 +++++++++++++++ 3 files changed, 98 insertions(+) create mode 100644 paddle/operators/fill_zeros_like_op.cc create mode 100644 paddle/operators/fill_zeros_like_op.cu create mode 100644 paddle/operators/fill_zeros_like_op.h diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc new file mode 100644 index 00000000000..3df3a2cfab6 --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fill_zeros_like_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { + +class FillZerosLike : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override { + PADDLE_ENFORCE(inputs.size() == 1, + "Input size of FillZerosLike must be one."); + PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one."); + PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, + "Outputs of FillZerosLike must all be set."); + outputs[0]->Resize(inputs[0]->dims()); + } +}; + +class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker { +public: + FillZerosLikeOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Src", "The input of fill-zeros-like op."); + AddOutput("Dst", "The varibale will be filled up with zeros."); + AddComment(R"DOC( +Fill up a vriable with zeros. + +The output will have the same size with input. +)DOC") + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP(fill_zeros_like, + paddle::operators::FillZerosLikeOp, + paddle::operators::FillZerosLikeOpMaker); +EGISTER_OP_CPU_KERNEL( + fill_zeros_like, + paddle::operators::FillZerosLikeKernal); \ No newline at end of file diff --git a/paddle/operators/fill_zeros_like_op.cu b/paddle/operators/fill_zeros_like_op.cu new file mode 100644 index 00000000000..55ad58f4f17 --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.cu @@ -0,0 +1,6 @@ +#include "paddle/framework/op_registry.h" +#include "paddle/operators/fill_zeros_like_op.h" + +REGISTER_OP_GPU_KERNEL( + fill_zeros_like, + paddle::operators::FillZerosLikeKernel); \ No newline at end of file diff --git a/paddle/operators/fill_zeros_like_op.h b/paddle/operators/fill_zeros_like_op.h new file mode 100644 index 00000000000..ca44a201f7d --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +class FillZerosLikeKernel : public framework::OpKernel { +public: + void Compute(const framework::KernelContext& context) const override { + auto* output = context.Output(0)->GetMutable(); + output->mutable_data(context.GetPlace()); + framework::EigenVector::Flatten(*output).setZero(); + } +}; + +} // namespace operators +} // namespace paddle -- GitLab