Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e0e0c0fa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e0e0c0fa
编写于
3年前
作者:
Y
Yuang Liu
提交者:
GitHub
3年前
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sync calc stream and add ut for fuse on gpu (#33580)
上级
2d7ef7ad
develop
1.8.5
2.4.1
Ligoml-patch-1
OliverLPH-patch-1
OliverLPH-patch-2
PaddlePM-patch-1
PaddlePM-patch-2
ZHUI-patch-1
add_kylinv10
add_some_yaml_config
addfile
ascendrelease
bugfix-eval-frame-leakgae
cherry-pick-fix-customOP-random-fail
cherry_undefined_var
cp_2.4_fix_numpy
delete_delete_addfile
delete_disable_iterable_dataset_unittest
delete_fix_dataloader_memory_leak
delete_fix_retry_ci
delete_fix_undefined_var
delete_improve_sccache
delete_paralleltest
delete_prv-disable-more-cache
delete_revert-34159-add_npu_bce_logical_dev
delete_revert-34910-spinlocks_for_allocator
delete_revert-35069-revert-34910-spinlocks_for_allocator
delete_revert-36057-dev/read_flags_in_ut
dingjiaweiww-patch-1
disable_iterable_dataset_unittest
dy2static
enable_eager_model_test
final_state_gen_python_c
final_state_intermediate
fix-numpy-issue
fix-run-program-grad-node-mem
fix_check
fix_concat_slice
fix_custom_device_copy_sync
fix_dataloader_memory_leak
fix_dlpack_for
fix_newexe_gc
fix_npu_ci
fix_op_flops
fix_retry_ci
fix_rnn_docs
fix_tensor_type
fix_undefined_var
fix_var_stop_gradient_error
fixiscan
fixiscan1
fixiscan2
fixiscan3
hack_event
improve_sccache
incuabte/new_frl
incubate/frl_train_eval
incubate/infrt
incubate/new_frl
incubate/new_frl_rc
incubate/stride
inplace_addto
layer_norm
make_flag_adding_easier
matmul_double_grad
move_embedding_to_phi
move_histogram_to_pten
move_sgd_to_phi
move_slice_to_pten
move_temporal_shift_to_phi
move_yolo_box_to_phi
npu_fix_alloc
operator_opt
paralleltest
pass-compile-eval-frame
preln_ernie
prv-disable-more-cache
prv-md-even-more
prv-onednn-2.5
prv-reshape-mkldnn-ut2
pten_tensor_refactor
release-deleted/2.5
release-rc/2.5
release/2.2
release/2.3
release/2.3-fc-ernie-fix
release/2.4
release/2.5
release/llm_2.5
revert-33475-fix_cifar_label_dimension
revert-34159-add_npu_bce_logical_dev
revert-34406-add_copy_from_tensor
revert-34910-spinlocks_for_allocator
revert-35069-revert-34910-spinlocks_for_allocator
revert-36057-dev/read_flags_in_ut
revert-36201-refine_fast_threaded_ssa_graph_executor
revert-36985-add_license
revert-37318-refactor_dygraph_to_eager
revert-37926-eager_coreops_500
revert-37956-revert-37727-pylayer_support_tuple
revert-38100-mingdong
revert-38301-allocation_rearrange_pr
revert-38703-numpy_bf16_package_reupload
revert-38732-remove_useless_header_in_elementwise_mul_grad
revert-38959-Reduce_Grad
revert-39143-adjust_empty
revert-39227-move_trace_op_to_pten
revert-39268-dev/remove_concat_fluid_kernel
revert-40170-support_partial_grad
revert-41056-revert-40727-move_some_activaion_to_phi
revert-41065-revert-40993-mv_ele_floordiv_pow
revert-41068-revert-40790-phi_new
revert-41944-smaller_inference_api_test
revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator
revert-43155-fix_ut_tempfile
revert-43882-revert-41944-smaller_inference_api_test
revert-45808-phi/simplify_size_op
revert-46827-deform_comment
revert-47325-remove_cudnn_hardcode
revert-47645-add_npu_storage_dims
revert-48815-set_free_when_no_cache_hit_default_value_true
revert-49499-test_ninja_on_ci
revert-49654-prim_api_gen
revert-49673-modify_get_single_cov
revert-49763-fix_static_composite_gen
revert-50158-fix_found_inf_bug_for_custom_optimizer
revert-50188-refine_optimizer_create_accumulators
revert-50335-fix_optminizer_set_auxiliary_var_bug
revert-51676-flag_delete
revert-51850-fix_softmaxce_dev
revert-52175-dev_peak_memory
revert-52186-deve
revert-52523-test_py38
revert-52912-develop
revert-53248-set_cmake_policy
revert-54029-fix_windows_compile_bug
revert-54068-support_translating_op_attribute
revert-54214-modify_cmake_dependencies
revert-54370-offline_pslib
revert-54391-fix_cmake_md5error
revert-54411-fix_cpp17_compile
revert-54466-offline_pslib
revert-54480-cmake-rocksdb
revert-55568-fix_BF16_bug1
revert-56328-new_ir_support_vector_type_place_transfer
revert-56366-fix_openssl_bug
revert-56545-revert-56366-fix_openssl_bug
revert-56620-fix_new_ir_ocr_bug
revert-56925-check_inputs_grad_semantic
revert-57005-refine_stride_flag
sd_conv_linear_autocast
semi-auto/rule-base
support-0D-sort
support_weight_transpose
test_for_Filtetfiles
zhiqiu-patch-1
v2.5.1
v2.5.0
v2.5.0-rc1
v2.5.0-rc0
v2.4.2
v2.4.1
v2.4.0
v2.4.0-rc0
v2.3.2
v2.3.1
v2.3.0
v2.3.0-rc0
v2.2.2
v2.2.1
v2.2.0
v2.2.0-rc0
v2.2.0-bak0
无相关合并请求
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
291 addition
and
130 deletion
+291
-130
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+24
-0
python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py
...istributed/fleet/meta_optimizers/raw_program_optimizer.py
+108
-130
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/dist_fleet_raw_program_optimizer_fuse_allreduce.py
...ttests/dist_fleet_raw_program_optimizer_fuse_allreduce.py
+112
-0
python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer_fuse_allreduce.py
...s/test_dist_fleet_raw_program_optimizer_fuse_allreduce.py
+45
-0
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
e0e0c0fa
...
...
@@ -177,6 +177,7 @@ message DistributedStrategy {
optional
bool
tensor_parallel
=
29
[
default
=
false
];
optional
bool
without_graph_optimization
=
30
[
default
=
false
];
optional
int32
fuse_grad_size_in_num
=
31
[
default
=
1
];
optional
bool
calc_comm_same_stream
=
32
[
default
=
false
];
optional
RecomputeConfig
recompute_configs
=
101
;
optional
AMPConfig
amp_configs
=
102
;
...
...
This diff is collapsed.
Click to expand it.
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
e0e0c0fa
...
...
@@ -853,6 +853,30 @@ class DistributedStrategy(object):
"WARNING: without_graph_optimization should have value of bool type"
)
@
property
def
_calc_comm_same_stream
(
self
):
"""
This based on raw_program_optimizer program
Set whether use same stream for calc and comm when fuse allreduce
The default value for the calc_comm_same_stream is False
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.calc_comm_same_stream = True
"""
return
self
.
strategy
.
calc_comm_same_stream
@
_calc_comm_same_stream
.
setter
@
is_strict_auto
def
_calc_comm_same_stream
(
self
,
same
):
if
isinstance
(
same
,
bool
):
self
.
strategy
.
calc_comm_same_stream
=
same
else
:
print
(
"WARNING: calc_comm_same_stream should have value of boolean type"
)
@
property
def
fuse_grad_size_in_num
(
self
):
"""
...
...
This diff is collapsed.
Click to expand it.
python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py
浏览文件 @
e0e0c0fa
...
...
@@ -44,6 +44,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
self
.
fuse_all_reduce_ops
=
user_defined_strategy
.
fuse_all_reduce_ops
if
self
.
fuse_all_reduce_ops
:
self
.
fuse_grad_size_in_num
=
user_defined_strategy
.
fuse_grad_size_in_num
self
.
calc_comm_same_stream
=
user_defined_strategy
.
_calc_comm_same_stream
def
_can_apply
(
self
):
if
not
self
.
role_maker
.
_is_collective
:
...
...
@@ -130,8 +131,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
def
_transpile_main_program
(
self
,
loss
):
self
.
_insert_loss_grad_ops
(
loss
)
if
self
.
fuse_all_reduce_ops
and
core
.
is_compiled_with_npu
():
self
.
_calc_stream
=
True
if
self
.
fuse_all_reduce_ops
:
self
.
_allreduce_fusion_program
()
else
:
self
.
_insert_allreduce_ops
()
...
...
@@ -206,22 +206,30 @@ class RawProgramOptimizer(MetaOptimizerBase):
OP_ROLE_KEY
:
OpRole
.
Backward
})
break
# TODO(Liu yuang): ADD CUDA allreduce_fusion fuction.
# This function helps reduce the input of allreduce by integrating can save communication time.
# This function helps reduce the number of allreduce by integrating op, which can save communication time.
# to use allreduce fuse, follow these codes:
# strategy = paddle.distributed.fleet.DistributedStrategy()
# strategy.without_graph_optimization = True
# strategy.fuse_all_reduce_ops = True
# strategy.calc_comm_same_stream = False
# strategy.fuse_grad_size_in_num = 8
def
_allreduce_fusion_program
(
self
):
block
=
self
.
main_program
.
global_block
()
ring_id
=
self
.
global_ring_id
record_idx
,
allreduce_input_vars
,
allreduce_output_vars
=
[],
[],
[]
block_ops
=
len
(
list
(
enumerate
(
block
.
ops
)
))
ops
=
list
(
enumerate
(
block
.
ops
))
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
idx
,
op
in
reversed
(
ops
):
# we travers the ops reversely
if
is_backward_op
(
op
)
and
\
OP_ROLE_VAR_KEY
in
op
.
attr_names
:
op_role_var
=
op
.
attr
(
OP_ROLE_VAR_KEY
)
if
len
(
op_role_var
)
==
0
:
continue
assert
len
(
op_role_var
)
%
2
==
0
assert
len
(
op_role_var
)
%
2
==
0
,
"vars need to be one param var followed by one grad var, "
\
"but got odd number of vars"
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
# handle vars in each op, each time handle a param and a grad
param_name
=
op_role_var
[
i
]
param
=
block
.
var
(
param_name
)
grad_name
=
op_role_var
[
i
+
1
]
...
...
@@ -229,6 +237,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
if
param
.
is_distributed
:
continue
if
".cast_fp16@GRAD"
in
grad_name
:
# when amp=True get the fp16 param
param_name
=
param_name
+
".cast_fp16"
if
not
block
.
has_var
(
param_name
):
raise
ValueError
(
"op cast name error {}"
.
format
(
...
...
@@ -236,154 +245,102 @@ class RawProgramOptimizer(MetaOptimizerBase):
else
:
param
=
block
.
var
(
param_name
)
if
len
(
allreduce_output_vars
)
==
0
:
allreduce_output_vars
.
append
([
grad
])
allreduce_input_vars
.
append
([
param
])
if
self
.
fuse_grad_size_in_num
==
1
:
record_idx
.
append
([
idx
,
idx
])
continue
record_idx
.
append
([
-
2
,
idx
])
elif
len
(
allreduce_output_vars
[
-
1
])
==
self
.
fuse_grad_size_in_num
:
if
len
(
allreduce_output_vars
)
==
0
or
\
len
(
allreduce_output_vars
[
-
1
])
==
\
self
.
fuse_grad_size_in_num
:
# start of the fusion or last group meets the config size
allreduce_output_vars
.
append
([
grad
])
allreduce_input_vars
.
append
([
param
])
if
self
.
fuse_grad_size_in_num
==
1
:
record_idx
.
append
([
idx
,
idx
])
continue
if
idx
!=
block_ops
-
1
:
record_idx
.
append
([
-
2
,
idx
])
# add the start and end idx to the record idx
record_idx
.
append
([
idx
,
idx
])
else
:
# Current group's size is below the config size
# append grad and param to the last group (current group)
# update the start idx to current op's idx
# Since we travers the ops reversely, the idx is descending
# we update the first entry of each entry for record_idx
allreduce_output_vars
[
-
1
].
append
(
grad
)
allreduce_input_vars
[
-
1
].
append
(
param
)
record_idx
[
-
1
][
0
]
=
idx
if
record_idx
[
-
1
][
0
]
==
-
2
:
record_idx
[
-
1
][
0
]
=
record_idx
[
-
1
][
1
]
assert
len
(
allreduce_output_vars
)
==
len
(
record_idx
),
"It has different lens between the allreduce_output_vars and record_idx."
if
not
allreduce_output_vars
or
not
allreduce_input_vars
:
# nothing needs to be allreduced
return
self
.
vars
=
collections
.
OrderedDict
()
index
,
offset_pos
,
pos
,
offset
=
0
,
0
,
0
,
0
index
,
pos
,
offset
=
0
,
0
,
0
start
,
end
=
record_idx
[
index
]
men_list
=
[
end
,
start
]
# Here we need to explain the flag. When integrating OP, we will encounter different groups of the same Op.
# Because we insert coalesce tensor in reverse ops,
# we need to use flag to record whether the current OP has been inserted into coalesce tensor。
# For example:
# [(3, 2), (2, 2), (1, 0)], (3, 2), (2, 2) using same op, but in different groups.
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
idx
,
op
in
reversed
(
ops
):
if
idx
==
start
:
pos
=
0
flag
=
True
if
end
==
men_list
[
-
1
]
else
False
offset
=
offset_pos
if
flag
else
0
done_output_vars
,
done_input_vars
=
self
.
_split_fuction
(
allreduce_output_vars
[
index
],
allreduce_input_vars
[
index
])
allreduce_output_vars
[
index
],
# grad
allreduce_input_vars
[
index
]
# param
)
for
id_
,
done_output_var
in
enumerate
(
done_output_vars
):
if
flag
:
tmp_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'FusedOutput_{}_{}'
.
format
(
start
,
id_
+
offset
)),
dtype
=
done_output_var
[
0
].
dtype
,
persistable
=
False
,
stop_gradient
=
True
)
self
.
vars
[
'FusedOutput_{}_{}'
.
format
(
start
,
id_
+
offset
)]
=
tmp_var
tmp_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'FusedOutput_{}'
.
format
(
done_output_var
[
0
].
name
)),
dtype
=
done_output_var
[
0
].
dtype
,
persistable
=
False
,
stop_gradient
=
True
)
self
.
vars
[
'FusedOutput_{}'
.
format
(
done_output_var
[
0
]
.
name
)]
=
tmp_var
block
.
_insert_op
(
idx
+
id_
+
offset
,
type
=
"coalesce_tensor"
,
inputs
=
{
"Input"
:
done_input_vars
[
id_
]},
outputs
=
{
"Output"
:
done_output_var
,
"FusedOutput"
:
tmp_var
},
attrs
=
{
"copy_data"
:
False
,
"use_align"
:
True
,
"dtype"
:
done_output_var
[
0
].
dtype
})
pos
+=
1
else
:
tmp_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'FusedOutput_{}_{}'
.
format
(
start
,
id_
)),
dtype
=
done_output_var
[
0
].
dtype
,
persistable
=
False
,
stop_gradient
=
True
)
self
.
vars
[
'FusedOutput_{}_{}'
.
format
(
start
,
id_
)]
=
tmp_var
block
.
_insert_op
(
idx
+
id_
,
type
=
"coalesce_tensor"
,
inputs
=
{
"Input"
:
done_input_vars
[
id_
]},
outputs
=
{
"Output"
:
done_output_var
,
"FusedOutput"
:
tmp_var
},
attrs
=
{
"copy_data"
:
False
,
"use_align"
:
True
,
"dtype"
:
done_output_var
[
0
].
dtype
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
pos
+=
1
block
.
_insert_op
(
idx
+
id_
,
type
=
"coalesce_tensor"
,
inputs
=
{
"Input"
:
done_input_vars
[
id_
]},
outputs
=
{
"Output"
:
done_output_var
,
"FusedOutput"
:
tmp_var
},
attrs
=
{
"copy_data"
:
False
,
"use_align"
:
True
,
"dtype"
:
done_output_var
[
0
].
dtype
})
pos
+=
1
offset_pos
=
pos
# TODO(Liu yuang): ADD CUDA and NPU's EVENT and c_allreduce_sum.
for
id_
in
range
(
len
(
done_output_vars
)):
if
flag
:
block
.
_insert_op
(
end
+
id_
+
pos
+
1
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
self
.
vars
[
'FusedOutput_{}_{}'
.
format
(
start
,
id_
+
offset
)]
},
outputs
=
{
'Out'
:
self
.
vars
[
'FusedOutput_{}_{}'
.
format
(
start
,
id_
+
offset
)]
},
attrs
=
{
'ring_id'
:
ring_id
,
'use_calc_stream'
:
True
if
self
.
_calc_stream
else
False
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
else
:
x
=
self
.
vars
[
'FusedOutput_{}'
.
format
(
done_output_vars
[
id_
][
0
].
name
)]
out
=
x
# NOTE: there still some optimize space if use EVENT instead of sync
if
not
self
.
calc_comm_same_stream
:
# need sync if the calc and comm stream are not the same
block
.
_insert_op
(
end
+
id_
+
pos
+
1
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
self
.
vars
[
'FusedOutput_{}_{}'
.
format
(
start
,
id_
)]
},
outputs
=
{
'Out'
:
self
.
vars
[
'FusedOutput_{}_{}'
.
format
(
start
,
id_
)]
},
attrs
=
{
'ring_id'
:
ring_id
,
'use_calc_stream'
:
True
if
self
.
_calc_stream
else
False
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Backward
})
block
.
_insert_op
(
end
+
id_
+
pos
+
1
if
self
.
calc_comm_same_stream
else
end
+
id_
+
pos
+
2
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'ring_id'
:
ring_id
,
'use_calc_stream'
:
self
.
calc_comm_same_stream
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
index
+=
1
men_list
.
append
(
end
)
men_list
.
append
(
start
)
if
len
(
record_idx
)
==
index
:
start
=
end
=
-
1
continue
break
start
,
end
=
record_idx
[
index
]
if
not
self
.
_calc_stream
:
if
not
self
.
calc_comm_same_stream
:
# need sync if the calc and comm stream are not the same
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
is_optimizer_op
(
op
):
block
.
_insert_op
(
...
...
@@ -397,34 +354,50 @@ class RawProgramOptimizer(MetaOptimizerBase):
})
break
# Integrate grads of the same type to form a combination. If skip_comb is selected, will return grads of the same group.
# Integrate grads of the same type to form a combination.
# If combination is selected, will return grads of the same type in a groups.
# For example:[(fp16, fp16), (fp32), (fp16)] -> [(fp16, fp16, fp16), (fp32)]
def
_split_fuction
(
self
,
allreduce_output_vars
,
allreduce_input_vars
,
skip_comb
=
True
):
combination
=
True
):
input_vars
,
final_input_vars
,
output_vars
,
final_output_vars
=
[],
[],
[],
[]
if
len
(
allreduce_output_vars
)
-
1
==
0
:
if
len
(
allreduce_output_vars
)
==
1
:
# only have one var to handle
final_output_vars
.
append
(
allreduce_output_vars
)
final_input_vars
.
append
(
allreduce_input_vars
)
return
final_output_vars
,
final_input_vars
for
idx
in
range
(
len
(
allreduce_input_vars
)
-
1
):
# the last var needs to be handled differently
if
allreduce_input_vars
[
idx
].
dtype
==
allreduce_input_vars
[
idx
+
1
].
dtype
:
# if current var and next var are in same type
# append current var to input_vars
input_vars
.
append
(
allreduce_input_vars
[
idx
])
if
idx
==
len
(
allreduce_input_vars
)
-
2
:
# if current var is the second last var
# append the last var to input_vars
# and update the final_input_vars
input_vars
.
append
(
allreduce_input_vars
[
idx
+
1
])
final_input_vars
.
append
(
input_vars
)
else
:
# the current var and next var are in different types
# append current var to input_vars
# update the final_input_vars
# reset input_vars to receive a new type
input_vars
.
append
(
allreduce_input_vars
[
idx
])
final_input_vars
.
append
(
input_vars
)
input_vars
=
[]
if
idx
==
len
(
allreduce_input_vars
)
-
2
:
# if current var is the second last var
# append the last var to a reset input_vars since they are in different types
# and update the final_input_vars
input_vars
.
append
(
allreduce_input_vars
[
idx
+
1
])
final_input_vars
.
append
(
input_vars
)
for
idx
in
range
(
len
(
allreduce_output_vars
)
-
1
):
# the procedure for the output vars is the same with that for the input vars
if
allreduce_output_vars
[
idx
].
dtype
==
allreduce_output_vars
[
idx
+
1
].
dtype
:
output_vars
.
append
(
allreduce_output_vars
[
idx
])
...
...
@@ -438,10 +411,14 @@ class RawProgramOptimizer(MetaOptimizerBase):
if
idx
==
len
(
allreduce_output_vars
)
-
2
:
output_vars
.
append
(
allreduce_output_vars
[
idx
+
1
])
final_output_vars
.
append
(
output_vars
)
if
skip_comb
:
# at this time, all vars in each group in final_input_vars and final_output_vars are in the same type
if
combination
:
input_fp16_vars
,
input_fp32_vars
,
output_fp16_vars
,
output_fp32_vars
=
[],
[],
[],
[]
for
final_input_var
in
final_input_vars
:
if
final_input_var
[
0
].
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
# extend the group
input_fp16_vars
.
extend
(
final_input_var
)
else
:
input_fp32_vars
.
extend
(
final_input_var
)
...
...
@@ -451,6 +428,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
output_fp16_vars
.
extend
(
final_output_var
)
else
:
output_fp32_vars
.
extend
(
final_output_var
)
final_output_vars
,
final_input_vars
=
[],
[]
if
output_fp16_vars
:
final_output_vars
.
append
(
output_fp16_vars
)
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
e0e0c0fa
...
...
@@ -718,6 +718,7 @@ if (WITH_DISTRIBUTE)
set_tests_properties
(
test_dist_fleet_sparse_embedding_ctr PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_dist_fleet_infer PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_dist_fleet_raw_program_optimizer PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_dist_fleet_raw_program_optimizer_fuse_allreduce PROPERTIES TIMEOUT 60
)
endif
()
if
(
WITH_DISTRIBUTE AND NOT APPLE
)
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/dist_fleet_raw_program_optimizer_fuse_allreduce.py
0 → 100644
浏览文件 @
e0e0c0fa
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
import
unittest
import
paddle
import
os
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet.base.role_maker
as
role_maker
import
numpy
as
np
from
functools
import
reduce
import
paddle.fluid
as
fluid
paddle
.
enable_static
()
DTYPE
=
"float32"
paddle
.
dataset
.
mnist
.
fetch
()
# Fix seed for test
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
def
cnn_model
(
data
):
conv_pool_1
=
fluid
.
nets
.
simple_img_conv_pool
(
input
=
data
,
filter_size
=
5
,
num_filters
=
20
,
pool_size
=
2
,
pool_stride
=
2
,
act
=
"relu"
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)))
conv_pool_2
=
fluid
.
nets
.
simple_img_conv_pool
(
input
=
conv_pool_1
,
filter_size
=
5
,
num_filters
=
50
,
pool_size
=
2
,
pool_stride
=
2
,
act
=
"relu"
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)))
SIZE
=
10
input_shape
=
conv_pool_2
.
shape
param_shape
=
[
reduce
(
lambda
a
,
b
:
a
*
b
,
input_shape
[
1
:],
1
)]
+
[
SIZE
]
scale
=
(
2.0
/
(
param_shape
[
0
]
**
2
*
SIZE
))
**
0.5
predict
=
fluid
.
layers
.
fc
(
input
=
conv_pool_2
,
size
=
SIZE
,
act
=
"softmax"
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)))
return
predict
class
TestFleetMetaOptimizerFuseAllReducePrecision
(
TestDistRunnerBase
):
def
get_model
(
self
,
batch_size
=
2
,
single_device
=
False
):
# Input data
images
=
fluid
.
layers
.
data
(
name
=
'pixel'
,
shape
=
[
1
,
28
,
28
],
dtype
=
DTYPE
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
# Train program
predict
=
cnn_model
(
images
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
# Evaluator
batch_size_tensor
=
fluid
.
layers
.
create_tensor
(
dtype
=
'int64'
)
batch_acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
,
total
=
batch_size_tensor
)
test_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
# Reader
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
batch_size
)
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
batch_size
)
optimizer
=
paddle
.
fluid
.
optimizer
.
Adam
(
0.01
)
if
single_device
:
optimizer
.
minimize
(
avg_cost
)
else
:
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
fleet
.
init
(
role
)
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
strategy
.
without_graph_optimization
=
True
strategy
.
fuse_all_reduce_ops
=
True
strategy
.
_calc_comm_same_stream
=
False
strategy
.
fuse_grad_size_in_num
=
8
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
avg_cost
)
return
test_program
,
avg_cost
,
train_reader
,
test_reader
,
batch_acc
,
predict
if
__name__
==
"__main__"
:
runtime_main
(
TestFleetMetaOptimizerFuseAllReducePrecision
)
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer_fuse_allreduce.py
0 → 100644
浏览文件 @
e0e0c0fa
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
test_dist_base
import
TestDistBase
import
paddle
import
os
paddle
.
enable_static
()
flag_name
=
os
.
path
.
splitext
(
__file__
)[
0
]
class
TestFleetMetaOptimizerAllReduceFusePrecision
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_use_reduce
=
False
self
.
_use_reader_alloc
=
False
self
.
_nccl2_mode
=
True
self
.
_nccl2_reduce_layer
=
True
self
.
_use_fleet_api
=
True
self
.
_use_fleet_api_20
=
True
def
test_dist_train
(
self
):
import
paddle.fluid
as
fluid
if
fluid
.
core
.
is_compiled_with_cuda
():
self
.
check_with_place
(
"dist_fleet_raw_program_optimizer_fuse_allreduce.py"
,
delta
=
1e-5
,
check_error_log
=
True
,
log_name
=
flag_name
)
if
__name__
==
'__main__'
:
unittest
.
main
()
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部