提交 67b48d7f 编写于 作者: Z zhoukunsheng 提交者: Tao Luo

add size op (#17412)

上级 6e0df310
...@@ -222,6 +222,7 @@ paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaul ...@@ -222,6 +222,7 @@ paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaul
paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '3ca6a761570d86e303e473afba99bb49')) paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '3ca6a761570d86e303e473afba99bb49'))
paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'bf61c8f79d795a8371bdb3b5468aa82b')) paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'bf61c8f79d795a8371bdb3b5468aa82b'))
paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '096df0e0273145ab80ed119a4c294db3')) paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '096df0e0273145ab80ed119a4c294db3'))
paddle.fluid.layers.size (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'cf2e156beae36378722666c4c33bebfe'))
paddle.fluid.layers.logical_and (ArgSpec(args=['x', 'y', 'out', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '1d6777f61831c54bea3a0029e2118448')) paddle.fluid.layers.logical_and (ArgSpec(args=['x', 'y', 'out', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '1d6777f61831c54bea3a0029e2118448'))
paddle.fluid.layers.logical_or (ArgSpec(args=['x', 'y', 'out', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '4d51a5a453755e0eb8c5ff6910a00dca')) paddle.fluid.layers.logical_or (ArgSpec(args=['x', 'y', 'out', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '4d51a5a453755e0eb8c5ff6910a00dca'))
paddle.fluid.layers.logical_xor (ArgSpec(args=['x', 'y', 'out', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '1840f54c5bd5338bdf854980d47bf771')) paddle.fluid.layers.logical_xor (ArgSpec(args=['x', 'y', 'out', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '1840f54c5bd5338bdf854980d47bf771'))
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/size_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class SizeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input (Input) of Size op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output (Out) of Size op should not be null.");
ctx->SetOutputDim("Out", {1});
}
};
class SizeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "The input tensor.");
AddOutput("Out",
"The returned tensor, the data type "
"is int64_t, will be on the same device with the input Tensor.");
AddComment(R"DOC(
Size Operator.
Return the number of elements in the input.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(size, ops::SizeOp, ops::SizeOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(size, ops::SizeKernel<int>, ops::SizeKernel<int32_t>,
ops::SizeKernel<float>, ops::SizeKernel<double>,
ops::SizeKernel<bool>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/size_op.h"
REGISTER_OP_CUDA_KERNEL(size, paddle::operators::SizeKernel<int>,
paddle::operators::SizeKernel<int32_t>,
paddle::operators::SizeKernel<float>,
paddle::operators::SizeKernel<bool>,
paddle::operators::SizeKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class SizeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("Input");
auto* out_t = ctx.Output<Tensor>("Out");
auto out_data = out_t->mutable_data<int64_t>(platform::CPUPlace());
out_data[0] = in_t->numel();
}
};
} // namespace operators
} // namespace paddle
...@@ -165,6 +165,7 @@ __all__ = [ ...@@ -165,6 +165,7 @@ __all__ = [
'slice', 'slice',
'shape', 'shape',
'rank', 'rank',
'size',
'logical_and', 'logical_and',
'logical_or', 'logical_or',
'logical_xor', 'logical_xor',
...@@ -9891,6 +9892,35 @@ def rank(input): ...@@ -9891,6 +9892,35 @@ def rank(input):
return out return out
def size(input):
"""
**Size Layer**
Returns the number of elements for a tensor, which is a int64 Tensor with shape [1].
Args:
input (Variable): The input variable.
Returns:
Variable: The number of elements for the input variable.
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
input = layers.data(
name="input", shape=[3, 100], dtype="float32", append_batch_size=False)
rank = layers.size(input) # 300
"""
helper = LayerHelper('size', **locals())
out = helper.create_variable_for_type_inference(dtype='int64')
helper.append_op(type='size', inputs={'Input': input}, outputs={'Out': out})
return out
def _elementwise_op(helper): def _elementwise_op(helper):
op_type = helper.layer_type op_type = helper.layer_type
x = helper.kwargs.get('x', None) x = helper.kwargs.get('x', None)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestSizeOp(OpTest):
def setUp(self):
self.op_type = "size"
self.shape = []
self.config()
input = np.zeros(self.shape, dtype='bool')
self.inputs = {'Input': input}
self.outputs = {'Out': np.array([np.size(input)], dtype='int64')}
def config(self):
pass
def test_check_output(self):
self.check_output()
class TestRank1Tensor(TestSizeOp):
def config(self):
self.shape = [2]
class TestRank2Tensor(TestSizeOp):
def config(self):
self.shape = [2, 3]
class TestRank3Tensor(TestSizeOp):
def config(self):
self.shape = [2, 3, 100]
class TestLargeTensor(TestSizeOp):
def config(self):
self.shape = [2**10]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册