未验证 提交 9ff6715f 编写于 作者: Q Qingsheng Li 提交者: GitHub

Enhanced is_empty_op for our seq2seq model (#10704)

* Added kernel to is_empty_op

* Added python API

* Updated code as required

* Updated unittests
上级 5828101c
......@@ -12,45 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/is_empty_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
constexpr char kInput[] = "X";
constexpr char kOutput[] = "Out";
class IsEmptyOp : public framework::OperatorBase {
class IsEmptyOp : public framework::OperatorWithKernel {
public:
IsEmptyOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
using framework::OperatorWithKernel::OperatorWithKernel;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// get input
auto *var = scope.FindVar(Input(kInput));
PADDLE_ENFORCE_NOT_NULL(var);
auto &tensor = var->Get<framework::LoDTensor>();
// get output
auto *out = scope.FindVar(Output(kOutput));
PADDLE_ENFORCE_NOT_NULL(out);
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of IsEmptyOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of IsEmptyOp should not be null.");
ctx->SetOutputDim("Out", {1});
}
out_tensor->Resize({1});
out_tensor->mutable_data<bool>(platform::CPUPlace())[0] =
framework::product(tensor.dims()) == 0;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
platform::CPUPlace());
return kt;
}
};
class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
class IsEmptyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(kInput, "(Tensor) Tensor which is to be checked.");
AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not.");
AddInput("X", "(LoDTensor) Tensor which is to be checked.");
AddOutput("Out",
"(LoDTensor) a boolean Tensor that indicate empty or not.");
AddComment(R"DOC(
IsEmpty Operator which checks whether a tensor is empty.
......@@ -62,5 +58,12 @@ It will just return product(tensor.ddims()) > 0;
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(is_empty, paddle::operators::IsEmptyOp,
paddle::operators::IsEmptyOpProtoMaker);
namespace ops = paddle::operators;
REGISTER_OPERATOR(is_empty, ops::IsEmptyOp, ops::IsEmptyOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
is_empty, ops::IsEmptyOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::IsEmptyOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::IsEmptyOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::IsEmptyOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class IsEmptyOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// get input
auto* input_tensor = context.Input<framework::LoDTensor>("X");
// get output
auto* output_tensor = context.Output<framework::LoDTensor>("Out");
output_tensor->mutable_data<bool>(platform::CPUPlace())[0] =
framework::product(input_tensor->dims()) == 0;
}
};
} // namespace operators
} // namespace paddle
......@@ -49,6 +49,7 @@ __all__ = [
'reorder_lod_tensor_by_rank',
'ParallelDo',
'Print',
'is_empty',
]
......@@ -1562,3 +1563,40 @@ def reorder_lod_tensor_by_rank(x, rank_table):
'RankTable': [rank_table]},
outputs={'Out': [out]})
return out
def is_empty(x, cond=None, **ignored):
"""
**Is Empty**
This layer returns the truth value of whether the variable is empty.
Args:
x(Variable): Operand of *is_empty*
cond(Variable|None): Optional output variable to store the result
of *is_empty*
Returns:
Variable: The tensor variable storing the output of *is_empty*.
Raises:
TypeError: If input cond is not a variable, or cond's dtype is
not bool
Examples:
.. code-block:: python
less = fluid.layers.is_empty(x=input)
"""
helper = LayerHelper("is_empty", **locals())
if cond is None:
cond = helper.create_tmp_variable(dtype='bool')
cond.stop_gradient = True
elif not isinstance(cond, Variable):
raise TypeError("cond takes a variable")
elif cond.dtype != 'bool':
raise TypeError("The data type of cond must be bool")
helper.append_op(
type='is_empty', inputs={'X': [x]}, outputs={'Out': [cond]})
return cond
......@@ -14,42 +14,24 @@
import unittest
import numpy as np
from paddle.fluid.op import Operator
import paddle.fluid.core as core
from op_test import OpTest
def create_tensor(scope, name, np_data):
tensor = scope.var(name).get_tensor()
tensor.set_dims(np_data.shape)
tensor.set(np_data, core.CPUPlace())
return tensor
class TestIsEmptyOp(unittest.TestCase):
class TestEmpty(OpTest):
def setUp(self):
self.scope = core.Scope()
# create input variables
np_data0 = np.array([0, 1, 2])
create_tensor(self.scope, "X0", np_data0)
np_data1 = np.array([1])
t = create_tensor(self.scope, "X1", np_data1)
t.set_dims([0])
self.op_type = "is_empty"
self.inputs = {'X': np.array([1, 2, 3])}
self.outputs = {'Out': np.array([False])}
# create output variables
self.scope.var("out")
def test_check_output(self):
self.check_output()
def test_no_empty(self):
self.one_case("X0", False)
def test_empty(self):
self.one_case("X1", True)
def one_case(self, input, target):
op = Operator(type="is_empty", X=input, Out="out")
op.run(self.scope, core.CPUPlace())
out = self.scope.var("out").get_tensor()
self.assertEqual(np.array(out)[0], target)
class TestNotEmpty(TestEmpty):
def setUp(self):
self.op_type = "is_empty"
self.inputs = {'X': np.array([])}
self.outputs = {'Out': np.array([True])}
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册