Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7aeec4ed
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
7aeec4ed
编写于
8月 12, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
8月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Data Parallel Optimization Pass 1 (#44882)
* bugfix * remove scaling * support rescale_grad opt
上级
cf17ae8a
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
382 addition
and
20 deletion
+382
-20
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+10
-0
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+8
-0
python/paddle/distributed/auto_parallel/process_group.py
python/paddle/distributed/auto_parallel/process_group.py
+11
-8
python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py
...dle/distributed/auto_parallel/tuner/optimization_tuner.py
+6
-6
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+15
-0
python/paddle/distributed/passes/__init__.py
python/paddle/distributed/passes/__init__.py
+1
-0
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
...ibuted/passes/auto_parallel_data_parallel_optimization.py
+207
-0
python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt
...e/fluid/tests/unittests/distributed_passes/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
...ttests/distributed_passes/auto_parallel_pass_test_base.py
+12
-6
python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py
...ses/test_auto_parallel_data_parallel_optimization_pass.py
+110
-0
未找到文件。
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
7aeec4ed
...
...
@@ -435,3 +435,13 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names,
return
sync_and_scale_gradients
(
dist_ctx
,
op
,
dp_group
,
out_grad_names
)
def
is_data_parallel_scale_op
(
op
):
return
op
.
type
==
"scale"
and
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
ParallelMode
.
DataParallel
in
op
.
desc
.
attr
(
"op_namescope"
)
def
is_data_parallel_reduce_op
(
op
):
return
op
.
type
in
[
"c_reduce_sum"
,
"c_allreduce_sum"
]
and
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
ParallelMode
.
DataParallel
in
op
.
desc
.
attr
(
"op_namescope"
)
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
7aeec4ed
...
...
@@ -195,6 +195,14 @@ class Parallelizer:
params_grads
):
if
self
.
_strategy
is
None
:
return
# data parallel optimization
config
=
{}
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"global_rank"
]
=
rank
dp_pass
=
new_pass
(
"auto_parallel_data_parallel_optimization"
,
config
)
dp_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
if
self
.
_strategy
.
sharding
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
sharding_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
...
...
python/paddle/distributed/auto_parallel/process_group.py
浏览文件 @
7aeec4ed
...
...
@@ -160,21 +160,24 @@ class ProcessGroup:
def
is_member
(
self
):
return
True
#
def __eq__(self, other):
#
if not isinstance(other, ProcessGroup):
#
return False
#
if self.id != other.id:
#
return False
#
return True
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
ProcessGroup
):
return
False
if
self
.
id
!=
other
.
id
:
return
False
return
True
#
def __ne__(self, other):
#
return not self.__eq__(other)
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
def
__str__
(
self
):
string
=
"id: {}, nranks: {}, ranks: {}."
.
format
(
self
.
id
,
self
.
nranks
,
", "
.
join
(
map
(
str
,
self
.
ranks
)))
return
string
def
__hash__
(
self
):
return
hash
(
self
.
__str__
())
# Note that Process group 0 is reserved for representing all ranks.
# At the beginning, group 0 is empty and new ranks will be added automatically.
...
...
python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py
浏览文件 @
7aeec4ed
...
...
@@ -266,7 +266,7 @@ class OptimizationTuner:
config
[
"input_data"
]
=
self
.
_baseline_dist_context
.
serial_feed_vars
[
"inputs"
]
\
+
self
.
_baseline_dist_context
.
serial_feed_vars
[
"labels"
]
if
config
[
"use_pure_fp16"
]:
config
[
"base_opt"
]
=
dist_context
.
optimizer
config
[
"base_opt"
]
=
dist_context
.
serial_
optimizer
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
.
apply
([
main_program
],
[
startup_program
],
pass_context
)
...
...
@@ -363,11 +363,11 @@ class OptimizationTuner:
profile_args
=
" "
.
join
([
"--rank"
,
str
(
self
.
rank
),
"--device_id"
,
str
(
self
.
device_id
)
,
"--ctx_filename
"
,
ctx_path
,
str
(
self
.
rank
),
"--device_id"
,
str
(
self
.
device_id
),
"--ctx_filename"
,
ctx_path
,
"--profile_start_step"
,
str
(
self
.
_config
.
profile_start_step
),
"--profile_end_step
"
,
str
(
self
.
_config
.
profile_end_step
)
])
cmd_args
=
"-m paddle.distributed.auto_parallel.tuner.profiler"
+
" "
+
profile_args
cmd
=
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
shlex
.
split
(
cmd_args
)
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
7aeec4ed
...
...
@@ -23,6 +23,7 @@ from functools import reduce
import
paddle.fluid.core
as
core
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.distributed.auto_parallel.process_group
import
get_all_process_groups
from
paddle.fluid.io
import
is_parameter
,
is_belong_to_optimizer
from
paddle.distributed.auto_parallel.dist_attribute
import
TensorDistributedAttribute
,
OperatorDistributedAttribute
...
...
@@ -1123,6 +1124,13 @@ def is_loss_op(op):
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
==
(
int
(
OpRole
.
Forward
)
|
int
(
OpRole
.
Loss
))
def
is_loss_grad_op
(
op
):
if
OP_ROLE_KEY
not
in
op
.
attr_names
:
return
False
op_role
=
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
return
op_role
&
int
(
OpRole
.
Backward
)
and
op_role
&
int
(
OpRole
.
Loss
)
def
is_prim_op
(
op
):
return
op
.
type
.
endswith
(
"_p"
)
...
...
@@ -1481,3 +1489,10 @@ def debug_program(program, path, name):
path
,
name
+
'_program'
+
".%d"
%
(
paddle
.
distributed
.
get_rank
()))
with
open
(
filename
,
'w'
)
as
f
:
f
.
write
(
str
(
program
))
def
ring_id_to_process_group
(
ring_id
):
for
g
in
get_all_process_groups
():
if
g
.
id
==
ring_id
:
return
g
return
None
python/paddle/distributed/passes/__init__.py
浏览文件 @
7aeec4ed
...
...
@@ -19,6 +19,7 @@ from .auto_parallel_sharding import *
from
.auto_parallel_amp
import
*
from
.auto_parallel_fp16
import
*
from
.auto_parallel_recompute
import
*
from
.auto_parallel_data_parallel_optimization
import
*
from
.cpp_pass
import
*
import
os
from
.ps_trainer_pass
import
*
...
...
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
0 → 100644
浏览文件 @
7aeec4ed
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
OrderedDict
import
paddle
from
paddle.fluid.framework
import
default_main_program
from
paddle.distributed.auto_parallel.operators.common
import
is_data_parallel_scale_op
,
is_data_parallel_reduce_op
from
paddle.distributed.auto_parallel.utils
import
is_loss_grad_op
,
is_optimize_op
,
ring_id_to_process_group
from
.pass_base
import
PassBase
,
PassType
,
register_pass
# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__
=
[
'lars_momentum'
,
'sparse_momentum'
,
'dgc_momentum'
,
'momentum'
,
'merge_momentum'
]
@
register_pass
(
"auto_parallel_data_parallel_optimization"
)
class
DataParallelOptimizationPass
(
PassBase
):
"""
Apply Optimizations that specialized for data parallelism in Auto Parallel.
1. prune grad scaling
2. overlap comm and calc
3. fuse allreduce
"""
def
__init__
(
self
):
super
(
DataParallelOptimizationPass
,
self
).
__init__
()
# NOTE not use depence on loss and param_grads
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"global_rank"
,
-
1
)
# {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory
self
.
_grad_name_to_group_map
=
OrderedDict
()
# {group1:[grad1, grad2] , group2:[grad3]}
self
.
_group_to_grad_name_map
=
OrderedDict
()
self
.
_support_rescale_grad
=
False
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
return
False
if
(
not
isinstance
(
self
.
get_attr
(
"global_rank"
),
int
))
or
self
.
get_attr
(
"global_rank"
)
<
0
:
return
False
return
True
def
_check_conflict
(
self
,
other_pass
):
return
True
def
_type
(
self
):
return
PassType
.
COMM_OPT
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
self
.
dist_context
=
self
.
get_attr
(
"dist_context"
)
self
.
global_rank
=
int
(
self
.
get_attr
(
"global_rank"
))
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
self
.
_analyze_program
()
self
.
_prune_grad_scaling
()
self
.
_overlap_comm
()
self
.
_fuse_allreduce
()
def
_prune_grad_scaling
(
self
):
if
not
self
.
_could_be_prune
():
return
if
self
.
_all_dp_groups_same_degree
():
self
.
_scale_backward_initial_grad
()
else
:
self
.
_update_opt_rescale_grad
()
self
.
_remove_grad_scaling
()
def
_overlap_comm
(
self
):
pass
def
_fuse_allreduce
(
self
):
pass
def
_analyze_program
(
self
):
"""
{param_grad_name: data_parallel_group}
{pdata_parallel_group: aram_grad_name}
"""
block
=
default_main_program
().
global_block
()
ops
=
block
.
ops
scaled_grads
=
[]
for
op
in
ops
:
if
is_data_parallel_reduce_op
(
op
):
grad_name
=
op
.
output_arg_names
[
0
]
if
grad_name
in
self
.
_grad_name_to_group_map
:
continue
assert
op
.
has_attr
(
"ring_id"
),
"Unexception: comm op [{}] has NOT ring id."
.
format
(
str
(
op
))
group
=
ring_id_to_process_group
(
op
.
attr
(
"ring_id"
))
assert
group
is
not
None
,
"Unexception: data parallel group of [{}] from op [{}] is None"
.
format
(
grad_name
,
str
(
op
))
self
.
_grad_name_to_group_map
[
grad_name
]
=
group
if
group
not
in
self
.
_group_to_grad_name_map
:
self
.
_group_to_grad_name_map
[
group
]
=
[
grad_name
]
else
:
self
.
_group_to_grad_name_map
[
group
].
append
(
grad_name
)
elif
is_data_parallel_scale_op
(
op
):
grad_name
=
op
.
output_arg_names
[
0
]
scaled_grads
.
append
(
grad_name
)
# TODO support multiple optimizers in on network in future.
# here we assume that the optimizer is unique in network.
elif
is_optimize_op
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
:
self
.
_support_rescale_grad
=
True
not_synchronized_grads
=
[]
for
grad_name
in
scaled_grads
:
if
grad_name
not
in
self
.
_grad_name_to_group_map
:
not_synchronized_grads
.
append
(
grad_name
)
assert
len
(
not_synchronized_grads
)
==
0
,
"Unexception: gradients [{}] is scaled BUT NOT synchronized."
.
format
(
not_synchronized_grads
)
def
_could_be_prune
(
self
):
return
self
.
_support_rescale_grad
or
self
.
_all_dp_groups_same_degree
()
def
_all_dp_groups_same_degree
(
self
):
return
len
(
set
([
len
(
group
.
ranks
)
for
group
in
self
.
_group_to_grad_name_map
.
keys
()
]))
==
1
def
_scale_backward_initial_grad
(
self
):
block
=
default_main_program
().
global_block
()
dp_degree
=
len
(
list
(
self
.
_group_to_grad_name_map
.
keys
())[
0
].
ranks
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_loss_grad_op
(
op
):
assert
op
.
type
==
'fill_constant'
,
\
"loss_grad_op must be fill_constant op, "
\
"but this op is {}"
.
format
(
op
.
type
)
assert
op
.
has_attr
(
'value'
)
loss_scale
=
float
(
op
.
attr
(
'value'
))
loss_scale
=
loss_scale
/
dp_degree
op
.
_set_attr
(
'value'
,
loss_scale
)
break
def
_remove_grad_scaling
(
self
):
block
=
default_main_program
().
global_block
()
for
op_idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_data_parallel_scale_op
(
op
):
block
.
_remove_op
(
op_idx
,
False
)
block
.
_sync_with_cpp
()
def
_update_opt_rescale_grad
(
self
):
block
=
default_main_program
().
global_block
()
scaled_grads
=
set
()
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_optimize_op
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
:
assert
op
.
has_attr
(
'rescale_grad'
),
"Unexception: op [{}] is supported to have [rescale_grad] attribute."
.
format
(
str
(
op
))
assert
len
(
op
.
input
(
"Grad"
)
)
==
1
,
"Unexception: op [{}] is supported to have only one input grad var."
.
format
(
str
(
op
))
grad_name
=
op
.
input
(
"Grad"
)[
0
]
dp_degree
=
len
(
list
(
self
.
_grad_name_to_group_map
[
grad_name
].
ranks
))
scaled_grads
.
add
(
grad_name
)
rescale_grad
=
float
(
op
.
attr
(
'rescale_grad'
))
/
dp_degree
op
.
_set_attr
(
'rescale_grad'
,
rescale_grad
)
assert
scaled_grads
==
set
(
self
.
_grad_name_to_group_map
.
keys
(
)),
"Unexception: gradients [{}] are unscaled."
.
format
(
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
scaled_grads
)
python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt
浏览文件 @
7aeec4ed
...
...
@@ -20,6 +20,8 @@ if((NOT WITH_GPU)
list
(
REMOVE_ITEM TEST_OPS
"test_auto_parallel_sharding_pass"
)
list
(
REMOVE_ITEM TEST_OPS
"test_auto_parallel_fp16_pass"
)
list
(
REMOVE_ITEM TEST_OPS
"test_auto_parallel_gradient_merge_pass"
)
list
(
REMOVE_ITEM TEST_OPS
"test_auto_parallel_data_parallel_optimization_pass"
)
endif
()
foreach
(
TEST_OP
${
TEST_OPS
}
)
...
...
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
浏览文件 @
7aeec4ed
...
...
@@ -108,7 +108,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
pickle
.
dump
(
all_fetch_values
,
f
)
def
get_gpt_model
(
self
,
strategy
,
place
,
batch_size
,
sequence_len
,
vocab_size
):
vocab_size
,
**
kwargs
):
modeling
.
init_global
()
if
strategy
==
"dp"
:
modeling
.
_global_parallel_strategy
=
"dp"
...
...
@@ -179,11 +179,17 @@ class AutoPallelPassTestBase(DistPassTestBase):
criterion
=
GPTPretrainingCriterion
()
loss
=
criterion
(
preds
,
labels
,
loss_mask
)
clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
1.0
)
optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.00001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
clip
)
if
kwargs
.
get
(
'optimizer'
,
None
)
==
"LarsMomentum"
:
optimizer
=
paddle
.
fluid
.
optimizer
.
LarsMomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
else
:
optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.00001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
clip
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
startup_program
=
paddle
.
static
.
default_startup_program
()
_
,
_
,
dist_startup_prog
,
dist_main_prog
=
optimizer
.
minimize
(
...
...
python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py
0 → 100644
浏览文件 @
7aeec4ed
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
random
import
numpy
as
np
import
unittest
import
paddle
import
paddle.nn
as
nn
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.auto_parallel
as
auto
from
paddle.distributed.auto_parallel.dist_context
import
get_default_distributed_context
from
paddle.distributed.passes
import
new_pass
,
PassManager
,
PassContext
from
auto_parallel_pass_test_base
import
AutoPallelPassTestBase
sys
.
path
.
append
(
".."
)
import
auto_parallel_gpt_model
as
modeling
from
auto_parallel_gpt_model
import
GPTModel
,
GPTForPretraining
,
GPTPretrainingCriterion
class
TestDataParallelPassWithScale1
(
AutoPallelPassTestBase
):
def
init
(
self
):
if
paddle
.
is_compiled_with_cuda
():
paddle
.
set_flags
({
'FLAGS_cudnn_deterministic'
:
1
})
self
.
rtol
=
1e-5
self
.
atol
=
1e-8
# NOTE a hack to compare pass apply or not, since there is no
# setting of this pass in dist_strategy
self
.
_apply_pass
=
False
rank
=
paddle
.
distributed
.
get_rank
()
paddle
.
seed
(
rank
+
2021
)
random
.
seed
(
rank
+
2021
)
np
.
random
.
seed
(
rank
+
2021
)
def
apply_passes
(
self
):
dist_strategy
=
fleet
.
DistributedStrategy
()
dist_strategy
.
semi_auto
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
self
.
_apply_pass
=
True
def
apply_no_passes
(
self
):
dist_strategy
=
fleet
.
DistributedStrategy
()
dist_strategy
.
semi_auto
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
self
.
_apply_pass
=
False
def
test_bs_8
(
self
):
self
.
check_main
(
gpus
=
[
0
,
1
],
batch_size
=
8
,
sequence_len
=
512
,
vocab_size
=
1000
)
# test scaling with fillconstant
def
get_model
(
self
,
place
,
batch_size
,
sequence_len
,
vocab_size
):
dist_main_prog
,
dist_startup_prog
,
data_holder
,
[
loss
],
gen_data
=
self
.
get_gpt_model
(
'dp'
,
place
,
batch_size
,
sequence_len
,
vocab_size
)
if
self
.
_apply_pass
:
config
=
{}
config
[
"dist_context"
]
=
get_default_distributed_context
()
config
[
"global_rank"
]
=
paddle
.
distributed
.
get_rank
()
dp_pass
=
new_pass
(
"auto_parallel_data_parallel_optimization"
,
config
)
dp_pass
.
apply
([
dist_main_prog
],
[
dist_startup_prog
],
PassContext
())
return
dist_main_prog
,
dist_startup_prog
,
data_holder
,
[
loss
],
gen_data
class
TestDataParallelPassWithScale2
(
TestDataParallelPassWithScale1
):
# test scaling with optimizer rescale_grad
def
get_model
(
self
,
place
,
batch_size
,
sequence_len
,
vocab_size
):
dist_main_prog
,
dist_startup_prog
,
data_holder
,
[
loss
],
gen_data
=
self
.
get_gpt_model
(
'dp'
,
place
,
batch_size
,
sequence_len
,
vocab_size
,
optimizer
=
'LarsMomentum'
)
if
self
.
_apply_pass
:
config
=
{}
config
[
"dist_context"
]
=
get_default_distributed_context
()
config
[
"global_rank"
]
=
paddle
.
distributed
.
get_rank
()
dp_pass
=
new_pass
(
"auto_parallel_data_parallel_optimization"
,
config
)
dp_pass
.
apply
([
dist_main_prog
],
[
dist_startup_prog
],
PassContext
())
return
dist_main_prog
,
dist_startup_prog
,
data_holder
,
[
loss
],
gen_data
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录