Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1cbffbc4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1cbffbc4
编写于
9月 16, 2021
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into make_flag_adding_easier
上级
ca0136a6
a4eadd15
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
595 addition
and
44 deletion
+595
-44
paddle/fluid/operators/group_norm_op.cc
paddle/fluid/operators/group_norm_op.cc
+6
-0
paddle/fluid/operators/group_norm_op.cu
paddle/fluid/operators/group_norm_op.cu
+3
-2
paddle/fluid/operators/group_norm_op.h
paddle/fluid/operators/group_norm_op.h
+4
-4
paddle/fluid/operators/index_select_op_npu.cc
paddle/fluid/operators/index_select_op_npu.cc
+107
-6
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
...ed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
+90
-19
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+1
-2
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+9
-2
python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py
...dle/fluid/tests/unittests/npu/test_index_select_op_npu.py
+23
-6
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
...uid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+27
-0
python/paddle/fluid/tests/unittests/test_segment_ops.py
python/paddle/fluid/tests/unittests/test_segment_ops.py
+61
-1
python/paddle/incubate/__init__.py
python/paddle/incubate/__init__.py
+13
-2
python/paddle/incubate/tensor/__init__.py
python/paddle/incubate/tensor/__init__.py
+25
-0
python/paddle/incubate/tensor/math.py
python/paddle/incubate/tensor/math.py
+225
-0
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
paddle/fluid/operators/group_norm_op.cc
浏览文件 @
1cbffbc4
...
...
@@ -66,6 +66,12 @@ class GroupNormOp : public framework::OperatorWithKernel {
"The Attr(groups) of Op(group_norm) must be "
"greater than or equal to 1. But received: groups is [%s]."
,
groups
));
PADDLE_ENFORCE_EQ
(
channel_num
%
groups
,
0
,
platform
::
errors
::
InvalidArgument
(
"Expected number of channels in input to be divisible by "
"num_groups, but got input channel is %d and num_groups is %d"
,
channel_num
,
groups
));
if
(
ctx
->
HasInput
(
"Scale"
))
{
PADDLE_ENFORCE_EQ
(
...
...
paddle/fluid/operators/group_norm_op.cu
浏览文件 @
1cbffbc4
...
...
@@ -144,7 +144,8 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
const
int
C
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
1
]
:
x_dims
[
x_dims
.
size
()
-
1
]);
const
int
group_size
=
(
C
-
1
)
/
groups
+
1
;
const
int
group_size
=
C
/
groups
;
const
int
W
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
x_dims
.
size
()
-
1
]
:
x_dims
[
x_dims
.
size
()
-
2
]);
...
...
@@ -314,7 +315,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const
int
C
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
1
]
:
x_dims
[
x_dims
.
size
()
-
1
]);
const
int
group_size
=
(
C
-
1
)
/
groups
+
1
;
const
int
group_size
=
C
/
groups
;
const
int
W
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
x_dims
.
size
()
-
1
]
:
x_dims
[
x_dims
.
size
()
-
2
]);
...
...
paddle/fluid/operators/group_norm_op.h
浏览文件 @
1cbffbc4
...
...
@@ -52,7 +52,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
const
int
C
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
1
]
:
x_dims
[
x_dims
.
size
()
-
1
]);
const
int
group_size
=
(
C
-
1
)
/
groups
+
1
;
const
int
group_size
=
C
/
groups
;
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
@@ -100,7 +100,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
int
imid
;
for
(
imid
=
0
;
imid
<
imsize
-
(
imsize
%
M
);
imid
+=
M
,
iter_x_data
+=
M
)
{
// TODO(gaoxiang)
:
Because AVX/AVX2/AVX512 can not directly used
// TODO(gaoxiang)
:
Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
...
...
@@ -138,7 +138,7 @@ class GroupNormKernel : public framework::OpKernel<T> {
int
imid
;
for
(
imid
=
0
;
imid
<
imsize
-
(
imsize
%
M
);
imid
+=
M
,
iter_x_data
+=
M
*
C
)
{
// TODO(gaoxiang)
:
Because AVX/AVX2/AVX512 can not directly used
// TODO(gaoxiang)
:
Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
...
...
@@ -236,7 +236,7 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const
int
C
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
1
]
:
x_dims
[
x_dims
.
size
()
-
1
]);
const
int
group_size
=
(
C
-
1
)
/
groups
+
1
;
const
int
group_size
=
C
/
groups
;
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
...
...
paddle/fluid/operators/index_select_op_npu.cc
浏览文件 @
1cbffbc4
...
...
@@ -21,12 +21,12 @@ namespace operators {
template
<
typename
DeviceContext
,
typename
T
>
class
IndexSelectNPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
dim
=
ctx
.
Attr
<
int
>
(
"dim"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
stream
=
...
...
@@ -43,7 +43,104 @@ class IndexSelectNPUKernel : public framework::OpKernel<T> {
}
};
// todo: add class 'IndexSelectGradNPUKernel' here.
template
<
typename
DeviceContext
,
typename
T
>
class
IndexSelectGradNPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x_grad
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
out_grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
stream
=
ctx
.
template
device_context
<
paddle
::
platform
::
NPUDeviceContext
>()
.
stream
();
auto
x_dims
=
x_grad
->
dims
();
auto
out_dims
=
out_grad
->
dims
();
int
dim
=
ctx
.
Attr
<
int
>
(
"dim"
);
if
(
dim
<
0
)
{
dim
+=
out_dims
.
size
();
}
Tensor
casted_index
;
if
(
index
->
type
()
!=
framework
::
proto
::
VarType
::
INT32
)
{
casted_index
.
mutable_data
<
int32_t
>
(
index
->
dims
(),
ctx
.
GetPlace
());
const
auto
&
cast_runner
=
NpuOpRunner
(
"Cast"
,
{
*
index
},
{
casted_index
},
{{
"dst_type"
,
ACL_INT32
}});
cast_runner
.
Run
(
stream
);
}
else
{
casted_index
.
ShareDataWith
(
*
index
);
}
if
(
dim
==
0
)
{
x_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
auto
&
zeros_runner
=
NpuOpRunner
(
"ZerosLike"
,
{
*
x_grad
},
{
*
x_grad
});
zeros_runner
.
Run
(
stream
);
NpuOpRunner
runner
;
runner
.
SetType
(
"UnsortedSegmentSum"
)
.
AddInput
(
*
out_grad
)
.
AddInput
(
casted_index
)
.
AddInput
(
std
::
vector
<
int64_t
>
{
x_dims
[
dim
]})
.
AddOutput
(
*
x_grad
);
runner
.
Run
(
stream
);
}
else
{
Tensor
transed_out_grad
;
std
::
vector
<
int
>
in_trans_perm
;
in_trans_perm
.
push_back
(
dim
);
for
(
int
i
=
0
;
i
<
out_dims
.
size
();
++
i
)
{
if
(
i
==
dim
)
continue
;
in_trans_perm
.
push_back
(
i
);
}
framework
::
DDim
transed_out_dims
(
out_dims
);
for
(
size_t
i
=
0
;
i
<
in_trans_perm
.
size
();
++
i
)
{
transed_out_dims
[
i
]
=
out_dims
[
in_trans_perm
[
i
]];
}
transed_out_grad
.
mutable_data
<
T
>
(
transed_out_dims
,
ctx
.
GetPlace
());
framework
::
NPUAttributeMap
in_trans_attr
=
{{
"perm"
,
in_trans_perm
}};
const
auto
&
in_trans_runner
=
NpuOpRunner
(
"TransposeD"
,
{
*
out_grad
},
{
transed_out_grad
},
in_trans_attr
);
in_trans_runner
.
Run
(
stream
);
Tensor
sum_out
;
framework
::
DDim
sum_dims
(
x_dims
);
sum_dims
[
0
]
=
x_dims
[
dim
];
auto
idx
=
1
;
for
(
int
i
=
0
;
i
<
x_dims
.
size
();
++
i
)
{
if
(
i
==
dim
)
continue
;
sum_dims
[
idx
++
]
=
x_dims
[
i
];
}
sum_out
.
mutable_data
<
T
>
(
sum_dims
,
ctx
.
GetPlace
());
const
auto
&
zeros_runner
=
NpuOpRunner
(
"ZerosLike"
,
{
sum_out
},
{
sum_out
});
zeros_runner
.
Run
(
stream
);
NpuOpRunner
runner
;
runner
.
SetType
(
"UnsortedSegmentSum"
)
.
AddInput
(
transed_out_grad
)
.
AddInput
(
casted_index
)
.
AddInput
(
std
::
vector
<
int64_t
>
{
x_dims
[
dim
]})
.
AddOutput
(
sum_out
);
runner
.
Run
(
stream
);
std
::
vector
<
int
>
out_trans_perm
;
for
(
int
i
=
1
;
i
<
1
+
dim
;
++
i
)
{
out_trans_perm
.
push_back
(
i
);
}
out_trans_perm
.
push_back
(
0
);
for
(
int
i
=
1
+
dim
;
i
<
x_dims
.
size
();
++
i
)
{
out_trans_perm
.
push_back
(
i
);
}
framework
::
NPUAttributeMap
out_trans_attr
=
{{
"perm"
,
out_trans_perm
}};
x_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
auto
&
out_trans_runner
=
NpuOpRunner
(
"TransposeD"
,
{
sum_out
},
{
*
x_grad
},
out_trans_attr
);
out_trans_runner
.
Run
(
stream
);
}
}
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -54,4 +151,8 @@ REGISTER_OP_NPU_KERNEL(
ops
::
IndexSelectNPUKernel
<
paddle
::
platform
::
NPUDeviceContext
,
float
>
,
ops
::
IndexSelectNPUKernel
<
paddle
::
platform
::
NPUDeviceContext
,
int
>
,
ops
::
IndexSelectNPUKernel
<
paddle
::
platform
::
NPUDeviceContext
,
int64_t
>
);
// todo: register npu index_select_grad kernel here.
REGISTER_OP_NPU_KERNEL
(
index_select_grad
,
ops
::
IndexSelectGradNPUKernel
<
paddle
::
platform
::
NPUDeviceContext
,
float
>
,
ops
::
IndexSelectGradNPUKernel
<
paddle
::
platform
::
NPUDeviceContext
,
int
>
,
ops
::
IndexSelectGradNPUKernel
<
paddle
::
platform
::
NPUDeviceContext
,
int64_t
>
);
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
浏览文件 @
1cbffbc4
...
...
@@ -142,32 +142,103 @@ class GradientClipHelper(object):
return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def
sync_global_norm
(
self
,
block
,
ring_ids
):
def
sync_global_norm
(
self
,
block
,
ring_ids
,
mp_rank
):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
# FIXME(wangxi): mp should prune duplicated param_grads
is_clip_grad_by_global_norm
=
False
for
idx
,
op
in
list
(
enumerate
(
block
.
ops
)):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
'sum'
:
is_clip_grad_by_global_norm
=
True
break
if
not
is_clip_grad_by_global_norm
:
# TODO(Yuang Liu): need some extra handles when clip_grad_norm for mp
return
removed_op_idx
=
set
()
removed_tmp_var
=
set
()
for
idx
,
op
in
list
(
enumerate
(
block
.
ops
)):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
'sum'
:
break
for
input_name
in
op
.
input_arg_names
:
input_var
=
block
.
var
(
input_name
)
# NOTE: when mp_degree > 1, some vars will be split into each mp rank.
# However, there still some vars such as Scale, Bias are not split.
# Those not be split vars should only be counted once during grad clip
# by global norm. Those vars either doesn't have is_distributed attr
# or the is_distributed attr has been set as False.
# Therefore, we prune those duplicated vars for grad clip.
if
mp_rank
>=
1
and
(
not
(
hasattr
(
input_var
,
'is_distributed'
)
and
input_var
.
is_distributed
)):
removed_op_idx
.
add
(
idx
)
for
output_name
in
op
.
output_arg_names
:
removed_tmp_var
.
add
(
output_name
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
idx
in
removed_op_idx
:
block
.
_remove_op
(
idx
,
sync
=
False
)
if
op
.
type
==
"sum"
:
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
for
var_name
in
removed_tmp_var
:
block
.
_remove_var
(
var_name
,
sync
=
False
)
idx
=
idx
+
1
block
.
_insert_op_without_sync
(
idx
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'ring_id'
:
ring_id
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
return
for
idx
,
op
in
list
(
enumerate
(
block
.
ops
)):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
'sum'
:
# If mp_rank == 0, no extra handles, just allreduce
# If mp_rank >= 1, some extra handles is needed
sum_rst_var
=
block
.
var
(
op
.
output_arg_names
[
0
])
if
mp_rank
>=
1
:
reserved_vars
=
[]
for
input_name
in
op
.
input_arg_names
:
if
input_name
not
in
removed_tmp_var
:
reserved_vars
.
append
(
input_name
)
if
len
(
reserved_vars
)
>
0
:
op
.
desc
.
set_input
(
"X"
,
reserved_vars
)
else
:
# If all input of sum op should be removed, then remove the sum op.
# And set the output's value of sum to 0.
namescope
=
op
.
attr
(
"op_namescope"
)
block
.
_remove_op
(
idx
,
sync
=
False
)
fill_constant_op
=
block
.
_insert_op_without_sync
(
idx
,
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:
sum_rst_var
},
attrs
=
{
'shape'
:
sum_rst_var
.
shape
,
'dtype'
:
sum_rst_var
.
dtype
,
'value'
:
0.0
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
fill_constant_op
.
_set_attr
(
'op_namescope'
,
namescope
)
self
.
_insert_allreduce
(
block
,
ring_ids
,
idx
,
sum_rst_var
)
break
@
staticmethod
def
_insert_allreduce
(
block
,
ring_ids
,
idx
,
var
):
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
idx
=
idx
+
1
block
.
_insert_op_without_sync
(
idx
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
var
},
outputs
=
{
'Out'
:
var
},
attrs
=
{
'ring_id'
:
ring_id
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
1cbffbc4
...
...
@@ -435,7 +435,6 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block
=
self
.
_main_program
.
global_block
()
startup_block
=
self
.
_startup_program
.
global_block
()
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var
rings
=
[
self
.
mp_ring_id
,
self
.
pp_ring_id
]
...
...
@@ -446,7 +445,7 @@ class ShardingOptimizer(MetaOptimizerBase):
gradientclip_helper
=
GradientClipHelper
(
None
)
gradientclip_helper
.
sync_global_norm
(
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
])
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
]
,
self
.
mp_rank
)
def
_insert_loss_grad_scale_op
(
self
):
main_block
=
self
.
_main_program
.
global_block
()
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
1cbffbc4
...
...
@@ -4381,7 +4381,7 @@ class PipelineOptimizer(object):
persistable
=
source_var
.
persistable
)
else
:
dest_var
=
block
.
_clone_variable
(
source_var
,
False
)
dest_var
.
stop_gradient
=
source_var
.
stop_gradient
self
.
_clone_var_attr
(
dest_var
,
source_var
)
# When use with sharding, allreduce_sum and allreduce_max
# used for global gradient clip and amp will be added by sharding.
op_idx
+=
1
...
...
@@ -4547,9 +4547,14 @@ class PipelineOptimizer(object):
persistable
=
ref_var
.
persistable
,
is_data
=
ref_var
.
is_data
,
need_check_feed
=
ref_var
.
desc
.
need_check_feed
())
new_var
.
stop_gradient
=
ref_var
.
stop_gradient
self
.
_clone_var_attr
(
new_var
,
ref_var
)
return
new_var
def
_clone_var_attr
(
self
,
dest
,
src
):
dest
.
stop_gradient
=
src
.
stop_gradient
if
hasattr
(
src
,
'is_distributed'
):
dest
.
is_distributed
=
src
.
is_distributed
def
_strip_grad_suffix
(
self
,
name
):
"""
Strip the grad suffix from the given variable name
...
...
@@ -5209,6 +5214,8 @@ class PipelineOptimizer(object):
persistable
=
True
,
stop_gradient
=
False
)
real_param
=
main_block
.
var
(
param
)
if
hasattr
(
real_param
,
'is_distributed'
):
merged_grad_var
.
is_distributed
=
real_param
.
is_distributed
tmp_size
=
self
.
_get_var_size
(
real_grad
)
# two strategies for splitting the grad
# 1. the current segment's size reach the user defined grad_size_in_MB
...
...
python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py
浏览文件 @
1cbffbc4
...
...
@@ -35,7 +35,10 @@ class TestNPUIndexSelect(OpTest):
x_np
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
x_type
)
index_np
=
np
.
random
.
randint
(
low
=
0
,
high
=
self
.
x_shape
[
self
.
dim
],
size
=
self
.
index_size
)
low
=
0
,
high
=
self
.
x_shape
[
self
.
dim
],
size
=
self
.
index_size
,
dtype
=
self
.
index_type
)
# compute real output as baseline.
outer_loop
=
np
.
prod
(
self
.
x_shape
[:
self
.
dim
])
...
...
@@ -56,18 +59,14 @@ class TestNPUIndexSelect(OpTest):
self
.
attrs
=
{
'dim'
:
self
.
dim
}
self
.
outputs
=
{
'Out'
:
out
}
# todo: comment second line when index_select grad npu op is ready.
def
set_npu
(
self
):
self
.
__class__
.
use_npu
=
True
self
.
__class__
.
no_need_check_grad
=
True
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
# todo: replace first line with second line when index_select grad npu op is ready.
def
test_check_grad
(
self
):
pass
#self.check_grad_with_place(self.place, ['X'], 'Out')
self
.
check_grad_with_place
(
self
.
place
,
[
'X'
],
'Out'
)
def
config
(
self
):
self
.
x_shape
=
(
100
,
4
,
5
)
...
...
@@ -86,6 +85,24 @@ class TestNPUIndexSelectCase2(TestNPUIndexSelect):
self
.
index_size
=
10
class
TestNPUIndexSelectCase3
(
TestNPUIndexSelect
):
def
config
(
self
):
self
.
dim
=
0
self
.
x_type
=
np
.
float32
self
.
index_type
=
np
.
int32
self
.
x_shape
=
(
10
,
10
,
4
,
10
)
self
.
index_size
=
10
class
TestNPUIndexSelectCase4
(
TestNPUIndexSelect
):
def
config
(
self
):
self
.
dim
=
-
1
self
.
x_type
=
np
.
float32
self
.
index_type
=
np
.
int32
self
.
x_shape
=
(
10
,
10
,
4
,
10
)
self
.
index_size
=
10
class
TestNPUIndexSelectAPI
(
unittest
.
TestCase
):
def
input_data
(
self
):
self
.
data_x
=
np
.
array
([[
1.0
,
2.0
,
3.0
,
4.0
],
[
5.0
,
6.0
,
7.0
,
8.0
],
...
...
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
浏览文件 @
1cbffbc4
...
...
@@ -658,6 +658,33 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id'
,
'c_comm_init'
])
self
.
assertEqual
(
main_prog_op_types
,
[
'partial_recv'
,
'partial_allgather'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'softmax'
,
'cast'
,
'cross_entropy2'
,
'mean'
,
'elementwise_mul'
,
'fill_constant'
,
'elementwise_mul_grad'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'cast'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'c_sync_calc_stream'
,
'partial_send'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'fill_constant'
,
'cast'
,
'sum'
,
'c_sync_comm_stream'
,
'check_finite_and_unscale'
,
'cast'
,
'c_allreduce_max'
,
'c_allreduce_max'
,
'cast'
,
'update_loss_scaling'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'sqrt'
,
'fill_constant'
,
'elementwise_max'
,
'elementwise_div'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
,
'momentum'
])
# pp + mp, partial send recv
self
.
assertIn
(
'partial_recv'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_allgather'
,
main_prog_op_types
)
...
...
python/paddle/fluid/tests/unittests/test_segment_ops.py
浏览文件 @
1cbffbc4
...
...
@@ -15,8 +15,11 @@
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
import
numpy
as
np
import
paddle
from
op_test
import
OpTest
...
...
@@ -198,5 +201,62 @@ class TestSegmentMean2(TestSegmentMean):
self
.
attrs
=
{
'pooltype'
:
"MEAN"
}
class
API_SegmentOpsTest
(
unittest
.
TestCase
):
def
test_static
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
static
.
data
(
name
=
"x"
,
shape
=
[
3
,
3
],
dtype
=
"float32"
)
y
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
[
3
],
dtype
=
'int32'
)
res_sum
=
paddle
.
incubate
.
segment_sum
(
x
,
y
)
res_mean
=
paddle
.
incubate
.
segment_mean
(
x
,
y
)
res_max
=
paddle
.
incubate
.
segment_max
(
x
,
y
)
res_min
=
paddle
.
incubate
.
segment_min
(
x
,
y
)
exe
=
paddle
.
static
.
Executor
(
paddle
.
CPUPlace
())
data1
=
np
.
array
([[
1
,
2
,
3
],
[
3
,
2
,
1
],
[
4
,
5
,
6
]],
dtype
=
'float32'
)
data2
=
np
.
array
([
0
,
0
,
1
],
dtype
=
"int32"
)
np_sum
=
np
.
array
([[
4
,
4
,
4
],
[
4
,
5
,
6
]],
dtype
=
"float32"
)
np_mean
=
np
.
array
([[
2
,
2
,
2
],
[
4
,
5
,
6
]],
dtype
=
"float32"
)
np_max
=
np
.
array
([[
3
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
"float32"
)
np_min
=
np
.
array
([[
1
,
2
,
1
],
[
4
,
5
,
6
]],
dtype
=
"float32"
)
ret
=
exe
.
run
(
feed
=
{
'x'
:
data1
,
'y'
:
data2
},
fetch_list
=
[
res_sum
,
res_mean
,
res_max
,
res_min
])
for
np_res
,
ret_res
in
zip
([
np_sum
,
np_mean
,
np_max
,
np_min
],
ret
):
self
.
assertTrue
(
np
.
allclose
(
np_res
,
ret_res
,
atol
=
1e-6
),
"two value is
\
{}
\n
{}, check diff!"
.
format
(
np_res
,
ret_res
))
def
test_dygraph
(
self
):
device
=
paddle
.
CPUPlace
()
with
paddle
.
fluid
.
dygraph
.
guard
(
device
):
x
=
paddle
.
to_tensor
(
[[
1
,
2
,
3
],
[
3
,
2
,
1
],
[
4
,
5
,
6
]],
dtype
=
'float32'
)
y
=
paddle
.
to_tensor
([
0
,
0
,
1
],
dtype
=
"int32"
)
res_sum
=
paddle
.
incubate
.
segment_sum
(
x
,
y
)
res_mean
=
paddle
.
incubate
.
segment_mean
(
x
,
y
)
res_max
=
paddle
.
incubate
.
segment_max
(
x
,
y
)
res_min
=
paddle
.
incubate
.
segment_min
(
x
,
y
)
np_sum
=
np
.
array
([[
4
,
4
,
4
],
[
4
,
5
,
6
]],
dtype
=
"float32"
)
np_mean
=
np
.
array
([[
2
,
2
,
2
],
[
4
,
5
,
6
]],
dtype
=
"float32"
)
np_max
=
np
.
array
([[
3
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
"float32"
)
np_min
=
np
.
array
([[
1
,
2
,
1
],
[
4
,
5
,
6
]],
dtype
=
"float32"
)
ret
=
[
res_sum
,
res_mean
,
res_max
,
res_min
]
for
np_res
,
ret_res
in
zip
([
np_sum
,
np_mean
,
np_max
,
np_min
],
ret
):
self
.
assertTrue
(
np
.
allclose
(
np_res
,
ret_res
.
numpy
(),
atol
=
1e-6
),
"two value is
\
{}
\n
{}, check diff!"
.
format
(
np_res
,
ret_res
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/__init__.py
浏览文件 @
1cbffbc4
...
...
@@ -18,7 +18,18 @@ from .checkpoint import auto_checkpoint # noqa: F401
from
..fluid.layer_helper
import
LayerHelper
# noqa: F401
from
.operators
import
softmax_mask_fuse_upper_triangle
# noqa: F401
from
.operators
import
softmax_mask_fuse
# noqa: F401
from
.tensor
import
segment_sum
from
.tensor
import
segment_mean
from
.tensor
import
segment_max
from
.tensor
import
segment_min
__all__
=
[
# noqa
'LookAhead'
,
'ModelAverage'
,
'softmax_mask_fuse_upper_triangle'
,
'softmax_mask_fuse'
__all__
=
[
'LookAhead'
,
'ModelAverage'
,
'softmax_mask_fuse_upper_triangle'
,
'softmax_mask_fuse'
,
'segment_sum'
,
'segment_mean'
,
'segment_max'
,
'segment_min'
,
]
python/paddle/incubate/tensor/__init__.py
0 → 100644
浏览文件 @
1cbffbc4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.math
import
segment_sum
from
.math
import
segment_mean
from
.math
import
segment_max
from
.math
import
segment_min
__all__
=
[
'segment_sum'
,
'segment_mean'
,
'segment_max'
,
'segment_min'
,
]
python/paddle/incubate/tensor/math.py
0 → 100644
浏览文件 @
1cbffbc4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[
'segment_sum'
,
'segment_mean'
,
'segment_max'
,
'segment_min'
,
]
import
paddle
from
paddle.fluid.layer_helper
import
LayerHelper
,
in_dygraph_mode
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
from
paddle
import
_C_ops
def
segment_sum
(
data
,
segment_ids
,
name
=
None
):
"""
Segment Sum Operator.
This operator sums the elements of input `data` which with
the same index in `segment_ids`.
It computes a tensor such that $out_i =
\\
sum_{j} data_{j}$
where sum is over j such that `segment_ids[j] == i`.
Args:
data (Tensor): A tensor, available data type float32, float64.
segment_ids (Tensor): A 1-D tensor, which have the same size
with the first dimension of input data.
Available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_sum(data, segment_ids)
#Outputs: [[4., 4., 4.], [4., 5., 6.]]
"""
if
in_dygraph_mode
():
out
,
tmp
=
_C_ops
.
segment_pool
(
data
,
segment_ids
,
'pooltype'
,
"SUM"
)
return
out
check_variable_and_dtype
(
data
,
"X"
,
(
"float32"
,
"float64"
),
"segment_pool"
)
check_variable_and_dtype
(
segment_ids
,
"SegmentIds"
,
(
"int32"
,
"int64"
),
"segment_pool"
)
helper
=
LayerHelper
(
"segment_sum"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
data
.
dtype
)
summed_ids
=
helper
.
create_variable_for_type_inference
(
dtype
=
data
.
dtype
)
helper
.
append_op
(
type
=
"segment_pool"
,
inputs
=
{
"X"
:
data
,
"SegmentIds"
:
segment_ids
},
outputs
=
{
"Out"
:
out
,
"SummedIds"
:
summed_ids
},
attrs
=
{
"pooltype"
:
"SUM"
})
return
out
def
segment_mean
(
data
,
segment_ids
,
name
=
None
):
"""
Segment mean Operator.
Ihis operator calculate the mean value of input `data` which
with the same index in `segment_ids`.
It computes a tensor such that $out_i =
\\
frac{1}{n_i}
\\
sum_{j} data[j]$
where sum is over j such that 'segment_ids[j] == i' and $n_i$ is the number
of all index 'segment_ids[j] == i'.
Args:
data (tensor): a tensor, available data type float32, float64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_mean(data, segment_ids)
#Outputs: [[2., 2., 2.], [4., 5., 6.]]
"""
if
in_dygraph_mode
():
out
,
tmp
=
_C_ops
.
segment_pool
(
data
,
segment_ids
,
'pooltype'
,
"MEAN"
)
return
out
check_variable_and_dtype
(
data
,
"X"
,
(
"float32"
,
"float64"
),
"segment_pool"
)
check_variable_and_dtype
(
segment_ids
,
"SegmentIds"
,
(
"int32"
,
"int64"
),
"segment_pool"
)
helper
=
LayerHelper
(
"segment_mean"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
data
.
dtype
)
summed_ids
=
helper
.
create_variable_for_type_inference
(
dtype
=
data
.
dtype
)
helper
.
append_op
(
type
=
"segment_pool"
,
inputs
=
{
"X"
:
data
,
"SegmentIds"
:
segment_ids
},
outputs
=
{
"Out"
:
out
,
"SummedIds"
:
summed_ids
},
attrs
=
{
"pooltype"
:
"MEAN"
})
return
out
def
segment_min
(
data
,
segment_ids
,
name
=
None
):
"""
Segment min operator.
This operator calculate the minimum elements of input `data` which with
the same index in `segment_ids`.
It computes a tensor such that $out_i =
\\
min_{j} data_{j}$
where min is over j such that `segment_ids[j] == i`.
Args:
data (tensor): a tensor, available data type float32, float64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_min(data, segment_ids)
#Outputs: [[1., 2., 1.], [4., 5., 6.]]
"""
if
in_dygraph_mode
():
out
,
tmp
=
_C_ops
.
segment_pool
(
data
,
segment_ids
,
'pooltype'
,
"MIN"
)
return
out
check_variable_and_dtype
(
data
,
"X"
,
(
"float32"
,
"float64"
),
"segment_pool"
)
check_variable_and_dtype
(
segment_ids
,
"SegmentIds"
,
(
"int32"
,
"int64"
),
"segment_pool"
)
helper
=
LayerHelper
(
"segment_min"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
data
.
dtype
)
summed_ids
=
helper
.
create_variable_for_type_inference
(
dtype
=
data
.
dtype
)
helper
.
append_op
(
type
=
"segment_pool"
,
inputs
=
{
"X"
:
data
,
"SegmentIds"
:
segment_ids
},
outputs
=
{
"Out"
:
out
,
"SummedIds"
:
summed_ids
},
attrs
=
{
"pooltype"
:
"MIN"
})
return
out
def
segment_max
(
data
,
segment_ids
,
name
=
None
):
"""
Segment max operator.
This operator calculate the maximum elements of input `data` which with
the same index in `segment_ids`.
It computes a tensor such that $out_i =
\\
min_{j} data_{j}$
where max is over j such that `segment_ids[j] == i`.
Args:
data (tensor): a tensor, available data type float32, float64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
Returns:
output (Tensor): the reduced result.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32')
segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32')
out = paddle.incubate.segment_max(data, segment_ids)
#Outputs: [[3., 2., 3.], [4., 5., 6.]]
"""
if
in_dygraph_mode
():
out
,
tmp
=
_C_ops
.
segment_pool
(
data
,
segment_ids
,
'pooltype'
,
"MAX"
)
return
out
check_variable_and_dtype
(
data
,
"X"
,
(
"float32"
,
"float64"
),
"segment_pool"
)
check_variable_and_dtype
(
segment_ids
,
"SegmentIds"
,
(
"int32"
,
"int64"
),
"segment_pool"
)
helper
=
LayerHelper
(
"segment_max"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
data
.
dtype
)
summed_ids
=
helper
.
create_variable_for_type_inference
(
dtype
=
data
.
dtype
)
helper
.
append_op
(
type
=
"segment_pool"
,
inputs
=
{
"X"
:
data
,
"SegmentIds"
:
segment_ids
},
outputs
=
{
"Out"
:
out
,
"SummedIds"
:
summed_ids
},
attrs
=
{
"pooltype"
:
"MAX"
})
return
out
python/setup.py.in
浏览文件 @
1cbffbc4
...
...
@@ -162,6 +162,7 @@ packages=['paddle',
'paddle.incubate.optimizer',
'paddle.incubate.checkpoint',
'paddle.incubate.operators',
'paddle.incubate.tensor',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.elastic',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录