Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7aeec4ed
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7aeec4ed
编写于
8月 12, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
8月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Data Parallel Optimization Pass 1 (#44882)
* bugfix * remove scaling * support rescale_grad opt
上级
cf17ae8a
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
382 addition
and
20 deletion
+382
-20
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+10
-0
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+8
-0
python/paddle/distributed/auto_parallel/process_group.py
python/paddle/distributed/auto_parallel/process_group.py
+11
-8
python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py
...dle/distributed/auto_parallel/tuner/optimization_tuner.py
+6
-6
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+15
-0
python/paddle/distributed/passes/__init__.py
python/paddle/distributed/passes/__init__.py
+1
-0
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
...ibuted/passes/auto_parallel_data_parallel_optimization.py
+207
-0
python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt
...e/fluid/tests/unittests/distributed_passes/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
...ttests/distributed_passes/auto_parallel_pass_test_base.py
+12
-6
python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py
...ses/test_auto_parallel_data_parallel_optimization_pass.py
+110
-0
未找到文件。
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
7aeec4ed
...
...
@@ -435,3 +435,13 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names,
return
sync_and_scale_gradients
(
dist_ctx
,
op
,
dp_group
,
out_grad_names
)
def
is_data_parallel_scale_op
(
op
):
return
op
.
type
==
"scale"
and
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
ParallelMode
.
DataParallel
in
op
.
desc
.
attr
(
"op_namescope"
)
def
is_data_parallel_reduce_op
(
op
):
return
op
.
type
in
[
"c_reduce_sum"
,
"c_allreduce_sum"
]
and
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
ParallelMode
.
DataParallel
in
op
.
desc
.
attr
(
"op_namescope"
)
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
7aeec4ed
...
...
@@ -195,6 +195,14 @@ class Parallelizer:
params_grads
):
if
self
.
_strategy
is
None
:
return
# data parallel optimization
config
=
{}
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"global_rank"
]
=
rank
dp_pass
=
new_pass
(
"auto_parallel_data_parallel_optimization"
,
config
)
dp_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
if
self
.
_strategy
.
sharding
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
sharding_configs
)
config
[
"dist_context"
]
=
self
.
_dist_context
...
...
python/paddle/distributed/auto_parallel/process_group.py
浏览文件 @
7aeec4ed
...
...
@@ -160,21 +160,24 @@ class ProcessGroup:
def
is_member
(
self
):
return
True
#
def __eq__(self, other):
#
if not isinstance(other, ProcessGroup):
#
return False
#
if self.id != other.id:
#
return False
#
return True
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
ProcessGroup
):
return
False
if
self
.
id
!=
other
.
id
:
return
False
return
True
#
def __ne__(self, other):
#
return not self.__eq__(other)
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
def
__str__
(
self
):
string
=
"id: {}, nranks: {}, ranks: {}."
.
format
(
self
.
id
,
self
.
nranks
,
", "
.
join
(
map
(
str
,
self
.
ranks
)))
return
string
def
__hash__
(
self
):
return
hash
(
self
.
__str__
())
# Note that Process group 0 is reserved for representing all ranks.
# At the beginning, group 0 is empty and new ranks will be added automatically.
...
...
python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py
浏览文件 @
7aeec4ed
...
...
@@ -266,7 +266,7 @@ class OptimizationTuner:
config
[
"input_data"
]
=
self
.
_baseline_dist_context
.
serial_feed_vars
[
"inputs"
]
\
+
self
.
_baseline_dist_context
.
serial_feed_vars
[
"labels"
]
if
config
[
"use_pure_fp16"
]:
config
[
"base_opt"
]
=
dist_context
.
optimizer
config
[
"base_opt"
]
=
dist_context
.
serial_
optimizer
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
.
apply
([
main_program
],
[
startup_program
],
pass_context
)
...
...
@@ -363,11 +363,11 @@ class OptimizationTuner:
profile_args
=
" "
.
join
([
"--rank"
,
str
(
self
.
rank
),
"--device_id"
,
str
(
self
.
device_id
)
,
"--ctx_filename
"
,
ctx_path
,
str
(
self
.
rank
),
"--device_id"
,
str
(
self
.
device_id
),
"--ctx_filename"
,
ctx_path
,
"--profile_start_step"
,
str
(
self
.
_config
.
profile_start_step
),
"--profile_end_step
"
,
str
(
self
.
_config
.
profile_end_step
)
])
cmd_args
=
"-m paddle.distributed.auto_parallel.tuner.profiler"
+
" "
+
profile_args
cmd
=
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
shlex
.
split
(
cmd_args
)
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
7aeec4ed
...
...
@@ -23,6 +23,7 @@ from functools import reduce
import
paddle.fluid.core
as
core
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.distributed.auto_parallel.process_group
import
get_all_process_groups
from
paddle.fluid.io
import
is_parameter
,
is_belong_to_optimizer
from
paddle.distributed.auto_parallel.dist_attribute
import
TensorDistributedAttribute
,
OperatorDistributedAttribute
...
...
@@ -1123,6 +1124,13 @@ def is_loss_op(op):
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
==
(
int
(
OpRole
.
Forward
)
|
int
(
OpRole
.
Loss
))
def
is_loss_grad_op
(
op
):
if
OP_ROLE_KEY
not
in
op
.
attr_names
:
return
False
op_role
=
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
return
op_role
&
int
(
OpRole
.
Backward
)
and
op_role
&
int
(
OpRole
.
Loss
)
def
is_prim_op
(
op
):
return
op
.
type
.
endswith
(
"_p"
)
...
...
@@ -1481,3 +1489,10 @@ def debug_program(program, path, name):
path
,
name
+
'_program'
+
".%d"
%
(
paddle
.
distributed
.
get_rank
()))
with
open
(
filename
,
'w'
)
as
f
:
f
.
write
(
str
(
program
))
def
ring_id_to_process_group
(
ring_id
):
for
g
in
get_all_process_groups
():
if
g
.
id
==
ring_id
:
return
g
return
None
python/paddle/distributed/passes/__init__.py
浏览文件 @
7aeec4ed
...
...
@@ -19,6 +19,7 @@ from .auto_parallel_sharding import *
from
.auto_parallel_amp
import
*
from
.auto_parallel_fp16
import
*
from
.auto_parallel_recompute
import
*
from
.auto_parallel_data_parallel_optimization
import
*
from
.cpp_pass
import
*
import
os
from
.ps_trainer_pass
import
*
...
...
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
0 → 100644
浏览文件 @
7aeec4ed
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
OrderedDict
import
paddle
from
paddle.fluid.framework
import
default_main_program
from
paddle.distributed.auto_parallel.operators.common
import
is_data_parallel_scale_op
,
is_data_parallel_reduce_op
from
paddle.distributed.auto_parallel.utils
import
is_loss_grad_op
,
is_optimize_op
,
ring_id_to_process_group
from
.pass_base
import
PassBase
,
PassType
,
register_pass
# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__
=
[
'lars_momentum'
,
'sparse_momentum'
,
'dgc_momentum'
,
'momentum'
,
'merge_momentum'
]
@
register_pass
(
"auto_parallel_data_parallel_optimization"
)
class
DataParallelOptimizationPass
(
PassBase
):
"""
Apply Optimizations that specialized for data parallelism in Auto Parallel.
1. prune grad scaling
2. overlap comm and calc
3. fuse allreduce
"""
def
__init__
(
self
):
super
(
DataParallelOptimizationPass
,
self
).
__init__
()
# NOTE not use depence on loss and param_grads
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"global_rank"
,
-
1
)
# {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory
self
.
_grad_name_to_group_map
=
OrderedDict
()
# {group1:[grad1, grad2] , group2:[grad3]}
self
.
_group_to_grad_name_map
=
OrderedDict
()
self
.
_support_rescale_grad
=
False
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
return
False
if
(
not
isinstance
(
self
.
get_attr
(
"global_rank"
),
int
))
or
self
.
get_attr
(
"global_rank"
)
<
0
:
return
False
return
True
def
_check_conflict
(
self
,
other_pass
):
return
True
def
_type
(
self
):
return
PassType
.
COMM_OPT
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
self
.
dist_context
=
self
.
get_attr
(
"dist_context"
)
self
.
global_rank
=
int
(
self
.
get_attr
(
"global_rank"
))
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
self
.
_analyze_program
()
self
.
_prune_grad_scaling
()
self
.
_overlap_comm
()
self
.
_fuse_allreduce
()
def
_prune_grad_scaling
(
self
):
if
not
self
.
_could_be_prune
():
return
if
self
.
_all_dp_groups_same_degree
():
self
.
_scale_backward_initial_grad
()
else
:
self
.
_update_opt_rescale_grad
()
self
.
_remove_grad_scaling
()
def
_overlap_comm
(
self
):
pass
def
_fuse_allreduce
(
self
):
pass
def
_analyze_program
(
self
):
"""
{param_grad_name: data_parallel_group}
{pdata_parallel_group: aram_grad_name}
"""
block
=
default_main_program
().
global_block
()
ops
=
block
.
ops
scaled_grads
=
[]
for
op
in
ops
:
if
is_data_parallel_reduce_op
(
op
):
grad_name
=
op
.
output_arg_names
[
0
]
if
grad_name
in
self
.
_grad_name_to_group_map
:
continue
assert
op
.
has_attr
(
"ring_id"
),
"Unexception: comm op [{}] has NOT ring id."
.
format
(
str
(
op
))
group
=
ring_id_to_process_group
(
op
.
attr
(
"ring_id"
))
assert
group
is
not
None
,
"Unexception: data parallel group of [{}] from op [{}] is None"
.
format
(
grad_name
,
str
(
op
))
self
.
_grad_name_to_group_map
[
grad_name
]
=
group
if
group
not
in
self
.
_group_to_grad_name_map
:
self
.
_group_to_grad_name_map
[
group
]
=
[
grad_name
]
else
:
self
.
_group_to_grad_name_map
[
group
].
append
(
grad_name
)
elif
is_data_parallel_scale_op
(
op
):
grad_name
=
op
.
output_arg_names
[
0
]
scaled_grads
.
append
(
grad_name
)
# TODO support multiple optimizers in on network in future.
# here we assume that the optimizer is unique in network.
elif
is_optimize_op
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
:
self
.
_support_rescale_grad
=
True
not_synchronized_grads
=
[]
for
grad_name
in
scaled_grads
:
if
grad_name
not
in
self
.
_grad_name_to_group_map
:
not_synchronized_grads
.
append
(
grad_name
)
assert
len
(
not_synchronized_grads
)
==
0
,
"Unexception: gradients [{}] is scaled BUT NOT synchronized."
.
format
(
not_synchronized_grads
)
def
_could_be_prune
(
self
):
return
self
.
_support_rescale_grad
or
self
.
_all_dp_groups_same_degree
()
def
_all_dp_groups_same_degree
(
self
):
return
len
(
set
([
len
(
group
.
ranks
)
for
group
in
self
.
_group_to_grad_name_map
.
keys
()
]))
==
1
def
_scale_backward_initial_grad
(
self
):
block
=
default_main_program
().
global_block
()
dp_degree
=
len
(
list
(
self
.
_group_to_grad_name_map
.
keys
())[
0
].
ranks
)
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_loss_grad_op
(
op
):
assert
op
.
type
==
'fill_constant'
,
\
"loss_grad_op must be fill_constant op, "
\
"but this op is {}"
.
format
(
op
.
type
)
assert
op
.
has_attr
(
'value'
)
loss_scale
=
float
(
op
.
attr
(
'value'
))
loss_scale
=
loss_scale
/
dp_degree
op
.
_set_attr
(
'value'
,
loss_scale
)
break
def
_remove_grad_scaling
(
self
):
block
=
default_main_program
().
global_block
()
for
op_idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_data_parallel_scale_op
(
op
):
block
.
_remove_op
(
op_idx
,
False
)
block
.
_sync_with_cpp
()
def
_update_opt_rescale_grad
(
self
):
block
=
default_main_program
().
global_block
()
scaled_grads
=
set
()
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_optimize_op
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
:
assert
op
.
has_attr
(
'rescale_grad'
),
"Unexception: op [{}] is supported to have [rescale_grad] attribute."
.
format
(
str
(
op
))
assert
len
(
op
.
input
(
"Grad"
)
)
==
1
,
"Unexception: op [{}] is supported to have only one input grad var."
.
format
(
str
(
op
))
grad_name
=
op
.
input
(
"Grad"
)[
0
]
dp_degree
=
len
(
list
(
self
.
_grad_name_to_group_map
[
grad_name
].
ranks
))
scaled_grads
.
add
(
grad_name
)
rescale_grad
=
float
(
op
.
attr
(
'rescale_grad'
))
/
dp_degree
op
.
_set_attr
(
'rescale_grad'
,
rescale_grad
)
assert
scaled_grads
==
set
(
self
.
_grad_name_to_group_map
.
keys
(
)),
"Unexception: gradients [{}] are unscaled."
.
format
(
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
scaled_grads
)
python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt
浏览文件 @
7aeec4ed
...
...
@@ -20,6 +20,8 @@ if((NOT WITH_GPU)
list
(
REMOVE_ITEM TEST_OPS
"test_auto_parallel_sharding_pass"
)
list
(
REMOVE_ITEM TEST_OPS
"test_auto_parallel_fp16_pass"
)
list
(
REMOVE_ITEM TEST_OPS
"test_auto_parallel_gradient_merge_pass"
)
list
(
REMOVE_ITEM TEST_OPS
"test_auto_parallel_data_parallel_optimization_pass"
)
endif
()
foreach
(
TEST_OP
${
TEST_OPS
}
)
...
...
python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py
浏览文件 @
7aeec4ed
...
...
@@ -108,7 +108,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
pickle
.
dump
(
all_fetch_values
,
f
)
def
get_gpt_model
(
self
,
strategy
,
place
,
batch_size
,
sequence_len
,
vocab_size
):
vocab_size
,
**
kwargs
):
modeling
.
init_global
()
if
strategy
==
"dp"
:
modeling
.
_global_parallel_strategy
=
"dp"
...
...
@@ -179,11 +179,17 @@ class AutoPallelPassTestBase(DistPassTestBase):
criterion
=
GPTPretrainingCriterion
()
loss
=
criterion
(
preds
,
labels
,
loss_mask
)
clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
1.0
)
optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.00001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
clip
)
if
kwargs
.
get
(
'optimizer'
,
None
)
==
"LarsMomentum"
:
optimizer
=
paddle
.
fluid
.
optimizer
.
LarsMomentumOptimizer
(
learning_rate
=
0.001
,
momentum
=
0.9
)
else
:
optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.00001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
clip
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
startup_program
=
paddle
.
static
.
default_startup_program
()
_
,
_
,
dist_startup_prog
,
dist_main_prog
=
optimizer
.
minimize
(
...
...
python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py
0 → 100644
浏览文件 @
7aeec4ed
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
random
import
numpy
as
np
import
unittest
import
paddle
import
paddle.nn
as
nn
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.auto_parallel
as
auto
from
paddle.distributed.auto_parallel.dist_context
import
get_default_distributed_context
from
paddle.distributed.passes
import
new_pass
,
PassManager
,
PassContext
from
auto_parallel_pass_test_base
import
AutoPallelPassTestBase
sys
.
path
.
append
(
".."
)
import
auto_parallel_gpt_model
as
modeling
from
auto_parallel_gpt_model
import
GPTModel
,
GPTForPretraining
,
GPTPretrainingCriterion
class
TestDataParallelPassWithScale1
(
AutoPallelPassTestBase
):
def
init
(
self
):
if
paddle
.
is_compiled_with_cuda
():
paddle
.
set_flags
({
'FLAGS_cudnn_deterministic'
:
1
})
self
.
rtol
=
1e-5
self
.
atol
=
1e-8
# NOTE a hack to compare pass apply or not, since there is no
# setting of this pass in dist_strategy
self
.
_apply_pass
=
False
rank
=
paddle
.
distributed
.
get_rank
()
paddle
.
seed
(
rank
+
2021
)
random
.
seed
(
rank
+
2021
)
np
.
random
.
seed
(
rank
+
2021
)
def
apply_passes
(
self
):
dist_strategy
=
fleet
.
DistributedStrategy
()
dist_strategy
.
semi_auto
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
self
.
_apply_pass
=
True
def
apply_no_passes
(
self
):
dist_strategy
=
fleet
.
DistributedStrategy
()
dist_strategy
.
semi_auto
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
self
.
_apply_pass
=
False
def
test_bs_8
(
self
):
self
.
check_main
(
gpus
=
[
0
,
1
],
batch_size
=
8
,
sequence_len
=
512
,
vocab_size
=
1000
)
# test scaling with fillconstant
def
get_model
(
self
,
place
,
batch_size
,
sequence_len
,
vocab_size
):
dist_main_prog
,
dist_startup_prog
,
data_holder
,
[
loss
],
gen_data
=
self
.
get_gpt_model
(
'dp'
,
place
,
batch_size
,
sequence_len
,
vocab_size
)
if
self
.
_apply_pass
:
config
=
{}
config
[
"dist_context"
]
=
get_default_distributed_context
()
config
[
"global_rank"
]
=
paddle
.
distributed
.
get_rank
()
dp_pass
=
new_pass
(
"auto_parallel_data_parallel_optimization"
,
config
)
dp_pass
.
apply
([
dist_main_prog
],
[
dist_startup_prog
],
PassContext
())
return
dist_main_prog
,
dist_startup_prog
,
data_holder
,
[
loss
],
gen_data
class
TestDataParallelPassWithScale2
(
TestDataParallelPassWithScale1
):
# test scaling with optimizer rescale_grad
def
get_model
(
self
,
place
,
batch_size
,
sequence_len
,
vocab_size
):
dist_main_prog
,
dist_startup_prog
,
data_holder
,
[
loss
],
gen_data
=
self
.
get_gpt_model
(
'dp'
,
place
,
batch_size
,
sequence_len
,
vocab_size
,
optimizer
=
'LarsMomentum'
)
if
self
.
_apply_pass
:
config
=
{}
config
[
"dist_context"
]
=
get_default_distributed_context
()
config
[
"global_rank"
]
=
paddle
.
distributed
.
get_rank
()
dp_pass
=
new_pass
(
"auto_parallel_data_parallel_optimization"
,
config
)
dp_pass
.
apply
([
dist_main_prog
],
[
dist_startup_prog
],
PassContext
())
return
dist_main_prog
,
dist_startup_prog
,
data_holder
,
[
loss
],
gen_data
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录