提交 1cbffbc4 编写于 作者: S sneaxiy

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into make_flag_adding_easier

......@@ -66,6 +66,12 @@ class GroupNormOp : public framework::OperatorWithKernel {
"The Attr(groups) of Op(group_norm) must be "
"greater than or equal to 1. But received: groups is [%s].",
groups));
PADDLE_ENFORCE_EQ(
channel_num % groups, 0,
platform::errors::InvalidArgument(
"Expected number of channels in input to be divisible by "
"num_groups, but got input channel is %d and num_groups is %d",
channel_num, groups));
if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ(
......
......@@ -144,7 +144,8 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;
const int W =
(data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]);
......@@ -314,7 +315,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;
const int W =
(data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]);
......
......@@ -52,7 +52,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
......@@ -100,7 +100,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M) {
// TODO(gaoxiang)Because AVX/AVX2/AVX512 can not directly used
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
......@@ -138,7 +138,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M * C) {
// TODO(gaoxiang)Because AVX/AVX2/AVX512 can not directly used
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
......@@ -236,7 +236,7 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int group_size = (C - 1) / groups + 1;
const int group_size = C / groups;
d_x->mutable_data<T>(ctx.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
......
......@@ -21,12 +21,12 @@ namespace operators {
template <typename DeviceContext, typename T>
class IndexSelectNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* index = ctx.Input<Tensor>("Index");
auto dim = ctx.Attr<int>("dim");
auto *out = ctx.Output<Tensor>("Out");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto stream =
......@@ -43,7 +43,104 @@ class IndexSelectNPUKernel : public framework::OpKernel<T> {
}
};
// todo: add class 'IndexSelectGradNPUKernel' here.
template <typename DeviceContext, typename T>
class IndexSelectGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* index = ctx.Input<Tensor>("Index");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
auto x_dims = x_grad->dims();
auto out_dims = out_grad->dims();
int dim = ctx.Attr<int>("dim");
if (dim < 0) {
dim += out_dims.size();
}
Tensor casted_index;
if (index->type() != framework::proto::VarType::INT32) {
casted_index.mutable_data<int32_t>(index->dims(), ctx.GetPlace());
const auto& cast_runner = NpuOpRunner("Cast", {*index}, {casted_index},
{{"dst_type", ACL_INT32}});
cast_runner.Run(stream);
} else {
casted_index.ShareDataWith(*index);
}
if (dim == 0) {
x_grad->mutable_data<T>(ctx.GetPlace());
const auto& zeros_runner = NpuOpRunner("ZerosLike", {*x_grad}, {*x_grad});
zeros_runner.Run(stream);
NpuOpRunner runner;
runner.SetType("UnsortedSegmentSum")
.AddInput(*out_grad)
.AddInput(casted_index)
.AddInput(std::vector<int64_t>{x_dims[dim]})
.AddOutput(*x_grad);
runner.Run(stream);
} else {
Tensor transed_out_grad;
std::vector<int> in_trans_perm;
in_trans_perm.push_back(dim);
for (int i = 0; i < out_dims.size(); ++i) {
if (i == dim) continue;
in_trans_perm.push_back(i);
}
framework::DDim transed_out_dims(out_dims);
for (size_t i = 0; i < in_trans_perm.size(); ++i) {
transed_out_dims[i] = out_dims[in_trans_perm[i]];
}
transed_out_grad.mutable_data<T>(transed_out_dims, ctx.GetPlace());
framework::NPUAttributeMap in_trans_attr = {{"perm", in_trans_perm}};
const auto& in_trans_runner = NpuOpRunner(
"TransposeD", {*out_grad}, {transed_out_grad}, in_trans_attr);
in_trans_runner.Run(stream);
Tensor sum_out;
framework::DDim sum_dims(x_dims);
sum_dims[0] = x_dims[dim];
auto idx = 1;
for (int i = 0; i < x_dims.size(); ++i) {
if (i == dim) continue;
sum_dims[idx++] = x_dims[i];
}
sum_out.mutable_data<T>(sum_dims, ctx.GetPlace());
const auto& zeros_runner = NpuOpRunner("ZerosLike", {sum_out}, {sum_out});
zeros_runner.Run(stream);
NpuOpRunner runner;
runner.SetType("UnsortedSegmentSum")
.AddInput(transed_out_grad)
.AddInput(casted_index)
.AddInput(std::vector<int64_t>{x_dims[dim]})
.AddOutput(sum_out);
runner.Run(stream);
std::vector<int> out_trans_perm;
for (int i = 1; i < 1 + dim; ++i) {
out_trans_perm.push_back(i);
}
out_trans_perm.push_back(0);
for (int i = 1 + dim; i < x_dims.size(); ++i) {
out_trans_perm.push_back(i);
}
framework::NPUAttributeMap out_trans_attr = {{"perm", out_trans_perm}};
x_grad->mutable_data<T>(ctx.GetPlace());
const auto& out_trans_runner =
NpuOpRunner("TransposeD", {sum_out}, {*x_grad}, out_trans_attr);
out_trans_runner.Run(stream);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -54,4 +151,8 @@ REGISTER_OP_NPU_KERNEL(
ops::IndexSelectNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::IndexSelectNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::IndexSelectNPUKernel<paddle::platform::NPUDeviceContext, int64_t>);
// todo: register npu index_select_grad kernel here.
REGISTER_OP_NPU_KERNEL(
index_select_grad,
ops::IndexSelectGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::IndexSelectGradNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::IndexSelectGradNPUKernel<paddle::platform::NPUDeviceContext, int64_t>);
......@@ -142,32 +142,103 @@ class GradientClipHelper(object):
return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def sync_global_norm(self, block, ring_ids):
def sync_global_norm(self, block, ring_ids, mp_rank):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
# FIXME(wangxi): mp should prune duplicated param_grads
is_clip_grad_by_global_norm = False
for idx, op in list(enumerate(block.ops)):
if not self._is_gradient_clip_op(op):
continue
if op.type == 'sum':
is_clip_grad_by_global_norm = True
break
if not is_clip_grad_by_global_norm:
# TODO(Yuang Liu): need some extra handles when clip_grad_norm for mp
return
removed_op_idx = set()
removed_tmp_var = set()
for idx, op in list(enumerate(block.ops)):
if not self._is_gradient_clip_op(op):
continue
if op.type == 'sum':
break
for input_name in op.input_arg_names:
input_var = block.var(input_name)
# NOTE: when mp_degree > 1, some vars will be split into each mp rank.
# However, there still some vars such as Scale, Bias are not split.
# Those not be split vars should only be counted once during grad clip
# by global norm. Those vars either doesn't have is_distributed attr
# or the is_distributed attr has been set as False.
# Therefore, we prune those duplicated vars for grad clip.
if mp_rank >= 1 and (not (hasattr(input_var, 'is_distributed')
and input_var.is_distributed)):
removed_op_idx.add(idx)
for output_name in op.output_arg_names:
removed_tmp_var.add(output_name)
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op):
continue
if idx in removed_op_idx:
block._remove_op(idx, sync=False)
if op.type == "sum":
sum_res = op.desc.output_arg_names()[0]
for ring_id in ring_ids:
if ring_id == -1: continue
for var_name in removed_tmp_var:
block._remove_var(var_name, sync=False)
idx = idx + 1
block._insert_op_without_sync(
idx,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
return
for idx, op in list(enumerate(block.ops)):
if not self._is_gradient_clip_op(op):
continue
if op.type == 'sum':
# If mp_rank == 0, no extra handles, just allreduce
# If mp_rank >= 1, some extra handles is needed
sum_rst_var = block.var(op.output_arg_names[0])
if mp_rank >= 1:
reserved_vars = []
for input_name in op.input_arg_names:
if input_name not in removed_tmp_var:
reserved_vars.append(input_name)
if len(reserved_vars) > 0:
op.desc.set_input("X", reserved_vars)
else:
# If all input of sum op should be removed, then remove the sum op.
# And set the output's value of sum to 0.
namescope = op.attr("op_namescope")
block._remove_op(idx, sync=False)
fill_constant_op = block._insert_op_without_sync(
idx,
type='fill_constant',
inputs={},
outputs={'Out': sum_rst_var},
attrs={
'shape': sum_rst_var.shape,
'dtype': sum_rst_var.dtype,
'value': 0.0,
OP_ROLE_KEY: OpRole.Optimize
})
fill_constant_op._set_attr('op_namescope', namescope)
self._insert_allreduce(block, ring_ids, idx, sum_rst_var)
break
@staticmethod
def _insert_allreduce(block, ring_ids, idx, var):
for ring_id in ring_ids:
if ring_id == -1:
continue
idx = idx + 1
block._insert_op_without_sync(
idx,
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
......@@ -435,7 +435,6 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var
rings = [self.mp_ring_id, self.pp_ring_id]
......@@ -446,7 +445,7 @@ class ShardingOptimizer(MetaOptimizerBase):
gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.sync_global_norm(
main_block, [self.mp_ring_id, self.pp_ring_id])
main_block, [self.mp_ring_id, self.pp_ring_id], self.mp_rank)
def _insert_loss_grad_scale_op(self):
main_block = self._main_program.global_block()
......
......@@ -4381,7 +4381,7 @@ class PipelineOptimizer(object):
persistable=source_var.persistable)
else:
dest_var = block._clone_variable(source_var, False)
dest_var.stop_gradient = source_var.stop_gradient
self._clone_var_attr(dest_var, source_var)
# When use with sharding, allreduce_sum and allreduce_max
# used for global gradient clip and amp will be added by sharding.
op_idx += 1
......@@ -4547,9 +4547,14 @@ class PipelineOptimizer(object):
persistable=ref_var.persistable,
is_data=ref_var.is_data,
need_check_feed=ref_var.desc.need_check_feed())
new_var.stop_gradient = ref_var.stop_gradient
self._clone_var_attr(new_var, ref_var)
return new_var
def _clone_var_attr(self, dest, src):
dest.stop_gradient = src.stop_gradient
if hasattr(src, 'is_distributed'):
dest.is_distributed = src.is_distributed
def _strip_grad_suffix(self, name):
"""
Strip the grad suffix from the given variable name
......@@ -5209,6 +5214,8 @@ class PipelineOptimizer(object):
persistable=True,
stop_gradient=False)
real_param = main_block.var(param)
if hasattr(real_param, 'is_distributed'):
merged_grad_var.is_distributed = real_param.is_distributed
tmp_size = self._get_var_size(real_grad)
# two strategies for splitting the grad
# 1. the current segment's size reach the user defined grad_size_in_MB
......
......@@ -35,7 +35,10 @@ class TestNPUIndexSelect(OpTest):
x_np = np.random.random(self.x_shape).astype(self.x_type)
index_np = np.random.randint(
low=0, high=self.x_shape[self.dim], size=self.index_size)
low=0,
high=self.x_shape[self.dim],
size=self.index_size,
dtype=self.index_type)
# compute real output as baseline.
outer_loop = np.prod(self.x_shape[:self.dim])
......@@ -56,18 +59,14 @@ class TestNPUIndexSelect(OpTest):
self.attrs = {'dim': self.dim}
self.outputs = {'Out': out}
# todo: comment second line when index_select grad npu op is ready.
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
def test_check_output(self):
self.check_output_with_place(self.place)
# todo: replace first line with second line when index_select grad npu op is ready.
def test_check_grad(self):
pass
#self.check_grad_with_place(self.place, ['X'], 'Out')
self.check_grad_with_place(self.place, ['X'], 'Out')
def config(self):
self.x_shape = (100, 4, 5)
......@@ -86,6 +85,24 @@ class TestNPUIndexSelectCase2(TestNPUIndexSelect):
self.index_size = 10
class TestNPUIndexSelectCase3(TestNPUIndexSelect):
def config(self):
self.dim = 0
self.x_type = np.float32
self.index_type = np.int32
self.x_shape = (10, 10, 4, 10)
self.index_size = 10
class TestNPUIndexSelectCase4(TestNPUIndexSelect):
def config(self):
self.dim = -1
self.x_type = np.float32
self.index_type = np.int32
self.x_shape = (10, 10, 4, 10)
self.index_size = 10
class TestNPUIndexSelectAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
......
......@@ -658,6 +658,33 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id', 'c_comm_init'
])
self.assertEqual(main_prog_op_types, [
'partial_recv', 'partial_allgather', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'softmax', 'cast', 'cross_entropy2', 'mean',
'elementwise_mul', 'fill_constant', 'elementwise_mul_grad',
'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream',
'partial_send', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'c_sync_comm_stream', 'check_finite_and_unscale',
'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast',
'update_loss_scaling', 'fill_constant', 'c_allreduce_sum',
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum',
'momentum', 'momentum', 'momentum', 'momentum', 'momentum',
'momentum', 'momentum'
])
# pp + mp, partial send recv
self.assertIn('partial_recv', main_prog_op_types)
self.assertIn('partial_allgather', main_prog_op_types)
......
......@@ -15,8 +15,11 @@
from __future__ import print_function
import unittest
import numpy as np
import sys
import numpy as np
import paddle
from op_test import OpTest
......@@ -198,5 +201,62 @@ class TestSegmentMean2(TestSegmentMean):
self.attrs = {'pooltype': "MEAN"}
class API_SegmentOpsTest(unittest.TestCase):
def test_static(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name="x", shape=[3, 3], dtype="float32")
y = paddle.static.data(name='y', shape=[3], dtype='int32')
res_sum = paddle.incubate.segment_sum(x, y)
res_mean = paddle.incubate.segment_mean(x, y)
res_max = paddle.incubate.segment_max(x, y)
res_min = paddle.incubate.segment_min(x, y)
exe = paddle.static.Executor(paddle.CPUPlace())
data1 = np.array([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
data2 = np.array([0, 0, 1], dtype="int32")
np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float32")
np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float32")
np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float32")
np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float32")
ret = exe.run(feed={'x': data1,
'y': data2},
fetch_list=[res_sum, res_mean, res_max, res_min])
for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret):
self.assertTrue(
np.allclose(
np_res, ret_res, atol=1e-6),
"two value is\
{}\n{}, check diff!".format(np_res, ret_res))
def test_dygraph(self):
device = paddle.CPUPlace()
with paddle.fluid.dygraph.guard(device):
x = paddle.to_tensor(
[[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
y = paddle.to_tensor([0, 0, 1], dtype="int32")
res_sum = paddle.incubate.segment_sum(x, y)
res_mean = paddle.incubate.segment_mean(x, y)
res_max = paddle.incubate.segment_max(x, y)
res_min = paddle.incubate.segment_min(x, y)
np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float32")
np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float32")
np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float32")
np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float32")
ret = [res_sum, res_mean, res_max, res_min]
for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret):
self.assertTrue(
np.allclose(
np_res, ret_res.numpy(), atol=1e-6),
"two value is\
{}\n{}, check diff!".format(np_res, ret_res))
if __name__ == '__main__':
unittest.main()
......@@ -18,7 +18,18 @@ from .checkpoint import auto_checkpoint # noqa: F401
from ..fluid.layer_helper import LayerHelper # noqa: F401
from .operators import softmax_mask_fuse_upper_triangle # noqa: F401
from .operators import softmax_mask_fuse # noqa: F401
from .tensor import segment_sum
from .tensor import segment_mean
from .tensor import segment_max
from .tensor import segment_min
__all__ = [ # noqa
'LookAhead', 'ModelAverage', 'softmax_mask_fuse_upper_triangle', 'softmax_mask_fuse'
__all__ = [
'LookAhead',
'ModelAverage',
'softmax_mask_fuse_upper_triangle',
'softmax_mask_fuse',
'segment_sum',
'segment_mean',
'segment_max',
'segment_min',
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .math import segment_sum
from .math import segment_mean
from .math import segment_max
from .math import segment_min
__all__ = [
'segment_sum',
'segment_mean',
'segment_max',
'segment_min',
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
'segment_sum',
'segment_mean',
'segment_max',
'segment_min',
]
import paddle
from paddle.fluid.layer_helper import LayerHelper, in_dygraph_mode
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle import _C_ops
def segment_sum(data, segment_ids, name=None):
"""
Segment Sum Operator.
This operator sums the elements of input `data` which with
the same index in `segment_ids`.
It computes a tensor such that $out_i = \\sum_{j} data_{j}$
where sum is over j such that `segment_ids[j] == i`.
Args:
data (Tensor): A tensor, available data type float32, float64.
segment_ids (Tensor): A 1-D tensor, which have the same size
with the first dimension of input data.
Available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_sum(data, segment_ids)
#Outputs: [[4., 4., 4.], [4., 5., 6.]]
"""
if in_dygraph_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
helper = LayerHelper("segment_sum", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "SUM"})
return out
def segment_mean(data, segment_ids, name=None):
"""
Segment mean Operator.
Ihis operator calculate the mean value of input `data` which
with the same index in `segment_ids`.
It computes a tensor such that $out_i = \\frac{1}{n_i} \\sum_{j} data[j]$
where sum is over j such that 'segment_ids[j] == i' and $n_i$ is the number
of all index 'segment_ids[j] == i'.
Args:
data (tensor): a tensor, available data type float32, float64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_mean(data, segment_ids)
#Outputs: [[2., 2., 2.], [4., 5., 6.]]
"""
if in_dygraph_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
helper = LayerHelper("segment_mean", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MEAN"})
return out
def segment_min(data, segment_ids, name=None):
"""
Segment min operator.
This operator calculate the minimum elements of input `data` which with
the same index in `segment_ids`.
It computes a tensor such that $out_i = \\min_{j} data_{j}$
where min is over j such that `segment_ids[j] == i`.
Args:
data (tensor): a tensor, available data type float32, float64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_min(data, segment_ids)
#Outputs: [[1., 2., 1.], [4., 5., 6.]]
"""
if in_dygraph_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
helper = LayerHelper("segment_min", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MIN"})
return out
def segment_max(data, segment_ids, name=None):
"""
Segment max operator.
This operator calculate the maximum elements of input `data` which with
the same index in `segment_ids`.
It computes a tensor such that $out_i = \\min_{j} data_{j}$
where max is over j such that `segment_ids[j] == i`.
Args:
data (tensor): a tensor, available data type float32, float64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_max(data, segment_ids)
#Outputs: [[3., 2., 3.], [4., 5., 6.]]
"""
if in_dygraph_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
helper = LayerHelper("segment_max", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MAX"})
return out
......@@ -162,6 +162,7 @@ packages=['paddle',
'paddle.incubate.optimizer',
'paddle.incubate.checkpoint',
'paddle.incubate.operators',
'paddle.incubate.tensor',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.elastic',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册