未验证 提交 4b683887 编写于 作者: Z Zhong Hui 提交者: GitHub

Add segment apis to paddle.incubate (#35759)

上级 f218330e
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import numpy as np
import sys import sys
import numpy as np
import paddle
from op_test import OpTest from op_test import OpTest
...@@ -198,5 +201,62 @@ class TestSegmentMean2(TestSegmentMean): ...@@ -198,5 +201,62 @@ class TestSegmentMean2(TestSegmentMean):
self.attrs = {'pooltype': "MEAN"} self.attrs = {'pooltype': "MEAN"}
class API_SegmentOpsTest(unittest.TestCase):
def test_static(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name="x", shape=[3, 3], dtype="float32")
y = paddle.static.data(name='y', shape=[3], dtype='int32')
res_sum = paddle.incubate.segment_sum(x, y)
res_mean = paddle.incubate.segment_mean(x, y)
res_max = paddle.incubate.segment_max(x, y)
res_min = paddle.incubate.segment_min(x, y)
exe = paddle.static.Executor(paddle.CPUPlace())
data1 = np.array([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
data2 = np.array([0, 0, 1], dtype="int32")
np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float32")
np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float32")
np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float32")
np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float32")
ret = exe.run(feed={'x': data1,
'y': data2},
fetch_list=[res_sum, res_mean, res_max, res_min])
for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret):
self.assertTrue(
np.allclose(
np_res, ret_res, atol=1e-6),
"two value is\
{}\n{}, check diff!".format(np_res, ret_res))
def test_dygraph(self):
device = paddle.CPUPlace()
with paddle.fluid.dygraph.guard(device):
x = paddle.to_tensor(
[[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
y = paddle.to_tensor([0, 0, 1], dtype="int32")
res_sum = paddle.incubate.segment_sum(x, y)
res_mean = paddle.incubate.segment_mean(x, y)
res_max = paddle.incubate.segment_max(x, y)
res_min = paddle.incubate.segment_min(x, y)
np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float32")
np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float32")
np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float32")
np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float32")
ret = [res_sum, res_mean, res_max, res_min]
for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret):
self.assertTrue(
np.allclose(
np_res, ret_res.numpy(), atol=1e-6),
"two value is\
{}\n{}, check diff!".format(np_res, ret_res))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -18,7 +18,18 @@ from .checkpoint import auto_checkpoint # noqa: F401 ...@@ -18,7 +18,18 @@ from .checkpoint import auto_checkpoint # noqa: F401
from ..fluid.layer_helper import LayerHelper # noqa: F401 from ..fluid.layer_helper import LayerHelper # noqa: F401
from .operators import softmax_mask_fuse_upper_triangle # noqa: F401 from .operators import softmax_mask_fuse_upper_triangle # noqa: F401
from .operators import softmax_mask_fuse # noqa: F401 from .operators import softmax_mask_fuse # noqa: F401
from .tensor import segment_sum
from .tensor import segment_mean
from .tensor import segment_max
from .tensor import segment_min
__all__ = [ # noqa __all__ = [
'LookAhead', 'ModelAverage', 'softmax_mask_fuse_upper_triangle', 'softmax_mask_fuse' 'LookAhead',
'ModelAverage',
'softmax_mask_fuse_upper_triangle',
'softmax_mask_fuse',
'segment_sum',
'segment_mean',
'segment_max',
'segment_min',
] ]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .math import segment_sum
from .math import segment_mean
from .math import segment_max
from .math import segment_min
__all__ = [
'segment_sum',
'segment_mean',
'segment_max',
'segment_min',
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
'segment_sum',
'segment_mean',
'segment_max',
'segment_min',
]
import paddle
from paddle.fluid.layer_helper import LayerHelper, in_dygraph_mode
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle import _C_ops
def segment_sum(data, segment_ids, name=None):
"""
Segment Sum Operator.
This operator sums the elements of input `data` which with
the same index in `segment_ids`.
It computes a tensor such that $out_i = \\sum_{j} data_{j}$
where sum is over j such that `segment_ids[j] == i`.
Args:
data (Tensor): A tensor, available data type float32, float64.
segment_ids (Tensor): A 1-D tensor, which have the same size
with the first dimension of input data.
Available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_sum(data, segment_ids)
#Outputs: [[4., 4., 4.], [4., 5., 6.]]
"""
if in_dygraph_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
helper = LayerHelper("segment_sum", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "SUM"})
return out
def segment_mean(data, segment_ids, name=None):
"""
Segment mean Operator.
Ihis operator calculate the mean value of input `data` which
with the same index in `segment_ids`.
It computes a tensor such that $out_i = \\frac{1}{n_i} \\sum_{j} data[j]$
where sum is over j such that 'segment_ids[j] == i' and $n_i$ is the number
of all index 'segment_ids[j] == i'.
Args:
data (tensor): a tensor, available data type float32, float64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_mean(data, segment_ids)
#Outputs: [[2., 2., 2.], [4., 5., 6.]]
"""
if in_dygraph_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
helper = LayerHelper("segment_mean", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MEAN"})
return out
def segment_min(data, segment_ids, name=None):
"""
Segment min operator.
This operator calculate the minimum elements of input `data` which with
the same index in `segment_ids`.
It computes a tensor such that $out_i = \\min_{j} data_{j}$
where min is over j such that `segment_ids[j] == i`.
Args:
data (tensor): a tensor, available data type float32, float64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_min(data, segment_ids)
#Outputs: [[1., 2., 1.], [4., 5., 6.]]
"""
if in_dygraph_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
helper = LayerHelper("segment_min", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MIN"})
return out
def segment_max(data, segment_ids, name=None):
"""
Segment max operator.
This operator calculate the maximum elements of input `data` which with
the same index in `segment_ids`.
It computes a tensor such that $out_i = \\min_{j} data_{j}$
where max is over j such that `segment_ids[j] == i`.
Args:
data (tensor): a tensor, available data type float32, float64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_max(data, segment_ids)
#Outputs: [[3., 2., 3.], [4., 5., 6.]]
"""
if in_dygraph_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
helper = LayerHelper("segment_max", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MAX"})
return out
...@@ -162,6 +162,7 @@ packages=['paddle', ...@@ -162,6 +162,7 @@ packages=['paddle',
'paddle.incubate.optimizer', 'paddle.incubate.optimizer',
'paddle.incubate.checkpoint', 'paddle.incubate.checkpoint',
'paddle.incubate.operators', 'paddle.incubate.operators',
'paddle.incubate.tensor',
'paddle.distributed.fleet', 'paddle.distributed.fleet',
'paddle.distributed.fleet.base', 'paddle.distributed.fleet.base',
'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.elastic',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册