From 186581d2cc598621df2d1f3941160d1f1051bf06 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Thu, 16 Nov 2017 20:18:31 +0800 Subject: [PATCH] add is empty op (#5639) --- paddle/operators/is_empty_op.cc | 67 +++++++++++++++++++ .../v2/framework/tests/test_is_empty_op.py | 43 ++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 paddle/operators/is_empty_op.cc create mode 100644 python/paddle/v2/framework/tests/test_is_empty_op.py diff --git a/paddle/operators/is_empty_op.cc b/paddle/operators/is_empty_op.cc new file mode 100644 index 000000000..54fecf44e --- /dev/null +++ b/paddle/operators/is_empty_op.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +constexpr char kInput[] = "X"; +constexpr char kOutput[] = "Out"; + +class IsEmptyOp : public framework::OperatorBase { + public: + IsEmptyOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + // get input + auto *var = scope.FindVar(Input(kInput)); + PADDLE_ENFORCE_NOT_NULL(var); + auto &tensor = var->Get(); + // get output + auto *out = scope.FindVar(Output(kOutput)); + PADDLE_ENFORCE_NOT_NULL(out); + auto *out_tensor = out->GetMutable(); + + out_tensor->Resize({1}); + out_tensor->mutable_data(platform::CPUPlace())[0] = + framework::product(tensor.dims()) == 0; + } +}; + +class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + IsEmptyOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput(kInput, "(Tensor) Tensor which is to be checked."); + AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not."); + AddComment(R"DOC( +IsEmpty Operator which checks whether a tensor is empty. + +It will just return product(tensor.ddims()) > 0; + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT(is_empty, paddle::operators::IsEmptyOp, + paddle::operators::IsEmptyOpProtoMaker); diff --git a/python/paddle/v2/framework/tests/test_is_empty_op.py b/python/paddle/v2/framework/tests/test_is_empty_op.py new file mode 100644 index 000000000..129d1c194 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_is_empty_op.py @@ -0,0 +1,43 @@ +import unittest +import numpy as np +from paddle.v2.framework.op import Operator +import paddle.v2.framework.core as core + + +def create_tensor(scope, name, np_data): + tensor = scope.var(name).get_tensor() + tensor.set_dims(np_data.shape) + tensor.set(np_data, core.CPUPlace()) + return tensor + + +class TestIsEmptyOp(unittest.TestCase): + def setUp(self): + self.scope = core.Scope() + # create input variables + np_data0 = np.array([0, 1, 2]) + create_tensor(self.scope, "X0", np_data0) + + np_data1 = np.array([1]) + t = create_tensor(self.scope, "X1", np_data1) + t.set_dims([0]) + + # create output variables + self.scope.var("out") + + def test_no_empty(self): + self.one_case("X0", False) + + def test_empty(self): + self.one_case("X1", True) + + def one_case(self, input, target): + op = Operator(type="is_empty", X=input, Out="out") + ctx = core.DeviceContext.create(core.CPUPlace()) + op.run(self.scope, ctx) + out = self.scope.var("out").get_tensor() + self.assertEqual(np.array(out)[0], target) + + +if __name__ == "__main__": + unittest.main() -- GitLab