未验证 提交 4dbcb975 编写于 作者: W Wu Yi 提交者: GitHub

Enable dist se resnext (#12365)

* enable dist se resnext

* small batch size
上级 2409d0f7
...@@ -278,7 +278,7 @@ class DistSeResneXt2x2: ...@@ -278,7 +278,7 @@ class DistSeResneXt2x2:
def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True): def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True):
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = get_model( test_program, avg_cost, train_reader, test_reader, batch_acc, predict = get_model(
batch_size=20) batch_size=2)
if is_dist: if is_dist:
t = get_transpiler(trainer_id, t = get_transpiler(trainer_id,
fluid.default_main_program(), endpoints, fluid.default_main_program(), endpoints,
...@@ -294,11 +294,7 @@ class DistSeResneXt2x2: ...@@ -294,11 +294,7 @@ class DistSeResneXt2x2:
strategy.num_threads = 1 strategy.num_threads = 1
strategy.allow_op_delay = False strategy.allow_op_delay = False
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
True, True, loss_name=avg_cost.name, exec_strategy=strategy)
loss_name=avg_cost.name,
exec_strategy=strategy,
num_trainers=trainers,
trainer_id=trainer_id)
feed_var_list = [ feed_var_list = [
var for var in trainer_prog.global_block().vars.itervalues() var for var in trainer_prog.global_block().vars.itervalues()
......
...@@ -56,7 +56,7 @@ class TestDistSeResneXt2x2(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestDistSeResneXt2x2(unittest.TestCase):
except os.error: except os.error:
retry_times -= 1 retry_times -= 1
def non_test_with_place(self): def test_with_place(self):
# *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN # *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
required_envs = { required_envs = {
"PATH": os.getenv("PATH"), "PATH": os.getenv("PATH"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册