diff --git a/paddle/fluid/operators/is_empty_op.cc b/paddle/fluid/operators/is_empty_op.cc index d3f3ad92442cafdd8d4cdc396d89721863d069c2..29b73951bbddd9bfd73c932d7801797590de5e8e 100644 --- a/paddle/fluid/operators/is_empty_op.cc +++ b/paddle/fluid/operators/is_empty_op.cc @@ -12,45 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/is_empty_op.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" namespace paddle { namespace operators { -constexpr char kInput[] = "X"; -constexpr char kOutput[] = "Out"; - -class IsEmptyOp : public framework::OperatorBase { +class IsEmptyOp : public framework::OperatorWithKernel { public: - IsEmptyOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + using framework::OperatorWithKernel::OperatorWithKernel; - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - // get input - auto *var = scope.FindVar(Input(kInput)); - PADDLE_ENFORCE_NOT_NULL(var); - auto &tensor = var->Get(); - // get output - auto *out = scope.FindVar(Output(kOutput)); - PADDLE_ENFORCE_NOT_NULL(out); - auto *out_tensor = out->GetMutable(); + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of IsEmptyOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of IsEmptyOp should not be null."); + ctx->SetOutputDim("Out", {1}); + } - out_tensor->Resize({1}); - out_tensor->mutable_data(platform::CPUPlace())[0] = - framework::product(tensor.dims()) == 0; + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + return kt; } }; -class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker { +class IsEmptyOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput(kInput, "(Tensor) Tensor which is to be checked."); - AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not."); + AddInput("X", "(LoDTensor) Tensor which is to be checked."); + AddOutput("Out", + "(LoDTensor) a boolean Tensor that indicate empty or not."); AddComment(R"DOC( IsEmpty Operator which checks whether a tensor is empty. @@ -62,5 +58,12 @@ It will just return product(tensor.ddims()) > 0; } // namespace operators } // namespace paddle -REGISTER_OP_WITHOUT_GRADIENT(is_empty, paddle::operators::IsEmptyOp, - paddle::operators::IsEmptyOpProtoMaker); +namespace ops = paddle::operators; + +REGISTER_OPERATOR(is_empty, ops::IsEmptyOp, ops::IsEmptyOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + is_empty, ops::IsEmptyOpKernel, + ops::IsEmptyOpKernel, + ops::IsEmptyOpKernel, + ops::IsEmptyOpKernel); diff --git a/paddle/fluid/operators/is_empty_op.h b/paddle/fluid/operators/is_empty_op.h new file mode 100644 index 0000000000000000000000000000000000000000..3e3af22fa8d842b6a1e67418446f1a40949e046b --- /dev/null +++ b/paddle/fluid/operators/is_empty_op.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +template +class IsEmptyOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + // get input + auto* input_tensor = context.Input("X"); + // get output + auto* output_tensor = context.Output("Out"); + + output_tensor->mutable_data(platform::CPUPlace())[0] = + framework::product(input_tensor->dims()) == 0; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 4b707973e27391a6bdcba138934f62a255e04bb2..dee41448081cbfcd8224ce2abbf3ba7b7b97eb7c 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -49,6 +49,7 @@ __all__ = [ 'reorder_lod_tensor_by_rank', 'ParallelDo', 'Print', + 'is_empty', ] @@ -1562,3 +1563,40 @@ def reorder_lod_tensor_by_rank(x, rank_table): 'RankTable': [rank_table]}, outputs={'Out': [out]}) return out + + +def is_empty(x, cond=None, **ignored): + """ + **Is Empty** + + This layer returns the truth value of whether the variable is empty. + + Args: + x(Variable): Operand of *is_empty* + cond(Variable|None): Optional output variable to store the result + of *is_empty* + + Returns: + Variable: The tensor variable storing the output of *is_empty*. + + Raises: + TypeError: If input cond is not a variable, or cond's dtype is + not bool + + Examples: + .. code-block:: python + + less = fluid.layers.is_empty(x=input) + """ + helper = LayerHelper("is_empty", **locals()) + if cond is None: + cond = helper.create_tmp_variable(dtype='bool') + cond.stop_gradient = True + elif not isinstance(cond, Variable): + raise TypeError("cond takes a variable") + elif cond.dtype != 'bool': + raise TypeError("The data type of cond must be bool") + + helper.append_op( + type='is_empty', inputs={'X': [x]}, outputs={'Out': [cond]}) + return cond diff --git a/python/paddle/fluid/tests/unittests/test_is_empty_op.py b/python/paddle/fluid/tests/unittests/test_is_empty_op.py index 4d11cf226be2ba4ffbe015198fed3191f1e02f72..11121d9b65351eab639b7618fac0e54714cf4680 100644 --- a/python/paddle/fluid/tests/unittests/test_is_empty_op.py +++ b/python/paddle/fluid/tests/unittests/test_is_empty_op.py @@ -14,42 +14,24 @@ import unittest import numpy as np -from paddle.fluid.op import Operator -import paddle.fluid.core as core +from op_test import OpTest -def create_tensor(scope, name, np_data): - tensor = scope.var(name).get_tensor() - tensor.set_dims(np_data.shape) - tensor.set(np_data, core.CPUPlace()) - return tensor - - -class TestIsEmptyOp(unittest.TestCase): +class TestEmpty(OpTest): def setUp(self): - self.scope = core.Scope() - # create input variables - np_data0 = np.array([0, 1, 2]) - create_tensor(self.scope, "X0", np_data0) - - np_data1 = np.array([1]) - t = create_tensor(self.scope, "X1", np_data1) - t.set_dims([0]) + self.op_type = "is_empty" + self.inputs = {'X': np.array([1, 2, 3])} + self.outputs = {'Out': np.array([False])} - # create output variables - self.scope.var("out") + def test_check_output(self): + self.check_output() - def test_no_empty(self): - self.one_case("X0", False) - def test_empty(self): - self.one_case("X1", True) - - def one_case(self, input, target): - op = Operator(type="is_empty", X=input, Out="out") - op.run(self.scope, core.CPUPlace()) - out = self.scope.var("out").get_tensor() - self.assertEqual(np.array(out)[0], target) +class TestNotEmpty(TestEmpty): + def setUp(self): + self.op_type = "is_empty" + self.inputs = {'X': np.array([])} + self.outputs = {'Out': np.array([True])} if __name__ == "__main__":