未验证 提交 b2986bab 编写于 作者: Z ziyoujiyi 提交者: GitHub

sync/geo test ok & fix heter_worker program ok (#39511)

* delete gloo connect retry

* the_one_ps dirs reconstruct

* .

* .

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* refactor ps optimize

* refactor ps optimize

* refactor ps optimize

* .

* .

* .

* .

* .

* .

* refactor theoneps

* the_one_ps

* add ps pass unittest

* add ps pass unittest

* ps unitest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* ps unittest ready

* ps unittest ready

* solve dist_pass init conflict

* solve import CommContext error

* unittest ok

* implement AllocateFrom

* solve setup.py.in conflict

* solve conflict

* solve conflict

* solve conflict

* .

* .

* cpu-async-ps minimize test ok & gpu minimize test ok

* add heter 2stage unittest

* add heter 2stage unittest

* add heter 2stage unittest

* sync/geo test ok & fix heter_worker program ok

* .
Co-authored-by: Nzkh2016 <zhangkaihuo@baidu.com>
上级 0bcf1365
...@@ -54,14 +54,15 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -54,14 +54,15 @@ class ParameterServerOptimizer(MetaOptimizerBase):
attrs['user_defined_strategy'] = self.user_defined_strategy attrs['user_defined_strategy'] = self.user_defined_strategy
attrs['trainer'] = TrainerRuntimeConfig(self.user_defined_strategy) attrs['trainer'] = TrainerRuntimeConfig(self.user_defined_strategy)
attrs['ps_mode'] = attrs['trainer'].mode attrs['ps_mode'] = attrs['trainer'].mode
logger.info("ps_mode: {}".format(attrs['ps_mode']))
attrs['role_maker'] = self.role_maker attrs['role_maker'] = self.role_maker
attrs[ attrs[
'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode 'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode
attrs['is_worker'] = self.role_maker._is_worker() attrs['is_worker'] = self.role_maker._is_worker()
attrs['is_server'] = self.role_maker._is_server() attrs['is_server'] = self.role_maker._is_server()
attrs['is_heter_worker'] = self.role_maker._is_heter_worker() attrs['is_heter_worker'] = self.role_maker._is_heter_worker()
logger.info("this process is heter? {}".format(attrs[
'is_heter_worker']))
attrs['use_ps_gpu'] = self.user_defined_strategy.a_sync_configs[ attrs['use_ps_gpu'] = self.user_defined_strategy.a_sync_configs[
"use_ps_gpu"] "use_ps_gpu"]
attrs['lr_decay_steps'] = self.user_defined_strategy.a_sync_configs[ attrs['lr_decay_steps'] = self.user_defined_strategy.a_sync_configs[
......
...@@ -47,7 +47,7 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用 ...@@ -47,7 +47,7 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用
if ps_mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]: if ps_mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
dummy_output = program.global_block().create_var( dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
logger.info("dummy_output: {}".format(dummy_output))
program.global_block().append_op( program.global_block().append_op(
type="send", type="send",
inputs={"X": send_input_vars}, inputs={"X": send_input_vars},
...@@ -61,7 +61,7 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用 ...@@ -61,7 +61,7 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用
return dummy_output return dummy_output
def _append_barrier_op(self, program, dummys): def _append_barrier_op(self, program, dummys, trainer_id):
program.global_block().append_op( program.global_block().append_op(
type="send_barrier", type="send_barrier",
inputs={"X": dummys}, inputs={"X": dummys},
...@@ -79,19 +79,24 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用 ...@@ -79,19 +79,24 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用
send_ctx = get_geo_trainer_send_context(attrs) # geo 模式 send_ctx = get_geo_trainer_send_context(attrs) # geo 模式
else: else:
send_ctx = get_the_one_send_context(attrs) # async、sync 等各种模式 send_ctx = get_the_one_send_context(attrs) # async、sync 等各种模式
logger.info("send_ctx: {}".format(send_ctx))
dummys = [] dummys = []
for merged_name, send in send_ctx.items(): for merged_name, send in send_ctx.items():
if send.is_sparse() and ps_mode != DistributedMode.GEO: if send.is_sparse() and ps_mode != DistributedMode.GEO:
continue continue
logger.info('merged_name, send: {}, {}'.format(merged_name, send))
is_sparse = 1 if send.is_sparse() else 0 is_sparse = 1 if send.is_sparse() else 0
is_sparse = 2 if send.is_distributed() else is_sparse is_sparse = 2 if send.is_distributed() else is_sparse
dummys.append( dummys.append(
self._append_send_op(main_program, self._append_send_op(main_program,
send.origin_varnames(), merged_name, send.origin_varnames(), merged_name,
is_sparse, send.table_id(), ps_mode)) is_sparse, send.table_id(), ps_mode))
logger.info('ps trainer pass - ps mode: {}'.format(ps_mode))
logger.info('dummys: {}'.format(dummys))
if ps_mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]: if ps_mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
self._append_barrier_op(main_program, dummys) logger.info('insert send_barrier_op')
trainer_id = get_role_id(attrs['role_maker'])
self._append_barrier_op(main_program, dummys, trainer_id)
@register_pass("distributed_ops_pass") @register_pass("distributed_ops_pass")
......
...@@ -97,7 +97,7 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder): ...@@ -97,7 +97,7 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx): def __init__(self, pass_ctx):
logger.info("start building cpu-sync-ps program") logger.info("start building cpu-sync-ps program")
super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx) super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode == DistributedMode.GEO: if self.ps_mode != DistributedMode.SYNC:
raise ValueError("ps mode: {} not matched {}", raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "CpuSyncPsProgramBuilder")) format(ps_mode, "CpuSyncPsProgramBuilder"))
......
...@@ -83,8 +83,10 @@ class DistributedMode: ...@@ -83,8 +83,10 @@ class DistributedMode:
class TrainerRuntimeConfig(object): class TrainerRuntimeConfig(object):
def __init__(self, valid_strategy): def __init__(self, valid_strategy):
self.mode = None
k_steps = valid_strategy.a_sync_configs["k_steps"] k_steps = valid_strategy.a_sync_configs["k_steps"]
logger.info("ps mode in strategy: {}, {}".format(
valid_strategy.a_sync, valid_strategy.a_sync_configs["k_steps"]))
if not valid_strategy.a_sync and k_steps == 0: if not valid_strategy.a_sync and k_steps == 0:
self.mode = DistributedMode.SYNC self.mode = DistributedMode.SYNC
...@@ -94,7 +96,6 @@ class TrainerRuntimeConfig(object): ...@@ -94,7 +96,6 @@ class TrainerRuntimeConfig(object):
if valid_strategy.a_sync and k_steps > 0: if valid_strategy.a_sync and k_steps > 0:
self.mode = DistributedMode.GEO self.mode = DistributedMode.GEO
self.mode = None
num_threads = os.getenv("CPU_NUM", "1") num_threads = os.getenv("CPU_NUM", "1")
self.runtime_configs = {} self.runtime_configs = {}
...@@ -161,6 +162,13 @@ def get_dist_env(): ...@@ -161,6 +162,13 @@ def get_dist_env():
} }
def get_role_id(role_maker):
try:
return role_maker._role_id()
except Exception:
return role_maker.role_id()
def get_ps_endpoint(role_maker): def get_ps_endpoint(role_maker):
try: try:
return role_maker._get_pserver_endpoints()[get_role_id(role_maker)] return role_maker._get_pserver_endpoints()[get_role_id(role_maker)]
...@@ -184,7 +192,7 @@ def get_trainer_endpoint(role_maker): ...@@ -184,7 +192,7 @@ def get_trainer_endpoint(role_maker):
def get_previous_stage_trainers(role_maker): def get_previous_stage_trainers(role_maker):
try: try:
return role_maker_get_previous_trainers() return role_maker._get_previous_trainers()
except Exception: except Exception:
return role_maker.get_previous_trainers() return role_maker.get_previous_trainers()
...@@ -229,18 +237,11 @@ def get_sparse_tablenames(program, is_distributed): ...@@ -229,18 +237,11 @@ def get_sparse_tablenames(program, is_distributed):
return list(tablenames) return list(tablenames)
def get_role_id(role_maker):
try:
return role_maker._role_id()
except Exception:
return role_maker.role_id()
def get_ps_endpoints(role_maker): def get_ps_endpoints(role_maker):
try: try:
return role_maker._get_pserver_endpoints()[get_role_id(role_maker)] return role_maker._get_pserver_endpoints()
except Exception: except Exception:
return role_maker.get_pserver_endpoints()[get_role_id(role_maker)] return role_maker.get_pserver_endpoints()
def get_trainers(role_maker): def get_trainers(role_maker):
...@@ -296,8 +297,35 @@ def get_geo_trainer_send_context(context): ...@@ -296,8 +297,35 @@ def get_geo_trainer_send_context(context):
if context['ps_mode'] != DistributedMode.GEO: if context['ps_mode'] != DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}", raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "get_geo_trainer_send_context")) format(ps_mode, "get_geo_trainer_send_context"))
send_ctx = {} send_ctx = {}
trainer_id = get_role_id(context['role_maker'])
idx = 0
distibuted_varnames = get_sparse_tablenames(context['origin_main_program'],
True)
for merged in context['merged_sparse_pairs']:
param, grad = merged
grad_name = grad.merged_var.name
param_name = param.merged_var.name
is_distributed = True if param_name in distibuted_varnames else False
var = context['origin_main_program'].global_block().vars[
grad.merged_var.name]
var_numel = reduce(lambda x, y: x * y, var.shape[1:])
sparse_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], [grad_name], trainer_id, True,
True, is_distributed, idx, False)
idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx
if len(send_ctx) == 0:
raise ValueError("GeoSGD require sparse parameters in your net.")
if len(context['tensor_table']) > 0 and context['is_worker']:
name, ctx = _step_ctx(idx, context['role_maker'])
send_ctx[name] = ctx
return send_ctx return send_ctx
...@@ -1253,6 +1281,60 @@ def find_op_input_output(program, block, op): ...@@ -1253,6 +1281,60 @@ def find_op_input_output(program, block, op):
return input_var_list, output_var_list return input_var_list, output_var_list
def add_heter_send_op(program, heter_program, block, block_var_detail):
def _get_send_op_dict():
send_op_dict = {}
send_op_list = find_send_op(program)
for op in send_op_list:
input_list, _ = find_op_input_output(program,
program.global_block(), op)
for var in input_list:
send_op_dict[var] = op
return send_op_dict
send_grad_var_list = []
send_op_dict = _get_send_op_dict()
table_dict = {}
for persistable_var in block_var_detail["backward"]["persistables"]:
if "@GRAD" not in persistable_var:
continue
if "GRAD" != persistable_var.split("@")[-1]:
continue
if persistable_var not in send_op_dict:
continue
send_op = send_op_dict[persistable_var]
is_sparse = send_op.attr('is_sparse')
table_id = send_op.attr('table_id')
send_varnames = send_op.attr('send_varnames')
send_grad_var_list.append(persistable_var)
if table_id not in table_dict:
table_dict[table_id] = {}
table_dict[table_id]['var_list'] = []
table_dict[table_id]['is_sparse'] = is_sparse
table_dict[table_id]['send_varnames'] = send_varnames
table_dict[table_id]['var_list'].append(persistable_var)
for table_id in table_dict:
dummy_output = block.create_var(
name=framework.generate_control_dev_var_name())
send_input_vars = [
block.vars[union_var]
for union_var in table_dict[table_id]['var_list']
]
block.append_op(
type="send",
inputs={"X": send_input_vars},
outputs={"Out": dummy_output},
attrs={
"send_varnames": table_dict[table_id]['send_varnames'],
"is_sparse": is_sparse,
"table_id": table_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
return send_grad_var_list
def get_vars_name_in_block(block): def get_vars_name_in_block(block):
vars_list = block.vars.keys() vars_list = block.vars.keys()
vars_name_list = [var_name for var_name in vars_list] vars_name_list = [var_name for var_name in vars_list]
...@@ -1302,10 +1384,6 @@ def create_backward_block(program, origin_program, bp_ops_list, ...@@ -1302,10 +1384,6 @@ def create_backward_block(program, origin_program, bp_ops_list,
return heter_block return heter_block
def debug_program(file, program, is_trainer): def debug_program(file, program):
if is_trainer: with open(file, 'w+') as f:
with open(file, 'w+') as f: f.write(str(program))
f.write(str(program))
else:
with open(file, 'w+') as f:
f.write(str(program))
...@@ -28,7 +28,7 @@ from paddle.fluid.tests.unittests.ps.ps_dnn_trainer import DnnTrainer ...@@ -28,7 +28,7 @@ from paddle.fluid.tests.unittests.ps.ps_dnn_trainer import DnnTrainer
class TestPsTrainerPass(PsPassTestBase): class TestPsTrainerPass(PsPassTestBase):
def init(self): def init(self):
self.config = {} self.config = {}
self.config['ps_mode_config'] = "../ps/cpu_async_ps_config.yaml" self.config['ps_mode_config'] = ""
self.config['worker_num'] = "1" self.config['worker_num'] = "1"
self.config['server_num'] = "1" self.config['server_num'] = "1"
self.config['run_minimize'] = "0" self.config['run_minimize'] = "0"
...@@ -47,23 +47,58 @@ class TestPsTrainerPass(PsPassTestBase): ...@@ -47,23 +47,58 @@ class TestPsTrainerPass(PsPassTestBase):
def check(self): def check(self):
pass pass
def test_ps_optimizer_minimize_cpu(self): def test_ps_optimizer_minimize_cpu_async(self):
self.init()
self.config['ps_mode_config'] = "../ps/cpu_async_ps_config.yaml"
self.config['run_minimize'] = '1'
self.config['debug_new_minimize'] = '0'
self.config['log_dir'] = "/async_cpu_log_old_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.config['debug_new_minimize'] = '1'
self.config['log_dir'] = "/async_cpu_log_new_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.check()
def test_ps_optimizer_minimize_cpu_sync(self):
self.init()
self.config['ps_mode_config'] = "../ps/cpu_sync_ps_config.yaml"
self.config['run_minimize'] = '1'
self.config['debug_new_minimize'] = '0'
self.config['log_dir'] = "/sync_cpu_log_old_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.config['debug_new_minimize'] = '1'
self.config['log_dir'] = "/sync_cpu_log_new_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.check()
def test_ps_optimizer_minimize_cpu_geo(self):
self.init() self.init()
self.config['ps_mode_config'] = "../ps/cpu_geo_ps_config.yaml"
self.config['run_minimize'] = '1' self.config['run_minimize'] = '1'
self.config['debug_new_minimize'] = '0' self.config['debug_new_minimize'] = '0'
self.config['log_dir'] = "/cpu_log_old_minimize" self.config['log_dir'] = "/geo_cpu_log_old_minimize"
remove_path_if_exists(self.config['log_dir']) remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config) self.ps_launch(self.config)
self.config['debug_new_minimize'] = '1' self.config['debug_new_minimize'] = '1'
self.config['log_dir'] = "/cpu_log_new_minimize" self.config['log_dir'] = "/geo_cpu_log_new_minimize"
remove_path_if_exists(self.config['log_dir']) remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config) self.ps_launch(self.config)
self.check() self.check()
# heter ps 三阶段待测 # heter ps 二阶段
def test_ps_optimizer_minimize_heter(self): def test_ps_optimizer_minimize_heter(self):
self.init() self.init()
self.config['worker_num'] = "2" self.config['worker_num'] = "2"
......
...@@ -26,7 +26,6 @@ hyper_parameters: ...@@ -26,7 +26,6 @@ hyper_parameters:
fc_sizes: [400, 400, 400] fc_sizes: [400, 400, 400]
runner: runner:
geo_step: 400
sync_mode: "async" # sync / async / geo / heter sync_mode: "async" # sync / async / geo / heter
thread_num: 16 thread_num: 16
use_gpu: 0 use_gpu: 0
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# refer to PaddleRec/models/rank/dnn/benchmark.yaml
hyper_parameters:
optimizer:
class: Adam
learning_rate: 0.0001
adam_lazy_mode: True
sparse_inputs_slots: 27
sparse_feature_number: 1000001
sparse_feature_dim: 10
dense_input_dim: 13
fc_sizes: [400, 400, 400]
runner:
geo_step: 400
sync_mode: "geo"
thread_num: 16
use_gpu: 0
model_path: "../ps_dnn_model.py"
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# refer to PaddleRec/models/rank/dnn/benchmark.yaml
hyper_parameters:
optimizer:
class: Adam
learning_rate: 0.0001
adam_lazy_mode: True
sparse_inputs_slots: 27
sparse_feature_number: 1000001
sparse_feature_dim: 10
dense_input_dim: 13
fc_sizes: [400, 400, 400]
runner:
sync_mode: "sync"
thread_num: 16
use_gpu: 0
model_path: "../ps_dnn_model.py"
...@@ -329,9 +329,9 @@ class DnnTrainer(object): ...@@ -329,9 +329,9 @@ class DnnTrainer(object):
sync_mode = self.config.get("runner.sync_mode") sync_mode = self.config.get("runner.sync_mode")
inner_optimizer = paddle.optimizer.Adam(learning_rate, lazy_mode=True) inner_optimizer = paddle.optimizer.Adam(learning_rate, lazy_mode=True)
self.role_maker._generate_role() # 必要
if self.config['debug_new_minimize'] == 1: if self.config['debug_new_minimize'] == 1:
logger.info("entering run_minimize -- new") logger.info("entering run_minimize -- new")
self.role_maker._generate_role() # 必要
from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer
ps_optimizer = ParameterServerOptimizer(inner_optimizer) ps_optimizer = ParameterServerOptimizer(inner_optimizer)
ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer, ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer,
...@@ -346,11 +346,16 @@ class DnnTrainer(object): ...@@ -346,11 +346,16 @@ class DnnTrainer(object):
if fleet.is_server(): if fleet.is_server():
_main_file = '/' + sync_mode + '_run_minimize' + '_debug:_' + str( _main_file = '/' + sync_mode + '_run_minimize' + '_debug:_' + str(
self.config['debug_new_minimize']) + '_server_main.prototxt' self.config['debug_new_minimize']) + '_server_main.prototxt'
debug_program(_main_file, loss.block.program, 0) debug_program(_main_file, loss.block.program)
elif fleet.is_worker(): elif fleet.is_worker():
_main_file = '/' + sync_mode + '_run_minimize' + '_debug:_' + str( _main_file = '/' + sync_mode + '_run_minimize' + '_debug:_' + str(
self.config['debug_new_minimize']) + '_worker_main.prototxt' self.config['debug_new_minimize']) + '_worker_main.prototxt'
debug_program(_main_file, loss.block.program, 1) debug_program(_main_file, loss.block.program)
elif self.role_maker._is_heter_worker():
_main_file = '/' + sync_mode + '_run_minimize' + '_debug:_' + str(
self.config[
'debug_new_minimize']) + '_heter_worker_main.prototxt'
debug_program(_main_file, loss.block.program)
def run_single_pass(self): def run_single_pass(self):
self.init_fleet_with_gloo() self.init_fleet_with_gloo()
...@@ -395,17 +400,18 @@ class DnnTrainer(object): ...@@ -395,17 +400,18 @@ class DnnTrainer(object):
_main_file = '/' + sync_mode + "_" + str(config[ _main_file = '/' + sync_mode + "_" + str(config[
"applied_pass_name"]) + '_debug:_' + str(self.config[ "applied_pass_name"]) + '_debug:_' + str(self.config[
'debug_new_pass']) + '_server_main.prototxt' 'debug_new_pass']) + '_server_main.prototxt'
debug_program(_main_file, _main, 0) debug_program(_main_file, _main)
elif fleet.is_worker(): elif fleet.is_worker():
_main_file = '/' + sync_mode + "_" + str(config[ _main_file = '/' + sync_mode + "_" + str(config[
"applied_pass_name"]) + '_debug:_' + str(self.config[ "applied_pass_name"]) + '_debug:_' + str(self.config[
'debug_new_pass']) + '_worker_main.prototxt' 'debug_new_pass']) + '_worker_main.prototxt'
debug_program(_main_file, _main, 1) debug_program(_main_file, _main)
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
config = parse_args() config = parse_args()
logger.info(">>>>>>>>>> python process started")
os.environ["CPU_NUM"] = str(config.get("runner.thread_num")) os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
benchmark_main = DnnTrainer(config) benchmark_main = DnnTrainer(config)
if config['run_single_pass'] == 1: if config['run_single_pass'] == 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册