未验证 提交 765a2ada 编写于 作者: Z ziyoujiyi 提交者: GitHub

统一ps:heter ps 二阶段单测通过 (#39468)

* delete gloo connect retry

* the_one_ps dirs reconstruct

* .

* .

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* refactor ps optimize

* refactor ps optimize

* refactor ps optimize

* .

* .

* .

* .

* .

* .

* refactor theoneps

* the_one_ps

* add ps pass unittest

* add ps pass unittest

* ps unitest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* ps unittest ready

* ps unittest ready

* solve dist_pass init conflict

* solve import CommContext error

* unittest ok

* implement AllocateFrom

* solve setup.py.in conflict

* solve conflict

* solve conflict

* solve conflict

* .

* .

* cpu-async-ps minimize test ok & gpu minimize test ok

* add heter 2stage unittest

* add heter 2stage unittest

* add heter 2stage unittest
Co-authored-by: Nzkh2016 <zhangkaihuo@baidu.com>
上级 2f642159
...@@ -21,25 +21,6 @@ from .pass_base import PassBase, register_pass ...@@ -21,25 +21,6 @@ from .pass_base import PassBase, register_pass
from paddle.fluid.transpiler.details.program_utils import delete_ops from paddle.fluid.transpiler.details.program_utils import delete_ops
from paddle.fluid.transpiler.collective import SingleProcessMultiThread from paddle.fluid.transpiler.collective import SingleProcessMultiThread
OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
SPARSE_GRAD_OP_TYPE_DICT = {
"lookup_table_grad": "W",
"lookup_table_v2_grad": "W"
}
DEVICE_LIST = ["cpu", "gpu", "xpu"]
COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"]
DEFAULT_DEVICE = 'cpu'
@register_pass("append_send_ops_pass") @register_pass("append_send_ops_pass")
class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用 class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用
...@@ -894,6 +875,100 @@ class SplitTrainerOpsPass(PassBase): ...@@ -894,6 +875,100 @@ class SplitTrainerOpsPass(PassBase):
def _check_conflict(self, other_pass): def _check_conflict(self, other_pass):
return True return True
def _replace_ops_by_communicate_op(self, program, attrs, heter_block_index,
ops_list, block_var_detail):
all_op = program.global_block().ops
start_op = ops_list[0]
first_op_idx = -1
for op in all_op:
if str(op) == str(start_op):
first_op_idx = all_op.index(op)
break
assert first_op_idx != -1
self._delete_same_ops(program.global_block(), ops_list)
entrance_var = []
role_maker = attrs['role_maker']
if heter_block_index == 1:
next_heter_worker_endpoints = get_next_stage_trainers(role_maker)
entrance_var = block_var_detail[heter_block_index]["forward"][
"entrance"]
comm_info = get_communicate_var_info(program, heter_block_index + 1,
entrance_var)
program.global_block()._insert_op(
index=first_op_idx,
type="send_and_recv",
inputs={"X": program.global_block().vars[entrance_var[0]]},
outputs={"Out": []},
attrs={
"mode": "forward",
"send_var_name": entrance_var + ["microbatch_id"],
"recv_var_name": [],
"message_name": comm_info["block_input_var_name"],
"next_endpoints": next_heter_worker_endpoints,
"previous_endpoints": [],
"trainer_id": get_role_id(role_maker),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
return entrance_var
def _delete_same_ops(self, block, ops):
for op in ops:
try:
for origin_op in block.ops:
if str(origin_op) == str(op):
idx = list(block.ops).index(origin_op)
block._remove_op(idx)
break
except Exception as e:
print(e)
def _remove_var_pair_by_grad(self, var_name, attrs):
for index, pair in enumerate(attrs['merged_variables_pairs']):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del attrs['merged_variables_pairs'][index]
for index, pair in enumerate(attrs['merged_dense_pairs']):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del attrs['merged_dense_pairs'][index]
return
for index, pair in enumerate(attrs['merged_sparse_pairs']):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del attrs['merged_sparse_pairs'][index]
return
def _remove_trainer_send_op(self, program, attrs, heter_block_index,
block_var_detail):
# if trainer do FF->BP->SEND, it has follow vars: var, var@GRAD
# if trainer only do SEND, it has one var: var@GRAD
# Delete Send op ,if trainer doesn't has pair var (var<->var@GRAD)
persistables = block_var_detail[heter_block_index]["forward"]["persistables"] + \
block_var_detail[heter_block_index]["backward"]["persistables"]
need_remove_send_op = []
need_remove_grad_var = []
for op in find_send_op(program):
input_list, _ = find_op_input_output(program,
program.global_block(), op)
for var_name in input_list:
origin_var_name = var_name.split("@GRAD")[0]
if origin_var_name in persistables:
need_remove_send_op.append(op)
need_remove_grad_var.append(var_name)
need_remove_send_op = list(set(need_remove_send_op))
delete_ops(program.global_block(), need_remove_send_op)
for grad_var_name in need_remove_grad_var:
self._remove_var_pair_by_grad(grad_var_name, attrs)
def _create_trainer_program(self, program, origin_program, attrs, def _create_trainer_program(self, program, origin_program, attrs,
program_block_ops_list, block_var_detail): program_block_ops_list, block_var_detail):
# This function mainly includes the following contents: # This function mainly includes the following contents:
...@@ -911,18 +986,18 @@ class SplitTrainerOpsPass(PassBase): ...@@ -911,18 +986,18 @@ class SplitTrainerOpsPass(PassBase):
ops_list = program_block_ops_list[heter_block_index][ ops_list = program_block_ops_list[heter_block_index][
"forward"] + program_block_ops_list[heter_block_index][ "forward"] + program_block_ops_list[heter_block_index][
"backward"] "backward"]
static_var += replace_ops_by_communicate_op( static_var += self._replace_ops_by_communicate_op(
program, attrs, heter_block_index, ops_list, block_var_detail) program, attrs, heter_block_index, ops_list, block_var_detail)
remove_trainer_send_op(program, attrs, heter_block_index, self._remove_trainer_send_op(program, attrs, heter_block_index,
block_var_detail) block_var_detail)
optimizer_block = [] optimizer_block = []
grad_to_block_id = [] grad_to_block_id = []
bp_ops_list = program_block_ops_list[0]["backward"] bp_ops_list = program_block_ops_list[0]["backward"]
delete_same_ops(program.global_block(), bp_ops_list) self._delete_same_ops(program.global_block(), bp_ops_list)
delete_trainer_useless_var(attrs, program, static_var) delete_trainer_useless_var(program, static_var)
backward_block = create_backward_block(program, origin_program, attrs, backward_block = create_backward_block(program, origin_program,
bp_ops_list, block_var_detail) bp_ops_list, block_var_detail)
bp_entrance_vars = block_var_detail[0]["backward"]["entrance"] bp_entrance_vars = block_var_detail[0]["backward"]["entrance"]
......
...@@ -186,10 +186,10 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder): ...@@ -186,10 +186,10 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
add_lr_decay_table_pass.apply([], [], self.pass_ctx) add_lr_decay_table_pass.apply([], [], self.pass_ctx)
distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs) distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs)
distributed_ops_pass.apply([self.cloned_main], [], self.pass_ctx) distributed_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
delete_optimizer_pass = new_pass("delete_optimizer_pass", self.attrs) delete_optimizer_pass = new_pass("delete_optimizer_pass", self.attrs)
delete_optimizer_pass.apply([None], [_startup], self.pass_ctx) delete_optimizer_pass.apply([self.cloned_main], [None], self.pass_ctx)
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs) append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
append_send_ops_pass.apply([self.cloned_main], [None], self.pass_ctx) append_send_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
...@@ -210,12 +210,13 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder): ...@@ -210,12 +210,13 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
else: else:
split_trainer_ops_pass = new_pass("split_trainer_ops_pass", split_trainer_ops_pass = new_pass("split_trainer_ops_pass",
self.attrs) self.attrs)
split_trainer_ops_pass([self.cloned_main], [], self.pass_ctx) split_trainer_ops_pass.apply([self.cloned_main], [None],
self.pass_ctx)
set_heter_pipeline_opt_pass = new_pass('set_heter_pipeline_opt_pass', set_heter_pipeline_opt_pass = new_pass('set_heter_pipeline_opt_pass',
self.attrs) self.attrs)
set_heter_pipeline_opt_pass.apply([self.cloned_main], set_heter_pipeline_opt_pass.apply([self.cloned_main],
[self.cloned_startup], pass_ctx) [self.cloned_startup], self.pass_ctx)
if self.launch_barrier and self.launch_barrier_flag: if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints) wait_server_ready(server_endpoints)
...@@ -228,7 +229,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder): ...@@ -228,7 +229,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
ps_set_heter_pipeline_opt_pass = new_pass( ps_set_heter_pipeline_opt_pass = new_pass(
"set_heter_pipeline_opt_pass", self.attrs) "set_heter_pipeline_opt_pass", self.attrs)
ps_set_heter_pipeline_opt_pass.apply( ps_set_heter_pipeline_opt_pass.apply(
[self.loss.block.program], [startup_program], self.pass_ctx) [self.cloned_main], [self.cloned_startup], self.pass_ctx)
elif self.attrs['is_server']: elif self.attrs['is_server']:
self._build_pserver_programs() self._build_pserver_programs()
......
...@@ -42,9 +42,17 @@ RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC ...@@ -42,9 +42,17 @@ RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
backward = core.op_proto_and_checker_maker.OpRole.Backward
DEVICE_LIST = ["cpu", "gpu", "xpu"]
COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"]
SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"] SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"]
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"} SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
SPARSE_GRAD_OP_TYPE_DICT = {
"lookup_table_grad": "W",
"lookup_table_v2_grad": "W"
}
DEFAULT_DEVICE = 'cpu'
def logger_config(log_path, logging_name): def logger_config(log_path, logging_name):
...@@ -640,6 +648,20 @@ def find_block_joints(program, program_block_ops_list, heter_ops): ...@@ -640,6 +648,20 @@ def find_block_joints(program, program_block_ops_list, heter_ops):
return block_var_detail return block_var_detail
def find_ops_list_input_output(program, ops_list):
input_var_list = []
output_var_list = []
for op in ops_list:
inputs = _get_input_map_from_op(program.global_block().vars, op)
input_var_list += get_varlist_from_op_map(inputs)
outputs = _get_output_map_from_op(program.global_block().vars, op)
output_var_list += get_varlist_from_op_map(outputs)
input_var_list = list(set(input_var_list))
output_var_list = list(set(output_var_list))
return input_var_list, output_var_list
def find_entrance_exit_private(program, program_block_ops_list): def find_entrance_exit_private(program, program_block_ops_list):
block_var_detail = [] block_var_detail = []
persistables = [] persistables = []
...@@ -850,6 +872,54 @@ def _get_output_map_from_op(varmap, op): ...@@ -850,6 +872,54 @@ def _get_output_map_from_op(varmap, op):
return iomap return iomap
def get_varlist_from_op_map(var_map):
var_list = []
for key, varlist in six.iteritems(var_map):
if not isinstance(varlist, list):
varlist = [varlist]
for i in range(len(varlist)):
var = varlist[i]
var_list.append(var.name)
return var_list
def _get_input_map_from_op(varmap, op):
"""Returns a dict from op input name to the vars in varmap."""
iomap = collections.OrderedDict()
for key in op.input_names:
vars = []
for varname in op.input(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def screen_persistables(program, var_list):
need_remove = []
for var_name in var_list:
if "@GRAD" in var_name:
if "GRAD" != var_name.split("@")[-1]:
continue
origin_var_name = var_name.split("@GRAD")[0]
var = program.global_block().vars[origin_var_name]
else:
var = program.global_block().vars[var_name]
if fluid.io.is_persistable(var):
need_remove.append(var_name)
for var_name in need_remove:
var_list.remove(var_name)
return need_remove
def block_append_op(program, origin_program, block, op): def block_append_op(program, origin_program, block, op):
merge_ordereddict = origin_program.global_block().vars.copy() merge_ordereddict = origin_program.global_block().vars.copy()
merge_ordereddict.update(block.vars) merge_ordereddict.update(block.vars)
...@@ -1154,6 +1224,84 @@ def get_param_grads(origin_program): ...@@ -1154,6 +1224,84 @@ def get_param_grads(origin_program):
return sparse_param_grads, dense_param_grads return sparse_param_grads, dense_param_grads
def delete_ops(block, ops):
for op in ops:
try:
idx = list(block.ops).index(op)
block._remove_op(idx)
except Exception as e:
print(e)
def find_send_op(program):
send_op_list = []
for op in program.global_block().ops:
if op.type == "send":
send_op_list.append(op)
return send_op_list
def find_op_input_output(program, block, op):
input_var_list = []
output_var_list = []
inputs = _get_input_map_from_op(block.vars, op)
input_var_list += get_varlist_from_op_map(inputs)
outputs = _get_output_map_from_op(block.vars, op)
output_var_list += get_varlist_from_op_map(outputs)
input_var_list = list(set(input_var_list))
output_var_list = list(set(output_var_list))
return input_var_list, output_var_list
def get_vars_name_in_block(block):
vars_list = block.vars.keys()
vars_name_list = [var_name for var_name in vars_list]
return vars_name_list
def delete_trainer_useless_var(program, static_var):
static_var = list(set(static_var))
program_useful_var_list = []
for op in program.global_block().ops:
input_var_list, output_var_list = find_op_input_output(
program, program.global_block(), op)
op_var_list = list(set(input_var_list).union(set(output_var_list)))
program_useful_var_list = list(
set(program_useful_var_list).union(set(op_var_list)))
program_useful_var_list += static_var
program_useless_var_list = list(
set(get_vars_name_in_block(program.global_block())).difference(
set(program_useful_var_list)))
for var in program_useless_var_list:
program.global_block()._remove_var(var)
return program_useless_var_list
def create_backward_block(program, origin_program, bp_ops_list,
block_var_detail):
pre_block_idx = program.num_blocks - 1
heter_block = program._create_block(pre_block_idx)
for _, op in enumerate(bp_ops_list):
if op.type == "send":
send_varnames = op.attr('send_varnames')
is_skip = False
for varname in send_varnames:
if varname not in program.global_block(
).vars and varname not in heter_block.vars:
is_skip = True
break
if is_skip == True:
continue
block_append_op(program, origin_program, heter_block, op)
entrance_vars = block_var_detail[0]["backward"]["entrance"]
add_vars_by_var_list(entrance_vars, origin_program, program, heter_block)
exit_vars = block_var_detail[0]["backward"]["exit"]
add_vars_by_var_list(exit_vars, origin_program, program, heter_block)
return heter_block
def debug_program(file, program, is_trainer): def debug_program(file, program, is_trainer):
if is_trainer: if is_trainer:
with open(file, 'w+') as f: with open(file, 'w+') as f:
......
...@@ -22,6 +22,7 @@ import inspect ...@@ -22,6 +22,7 @@ import inspect
import unittest import unittest
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
from paddle.distributed.ps.utils.public import logger
from dist_pass_test_base import prepare_python_path_and_return_module, remove_path_if_exists from dist_pass_test_base import prepare_python_path_and_return_module, remove_path_if_exists
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
...@@ -37,7 +38,7 @@ class PsPassTestBase(unittest.TestCase): ...@@ -37,7 +38,7 @@ class PsPassTestBase(unittest.TestCase):
print('Ps tearDown...') print('Ps tearDown...')
def ps_launch(self, config, ps_mode="cpu-ps"): def ps_launch(self, config, ps_mode="cpu-ps"):
if ps_mode == "cpu-ps": if ps_mode == "cpu-ps" or ps_mode == 'heter-ps':
os.environ['WITH_DISTRIBUTE'] = 'ON' os.environ['WITH_DISTRIBUTE'] = 'ON'
cmd = [ cmd = [
...@@ -45,7 +46,16 @@ class PsPassTestBase(unittest.TestCase): ...@@ -45,7 +46,16 @@ class PsPassTestBase(unittest.TestCase):
"-u", "-u",
] + [ ] + [
"-m", "launch", "--log_dir", config['log_dir'], "--worker_num", "-m", "launch", "--log_dir", config['log_dir'], "--worker_num",
config['worker_num'], "--server_num", config['server_num'], config['worker_num'], "--server_num", config['server_num']
]
if ps_mode == 'heter-ps':
os.environ['FLAGS_START_PORT'] = '12004'
cmd += [
'--heter_worker_num', config['heter_worker_num'],
'--heter_devices', config['heter_devices']
]
cmd += [
"../ps/ps_dnn_trainer.py", "-m", config['ps_mode_config'], "../ps/ps_dnn_trainer.py", "-m", config['ps_mode_config'],
"--run_minimize", config['run_minimize'], "--run_single_pass", "--run_minimize", config['run_minimize'], "--run_single_pass",
config['run_single_pass'], "--debug_new_pass", config['run_single_pass'], "--debug_new_pass",
......
...@@ -63,6 +63,27 @@ class TestPsTrainerPass(PsPassTestBase): ...@@ -63,6 +63,27 @@ class TestPsTrainerPass(PsPassTestBase):
self.check() self.check()
# heter ps 三阶段待测
def test_ps_optimizer_minimize_heter(self):
self.init()
self.config['worker_num'] = "2"
self.config['server_num'] = "2"
self.config['heter_worker_num'] = '2'
self.config['heter_devices'] = 'gpu'
self.config['run_minimize'] = '1'
self.config['ps_mode_config'] = "../ps/heter_ps_config.yaml"
self.config['debug_new_minimize'] = '0'
self.config['log_dir'] = "/heter_log_old_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config, 'heter-ps')
self.config['debug_new_minimize'] = '1'
self.config['log_dir'] = "/heter_log_new_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config, 'heter-ps')
def test_ps_optimizer_minimize_gpu(self): def test_ps_optimizer_minimize_gpu(self):
self.init() self.init()
self.config['run_minimize'] = '1' self.config['run_minimize'] = '1'
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
hyper_parameters:
optimizer:
class: Adam
learning_rate: 0.0001
strategy: async # 有用
sparse_inputs_slots: 27
sparse_feature_number: 1024
sparse_feature_dim: 11
dense_input_dim: 13
fc_sizes: [512, 256, 128, 32]
distributed_embedding: 0
runner:
sync_mode: "heter"
thread_num: 8
micro_num: 8 # micro batch num for each thread
pipeline: True
model_path: "../ps_dnn_model.py"
...@@ -23,7 +23,6 @@ import yaml, six, copy ...@@ -23,7 +23,6 @@ import yaml, six, copy
import paddle import paddle
import os import os
import warnings import warnings
import logging
import ast import ast
import numpy as np import numpy as np
import struct import struct
...@@ -176,6 +175,10 @@ def get_user_defined_strategy(config): ...@@ -176,6 +175,10 @@ def get_user_defined_strategy(config):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True strategy.a_sync = True
strategy.a_sync_configs = {"heter_worker_device_guard": "gpu"} strategy.a_sync_configs = {"heter_worker_device_guard": "gpu"}
strategy.pipeline = True
strategy.pipeline_configs = {
"accumulate_steps": config.get('runner.micro_num')
}
elif sync_mode == "gpubox": elif sync_mode == "gpubox":
print("sync_mode = {}".format(sync_mode)) print("sync_mode = {}".format(sync_mode))
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
...@@ -328,6 +331,7 @@ class DnnTrainer(object): ...@@ -328,6 +331,7 @@ class DnnTrainer(object):
if self.config['debug_new_minimize'] == 1: if self.config['debug_new_minimize'] == 1:
logger.info("entering run_minimize -- new") logger.info("entering run_minimize -- new")
self.role_maker._generate_role() # 必要
from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer
ps_optimizer = ParameterServerOptimizer(inner_optimizer) ps_optimizer = ParameterServerOptimizer(inner_optimizer)
ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer, ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer,
......
...@@ -17,6 +17,7 @@ import paddle.nn as nn ...@@ -17,6 +17,7 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
import math import math
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from paddle.distributed.ps.utils.public import logger
class DNNLayer(nn.Layer): class DNNLayer(nn.Layer):
...@@ -77,8 +78,13 @@ class DNNLayer(nn.Layer): ...@@ -77,8 +78,13 @@ class DNNLayer(nn.Layer):
y_dnn = paddle.concat(x=sparse_embs + [dense_inputs], axis=1) y_dnn = paddle.concat(x=sparse_embs + [dense_inputs], axis=1)
for n_layer in self._mlp_layers: if self.sync_mode == 'heter':
y_dnn = n_layer(y_dnn) with paddle.fluid.device_guard('gpu'):
for n_layer in self._mlp_layers:
y_dnn = n_layer(y_dnn)
else:
for n_layer in self._mlp_layers:
y_dnn = n_layer(y_dnn)
return y_dnn return y_dnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册