未验证 提交 196dbfc2 编写于 作者: Z ziyoujiyi 提交者: GitHub

ps optimize refactor (#38982)

* delete gloo connect retry

* the_one_ps dirs reconstruct

* .

* .

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* refactor ps optimize

* refactor ps optimize

* refactor ps optimize

* .

* .

* .

* .

* .

* .

* refactor theoneps

* the_one_ps

* add ps pass unittest

* add ps pass unittest

* ps unitest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* ps unittest ready

* ps unittest ready

* solve dist_pass init conflict

* solve import CommContext error

* unittest ok

* implement AllocateFrom

* solve setup.py.in conflict

* solve conflict

* solve conflict

* solve conflict

* .

* .
Co-authored-by: Nzkh2016 <zhangkaihuo@baidu.com>
上级 de0bad2a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from paddle import fluid
import paddle.distributed.passes
from .meta_optimizer_base import MetaOptimizerBase
from paddle.fluid import core
import subprocess
import re
import os
import platform
from paddle.distributed.ps.utils.public import *
from paddle.distributed.passes import PassContext
from ..base.private_helper_function import wait_server_ready
from paddle.distributed.ps.utils.ps_factory import PsProgramBuilderFactory
class ParameterServerOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(ParameterServerOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = []
self.attrs = {}
self.pass_ctx = PassContext()
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
user_defined_strategy):
super(ParameterServerOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)
def _init_ps_pass_context(self, loss, startup_program):
# trainer
self.attrs["env"] = get_dist_env()
self.attrs['loss'] = loss
self.attrs['min_block_size'] = 81920
self.attrs['origin_main_program'] = loss.block.program
self.attrs['origin_startup_program'] = startup_program
self.attrs['cloned_main'] = loss.block.program.clone()
self.attrs['cloned_startup'] = startup_program.clone()
self.attrs['user_defined_strategy'] = self.user_defined_strategy
self.attrs['trainer'] = TrainerRuntimeConfig(self.user_defined_strategy)
self.attrs['ps_mode'] = self.attrs['trainer'].mode
self.attrs['role_maker'] = self.role_maker
self.attrs[
'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode
self.attrs['is_worker'] = self.role_maker._is_worker()
self.attrs['is_server'] = self.role_maker._is_server()
self.attrs['is_heter_worker'] = self.role_maker._is_heter_worker()
self.attrs['use_ps_gpu'] = self.user_defined_strategy.a_sync_configs[
"use_ps_gpu"]
self.attrs[
'lr_decay_steps'] = self.user_defined_strategy.a_sync_configs[
"lr_decay_steps"]
self.attrs['k_steps'] = self.user_defined_strategy.a_sync_configs[
"k_steps"]
self.attrs[
'launch_barrier'] = self.user_defined_strategy.a_sync_configs[
"launch_barrier"]
self.attrs['launch_barrier_flag'] = int(
os.getenv("FLAGS_LAUNCH_BARRIER", "1"))
build_var_distributed(self.attrs)
# server
self.attrs['_main_server'] = fluid.Program()
self.attrs['_startup_server'] = fluid.Program()
self.attrs['tensor_table'] = {}
self.pass_ctx._attrs = self.attrs
def _is_graph_out(self):
return False
def _can_apply(self):
if self._attrs['role_maker']._is_collective or self._attrs[
'k_steps'] < 0:
return False
return True
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
self.inner_opt.minimize(loss, startup_program, parameter_list,
no_grad_set)
if startup_program == None:
startup_program = paddle.static.default_startup_program()
self._init_ps_pass_context(loss, startup_program)
ps_builder = PsProgramBuilderFactory()._create_ps_program_builder(
self.pass_ctx)
ps_builder._build_programs()
return None, None
def _can_apply_geo(self, program):
def get_sys_free_mem():
plat = platform.system()
if platform.system() == "Darwin":
vm = subprocess.Popen(
['vm_stat'], stdout=subprocess.PIPE).communicate()[0]
# Process vm_stat
vmLines = vm.split('\n')
sep = re.compile(r':[\s]+')
vmStats = {}
for row in range(1, len(vmLines) - 2):
rowText = vmLines[row].strip()
rowElements = sep.split(rowText)
vmStats[(rowElements[0]
)] = int(rowElements[1].strip(r'\.')) * 4096
return vmStats["Pages free"]
elif platform.system() == "Linux":
mems = {}
with open('/proc/meminfo', 'rb') as f:
for line in f:
fields = line.split()
mems[fields[0]] = int(fields[1]) * 1024
free = mems[b'MemFree:']
return free
else:
raise ValueError(
"%s platform is unsupported is parameter server optimizer" %
(platform.system()))
if not isinstance(self.inner_opt, fluid.optimizer.SGDOptimizer):
return False
free = get_sys_free_mem()
processed_var_names = set(["@EMPTY@"])
param_memory_size = 0
for varname in program.global_block().vars:
var = program.global_block().vars[varname]
if not var.persistable or var.desc.type(
) != core.VarDesc.VarType.LOD_TENSOR:
continue
set_var_lod_type(var)
param_memory_size += get_var_mem_size(var)
processed_var_names.add(varname)
upper_mem_use = param_memory_size * 5.0
program_tmp_vars = dict()
eval_batch_size = 1024
for op in program.global_block().ops:
for var_name in op.output_arg_names:
if var_name in processed_var_names:
continue
processed_var_names.add(var_name)
var = program.global_block().vars[var_name]
if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
continue
data_count = 1
neg_dim_count = 0
for x in var.shape:
if x < 0:
if neg_dim_count >= 1:
raise ValueError(
"Var %s has more than one negative dim." %
(var_name))
neg_dim_count += 1
data_count *= (-x)
else:
data_count *= x
program_tmp_vars[var_name] = (
data_count, neg_dim_count,
vars_metatools.dtype_to_size[var.dtype])
for varname in program_tmp_vars:
data_count, neg_dim_count, type_size = program_tmp_vars[varname]
if neg_dim_count == 1:
data_count *= eval_batch_size
var_memory = data_count * type_size
upper_mem_use += var_memory
if upper_mem_use < free:
return True
else:
return False
def _enable_strategy(self, dist_strategy, context):
if dist_strategy.a_sync_configs["k_steps"] >= 0:
return
dist_strategy.a_sync = True
is_geo = self._can_apply_geo(context["origin_main_program"])
dist_strategy.a_sync_configs["k_steps"] = 800 if is_geo else 0
def _disable_strategy(self, dist_strategy):
dist_strategy.a_sync = False
dist_strategy.a_sync_configs["k_steps"] = -1
......@@ -19,6 +19,10 @@ from .auto_parallel_sharding import *
from .auto_parallel_amp import *
from .auto_parallel_recompute import *
from .cpp_pass import *
import os
if os.getenv("WITH_DISTRIBUTE") == "ON":
from .ps_trainer_pass import *
from .ps_server_pass import *
__all__ = [
'new_pass',
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from ..ps.utils.public import *
from paddle.framework import core
from .pass_base import PassBase, register_pass
from paddle.optimizer.lr import LRScheduler
from paddle.optimizer.lr import ExponentialDecay, NoamDecay, PiecewiseDecay, NaturalExpDecay, InverseTimeDecay
from paddle.fluid.layers.learning_rate_scheduler import exponential_decay, noam_decay, piecewise_decay, natural_exp_decay, inverse_time_decay
@register_pass("add_lr_decay_table_pass")
class AddLrDecayTablePass(PassBase):
def __init__(self):
super(AddLrDecayTablePass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _add_tensor_table(self,
attrs,
feed_var_name,
fetch_var_name="",
startup_program=None,
main_program=None,
tensor_table_class=""):
tensor_table_dict = {}
tensor_table_dict[feed_var_name] = {}
tensor_table_dict[feed_var_name]["feed_var_name"] = feed_var_name
tensor_table_dict[feed_var_name]["fetch_var_name"] = fetch_var_name
tensor_table_dict[feed_var_name]["startup_program"] = startup_program
tensor_table_dict[feed_var_name]["main_program"] = main_program
tensor_table_dict[feed_var_name][
"tensor_table_class"] = tensor_table_class
attrs['tensor_table'] = tensor_table_dict
def _get_lr_sheduler_program(self, lr_sheduler, lr_decay_steps):
schedler_decay = [
'NoamDecay', 'NaturalExpDecay', 'InverseTimeDecay',
'ExponentialDecay'
]
decay_main_program = fluid.framework.Program()
decay_startup_program = fluid.framework.Program()
lr_name = ""
if isinstance(lr_sheduler, ExponentialDecay):
with fluid.program_guard(decay_main_program, decay_startup_program):
lr = exponential_decay(1.0, lr_decay_steps, lr_sheduler.gamma,
True)
lr_name = lr.name
logging.warn(
"ExponentialDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n"
"\t strategy = paddle.distributed.fleet.DistributedStrategy() \n "
"\t strategy.a_sync = True \n"
"\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n"
% lr_decay_steps)
elif isinstance(lr_sheduler, NoamDecay):
with fluid.program_guard(decay_main_program, decay_startup_program):
lr = noam_decay(lr_sheduler.d_model, lr_sheduler.warmup_steps,
1.0)
lr_name = lr.name
logging.warn("NoamDecay is set, warmup steps is [ %d ]" %
lr_sheduler.warmup_steps)
elif isinstance(lr_sheduler, NaturalExpDecay):
with fluid.program_guard(decay_main_program, decay_startup_program):
lr = natural_exp_decay(1.0, lr_decay_steps, lr_sheduler.gamma,
True)
lr_name = lr.name
logging.warn(
"NaturalExpDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n"
"\t strategy = paddle.distributed.fleet.DistributedStrategy() \n "
"\t strategy.a_sync = True \n"
"\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n"
% lr_decay_steps)
elif isinstance(lr_sheduler, InverseTimeDecay):
with fluid.program_guard(decay_main_program, decay_startup_program):
lr = inverse_time_decay(1.0, lr_decay_steps, lr_sheduler.gamma,
True)
lr_name = lr.name
logging.warn(
"InverseTimeDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n"
"\t strategy = paddle.distributed.fleet.DistributedStrategy() \n "
"\t strategy.a_sync = True \n"
"\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n"
% lr_decay_steps)
else:
raise ValueError(
"Not supported current LearningRate strategy, please use follow decay strategy: {}".
format(schedler_decay))
return decay_main_program, decay_startup_program, lr_name
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
if hasattr(attrs['origin_main_program'], 'lr_sheduler') == False:
return
assert isinstance(attrs['origin_main_program'].lr_sheduler,
LRScheduler), "must be LRScheduler"
ops = get_optimize_ops(attrs['origin_main_program'])
lr_decay_main_program, lr_decay_startup_program, lr_name = _get_lr_sheduler_program(
attrs['origin_main_program'].lr_sheduler, attrs['lr_decay_steps'])
_add_tensor_table(attrs, "@LR_DECAY_COUNTER@", lr_name,
lr_decay_startup_program, lr_decay_main_program,
"GlobalStepTable")
return
@register_pass("add_listen_and_serv_pass")
class AddListenAndServPass(PassBase):
def __init__(self):
super(AddListenAndServPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
opt = {
"grad_to_block_id": None,
"sparse_grad_to_param": None,
"lr_decay_block_id": None,
"dense_optimize_blocks": None,
"sparse_optimize_blocks": None,
# runtime attribute
"endpoint": get_ps_endpoint(attrs['role_maker']),
"pserver_id": get_role_id(attrs['role_maker']),
"Fanin": get_trainers(attrs['role_maker']),
"distributed_mode": attrs['ps_mode'],
"rpc_get_thread_num": -1,
"rpc_send_thread_num": -1,
"rpc_prefetch_thread_num": -1
}
main_program.global_block().append_op(
type="listen_and_serv", inputs={'X': []}, outputs={}, attrs=opt)
attrs['cloned_main'] = main_program
@register_pass("add_rpc_global_flags_pass")
class AddRpcGlobalFlagsPass(PassBase):
def __init__(self):
super(AddRpcGlobalFlagsPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
pass
@register_pass("add_optimizer_pass")
class AddOptimizerPass(PassBase):
def __init__(self):
super(AddOptimizerPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
pass
@register_pass("add_geo_optimizer_pass")
class AddGeoOptimizerPass(PassBase):
def __init__(self):
super(AddGeoOptimizerPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
pass
@register_pass("build_pserver_startup_program_pass")
class BuildPserverStartupProgramPass(PassBase):
def __init__(self):
super(BuildPserverStartupProgramPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
pass
@register_pass("delete_unused_in_startup_pass")
class DeleteUnusedInStartupPass(PassBase):
def __init__(self):
super(DeleteUnusedInStartupPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
pass
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.compat as cpt
from ..ps.utils.public import *
from paddle.framework import core
from .pass_base import PassBase, register_pass
from paddle.fluid.transpiler.details.program_utils import delete_ops
from paddle.fluid.transpiler.collective import SingleProcessMultiThread
OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
SPARSE_GRAD_OP_TYPE_DICT = {
"lookup_table_grad": "W",
"lookup_table_v2_grad": "W"
}
DEVICE_LIST = ["cpu", "gpu", "xpu"]
COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"]
DEFAULT_DEVICE = 'cpu'
@register_pass("append_send_ops_pass")
class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用
def __init__(self):
super(AppendSendOpsPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _append_send_op(self, program, union_vars, queue, is_sparse, table_id,
ps_mode):
if queue == STEP_COUNTER:
send_input_vars = []
else:
send_input_vars = [
program.global_block().vars[union_var]
for union_var in union_vars
]
dummy_output = []
if ps_mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
program.global_block().append_op(
type="send",
inputs={"X": send_input_vars},
outputs={"Out": dummy_output},
attrs={
"send_varnames": [queue],
"is_sparse": is_sparse,
"table_id": table_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
return dummy_output
def _append_barrier_op(self, program, dummys):
program.global_block().append_op(
type="send_barrier",
inputs={"X": dummys},
outputs={"Out": []},
attrs={
"trainer_id": trainer_id,
"half_async": True,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
ps_mode = attrs['ps_mode']
if ps_mode == DistributedMode.GEO:
send_ctx = get_geo_trainer_send_context(attrs) # geo 模式
else:
send_ctx = get_the_one_send_context(attrs) # async、sync 等各种模式
dummys = []
for merged_name, send in send_ctx.items():
if send.is_sparse() and ps_mode != DistributedMode.GEO:
continue
is_sparse = 1 if send.is_sparse() else 0
is_sparse = 2 if send.is_distributed() else is_sparse
dummys.append(
self._append_send_op(main_program,
send.origin_varnames(), merged_name,
is_sparse, send.table_id(), ps_mode))
if ps_mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
self._append_barrier_op(main_program, dummys)
@register_pass("distributed_ops_pass")
class DistributedOpsPass(PassBase):
def __init__(self):
super(DistributedOpsPass, self).__init__()
self.w_2_table_id = {}
self.emb_size = {}
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _push_sparse_fuse(self, _program, push_sparse_ops, attrs):
if attrs['use_ps_gpu']:
return
if len(push_sparse_ops) == 0:
return
show = None
clk = None
use_entry = False
for param, ops in push_sparse_ops.items():
op_first = ops[0]
break
if op_first.has_attr("entry"):
entry = op_first.attr("entry")
entry = entry.split(':')
if len(entry) == 3 and entry[0] == 'show_click_entry':
show_var_name = entry[1]
click_var_name = entry[2]
if show_var_name in _program.global_block(
).vars and click_var_name in _program.global_block().vars:
show = _program.global_block().vars[show_var_name]
clk = _program.global_block().vars[click_var_name]
use_entry = True
else:
warnings.warn(
'ShowClickEntry configured, but cannot find show/click var, will not use'
)
if not use_entry:
print('ShowClickEntry not configured, will not use')
show = _program.global_block().create_var(
name="show",
dtype=core.VarDesc.VarType.INT64,
persistable=False,
stop_gradient=True)
_program.global_block()._insert_op(
index=0,
type='fill_constant',
inputs={},
outputs={'Out': show},
attrs={
'shape': [1],
'dtype': show.dtype,
'value': 1,
})
clk = _program.global_block().create_var(
name="clk",
dtype=core.VarDesc.VarType.INT64,
persistable=False,
stop_gradient=True)
_program.global_block()._insert_op(
index=0,
type='fill_constant',
inputs={},
outputs={'Out': clk},
attrs={
'shape': [1],
'dtype': clk.dtype,
'value': 0,
})
for param, ops in push_sparse_ops.items():
all_ops = _program.global_block().ops
op_idxs = [all_ops.index(op) for op in ops]
inputs = [
_program.global_block().vars[op.input("Ids")[0]] for op in ops
]
w = _program.global_block().vars[ops[0].output("W@GRAD")[0]]
table_id = self.w_2_table_id[param]
padding_idx = ops[0].attr("padding_idx")
is_distributed = ops[0].attr("is_distributed")
op_type = ops[0].type
outputs = [
_program.global_block().vars[op.input("Out@GRAD")[0]]
for op in ops
]
for idx in op_idxs[::-1]:
_program.global_block()._remove_op(idx)
_program.global_block().append_op(
type="distributed_push_sparse",
inputs={
"Ids": inputs,
'W': w,
"Outputs": outputs,
"Shows": show,
"Clicks": clk
},
outputs={"Outputs": outputs},
attrs={
"is_distributed": is_distributed,
"padding_idx": padding_idx,
"table_id": table_id,
"size": self.emb_size[param]
})
def _pull_sparse_fuse(self, _program, pull_sparse_ops, attrs, send_ctx):
def dag_check_up_and_reorder(program, inputs, outputs):
global_block = program.global_block()
min_output_index = len(global_block.ops)
max_input_index = -1
input_indexes = [0] * len(global_block.ops)
output_indexes = [0] * len(global_block.ops)
for idx, op in enumerate(global_block.ops):
for i in range(0, len(op.output_names)):
if input_indexes[idx] == 1:
break
outs = op.output(op.output_names[i])
for in_id, in_var in enumerate(inputs):
if in_var.name in outs:
input_indexes[idx] = 1
max_input_index = max(max_input_index, idx)
break
for i in range(0, len(op.input_names)):
if output_indexes[idx] == 1:
break
ins = op.input(op.input_names[i])
for out_id, out_var in enumerate(outputs):
if out_var.name in ins:
output_indexes[idx] = 1
min_output_index = min(min_output_index, idx)
for i in range(len(global_block.ops)):
if input_indexes[i] == 1 and output_indexes[i] == 1:
warnings.warn(
"unable to re-arrange dags order to combine distributed embedding ops because a op both needs embedding table's output as input and produces ids as the same embedding table's input"
)
return
if min_output_index < max_input_index:
move_ops = []
for i in range(min_output_index + 1, len(input_indexes)):
if input_indexes[i] == 1:
move_ops.append((global_block.ops[i], i))
for i, op in enumerate(move_ops):
queue = list()
visited = set()
queue.append(op[1])
visited.add(op[0])
start = 0
while start < len(queue):
pos = queue[start]
op = global_block.ops[pos]
op_inputs = []
for k in range(0, len(op.input_names)):
ins = op.input(op.input_names[k])
op_inputs.append(ins)
for j in range(pos - 1, min_output_index - 1, -1):
op1 = global_block.ops[j]
if op1 in visited:
continue
found = False
for k in range(0, len(op1.output_names)):
outs = op1.output(op1.output_names[k])
for t in range(len(op_inputs)):
for y in op_inputs[t]:
if y in outs:
found = True
break
if found:
break
if found:
break
if found:
if output_indexes[j] == True:
warnings.warn(
"unable to re-arrange dags order to combine distributed embedding ops"
)
return
queue.append(j)
visited.add(global_block.ops[j])
start = start + 1
queue.sort()
for index in queue:
desc = global_block.desc._insert_op(min_output_index)
desc.copy_from(global_block.ops[index].desc)
global_block.desc._remove_op(index + 1, index + 2)
global_block.ops[index].desc = desc
insert_op = global_block.ops.pop(index)
input_state = input_indexes.pop(index)
output_state = output_indexes.pop(index)
global_block.ops.insert(min_output_index, insert_op)
input_indexes.insert(min_output_index, input_state)
output_indexes.insert(min_output_index, output_state)
min_output_index = min_output_index + 1
assert global_block.desc.op_size() == len(global_block.ops)
for i in range(len(global_block.ops)):
assert global_block.desc.op(i) == global_block.ops[i].desc
for param, ops in pull_sparse_ops.items():
all_ops = _program.global_block().ops
op_device = ""
if attrs['is_heter_ps_mode']:
op_device = ops[0].attr("op_device")
inputs = [
_program.global_block().vars[op.input("Ids")[0]] for op in ops
]
w = _program.global_block().vars[ops[0].input("W")[0]]
self.emb_size[param] = w.shape[1]
grad_name = attrs['param_name_to_grad_name'][w.name]
table_id = -1
for name, ctx in send_ctx.items():
if grad_name in ctx.origin_varnames():
table_id = ctx.table_id()
if table_id == -1:
raise ValueError(
"can not find suitable sparse table, please check")
self.w_2_table_id[param] = table_id
padding_idx = ops[0].attr("padding_idx")
is_distributed = ops[0].attr("is_distributed")
op_type = ops[0].type
outputs = [
_program.global_block().vars[op.output("Out")[0]] for op in ops
]
dag_check_up_and_reorder(_program, inputs, outputs)
op_idxs = [all_ops.index(op) for op in ops]
for idx in op_idxs[::-1]:
_program.global_block()._remove_op(idx)
inputs_idxs = [-1] * len(inputs)
outputs_idxs = [len(_program.global_block().ops) + 1] * len(outputs)
for idx, op in enumerate(_program.global_block().ops):
for i in range(0, len(op.output_names)):
outs = op.output(op.output_names[i])
for in_id, in_var in enumerate(inputs):
if in_var.name in outs:
inputs_idxs[in_id] = max(idx, inputs_idxs[in_id])
for i in range(0, len(op.input_names)):
ins = op.input(op.input_names[i])
for out_id, out_var in enumerate(outputs):
if out_var.name in ins:
outputs_idxs[out_id] = min(idx,
outputs_idxs[out_id])
if min(outputs_idxs) - max(inputs_idxs) >= 1:
if max(inputs_idxs) == -1:
distributed_idx = min(op_idxs)
else:
distributed_idx = max(inputs_idxs) + 1
if attrs['use_ps_gpu']:
_program.global_block()._insert_op(
index=distributed_idx,
type="pull_box_sparse",
inputs={"Ids": inputs,
'W': w},
outputs={"Out": outputs},
attrs={
"size": w.shape[1],
"is_distributed": True,
"is_sparse": True
})
else:
_program.global_block()._insert_op(
index=distributed_idx,
type="distributed_lookup_table",
inputs={"Ids": inputs,
'W': w},
outputs={"Outputs": outputs},
attrs={
"is_distributed": is_distributed,
"padding_idx": padding_idx,
"table_id": table_id,
"lookup_table_version": op_type,
"op_device": op_device
})
else:
for i in range(len(inputs_idxs)):
distributed_idx = op_idxs[i]
_program.global_block()._insert_op(
index=distributed_idx,
type="distributed_lookup_table",
inputs={"Ids": [inputs[i]],
'W': w},
outputs={"Outputs": [outputs[i]]},
attrs={
"is_distributed": is_distributed,
"padding_idx": padding_idx,
"table_id": table_id,
"lookup_table_version": op_type,
"op_device": op_device
})
def _get_pull_sparse_ops(self, _program, attrs):
pull_sparse_ops = {}
pull_sparse_ids = {}
push_sparse_ops = {}
ops = {}
for op in _program.global_block().ops:
if op.type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
if attrs['is_heter_ps_mode']:
# trick for matchnet, need to modify
param_name += op.input("Ids")[0][0]
ops = pull_sparse_ops.get(param_name, [])
ops.append(op)
pull_sparse_ops[param_name] = ops
ids = pull_sparse_ids.get(param_name, [])
ids.append(op.input("Ids")[0])
pull_sparse_ids[param_name] = ids
for op in _program.global_block().ops:
if op.type in SPARSE_GRAD_OP_TYPE_DICT.keys():
param_name = op.input(SPARSE_GRAD_OP_TYPE_DICT[op.type])[0]
if param_name in pull_sparse_ids and op.input("Ids")[
0] in pull_sparse_ids[param_name]:
ops = push_sparse_ops.get(param_name, [])
ops.append(op)
push_sparse_ops[param_name] = ops
return pull_sparse_ops, push_sparse_ops
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
pull_sparse_ops, push_sparse_ops = self._get_pull_sparse_ops(
main_program, attrs)
send_ctx = get_the_one_send_context(
attrs, split_dense_table=attrs['is_heter_ps_mode'])
self._pull_sparse_fuse(main_program, pull_sparse_ops, attrs, send_ctx)
self._push_sparse_fuse(main_program, push_sparse_ops, attrs)
@register_pass("delete_optimizer_pass")
class DeleteOptimizesPass(PassBase):
def __init__(self):
super(DeleteOptimizesPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _delete_optimizer_op_and_vars(self, _program, optimize_ops):
optimize_vars = []
optimize_op_role_vars = []
optimize_need_delete_vars = []
for op in optimize_ops:
optimize_vars.extend(op.input_arg_names)
optimize_op_role_vars.extend(op.attr("op_role_var"))
optimize_vars = list(set(optimize_vars))
optimize_op_role_vars = list(set(optimize_op_role_vars))
for var in optimize_vars:
if var not in optimize_op_role_vars:
optimize_need_delete_vars.append(var)
need_delete_optimize_vars = list(set(optimize_need_delete_vars))
delete_ops(_program.global_block(), optimize_ops)
for var in need_delete_optimize_vars:
if _program.global_block().has_var(var):
_program.global_block()._remove_var(var)
def _add_lr_var(self, main_program, attrs):
# Todo: hard code for pe
lr_var = attrs['origin_main_program'].global_block().vars[
"learning_rate_0"]
main_program.global_block().create_var(
name=lr_var.name,
shape=lr_var.shape,
dtype=lr_var.dtype,
type=lr_var.type,
lod_level=lr_var.lod_level,
persistable=True)
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
optimizer_ops = get_optimize_ops(main_program)
lr_ops = get_lr_ops(main_program)
optimizer_ops.extend(lr_ops)
self._delete_optimizer_op_and_vars(main_program, optimizer_ops)
if hasattr(attrs['origin_main_program'], 'lr_sheduler'):
self._add_lr_var(main_program, attrs)
@register_pass("delete_extra_optimizer_pass")
class DeleteExtraOptimizerPass(PassBase):
def __init__(self):
super(DeleteExtraOptimizerPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
optimize_vars = []
optimize_op_role_vars = []
optimize_need_delete_vars = []
for op in get_optimize_ops(main_program):
optimize_vars.extend(op.input_arg_names)
optimize_op_role_vars.extend(op.attr("op_role_var"))
optimize_vars = list(set(optimize_vars))
optimize_op_role_vars = list(set(optimize_op_role_vars))
for var in optimize_vars:
if var not in optimize_op_role_vars:
optimize_need_delete_vars.append(var)
need_delete_optimize_vars = list(set(optimize_need_delete_vars))
init_ops = []
for var in need_delete_optimize_vars:
param_init_op = []
for op in startup_program.global_block().ops:
if var in op.output_arg_names:
param_init_op.append(op)
init_ops.extend(param_init_op)
delete_ops(startup_program.global_block(), init_ops)
for var in need_delete_optimize_vars:
if startup_program.global_block().has_var(var):
startup_program.global_block()._remove_var(var)
@register_pass("fake_init_ops_pass")
class FakeInitOpsPass(PassBase):
def __init__(self):
super(FakeInitOpsPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _get_sparse_table_names(self, attrs):
dist_varnames = get_sparse_tablenames(attrs['origin_main_program'],
True)
sparse_varnames = get_sparse_tablenames(attrs['origin_main_program'],
False)
return list(set(dist_varnames + sparse_varnames))
def _fake_init_sparsetable(self, program, sparse_table_names):
# delete table init op
for table_name in sparse_table_names:
table_var = program.global_block().vars[table_name]
table_param_init_op = []
for op in program.global_block().ops:
if table_name in op.output_arg_names:
table_param_init_op.append(op)
init_op_num = len(table_param_init_op)
if init_op_num != 1:
raise ValueError("table init op num should be 1, now is " + str(
init_op_num))
table_init_op = table_param_init_op[0]
program.global_block().append_op(
type="fake_init",
inputs={},
outputs={"Out": table_var},
attrs={"shape": table_init_op.attr('shape')})
delete_ops(program.global_block(), table_param_init_op)
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
sparse_tables = self._get_sparse_table_names(attrs)
self._fake_init_sparsetable(startup_program, sparse_tables)
@register_pass("ps_gpu_pass")
class PsGpuPass(PassBase):
def __init__(self):
super(PsGpuPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _add_push_box_sparse_op(self, program):
for op in program.global_block().ops:
if op.type != "pull_box_sparse":
continue
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(set()), [])
for op_desc in grad_op_desc:
new_op_desc = program.global_block().desc.append_op()
new_op_desc.copy_from(op_desc)
new_op_desc._set_attr(op_role_attr_name, backward)
def _remove_optimizer_var(self, program):
embedding_w = {}
for idx, op in list(enumerate(program.global_block().ops)):
if op.type == "lookup_table_grad":
for name in op.input("W"):
embedding_w[name] = 1
optimize_vars = []
optimize_op_role_vars = []
optimize_need_delete_vars = []
for op in get_optimize_ops(program):
for name in op.input("Param"):
if name in embedding_w:
optimize_op_role_vars.extend(op.attr("op_role_var"))
for key_name in op.input_names:
if key_name == "LearningRate":
continue
for var in op.input(key_name):
optimize_vars.append(var)
optimize_vars = list(set(optimize_vars))
optimize_op_role_vars = list(set(optimize_op_role_vars))
for var in optimize_vars:
if var not in optimize_op_role_vars:
optimize_need_delete_vars.append(var)
need_delete_optimize_vars = list(set(optimize_need_delete_vars))
for name in need_delete_optimize_vars:
if program.global_block().has_var(name):
program.global_block()._remove_var(name)
def _remove_lookup_table_grad_op_and_var(self, program):
lookup_table_grad_var = {}
remove_op_index = []
remove_var = []
for idx, op in list(enumerate(program.global_block().ops)):
if op.type == "lookup_table_grad":
for name in op.output("W@GRAD"):
lookup_table_grad_var[name] = 1
remove_op_index.append(idx)
remove_var.append(name)
for name in op.input("W"):
lookup_table_grad_var[name] = 1
for idx, op in list(enumerate(program.global_block().ops)):
if op.type == "pull_box_sparse":
continue
for key_name in op.input_names:
for var in op.input(key_name):
if var in lookup_table_grad_var:
remove_op_index.append(idx)
break
remove_op_index = list(set(remove_op_index))
remove_op_index.sort(reverse=True)
for idx in remove_op_index:
program.global_block()._remove_op(idx)
for name in remove_var:
program.global_block()._remove_var(name)
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
self._add_push_box_sparse_op(main_program)
self._remove_optimizer_var(main_program)
self._remove_lookup_table_grad_op_and_var(main_program)
@register_pass("ps_transpile_pass")
class PsTranspilePass(PassBase):
def __init__(self):
super(PsTranspilePass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
t = SingleProcessMultiThread()
env = get_dist_env()
t.transpile(
startup_program=startup_program,
main_program=main_program,
rank=env["trainer_id"],
endpoints=env["trainer_endpoints"],
current_endpoint=env['current_endpoint'],
wait_port=False)
@register_pass("split_heter_worker_ops_pass")
class SplitHeterWorkerOpsPass(PassBase):
def __init__(self):
super(SplitHeterWorkerOpsPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _create_heter_program(self, program, attrs, heter_program,
program_block_ops_list, heter_ops,
block_var_detail):
# This function mainly includes the following contents:
# 1. For every heter block:
# a) copy heter device op from origin program
# b) create variables which belong to heter op:
# -> if variable is persistable, clone it in global_scope
# -> if variable is temp, create it in heter block
# c) create communicate related op as follow:
# joint_var.0_1 -> slice -> reshape -> origin_var
# origin_var -> origin_program
# reshape -> concat -> joint_var.1_2
# d) copy send op from origin program for var@grad which loacted in current heter block
# e) re-check every op in current blcok if its device is not current heter devie
# 2. Create send op for step counter in last heter-block
# 3. Create Listen&Serv OP and Send&Recv OP for distributed training
# 4. update CompileTimeStrategy for heter_program
optimizer_block = []
grad_to_block_id = []
send_grad_var_list = []
pre_block_idx = heter_program.num_blocks - 1
role_maker = attrs['role_maker']
current_device = role_maker._heter_device_type().lower()
stage_id = int(role_maker._get_stage_id())
heter_block_ops_forward = program_block_ops_list[stage_id - 1][
"forward"]
heter_block_ops_backward = program_block_ops_list[stage_id - 1][
"backward"]
heter_block = heter_program._create_block(pre_block_idx)
optimizer_block.append(heter_block)
for _, op in enumerate(heter_block_ops_forward):
block_append_op(heter_program, program, heter_block, op)
entrance_vars = block_var_detail[stage_id - 1]["forward"]["entrance"]
add_vars_by_var_list(entrance_vars, program, heter_program, heter_block)
exit_vars = block_var_detail[stage_id - 1]["forward"]["exit"]
add_vars_by_var_list(exit_vars, program, heter_program, heter_block)
first_op_index_fp = len(heter_block.ops)
if stage_id < len(program_block_ops_list):
heter_block_bp = heter_program._create_block(pre_block_idx)
optimizer_block.append(heter_block_bp)
for _, op in enumerate(heter_block_ops_backward):
block_append_op(heter_program, program, heter_block_bp, op)
bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][
"entrance"]
add_vars_by_var_list(bp_entrance_vars, program, heter_program,
heter_block_bp)
bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"]
add_vars_by_var_list(bp_exit_vars, program, heter_program,
heter_block_bp)
backward_comm_info = get_communicate_var_info(
program, stage_id, bp_entrance_vars, type="backward")
grad_to_block_id.append(backward_comm_info["block_input_var_name"] +
":" + str(heter_block_bp.idx))
else:
for _, op in enumerate(heter_block_ops_backward):
block_append_op(heter_program, program, heter_block, op)
bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][
"entrance"]
add_vars_by_var_list(bp_entrance_vars, program, heter_program,
heter_block)
bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"]
add_vars_by_var_list(bp_exit_vars, program, heter_program,
heter_block)
heter_block_bp = heter_block
forward_comm_info = get_communicate_var_info(
program, stage_id, entrance_vars, type="forward")
grad_to_block_id.append(forward_comm_info["block_input_var_name"] + ":"
+ str(heter_block.idx))
first_op_index_bp = len(heter_block_bp.ops)
if stage_id <= len(block_var_detail) - 1:
static_var = insert_communicate_op(program, role_maker, heter_block,
stage_id, first_op_index_fp,
block_var_detail, current_device)
static_var_bp = insert_communicate_op(
program, role_maker, heter_block_bp, stage_id, first_op_index_bp,
block_var_detail, current_device, False)
# add send op
send_grad_var_list = add_heter_send_op(program, heter_program,
heter_block_bp,
block_var_detail[stage_id - 1])
# add step conter
send_input_vars = []
dummy_output = []
pserver_endpoints = get_ps_endpoints(role_maker)
attrs = {
"message_to_block_id": grad_to_block_id,
"optimize_blocks": optimizer_block,
# runtime attribute
"endpoint": get_heter_worker_endpoint(role_maker),
"fanin": len(get_previous_stage_trainers(role_maker)),
"pserver_id": get_role_id(role_maker),
"distributed_mode": attrs['ps_mode'],
"rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}
# append the listen_and_serv op
heter_program.global_block().append_op(
type="heter_listen_and_serv",
inputs={'X': []},
outputs={},
attrs=attrs)
# TODO check heter program
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
"""
split heter worker program from origin-program
1. find heter op (located on different device)
2. find input&output of every heter-block
3. create heter worker program, add listen&serv op
"""
attrs = pass_ctx._attrs
default_deveice = "cpu"
program, heter_ops, _, program_block_ops = find_heter_ops(
main_program, default_deveice)
if len(heter_ops) == 0:
warnings.warn(
"Currently running in Heter Parameter Server mode, but no OP running on heterogeneous devices, Please check your code."
)
main_program = program
return
program_block_ops = union_forward_gradient_op(program_block_ops)
block_vars_detail = find_block_joints(program, program_block_ops,
heter_ops)
heter_program = framework.Program()
self._create_heter_program(program, attrs, heter_program,
program_block_ops, heter_ops,
block_vars_detail)
main_program = heter_program
@register_pass("split_trainer_ops_pass")
class SplitTrainerOpsPass(PassBase):
def __init__(self):
super(SplitTrainerOpsPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _create_trainer_program(self, program, origin_program, attrs,
program_block_ops_list, block_var_detail):
# This function mainly includes the following contents:
# 1. For every heter block in origin program
# a) delete heter op and related variables
# b) add send&recv op
# c) add communicate ops as follows:
# origin_var -> reshape -> concat -> joint_var.0_1
# send&recv op(send joint_var.0_1; recv joint_var.1_2)
# joint_var.1_2 -> slice -> reshape -> origin_var
# d) remove send op which related var@grad is not in trainer program
# 2. check every op's device
static_var = []
for heter_block_index in range(1, len(program_block_ops_list)):
ops_list = program_block_ops_list[heter_block_index][
"forward"] + program_block_ops_list[heter_block_index][
"backward"]
static_var += replace_ops_by_communicate_op(
program, attrs, heter_block_index, ops_list, block_var_detail)
remove_trainer_send_op(program, attrs, heter_block_index,
block_var_detail)
optimizer_block = []
grad_to_block_id = []
bp_ops_list = program_block_ops_list[0]["backward"]
delete_same_ops(program.global_block(), bp_ops_list)
delete_trainer_useless_var(attrs, program, static_var)
backward_block = create_backward_block(program, origin_program, attrs,
bp_ops_list, block_var_detail)
bp_entrance_vars = block_var_detail[0]["backward"]["entrance"]
backward_comm_info = get_communicate_var_info(
origin_program, 1, bp_entrance_vars, type="backward")
grad_to_block_id.append(backward_comm_info["block_input_var_name"] + ":"
+ str(backward_block.idx))
optimizer_block.append(backward_block)
role_maker = attrs['role_maker']
attrs = {
"message_to_block_id": grad_to_block_id,
"optimize_blocks": optimizer_block,
# runtime attribute
"endpoint":
get_trainer_endpoint(role_maker), ## get trainer endpoint
"fanin": 0, ## get heter worker
"pserver_id": get_role_id(role_maker),
"distributed_mode": attrs['ps_mode'],
"rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}
# append the listen_and_serv op
program.global_block()._insert_op(
index=0,
type="heter_listen_and_serv",
inputs={'X': []},
outputs={},
attrs=attrs)
## TODO add check for bp block
#check_op_device(program.global_block(), DEFAULT_DEVICE)
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
"""
split cpu-trainer program from origin-program
1. find heter op (located on different device)
2. find input&output of every heter-block
3. create cpu-trainer program, add send&recv op
"""
attrs = pass_ctx._attrs
default_device_ = 'cpu'
program, heter_ops, default_ops, program_block_ops = find_heter_ops(
main_program, default_device_)
program_block_ops = union_forward_gradient_op(program_block_ops)
block_vars_detail = find_block_joints(program, program_block_ops,
heter_ops)
trainer_program = program.clone()
self._create_trainer_program(trainer_program, program, attrs,
program_block_ops, block_vars_detail)
main_program = trainer_program
@register_pass("set_heter_pipeline_opt_pass")
class SetHeterPipelineOptPass(PassBase):
def __init__(self):
super(SetHeterPipelineOptPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
role_maker = attrs['role_maker']
num_microbatches = attrs['user_defined_strategy'].pipeline_configs[
'accumulate_steps']
attrs['origin_startup_program']._heter_pipeline_opt = {
"startup_program": startup_program,
"pipeline_stage": int(role_maker._get_stage_id()) - 1,
"heter_place": role_maker._heter_device(),
}
attrs['origin_main_program']._heter_pipeline_opt = {
"trainer": "HeterPipelineTrainer",
"device_worker": "HeterSection",
"trainers":
role_maker._get_stage_trainers(), ## trainer num in each stage
"trainer_id": int(role_maker._role_id()),
"pipeline_stage": int(role_maker._get_stage_id()) - 1,
"num_pipeline_stages": int(role_maker._get_num_stage()),
"section_program": main_program,
"num_microbatches": num_microbatches,
"heter_place": role_maker._heter_device(),
}
......@@ -11,3 +11,1349 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import os
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
from paddle.fluid import core
from .utils.public import *
from paddle.fluid.framework import Program
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.executor import Executor
from paddle.fluid.parallel_executor import ParallelExecutor
from paddle.fluid.framework import Variable, Parameter
from .runtime_base import RuntimeBase
from ..base.private_helper_function import wait_server_ready
from paddle.fluid.communicator import Communicator, HeterClient
from google.protobuf import text_format
__all__ = []
def conv_indent(indent):
return "".join([" "] * indent)
PSERVER_SAVE_SUFFIX = ".shard"
def parse_table_class(varname, o_main_program):
for op in o_main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue
param_name = op.input("W")[0]
if param_name == varname and op.type == "lookup_table" or op.type == "lookup_table_v2":
if op.has_attr('table_class') and op.attr("table_class") != "none":
return op.attr('table_class')
else:
return "MemorySparseTable"
def get_default_accessor_proto(accessor, varname, o_main_program):
embedding_dim = 0
for var in o_main_program.list_vars():
if var.name == varname:
embedding_dim = var.shape[1]
break
if not accessor.HasField("accessor_class"):
accessor.accessor_class = "CtrCommonAccessor"
if not accessor.HasField("fea_dim"):
accessor.fea_dim = embedding_dim + 2
if not accessor.HasField("embedx_dim"):
accessor.embedx_dim = embedding_dim - 1
if not accessor.HasField("embedx_threshold"):
accessor.embedx_threshold = 0
ctr_accessor_param = accessor.ctr_accessor_param
if not ctr_accessor_param.HasField("nonclk_coeff"):
ctr_accessor_param.nonclk_coeff = 0.1
if not ctr_accessor_param.HasField("click_coeff"):
ctr_accessor_param.click_coeff = 1.0
if not ctr_accessor_param.HasField("base_threshold"):
ctr_accessor_param.base_threshold = 0
if not ctr_accessor_param.HasField("delta_threshold"):
ctr_accessor_param.delta_threshold = 0
if not ctr_accessor_param.HasField("delta_keep_days"):
ctr_accessor_param.delta_keep_days = 16
if not ctr_accessor_param.HasField("show_click_decay_rate"):
ctr_accessor_param.show_click_decay_rate = 1
if not ctr_accessor_param.HasField("delete_threshold"):
ctr_accessor_param.delete_threshold = 0
if not ctr_accessor_param.HasField("delete_after_unseen_days"):
ctr_accessor_param.delete_after_unseen_days = 30
if not ctr_accessor_param.HasField("ssd_unseenday_threshold"):
ctr_accessor_param.ssd_unseenday_threshold = 1
for sgd_param in [accessor.embed_sgd_param, accessor.embedx_sgd_param]:
if not sgd_param.HasField("name"):
sgd_param.name = "SparseAdaGradSGDRule"
if sgd_param.name == "SparseAdaGradSGDRule" or sgd_param.name == "StdAdaGradSGDRule":
if not sgd_param.adagrad.HasField("learning_rate"):
sgd_param.adagrad.learning_rate = 0.05
if not sgd_param.adagrad.HasField("initial_g2sum"):
sgd_param.adagrad.initial_g2sum = 3.0
if not sgd_param.adagrad.HasField("initial_range"):
sgd_param.adagrad.initial_range = 0.0001
if len(sgd_param.adagrad.weight_bounds) == 0:
sgd_param.adagrad.weight_bounds.extend([-10.0, 10.0])
if sgd_param.name == "SparseNaiveSGDRule":
if not sgd_param.naive.HasField("learning_rate"):
sgd_param.naive.learning_rate = 0.05
if not sgd_param.naive.HasField("initial_range"):
sgd_param.naive.initial_range = 0.0001
if len(sgd_param.naive.weight_bounds) == 0:
sgd_param.naive.weight_bounds.extend([-10.0, 10.0])
if sgd_param.name == "SparseAdamSGDRule":
if not sgd_param.adam.HasField("learning_rate"):
sgd_param.adam.learning_rate = 0.001
if not sgd_param.adam.HasField("initial_range"):
sgd_param.adam.initial_range = 0.0001
if not sgd_param.adam.HasField("beta1_decay_rate"):
sgd_param.adam.beta1_decay_rate = 0.9
if not sgd_param.adam.HasField("beta2_decay_rate"):
sgd_param.adam.beta2_decay_rate = 0.999
if not sgd_param.adam.HasField("ada_epsilon"):
sgd_param.adam.ada_epsilon = 1e-08
if len(sgd_param.adam.weight_bounds) == 0:
sgd_param.adam.weight_bounds.extend([-10.0, 10.0])
def check_embedding_dim(accessor, varname, o_main_program):
embedding_dim = 0
for var in o_main_program.list_vars():
if var.name == varname:
embedding_dim = var.shape[1]
break
fea_dim = accessor.fea_dim
if fea_dim != embedding_dim + 2:
raise ValueError(
"The fea_dim is wrong, it will be sparse_embedding_dim + 2: {}, but got {}".
format(embedding_dim + 2, fea_dim))
embedx_dim = accessor.embedx_dim
if embedx_dim != embedding_dim - 1:
raise ValueError(
"The embedx_dim is wrong, it will be sparse_embedding_dim - 1: {}, but got {}".
format(embedding_dim - 1, embedx_dim))
class Accessor:
def __init__(self):
self.accessor_class = ""
self.optimizer = None
self.feature_dim = -1
self.embedding_dim = -1
self.optimizer = None
def to_string(self, indent):
accessor_str = "{}accessor {{{}\n{}}}"
attrs = ""
attrs += "accessor_class: \"{}\" ".format(self.accessor_class)
attrs += "fea_dim: {} ".format(self.feature_dim)
attrs += "embedx_dim: {} ".format(self.embedding_dim)
attrs += "\n"
if self.optimizer is not None:
attrs += self.optimizer.to_string(indent)
return accessor_str.format(
conv_indent(indent), attrs, conv_indent(indent))
class CommonAccessor:
def __init__(self):
self.accessor_class = ""
self.table_name = None
self.entry = None
self.attrs = []
self.params = []
self.dims = []
self.trainer_num = 0
self.sync = "false"
self.initializers = []
self.opt_input_map = {}
self.opt_attr_map = {}
self.opt_init_map = {}
self.define_optimize_map()
def define_optimize_map(self):
opt_input_map = {}
opt_input_map["sgd"] = [("Param", None), ("LearningRate", 1)]
opt_input_map["adam"] = [("Param", None), ("Moment1", None),
("Moment2", None), ("Beta1Pow", 1),
("Beta2Pow", 1), ("LearningRate", 1)]
opt_input_map["adam_d2sum"] = [
("Param", None), ("D2Sum", None), ("G2Sum", None), ("Moment", None),
("MomentDecayRate", 1), ("AdaDecayRate", 1), ("AdaEpsilon", 1),
("LearningRate", 1)
]
opt_input_map["sum"] = [("Param", None)]
opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1),
("LearningRate", 1)]
opt_attr_map = {}
opt_attr_map["sgd"] = []
opt_attr_map["sum"] = []
opt_attr_map["naive_adagrad"] = []
opt_attr_map["adam"] = [("beta1", "f"), ("beta2", "f"),
("epsilon", "f")]
opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"),
("epsilon", "f")]
opt_init_map = {}
opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
opt_init_map["fill_constant"] = ["value"]
opt_init_map["uniform_random"] = ["seed", "min", "max"]
opt_init_map["truncated_gaussian_random"] = ["seed", "mean", "std"]
self.opt_attr_map = opt_attr_map
self.opt_input_map = opt_input_map
self.opt_init_map = opt_init_map
def parse_entry(self, varname, o_main_program):
for op in o_main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue
param_name = op.input("W")[0]
if param_name == varname and op.type == "lookup_table":
self.entry = op.attr('entry')
break
if param_name == varname and op.type == "lookup_table_v2":
self.entry = "none"
break
def get_shard(self, total_dim, shard_num, pserver_id):
blocksize = int(total_dim / shard_num + 1)
if blocksize * (pserver_id + 1) <= total_dim:
return blocksize
else:
if blocksize * pserver_id < total_dim:
return total_dim - blocksize * pserver_id
else:
return 0
def get_initializer_attr(self, value_name, o_startup_program):
l_in = "&"
attr_str = ""
origin_var_name = value_name
for op in o_startup_program.global_block().ops:
if op.type in self.opt_init_map.keys(
) and origin_var_name == op.output("Out")[0]:
init_attr = [op.type]
for attr in self.opt_init_map[op.type]:
init_attr.append(str(op.attr(attr)))
attr_str = l_in.join(init_attr)
break
return attr_str
def parse_by_optimizer(self, grad_name, is_sparse, total_dims, context,
adam_d2sum):
main_program = context['origin_main_program']
startup_program = context['startup_main_program']
pserver_id = get_role_id(context['role_maker'])
pserver_num = len(get_ps_endpoints(context['role_maker']))
optimizer_ops = get_optimize_ops(main_program)
oop = None
for op in optimizer_ops:
if ("Param" in op.input_names) and (
op.input("Param")[0] ==
context['grad_name_to_param_name'][grad_name]):
oop = op
break
if oop is None:
raise ValueError("can not find optimizer for {}".format(grad_name))
params = []
dims = []
attrs = []
initializers = []
self.trainer_num = get_trainers(context['role_maker'])
if oop.type != 'adam' and adam_d2sum == True:
print('optimization algorithm is not adam, set adam_d2sum False')
adam_d2sum = False
print("adam_d2sum:", adam_d2sum)
if context['ps_mode'] == DistributedMode.GEO:
param_varnames = self.opt_input_map["sum"]
attr_varnames = self.opt_attr_map["sum"]
self.accessor_class = "sum"
elif context['use_ps_gpu'] and is_sparse:
param_varnames = self.opt_input_map["naive_adagrad"]
attr_varnames = self.opt_attr_map["naive_adagrad"]
self.accessor_class = "sgd"
elif adam_d2sum:
param_varnames = self.opt_input_map["adam_d2sum"]
attr_varnames = self.opt_attr_map["adam_d2sum"]
self.accessor_class = "adam_d2sum"
else:
param_varnames = self.opt_input_map[oop.type]
attr_varnames = self.opt_attr_map[oop.type]
self.accessor_class = oop.type
for (formal_name, shape) in param_varnames:
params.append(formal_name)
if self.accessor_class == "adam_d2sum":
#for dims
if shape is None:
if is_sparse:
shape = total_dims
else:
shape = self.get_shard(total_dims, pserver_num,
pserver_id)
dims.append(shape)
#for initializers
if formal_name == "Param" or formal_name == "LearningRate":
param = main_program.global_block().vars[oop.input(
formal_name)[0]]
#TODO: for dense learning_rate, can be different from sparse lr
if formal_name == "LearningRate" and param.name != "learning_rate_0":
warnings.warn("will support decay soon")
param = main_program.global_block().vars[
"learning_rate_0"]
initializer = self.get_initializer_attr(param.name,
startup_program)
elif formal_name == "MomentDecayRate":
initializer = "fill_constant&0.99"
elif formal_name == "AdaDecayRate":
initializer = "fill_constant&0.9999"
elif formal_name == "AdaEpsilon":
initializer = "fill_constant&1.0e-8"
else:
initializer = "fill_constant&0"
initializers.append(initializer)
else:
if formal_name == "G2Sum":
dims.append(1)
initializer = "fill_constant&0"
initializers.append(initializer)
else:
param = main_program.global_block().vars[oop.input(
formal_name)[0]]
if formal_name == "LearningRate" and param.name != "learning_rate_0":
warnings.warn("will support decay soon")
param = main_program.global_block().vars[
"learning_rate_0"]
if shape is None:
if is_sparse:
shape = total_dims
else:
shape = self.get_shard(total_dims, pserver_num,
pserver_id)
dims.append(shape)
initializer = self.get_initializer_attr(param.name,
startup_program)
initializers.append(initializer)
for (attr_varname, type_) in attr_varnames:
value = oop.attr(attr_varname)
attrs.append("&".join([attr_varname, type_, str(value)]))
self.params = params
self.dims = dims
self.initializers = initializers
self.attrs = attrs
def to_string(self, indent):
accessor_str = "{}common {{{}\n{}}}"
attrs = ""
attrs += "name: \"{}\" ".format(self.accessor_class)
if self.table_name:
attrs += "table_name: \"{}\" ".format(self.table_name)
if self.entry:
attrs += "entry: \"{}\" ".format(self.entry)
attrs += "trainer_num: {} ".format(self.trainer_num)
attrs += "sync: {} ".format(self.sync)
for param in self.params:
attrs += "params: \"{}\" ".format(param)
for dim in self.dims:
attrs += "dims: {} ".format(dim)
for initializer in self.initializers:
attrs += "initializers: \"{}\" ".format(initializer)
attrs += "\n"
return accessor_str.format(
conv_indent(indent), attrs, conv_indent(indent))
class Tensor:
def __init__(self):
self.main_program_id = None
self.startup_program_id = None
self.feed_var_name = None
self.fetch_var_name = None
self.tensor_table_class = False
def to_string(self, indent):
program_str = "{}tensor {{{}\n{}}}"
attrs = ""
attrs += "feed_var_name: \"{}\" ".format(str(self.feed_var_name))
attrs += "fetch_var_name: \"{}\" ".format(str(self.fetch_var_name))
attrs += "startup_program_id: {} ".format(str(self.startup_program_id))
attrs += "main_program_id: {} ".format(str(self.main_program_id))
attrs += "tensor_table_class: \"{}\" ".format(
str(self.tensor_table_class))
attrs += "\n"
return program_str.format(
conv_indent(indent), attrs, conv_indent(indent))
class Table:
def __init__(self):
self.id = -1
self.table_class = None
self.shard_num = -1
self.type = None
self.accessor = None
self.common = None
self.tensor = None
self.accessor_proto = None
def to_string(self, indent):
# if self.id == 1:
# proto_txt = ''
# with open('./sparse_table.prototxt') as f:
# proto_txt = f.read()
# return proto_txt
table_str = "{}downpour_table_param {{{}\n{}}}"
attrs = ""
attrs += "table_id: {} ".format(self.id)
attrs += "table_class: \"{}\" ".format(self.table_class)
attrs += "shard_num: {} ".format(self.shard_num)
attrs += "type: {}".format(self.type)
attrs += "\n"
indent += 2
if self.accessor_proto is not None:
accessor_str = "{}accessor {{{}\n{}}}"
accessor_str = accessor_str.format(
conv_indent(indent), self.accessor_proto, conv_indent(indent))
attrs += accessor_str + "\n"
return table_str.format(
conv_indent(indent), attrs, conv_indent(indent))
if self.accessor is not None:
attrs += self.accessor.to_string(indent)
attrs += "\n"
if self.tensor is not None:
attrs += self.tensor.to_string(indent)
attrs += "\n"
if self.common is not None:
attrs += self.common.to_string(indent)
attrs += "\n"
return table_str.format(conv_indent(indent), attrs, conv_indent(indent))
class Service:
def __init__(self):
self.server_class = "BrpcPsServer"
self.client_class = "BrpcPsClient"
self.service_class = "BrpcPsService"
self.start_server_port = 0
self.server_thread_num = 12
def to_string(self, indent):
service_str = "{}service_param {{{}\n{}}}"
attrs = ""
attrs += "server_class: \"{}\" ".format(self.server_class)
attrs += "client_class: \"{}\" ".format(self.client_class)
attrs += "service_class: \"{}\" ".format(self.service_class)
attrs += "start_server_port: {} ".format(self.start_server_port)
attrs += "server_thread_num: {} ".format(self.server_thread_num)
return service_str.format(
conv_indent(indent), attrs, conv_indent(indent))
class DownpourServer:
def __init__(self):
self.service = None
self.tables = []
def set_service_param(self, service):
self.service = service
def append_tables(self, table):
if not isinstance(table, Table):
raise ValueError("only support instance Table")
self.tables.append(table)
def to_string(self, indent):
server_str = "{}downpour_server_param {{{}\n{}}}"
table_strs = ""
indent += 2
table_strs += "\n"
table_strs += self.service.to_string(indent)
for table in self.tables:
table_strs += "\n"
table_strs += table.to_string(indent)
return server_str.format(
conv_indent(indent), table_strs, conv_indent(indent))
class Server:
def __init__(self):
self.servers = []
def add_server(self, server):
if not isinstance(server, DownpourServer):
raise ValueError("only support instance DownpourServer")
self.servers.append(server)
def __str__(self):
server_str = "server_param {{{}\n}}"
indent = 2
servers_str = ""
for server in self.servers:
servers_str += "\n"
servers_str += server.to_string(indent)
return server_str.format(servers_str)
class DownpourWorker:
def __init__(self):
self.tables = []
def append_tables(self, table):
if not isinstance(table, Table):
raise ValueError("only support instance Table")
self.tables.append(table)
def to_string(self, indent):
worker_str = "{}downpour_worker_param {{{}\n{}}}"
table_strs = ""
indent += 2
for table in self.tables:
table_strs += "\n"
table_strs += table.to_string(indent)
return worker_str.format(
conv_indent(indent), table_strs, conv_indent(indent))
class Worker:
def __init__(self):
self.workers = []
def add_worker(self, worker):
if not isinstance(worker, DownpourWorker):
raise ValueError("only support instance DownpourWorker")
self.workers.append(worker)
def __str__(self):
worker_str = "worker_param {{{}\n}}"
indent = 2
workers_str = ""
for worker in self.workers:
workers_str += "\n"
workers_str += worker.to_string(indent)
return worker_str.format(workers_str)
class fsClient:
def __init__(self, proto):
self.proto = proto
self.uri = proto.uri
self.user = proto.user
self.passwd = proto.passwd
self.hadoop_bin = proto.hadoop_bin
def to_string(self):
proto_txt = text_format.MessageToString(self.proto)
if proto_txt:
fs_str = "fs_client_param {{\n{}}}"
return fs_str.format(proto_txt)
else:
return ""
class TheOnePSRuntime(RuntimeBase):
def __init__(self):
super(TheOnePSRuntime, self).__init__()
self._communicator = None
self._server = None
self._worker = fluid.core.DistFleetWrapper()
self._server_sub_program = []
self._heter_client = None
def _set_basic_info(self, context):
self.context = context
self.role_maker = context["role_maker"]
self.origin_main_program = context["origin_main_program"]
self.context[
'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode
self.is_heter_ps_mode = self.context['is_heter_ps_mode']
self.context['trainer'] = TrainerRuntimeConfig(context[
'valid_strategy'])
self.context['ps_mode'] = self.context['trainer'].mode
self.context['use_ps_gpu'] = context['valid_strategy'].use_ps_gpu
self.is_sync = True if self.context[
'ps_mode'] == DistributedMode.SYNC else False
self.context['grad_name_to_param_name'] = {}
def _init_worker(self):
worker = self._get_fleet_proto(is_server=False, is_sync=self.is_sync)
server = self._get_fleet_proto(is_server=True, is_sync=self.is_sync)
if self.context['use_ps_gpu']:
main_program = self.context['loss'].block.program
if not main_program._fleet_opt:
main_program._fleet_opt = {}
main_program._fleet_opt["use_ps_gpu"] = True
gpus_env = os.getenv("FLAGS_selected_gpus")
main_program._fleet_opt[
"worker_places"] = [int(s) for s in gpus_env.split(",")]
def sync_strategy_envs():
kwargs = {}
kwargs[
"pserver_endpoints"] = self.role_maker._get_pserver_endpoints()
kwargs["trainer_id"] = self.role_maker._worker_index()
return kwargs
proto_txt = str(worker) + "\n" + str(server)
with open('proto_txt', 'w') as f:
f.write(proto_txt)
debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
if debug:
print("worker: \n{}".format(proto_txt))
endpoints = get_ps_endpoints(self.role_maker)
string_hosts = []
for idx, ep in enumerate(endpoints):
host, port = ep.split(":")
pshost = fluid.core.PSHost(host, int(port), idx)
string_hosts.append(pshost.serialize_to_string())
dense_map = get_the_one_recv_context(
self.context, split_dense_table=self.is_heter_ps_mode)
send_ctx = get_the_one_send_context(
self.context,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=self.is_heter_ps_mode,
ep_list=endpoints)
trainer_config = self.context['trainer']
debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
if debug:
print("worker: \n{}".format(proto_txt))
print("communicator send_ctx:")
for key in send_ctx:
print("{}: {}".format(key, send_ctx[key]))
for key in dense_map:
print("{}: {}".format(key, dense_map[key]))
kwargs = {}
kwargs['need_global_step'] = "0"
kwargs["trainer_id"] = self.role_maker._role_id()
kwargs["trainers"] = self.role_maker._worker_num()
for table in server.servers[0].tables:
if table.table_class == "BarrierTable":
kwargs["barrier_table_id"] = table.id
break
if self.context['ps_mode'] == DistributedMode.SYNC:
sync_kwargs = sync_strategy_envs()
kwargs.update(sync_kwargs)
self._communicator = Communicator(
trainer_config.mode, kwargs,
trainer_config.get_communicator_flags())
self._communicator.init_with_ctx(send_ctx, dense_map, proto_txt,
string_hosts, fluid.global_scope())
fleet.util.barrier()
info = self._communicator.get_client_info()
if isinstance(info, list) and len(info) > 0:
all_info = self.role_maker._all_gather(info[0])
# for unittest
if not isinstance(all_info, list):
warnings.warn("gloo may not initialize correctly")
all_info = [all_info]
self._communicator.set_clients(all_info)
self._communicator.create_client_to_client_connection()
print('create c2c connection done')
else:
print('cannot create c2c connection')
dist_strategy = self.context["valid_strategy"]
is_test = bool(int(os.getenv("TEST_MODE", "0")))
if self.role_maker._is_first_worker() and self.is_heter_ps_mode:
# for ps-heter mode load all parameters on first_worker
init_params = get_the_one_recv_context(
self.context, split_dense_table=True, use_origin_program=True)
else:
init_params = dense_map
if not is_test:
self._communicator.init_params(init_params)
fleet.util.barrier()
self._communicator.pull_dense(init_params)
fleet.util.barrier()
if not self._communicator.is_running():
self._communicator.start()
else:
warnings.warn("communicator has been initialized, skip")
launch_barrier = dist_strategy.a_sync_configs["launch_barrier"]
launch_barrier_flag = int(os.getenv("FLAGS_LAUNCH_BARRIER", "1"))
if launch_barrier and launch_barrier_flag:
# for trainer wait server ready
wait_server_ready(self.role_maker._get_pserver_endpoints())
if self.is_heter_ps_mode and self.role_maker._get_next_trainers(
) != []:
wait_server_ready(self.role_maker._get_next_trainers())
if self.is_heter_ps_mode:
previous_trainers = []
if self.role_maker._get_previous_trainers() != []:
previous_trainers = self.role_maker._get_previous_trainers()
next_trainers = []
if self.role_maker._get_next_trainers() != []:
next_trainers = self.role_maker._get_next_trainers()
self._heter_client = HeterClient(next_trainers,
previous_trainers,
self.role_maker._role_id())
def _push_sparse_param(self,
var_name,
table_id=-1,
scope=fluid.global_scope()):
self._communicator.push_sparse_param(var_name, table_id, scope)
def _get_executor(self):
executor = fluid.Executor(fluid.CPUPlace())
if self.is_heter_ps_mode:
if self.role_maker._is_heter_worker():
heter_device_type = self.role_maker._heter_device_type().upper()
if heter_device_type not in ["GPU", "XPU", "CPU"]:
raise ValueError("Heter Worker Not Support Device {}".
format(device_type))
if heter_device_type == "GPU":
executor = Executor(
fluid.CUDAPlace(
int(os.getenv("FLAGS_selected_gpus", "0"))))
elif heter_device_type == "XPU":
executor = Executor(
fluid.XPUPlace(
int(os.getenv("FLAGS_selected_xpus", "0"))))
return executor
def _get_fleet_proto(self, is_server, is_sync, **kwargs):
def _build_merge_accessor(ctx):
accessor = Accessor()
accessor.accessor_class = "CommMergeAccessor"
accessor.optimizer = None
if ctx.is_sparse():
accessor.feature_dim = ctx.sections()[0]
accessor.embedding_dim = ctx.sections()[1]
else:
accessor.feature_dim = ctx.sections()[0]
accessor.embedding_dim = 1
return accessor
def _build_barrier_table(idx):
table = Table()
table.id = idx
table.type = "PS_OTHER_TABLE"
table.table_class = "BarrierTable"
table.shard_num = 256
accessor = Accessor()
accessor.accessor_class = "CommMergeAccessor"
accessor.optimizer = None
accessor.feature_dim = 0
accessor.embedding_dim = 0
table.accessor = accessor
common = CommonAccessor()
common.table_name = "barrier_table"
trainer_num = get_trainers(self.context['role_maker'])
if self.is_heter_ps_mode:
trainer_num += len(self.role_maker._get_heter_worker_endpoints(
))
common.trainer_num = trainer_num
common.attrs = ""
common.dims = []
common.params = []
table.common = common
return table
def _build_tensor_table(idx, tensor_dict):
table = Table()
table.id = idx
table.type = "PS_OTHER_TABLE"
table.table_class = tensor_dict["tensor_table_class"]
table.shard_num = 256
accessor = Accessor()
accessor.accessor_class = "CommMergeAccessor"
accessor.optimizer = None
accessor.feature_dim = 0
accessor.embedding_dim = 0
table.accessor = accessor
common = CommonAccessor()
common.table_name = tensor_dict["feed_var_name"]
common.trainer_num = get_trainers(self.role_maker)
common.attrs = ""
common.dims = []
common.params = []
table.common = common
tensor = Tensor()
tensor.main_program_id = tensor_dict["main_program_id"]
tensor.startup_program_id = tensor_dict["startup_program_id"]
tensor.feed_var_name = tensor_dict["feed_var_name"]
tensor.fetch_var_name = tensor_dict["fetch_var_name"]
tensor.tensor_table_class = tensor_dict["tensor_table_class"]
table.tensor = tensor
return table
def _add_tensor_table(tables):
tensor_table_dict = {}
program_idx = 0
for table_name in tensor_table_dict:
if tensor_table_dict[table_name]["startup_program"] != None:
tensor_table_dict[table_name][
"startup_program_id"] = program_idx
self._server_sub_program.append(tensor_table_dict[
table_name]["startup_program"].desc)
program_idx += 1
if tensor_table_dict[table_name]["main_program"] != None:
tensor_table_dict[table_name][
"main_program_id"] = program_idx
self._server_sub_program.append(tensor_table_dict[
table_name]["main_program"].desc)
program_idx += 1
# Todo: Hard code for lr_decay table apply table id
new_table = _build_tensor_table(
len(tables), tensor_table_dict[table_name])
tables.append(new_table)
return tables
def _get_tables():
send_ctx = get_the_one_send_context(
self.context,
use_origin_program=True,
split_dense_table=self.is_heter_ps_mode)
tables = []
for idx, (name, ctx) in enumerate(send_ctx.items()):
print(" wxm python test send_ctx.items-->", idx, (name, ctx))
if ctx.is_tensor_table() or len(ctx.origin_varnames()) < 1:
continue
table = Table()
table.id = ctx.table_id()
common = CommonAccessor()
if ctx.is_sparse():
table.type = "PS_SPARSE_TABLE"
table.shard_num = 256
common.table_name = self.context['grad_name_to_param_name'][
ctx.origin_varnames()[0]]
if self.ps_mode == DistributedMode.GEO:
table.table_class = "SparseGeoTable"
else:
all_table_proto = self.context[
"user_defined_strategy"].sparse_table_configs
table_proto = all_table_proto.add()
for proto in all_table_proto:
if proto.table_name == common.table_name:
table_proto = proto
break
if table_proto.HasField("table_class"):
table.table_class = table_proto.table_class
else:
table.table_class = parse_table_class(
common.table_name, self.origin_main_program)
if table.table_class != 'MemorySparseTable':
table.table_class = 'MemorySparseTable'
warnings.warn(
"The PS mode must use MemorySparseTable.")
if table_proto.HasField("shard_num"):
table.shard_num = table_proto.shard_num
else:
table.shard_num = 1000
warnings.warn(
"The shard_num of sparse table is not set, use default value 1000."
)
if table_proto.accessor.ByteSize() == 0:
warnings.warn(
"The accessor of sparse table is not set, use default value."
)
get_default_accessor_proto(table_proto.accessor,
common.table_name,
self.origin_main_program)
check_embedding_dim(table_proto.accessor,
common.table_name,
self.origin_main_program)
table.accessor_proto = text_format.MessageToString(
table_proto.accessor)
else:
table.type = "PS_DENSE_TABLE"
table.table_class = "CommonDenseTable"
table.shard_num = 256
common.table_name = "MergedDense"
adam_d2sum = self.context["user_defined_strategy"].adam_d2sum
common.parse_by_optimizer(ctx.origin_varnames()[0],
ctx.is_sparse(),
ctx.sections()[1] if ctx.is_sparse()
else ctx.sections()[0], self.context,
adam_d2sum)
if ctx.is_sparse():
common.parse_entry(common.table_name,
self.origin_main_program)
if is_sync:
common.sync = "true"
else:
common.sync = "false"
table.common = common
if table.table_class != 'MemorySparseTable':
accessor = _build_merge_accessor(ctx)
table.accessor = accessor
tables.append(table)
tensor_table_dict = {}
if len(tensor_table_dict) > 0:
tables = _add_tensor_table(tables)
else:
empty_porgram = Program()
self._server_sub_program.append(empty_porgram.desc)
barrier_table = _build_barrier_table(len(tables))
tables.append(barrier_table)
return tables
if is_server:
server = Server()
downpour_server = DownpourServer()
service = Service()
dist_strategy = self.context["valid_strategy"]
use_ps_gpu = dist_strategy.a_sync_configs["use_ps_gpu"]
if use_ps_gpu:
service.server_class = "PsLocalServer"
service.client_class = "PsLocalClient"
downpour_server.set_service_param(service)
tables = _get_tables()
downpour_server.tables = tables
server.add_server(downpour_server)
return server
else:
worker = Worker()
downpour_worker = DownpourWorker()
tables = _get_tables()
downpour_worker.tables = tables
worker.add_worker(downpour_worker)
return worker
def _init_server(self, dirname=None, var_names=None, **kwargs):
role_id = get_role_id(self.role_maker)
endpoints = get_ps_endpoints(self.role_maker)
trainers = get_trainers(self.role_maker)
if self.is_heter_ps_mode:
trainers += len(self.role_maker._get_heter_worker_endpoints())
server = self._get_fleet_proto(is_server=True, is_sync=self.is_sync)
proto_txt = str(server)
fs_client = fsClient(self.context["user_defined_strategy"]
.fs_client_param)
proto_txt = proto_txt + "\n" + fs_client.to_string()
debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
if debug:
print("server: \n{}".format(proto_txt))
string_hosts = []
for idx, ep in enumerate(endpoints):
host, port = ep.split(":")
pshost = fluid.core.PSHost(host, int(port), idx)
string_hosts.append(pshost.serialize_to_string())
self._server = fluid.core.DistFleetWrapper()
self._server.init_server(proto_txt, string_hosts, role_id, trainers,
self._server_sub_program)
dist_varnames = get_sparse_tablenames(self.origin_main_program, True)
sparse_varnames = get_sparse_tablenames(self.origin_main_program, False)
distributed_varnames = dist_varnames + sparse_varnames
if var_names is None:
load_varnames = distributed_varnames
else:
for var_name in var_names:
if var_name not in distributed_varnames:
raise ValueError(
"fleet.init server can only load sparse variables in {}".
format(distributed_varnames))
load_varnames = var_names
if dirname is None or not load_varnames:
return
sparse_table_maps = {}
for table in server.servers[0].tables:
if table.type == "PS_SPARSE_TABLE" and table.common is not None:
sparse_table_maps[table.common.table_name] = table.id
dirname = os.path.normpath(dirname)
pserver_id = self.role_maker._role_id()
for var_name in load_varnames:
table_id = sparse_table_maps[var_name]
self._server.load_sparse(dirname, "0", table_id)
def _run_server(self):
ep = get_ps_endpoint(self.role_maker)
host, port = ep.split(":")
self._server.run_server(host, int(port))
def _stop_worker(self):
self._communicator.stop()
if self.is_heter_ps_mode:
assert self._heter_client != None, "heter client should not be None in heterps mode"
self._heter_client.stop()
@staticmethod
def __exclude_vars(exclude_var_names=[]):
def is_valid(var):
if var.name in exclude_var_names:
return False
origin_varname, _, _ = _get_varname_parts(var.name)
if origin_varname.endswith("@GRAD"):
return False
if origin_varname == "learning_rate_0":
return False
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.READER:
return False
return var.persistable
return is_valid
def _save_sparse_params(self, executor, dirname, context, main_program,
mode):
distributed_varnames = get_sparse_tablenames(
self.context['origin_main_program'], True)
values = []
for id, names in context.items():
if names[0] not in distributed_varnames:
# only save sparse param to local
try:
self._worker.recv_and_save_model(id, dirname)
except:
pass
# save sparse & distributed param on server
self._worker.save_one_model(id, dirname, mode)
values.extend(names)
# self._worker.save_all_model(dirname, mode)
return values
def _save_distributed_persistables(self,
executor,
dirname,
main_program,
mode=0):
denses = get_the_one_recv_context(
self.context,
is_dense=True,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=True)
sparses = get_the_one_recv_context(
self.context,
is_dense=False,
split_dense_table=self.is_heter_ps_mod,
use_origin_program=True)
sparse_varnames = self._save_sparse_params(executor, dirname, sparses,
main_program, mode)
recv_dense_varnames = []
for id, names in denses.items():
recv_dense_varnames.extend(names)
self._communicator.pull_dense(denses)
saved_varnames = sparse_varnames
remaining_vars = list(
filter(
TheOnePSRuntime.__exclude_vars(saved_varnames),
main_program.list_vars()))
import paddle
for var in remaining_vars:
# if var.name not in recv_dense_varnames:
# continue
tensor = var.get_value()
paddle.save(
tensor, os.path.join(dirname, var.name), use_binary_format=True)
def _ps_inference_save_persistables(self,
executor,
dirname,
main_program=None,
mode=0,
**kwargs):
"""
This function filters out all variables with `persistable==True` from the
give `main_program` and then saves these variables to the folder `dirname`
or file `filename`.
The `dirname` is used to specify the folder where persistable variables
are going to be saved. If you would like to save variables in separate
files, set `filename` None; if you would like to save all variables in a
single file, use `filename` to specify the file name.
"""
if isinstance(executor, ParallelExecutor):
raise TypeError(
"in fleet.save() function, executor must be as Executor type, ParallelExecutor is not allowed"
)
if not isinstance(executor, Executor):
raise TypeError(
"in fleet.save() function, executor must be as Executor type")
if main_program is None:
main_program = self.context['origin_ps_main_program']
if isinstance(main_program, CompiledProgram):
raise TypeError(
"in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
)
# Todo(MrChengmo): Save optimizer status
# self._save_distributed_persistables(executor, dirname, main_program,
# mode)
self._worker.save_all_model(dirname, mode)
def _ps_inference_save_inference_model(self,
executor,
dirname,
feeded_var_names,
target_vars,
main_program=None,
export_for_deployment=True,
mode=0):
"""
Prune the given `main_program` to build a new program especially for inference,
and then save it and all related parameters to given `dirname` by the `executor`.
"""
if isinstance(executor, ParallelExecutor):
raise TypeError(
"in fleet.save() function, executor must be as Executor type, ParallelExecutor is not allowed"
)
if not isinstance(executor, Executor):
raise TypeError(
"in fleet.save() function, executor must be as Executor type")
import paddle
program = self.origin_main_program if main_program is None else main_program
if isinstance(program, CompiledProgram):
raise TypeError(
"in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
)
feed_vars = [
program.global_block().var(name) for name in feeded_var_names
]
infer_program = paddle.static.normalize_program(program, feed_vars,
target_vars)
infer_program._copy_dist_param_info_from(program)
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
model_basename = "__model__"
model_basename = os.path.join(model_path, model_basename)
paddle.save(infer_program, model_basename)
sparses = get_the_one_recv_context(
self.context,
is_dense=False,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=True)
sparse_names = self._save_sparse_params(executor, dirname, sparses,
main_program, mode)
denses = get_the_one_recv_context(
self.context,
is_dense=True,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=True)
self._communicator.pull_dense(denses)
generate_vars = self.context[
"user_defined_strategy"].trainer_desc_configs["stat_var_names"]
generate_vars = [var for var in generate_vars]
remaining_vars = list(
filter(
TheOnePSRuntime.__exclude_vars(sparse_names),
infer_program.list_vars()))
for var in remaining_vars:
tensor = var.get_value()
paddle.save(
tensor,
os.path.join(model_path, var.name),
use_binary_format=True)
def _save_inference_model(self, *args, **kwargs):
self._ps_inference_save_inference_model(*args, **kwargs)
def _save_persistables(self, *args, **kwargs):
self._ps_inference_save_persistables(*args, **kwargs)
def _load_sparse_params(self, dirname, context, main_program, mode):
distributed_varnames = get_sparse_tablenames(self.origin_main_program,
True)
values = []
for id, names in context.items():
if names[0] not in distributed_varnames:
# TODO: only load sparse param from local
warnings.warn("varname is not in distributed_varnames, pass")
# load sparse & distributed param on server
self._worker.load_one_table(id, dirname, mode)
values.extend(names)
return values
def _ps_inference_load_inference_model(self,
dirname,
mode=0,
main_program=None):
if main_program is None:
main_program = self.origin_main_program
if isinstance(main_program, CompiledProgram):
raise TypeError(
"in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
)
denses = get_the_one_recv_context(
self.context,
is_dense=True,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=True)
sparses = get_the_one_recv_context(
self.context,
is_dense=False,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=True)
sparse_varnames = self._load_sparse_params(dirname, sparses,
main_program, mode)
recv_dense_varnames = []
for id, names in denses.items():
recv_dense_varnames.extend(names)
loaded_varnames = sparse_varnames
remaining_vars = list(
filter(
TheOnePSRuntime.__exclude_vars(loaded_varnames),
main_program.list_vars()))
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
import paddle
for var in remaining_vars:
if var.name not in recv_dense_varnames:
continue
tensor = paddle.load(os.path.join(model_path, var.name))
var.set_value(tensor)
self._communicator.init_params(denses)
def _load_distributed_persistables(self, path, mode):
self._worker.load_model(path, mode)
def load_model(self, path, mode):
if mode == 0 or mode == 3:
self._load_distributed_persistables(path, mode)
else:
self._ps_inference_load_inference_model(path, mode)
def _shrink(self, threshold=None):
if threshold is not None:
warnings.warn(
"The param threshold is not used in MemorySparseTable, if you need to shrink, please set the config of accessor"
)
else:
threshold = 0
fleet.util.barrier()
if self.role_maker._is_first_worker():
sparses = sget_the_one_recv_context(
self.context,
is_dense=False,
split_dense_table=self.role_maker.
_is_heter_parameter_server_mode,
use_origin_program=True)
for id, names in sparses.items():
self._worker.shrink_sparse_table(id, threshold)
fleet.util.barrier()
......@@ -11,3 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .ps_program_builder import *
from .public import *
__all__ = [
'PsProgramBuilder', 'GeoPsProgramBuilder', 'CpuSyncPsProgramBuilder',
'CpuAsyncPsProgramBuilder', 'GpuPsProgramBuilder',
'HeterAsyncPsProgramBuilder', 'FlPsProgramBuilder'
]
class PsProgramBuilderFactory(object):
def __init__(self):
pass
def _create_ps_program_builder(self, pass_ctx):
attrs = pass_ctx._attrs
if attrs['ps_mode'] == DistributedMode.GEO:
return globals()['GeoPsProgramBuilder'](pass_ctx)
elif attrs['use_ps_gpu']:
return globals()['GpuPsProgramBuilder'](pass_ctx)
elif attrs['is_heter_ps_mode']:
return globals()['HeterAsyncPsProgramBuilder'](pass_ctx)
elif 'is_fl_ps_mode' in attrs and attrs[
'is_fl_ps_mode'] == DistributedMode.FL:
return globals()['FlPsProgramBuilder'](pass_ctx)
else:
return globals()['CpuSyncPsProgramBuilder'](pass_ctx)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .public import *
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready
from paddle.distributed.passes import new_pass, PassContext
class PsProgramBuilder(object):
def __init__(self, pass_ctx):
self.pass_ctx = pass_ctx
self.attrs = self.pass_ctx._attrs
self.loss = self.attrs['loss']
self.cloned_main = self.attrs['cloned_main']
self.cloned_startup = self.attrs['cloned_startup']
self.use_ps_gpu = self.attrs['use_ps_gpu']
self.use_heter_ps = self.attrs['is_heter_ps_mode']
self.is_worker = self.attrs['is_worker']
self.is_heter_worker = self.attrs['is_heter_worker']
self.ps_mode = self.attrs['ps_mode']
self.launch_barrier = self.attrs['launch_barrier']
self.launch_barrier_flag = self.attrs['launch_barrier_flag']
self.server_endpoints = self.attrs['role_maker']._get_pserver_endpoints(
)
def _optimize_programs(self):
pass
def _build_trainer_programs(self):
pass
def _build_pserver_programs(self):
is_sgd_adam = False
ops = get_optimize_ops(self.attrs['origin_main_program'])
if len(ops) == 0:
return
add_lr_decay_table_pass = new_pass('add_lr_decay_table_pass',
self.attrs)
add_lr_decay_table_pass.apply([], [], self.pass_ctx)
for op in ops:
if op.type in ["sgd", "adam"]:
is_sgd_adam = True
break
if is_sgd_adam:
return
def _build_programs(self):
if self.attrs['is_worker']:
self._build_trainer_programs()
fluid.framework.switch_startup_program(self.cloned_startup)
self.loss.block.program = self.cloned_main
elif self.attrs['is_server']:
self._build_pserver_programs()
self.loss.block.program = self.attrs['_main_server']
fluid.framework.switch_startup_program(self.attrs[
'_startup_server'])
class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式
def __init__(self, pass_ctx):
super(GeoPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode != DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "GeoPsProgramBuilder"))
def _build_trainer_programs(self):
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
append_send_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
self.attrs['origin_main_program'] = self.cloned_main
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
return
class CpuSyncPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx):
super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode == DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "CpuSyncPsProgramBuilder"))
def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
self.attrs)
add_lr_decay_table_pass.apply([], [], self.pass_ctx)
distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs)
distributed_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
delete_optimizer_pass = new_pass("delete_optimizer_pass", self.attrs)
delete_optimizer_pass.apply([self.cloned_main], [None], self.pass_ctx)
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
append_send_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
delete_extra_optimizer_pass = new_pass("delete_extra_optimizer_pass",
self.attrs)
delete_extra_optimizer_pass.apply([self.attrs['origin_main_program']],
[self.cloned_startup], self.pass_ctx)
fake_init_ops_pass = new_pass("fake_init_ops_pass", self.attrs)
fake_init_ops_pass.apply([None], [self.cloned_startup], self.pass_ctx)
self.attrs['origin_main_program'] = self.cloned_main
self.attrs['origin_startup_program'] = self.cloned_startup
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
return
class CpuAsyncPsProgramBuilder(CpuSyncPsProgramBuilder):
def __init__(self, pass_ctx):
super(CpuAsyncPsProgramBuilder, self).__init__(pass_ctx)
class GpuPsProgramBuilder(PsProgramBuilder): # 和 geo、sync、async 等模式无关
def __init__(self, pass_ctx):
super(GpuPsProgramBuilder, self).__init__(pass_ctx)
def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
self.attrs)
add_lr_decay_table_pass.apply([], [], self.pass_ctx)
distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs)
distributed_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
fake_init_ops_pass = new_pass("fake_init_ops_pass", self.attrs)
fake_init_ops_pass.apply([None], [self.cloned_startup], self.pass_ctx)
ps_gpu_pass = new_pass("ps_gpu_pass", self.attrs)
ps_gpu_pass.apply([self.cloned_main], [None], self.pass_ctx)
ps_transpile_pass = new_pass("ps_transpile_pass", self.attrs)
ps_transpile_pass.apply([_main], [_startup], self.pass_ctx)
self.attrs['origin_main_program'] = self.cloned_main
self.attrs['origin_startup_program'] = self.cloned_startup
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
return
class HeterAsyncPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx):
super(HeterAsyncPsProgramBuilder, self).__init__(pass_ctx)
if self.use_ps_gpu or self.ps_mode == DistributedMode.GEO or self.attrs[
'is_heter_ps_mode'] == False:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "HeterAsyncPsProgramBuilder"))
def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
self.attrs)
add_lr_decay_table_pass.apply([], [], self.pass_ctx)
distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs)
distributed_ops_pass.apply([self.cloned_main], [], self.pass_ctx)
delete_optimizer_pass = new_pass("delete_optimizer_pass", self.attrs)
delete_optimizer_pass.apply([None], [_startup], self.pass_ctx)
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
append_send_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
delete_extra_optimizer_pass = new_pass("delete_extra_optimizer_pass",
self.attrs)
delete_extra_optimizer_pass.apply([self.attrs['origin_main_program']],
[self.cloned_startup], self.pass_ctx)
fake_init_ops_pass = new_pass("fake_init_ops_pass", self.attrs)
fake_init_ops_pass.apply([None], [self.cloned_startup], self.pass_ctx)
if self.is_heter_worker:
split_heter_worker_ops_pass = new_pass(
"split_heter_worker_ops_pass", self.attrs)
split_heter_worker_ops_pass.apply([self.cloned_main], [None],
self.pass_ctx)
else:
split_trainer_ops_pass = new_pass("split_trainer_ops_pass",
self.attrs)
split_trainer_ops_pass([self.cloned_main], [], self.pass_ctx)
set_heter_pipeline_opt_pass = new_pass('set_heter_pipeline_opt_pass',
self.attrs)
set_heter_pipeline_opt_pass.apply([self.cloned_main],
[self.cloned_startup], pass_ctx)
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
return
def _build_programs(self):
if self.attrs['is_worker'] or self.attrs['is_heter_worker']:
self._build_trainer_programs()
ps_set_heter_pipeline_opt_pass = new_pass(
"set_heter_pipeline_opt_pass", self.attrs)
ps_set_heter_pipeline_opt_pass.apply(
[self.loss.block.program], [startup_program], self.pass_ctx)
elif self.attrs['is_server']:
self._build_pserver_programs()
self.loss.block.program = self.attrs['_main_server']
fluid.framework.switch_startup_program(self.attrs[
'_startup_server'])
class FlPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx):
super(FlPsProgramBuilder, self).__init__(pass_ctx)
def _build_trainer_programs(self):
pass
def _build_pserver_programs(self):
pass
def _build_programs(self):
pass
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from functools import reduce
import collections
import math
import os
import warnings
import logging
import six
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.core import CommContext
import paddle.fluid.framework as framework
import paddle.distributed.fleet as fleet
OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@"
LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName()
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"]
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
class DistributedMode:
SYNC = 0
ASYNC = 1
HALF_ASYNC = 2
GEO = 3
FL = 4
class TrainerRuntimeConfig(object):
def __init__(self, valid_strategy):
k_steps = valid_strategy.a_sync_configs["k_steps"]
if not valid_strategy.a_sync and k_steps == 0:
self.mode = DistributedMode.SYNC
if valid_strategy.a_sync and k_steps == 0:
self.mode = DistributedMode.ASYNC
if valid_strategy.a_sync and k_steps > 0:
self.mode = DistributedMode.GEO
self.mode = None
num_threads = os.getenv("CPU_NUM", "1")
self.runtime_configs = {}
self.runtime_configs['communicator_max_merge_var_num'] = os.getenv(
"FLAGS_communicator_max_merge_var_num", num_threads)
self.runtime_configs['communicator_send_queue_size'] = os.getenv(
"FLAGS_communicator_send_queue_size", num_threads)
self.runtime_configs[
'communicator_independent_recv_thread'] = os.getenv(
"FLAGS_communicator_independent_recv_thread", "1")
self.runtime_configs[
'communicator_min_send_grad_num_before_recv'] = os.getenv(
"FLAGS_communicator_min_send_grad_num_before_recv", num_threads)
self.runtime_configs['communicator_thread_pool_size'] = os.getenv(
"FLAGS_communicator_thread_pool_size", "5")
self.runtime_configs['communicator_send_wait_times'] = os.getenv(
"FLAGS_communicator_send_wait_times", "5")
self.runtime_configs['communicator_is_sgd_optimizer'] = os.getenv(
"FLAGS_communicator_is_sgd_optimizer", "1")
def get_lr_ops(program):
lr_ops = []
for index, op in enumerate(program.global_block().ops):
role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME))
if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \
role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \
int(OPT_OP_ROLE_ATTR_VALUE):
lr_ops.append(op)
return lr_ops
def get_optimize_ops(_program):
block = _program.global_block()
opt_ops = []
for op in block.ops:
if _is_opt_role_op(op):
# delete clip op from opt_ops when run in Parameter Server mode
if OP_NAME_SCOPE in op.all_attrs() \
and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE):
op._set_attr(
"op_role",
int(core.op_proto_and_checker_maker.OpRole.Backward))
continue
opt_ops.append(op)
return opt_ops
def get_dist_env():
trainer_id = int(os.getenv('PADDLE_TRAINER_ID', '0'))
trainer_endpoints = ''
current_endpoint = ''
num_trainers = 0
if os.getenv('PADDLE_TRAINER_ENDPOINTS'):
trainer_endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS')
current_endpoint = trainer_endpoints.split(',')[trainer_id]
num_trainers = len(trainer_endpoints.split(','))
return {
'trainer_id': trainer_id,
'num_trainers': num_trainers,
'current_endpoint': current_endpoint,
'trainer_endpoints': trainer_endpoints
}
def get_ps_endpoint(role_maker):
try:
return role_maker._get_pserver_endpoints()[get_role_id(role_maker)]
except Exception:
return role_maker.get_pserver_endpoints()[get_role_id(role_maker)]
def get_heter_worker_endpoint(role_maker):
try:
return role_maker._get_heter_worker_endpoint()
except Exception:
return role_maker.get_heter_worker_endpoint()
def get_trainer_endpoint(role_maker):
try:
return role_maker._get_trainer_endpoint()
except Exception:
return role_maker.get_trainer_endpoint()
def get_previous_stage_trainers(role_maker):
try:
return role_maker_get_previous_trainers()
except Exception:
return role_maker.get_previous_trainers()
def is_distributed_sparse_op(op):
if op.type in SPARSE_OP_LIST and op.attr('is_distributed') is True:
return True
if op.type == "distributed_lookup_table" and op.attr(
'is_distributed') is True:
return True
return False
def get_sparse_tablename(op):
return op.input("W")[0]
def is_sparse_op(op):
if op.type in SPARSE_OP_LIST and op.attr('is_sparse') is True and op.attr(
'is_distributed') is False:
return True
if op.type == "distributed_lookup_table" and op.attr(
'is_distributed') is False:
return True
return False
def get_sparse_tablenames(program, is_distributed):
tablenames = set()
if is_distributed:
for op in program.global_block().ops:
if is_distributed_sparse_op(op):
tablenames.add(get_sparse_tablename(op))
else:
for op in program.global_block().ops:
if is_sparse_op(op):
tablenames.add(get_sparse_tablename(op))
return list(tablenames)
def get_role_id(role_maker):
try:
return role_maker._role_id()
except Exception:
return role_maker.role_id()
def get_ps_endpoints(role_maker):
try:
return role_maker._get_pserver_endpoints()[get_role_id(role_maker)]
except Exception:
return role_maker.get_pserver_endpoints()[get_role_id(role_maker)]
def get_trainers(role_maker):
try:
return role_maker._worker_num()
except Exception:
return role_maker.worker_num()
def get_dense_send_context(context,
send_ctx,
idx,
merged_dense_pairs,
trainer_id,
split_dense_table=False):
if len(merged_dense_pairs) < 1:
return idx
if not split_dense_table:
origin_varnames = []
var_numel = 0
for merged in merged_dense_pairs:
grad = merged[1]
origin_varnames.append(grad.merged_var.name)
var = context['origin_main_program'].global_block().vars[
grad.merged_var.name]
var_numel += reduce(lambda x, y: x * y, var.shape)
grad_name = "Dense@Grad"
trainer_id = get_role_id(context['role_maker'])
aggregate = True
dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], origin_varnames, trainer_id,
aggregate, False, False, idx, False)
send_ctx[grad_name] = dense_ctx
idx += 1
else:
for merged in merged_dense_pairs:
grad = merged[1]
origin_varname = grad.merged_var.name
var = context['origin_main_program'].global_block().vars[
origin_varname]
var_numel = reduce(lambda x, y: x * y, var.shape)
grad_name = origin_varname
aggregate = True
dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], [origin_varname], trainer_id,
aggregate, False, False, idx, False)
send_ctx[grad_name] = dense_ctx
idx += 1
return idx
def get_geo_trainer_send_context(context):
if context['ps_mode'] != DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "get_geo_trainer_send_context"))
send_ctx = {}
return send_ctx
def _step_ctx(idx, role_maker):
name = STEP_COUNTER
trainer_id = get_role_id(role_maker)
endpoints = get_ps_endpoints(role_maker)
sections = [1] * len(endpoints)
names = [name] * len(endpoints)
ctx = CommContext(name, names, endpoints, sections, [name], trainer_id,
True, False, False, idx, True)
return name, ctx
def get_the_one_send_context(context,
split_dense_table=False,
use_origin_program=False,
ep_list=None):
if ep_list is None:
ep_list = ["127.0.0.1:6071"]
send_ctx = {}
trainer_id = get_role_id(context['role_maker'])
idx = 0
idx += get_dense_send_context(context, send_ctx, idx,
context['merged_dense_pairs'], trainer_id,
split_dense_table)
distibuted_varnames = get_sparse_tablenames(context['origin_main_program'],
True)
for merged in context['merged_sparse_pairs']:
param, grad = merged
grad_name = grad.merged_var.name
param_name = param.merged_var.name
splited_varname = []
for i in range(len(ep_list)):
splited_varname.append("{}.block{}".format(param_name, i))
is_distributed = True if param_name in distibuted_varnames else False
var = context['origin_main_program'].global_block().vars[
grad.merged_var.name]
shape = list(var.shape)
shape[0] = 0 if is_distributed else shape[0]
sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape,
[grad_name], trainer_id, True, True,
is_distributed, idx, False)
idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx
if len(context['tensor_table']) > 0 and context['is_worker']:
name, ctx = _step_ctx(idx, context['role_maker'])
send_ctx[name] = ctx
return send_ctx
def find_heter_ops(program, default_device="cpu"):
if default_device not in DEVICE_LIST:
raise ValueError("Given device {} is not in device list {}".format(
default_device, DEVICE_LIST))
def _is_heter_op(op, current_heter_device, default_device="cpu"):
heter_devices = list(DEVICE_LIST)
heter_devices.remove(default_device)
op_device = op.attr("op_device")
op_type = op.type
if op_device in heter_devices:
return True
elif op_type in COMMUNICATE_OPS_TYPE and current_heter_device != default_device:
# for distributed communciate ops: send & recv & barrier etc.
# Todo: need update this method
#op._set_attr('op_device', current_heter_device)
return True
elif op_device == None or op_device == default_device:
op._set_attr('op_device', default_device)
return False
return False
def _is_same_device(op, pre_device, default_device="cpu"):
op_device = op.attr("op_device")
if op_device == pre_device:
return True
if pre_device == default_device:
return True
return False
def _append_heter_op(op, current_heter_block_ops, heter_ops):
op_device = op.attr("op_device")
if op_device not in heter_ops:
heter_ops[op_device] = {}
current_heter_block_ops.append(op)
origin_porgram = program.clone()
block = program.global_block()
'''
re-place sum op to fix bug for union forward backward op
'''
var2idx = {}
op_list = list(block.ops)
op_size = len(op_list)
for i in range(op_size - 1, -1, -1):
op_list = list(block.ops)
op = op_list[i]
if "_grad" in op.type:
forward_op_type = op.type.split("_grad")[0]
if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
param_name = op.input(SPARSE_OP_TYPE_DICT[forward_op_type])[0]
if param_name in var2idx:
## insert sum op & remove sum op from var2idx and origin place
op_list = list(block.ops)
sum_op = op_list[var2idx[param_name]]
sum_op_inputs = {
sum_op.input_names[0]: [
block.vars[input]
for input in sum_op.input_arg_names
]
}
sum_op_outputs = {
sum_op.output_names[0]: [
block.vars[output]
for output in sum_op.output_arg_names
]
}
block._insert_op(
index=i + 1,
type=sum_op.type,
inputs=sum_op_inputs,
outputs=sum_op_outputs,
attrs=sum_op.all_attrs())
block._remove_op(var2idx[param_name] + 1)
var2idx.pop(param_name)
for var_ in var2idx:
var2idx[var_] += 1
elif forward_op_type == "elementwise_mul":
"""
get output varname of pre op
"""
output_vars_no_grad = []
for key in op.output_names:
for varname in op.output(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
output_vars_no_grad.append(varname.split("@GRAD")[0])
for no_grad_var in output_vars_no_grad:
if no_grad_var in var2idx:
"""
insert sum op & remove sum op from var2idx and origin place
"""
op_list = list(block.ops)
sum_op = op_list[var2idx[no_grad_var]]
sum_op_inputs = {
sum_op.input_names[0]: [
block.vars[input]
for input in sum_op.input_arg_names
]
}
sum_op_outputs = {
sum_op.output_names[0]: [
block.vars[output]
for output in sum_op.output_arg_names
]
}
block._insert_op(
index=i + 1,
type=sum_op.type,
inputs=sum_op_inputs,
outputs=sum_op_outputs,
attrs=sum_op.all_attrs())
block._remove_op(var2idx[no_grad_var] + 1)
var2idx.pop(no_grad_var)
for var_ in var2idx:
var2idx[var_] += 1
else:
if op.type == "sum":
var = op.output("Out")[0]
if "@GRAD" in var:
origin_var = var.split("@GRAD")[0]
pre_op = op_list[i - 1]
if "_grad" in pre_op.type:
forward_op_type = pre_op.type.split("_grad")[0]
if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \
and pre_op.attr('remote_prefetch') is True:
param_name = pre_op.input(SPARSE_OP_TYPE_DICT[
forward_op_type])[0]
if param_name == origin_var and op.attr(
"op_device") == pre_op.attr("op_device"):
continue
else:
var2idx[origin_var] = i
elif forward_op_type == "elementwise_mul":
output_vars = []
for key in pre_op.output_names:
for varname in pre_op.output(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
output_vars.append(varname)
input_vars = []
for key in op.input_names:
for varname in op.input(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
input_vars.append(varname)
is_match = False
for varname in output_vars:
if varname in input_vars:
is_match = True
break
if is_match:
continue
else:
var2idx[origin_var] = i
else:
var2idx[origin_var] = i
origin_porgram = program.clone()
block = program.global_block()
program_block_ops = []
default_ops = {default_device: {}}
heter_ops = {}
block_index = 0
current_heter_block_ops = []
current_default_block_ops = []
current_heter_device = default_device
is_heter = False
for op in block.ops:
if _is_heter_op(op, current_heter_device, default_device):
# for gpu/xpu-op
is_heter = True
# for cpu-op block append
if len(current_default_block_ops) > 1:
default_ops[default_device][
block_index] = current_default_block_ops
program_block_ops.append(current_default_block_ops)
current_default_block_ops = []
block_index += 1
if _is_same_device(op, current_heter_device, default_device):
# for gpu-op, gpu-op -> gpu-op,...
current_heter_device = op.attr("op_device")
_append_heter_op(op, current_heter_block_ops, heter_ops)
else:
# for gpu-op -> xpu-op, ...
op_device = current_heter_block_ops[0].attr("op_device")
heter_ops[op_device][block_index] = current_heter_block_ops
program_block_ops.append(current_heter_block_ops)
block_index += 1
current_heter_block_ops = []
current_heter_device = op.attr("op_device")
_append_heter_op(op, current_heter_block_ops, heter_ops)
elif is_heter:
# for gpu/xpu-op -> cpu-op
op_device = current_heter_block_ops[0].attr("op_device")
heter_ops[op_device][block_index] = current_heter_block_ops
program_block_ops.append(current_heter_block_ops)
block_index += 1
current_heter_block_ops = []
current_heter_device = default_device
is_heter = False
current_default_block_ops.append(op)
else:
# for cpu-op
current_default_block_ops.append(op)
if current_default_block_ops != []:
default_ops[default_device][block_index] = current_default_block_ops
program_block_ops.append(current_default_block_ops)
if current_heter_block_ops != []:
op_device = current_heter_block_ops[0].attr("op_device")
heter_ops[op_device][block_index] = current_heter_block_ops
program_block_ops.append(current_heter_block_ops)
if len(heter_ops) == 0:
warnings.warn(
"No heterogeneous OP was found in your program , "
" please using fluid.device_guard() to run OPs on different device.")
total_heter_ops = 0
heter_blocks = 0
for device in heter_ops.keys():
heter_block_dict = heter_ops[device]
heter_blocks += len(heter_block_dict)
for _, heter_block in heter_block_dict.items():
total_heter_ops += len(heter_block)
print(
"There are {} OPs in your main_program, and contains {} heter-OPs which is made up of {} heter-blocks.".
format(len(block.ops), total_heter_ops, heter_blocks))
return origin_porgram, heter_ops, default_ops, program_block_ops
def union_forward_gradient_op(program_block_ops_list):
"""
before analyzing the input & output of each block in program_block_list, we should
union the forward op and corresponding gradient op to elimincate the uneccessary variable
transmit
"""
"""
fix for 2emb model, re-place sum op
"""
block_length = len(program_block_ops_list)
union_program_block_ops_list = []
assert block_length % 2 != 0, "the length of program_block_ops_list should be odd"
for i in range(0, block_length // 2):
block_op_list = {"forward": program_block_ops_list[i]}
block_op_list.update({
"backward": program_block_ops_list[block_length - 1 - i]
})
union_program_block_ops_list.append(block_op_list)
block_op_list = {"forward": [], "backward": []}
for op in program_block_ops_list[block_length // 2]:
if not "_grad" in op.type and not (op.type == "sum"):
block_op_list["forward"].append(op)
else:
block_op_list["backward"].append(op)
union_program_block_ops_list.append(block_op_list)
return union_program_block_ops_list
def find_block_joints(program, program_block_ops_list, heter_ops):
block_var_detail = find_entrance_exit_private(program,
program_block_ops_list)
block_var_detail = entrance_exit_check(program, program_block_ops_list,
block_var_detail, heter_ops)
block_var_detail = delete_block_useless_exit(
program, program_block_ops_list, block_var_detail)
return block_var_detail
def find_entrance_exit_private(program, program_block_ops_list):
block_var_detail = []
persistables = []
for index, block_op_list in enumerate(program_block_ops_list):
## forward
block_input, block_output = find_ops_list_input_output(
program, block_op_list["forward"])
persistables = screen_persistables(
program, block_input) + screen_persistables(program, block_output)
# find entrance & exit
block_private_vars = list(set(block_input) & set(block_output))
block_entrance = list(set(block_input) - set(block_private_vars))
block_exit = list(set(block_output) - set(block_private_vars))
detail = {
"forward": {
"entrance": block_entrance,
"exit": block_exit,
"private": block_private_vars,
"persistables": persistables
}
}
## backward
bp_block_input, bp_block_output = find_ops_list_input_output(
program, block_op_list["backward"])
bp_persistables = screen_persistables(
program, bp_block_input) + screen_persistables(program,
bp_block_output)
# find entrance & exit
bp_block_private_vars = list(set(bp_block_input) & set(bp_block_output))
bp_block_entrance = list(
set(bp_block_input) - set(bp_block_private_vars))
bp_block_exit = list(set(bp_block_output) - set(bp_block_private_vars))
detail.update({
"backward": {
"entrance": bp_block_entrance,
"exit": bp_block_exit,
"private": bp_block_private_vars,
"persistables": bp_persistables
}
})
block_var_detail.append(detail)
return block_var_detail
def entrance_exit_check(program, program_block_ops_list, block_var_detail,
heter_ops):
for index in range(len(block_var_detail) - 1, -1, -1):
if index - 1 < 0:
break
previous_block_exit = block_var_detail[index - 1]["forward"]["exit"]
previous_block_exit.sort()
current_block_entrance = block_var_detail[index]["forward"]["entrance"]
backward_entrance = block_var_detail[index]["backward"]["entrance"]
forward_all = block_var_detail[index]["forward"][
"entrance"] + block_var_detail[index]["forward"][
"private"] + block_var_detail[index]["forward"]["exit"]
for var in backward_entrance:
if not ("@GRAD" in var) and not (var in forward_all):
current_block_entrance.append(var)
current_block_entrance.sort()
if previous_block_exit == current_block_entrance:
continue
exist_vars = list(
set(previous_block_exit) & set(current_block_entrance))
need_add_vars = list(set(current_block_entrance) - set(exist_vars))
# var in different stage should not be ignored, since they are not placed in the same program & device
#need_add_vars = find_need_var_from_previous_block(
# need_add_vars, block_var_detail, index, heter_ops)
previous_block_private = block_var_detail[index - 1]["forward"][
"private"]
previous_block_entrance = block_var_detail[index - 1]["forward"][
"entrance"]
for var in need_add_vars:
if var not in previous_block_private and var not in previous_block_entrance:
previous_block_entrance.append(var)
previous_block_exit.append(var)
if not var in current_block_entrance:
current_block_entrance.append(var)
for index in range(0, len(block_var_detail) - 1, 1):
previous_block_exit = block_var_detail[index + 1]["backward"]["exit"]
previous_block_exit.sort()
current_block_entrance = block_var_detail[index]["backward"]["entrance"]
current_block_entrance.sort()
if previous_block_exit == current_block_entrance:
continue
exist_vars = list(
set(previous_block_exit) & set(current_block_entrance))
need_add_vars = list(set(current_block_entrance) - set(exist_vars))
need_ignore_vars = []
for var in need_add_vars:
if not "@GRAD" in var:
need_ignore_vars.append(var)
need_add_vars = list(
set(need_add_vars).difference(set(need_ignore_vars)))
previous_block_private = block_var_detail[index + 1]["backward"][
"private"]
previous_block_entrance = block_var_detail[index + 1]["backward"][
"entrance"]
for var in need_add_vars:
if var not in previous_block_private and var not in previous_block_entrance:
previous_block_entrance.append(var)
previous_block_exit.append(var)
return block_var_detail
def delete_block_useless_exit(program, program_block_ops_list,
block_var_detail):
## forward
for index in range(len(block_var_detail)):
if index == len(block_var_detail) - 1:
break
current_block_exit = block_var_detail[index]["forward"]["exit"]
next_block_entrance = block_var_detail[index + 1]["forward"]["entrance"]
need_delete_var = []
for var in current_block_exit:
if var not in next_block_entrance:
need_delete_var.append(var)
for var in need_delete_var:
current_block_exit.remove(var)
## backward
for index in range(len(block_var_detail) - 1, -1, -1):
if index - 1 < 0:
break
current_block_exit = block_var_detail[index]["backward"]["exit"]
next_block_entrance = block_var_detail[index - 1]["backward"][
"entrance"]
need_delete_var = []
for var in current_block_exit:
if var not in next_block_entrance:
need_delete_var.append(var)
for var in need_delete_var:
current_block_exit.remove(var)
return block_var_detail
def get_communicate_var_info(program,
block_index,
entrance_var_list,
type="forward"):
input_var_reshape_dim = []
input_var_reshape_name = []
if type == "forward":
block_input_var_name = "forward_joint_{}_{}@Heter".format(
block_index - 1, block_index)
else:
block_input_var_name = "backward_joint_{}_{}@Heter".format(
block_index + 1, block_index)
entrance_var_list.sort()
# input
# Heter_SERVER_BLOCK_index@JOINT_VAR -> slice -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> reshape -> var
for name in entrance_var_list:
var = program.global_block().vars[name]
shape = var.shape
recv_var_dim = -1 * reduce(lambda x, y: x * y, shape)
input_var_reshape_dim.append(recv_var_dim)
input_var_reshape_name.append("{}.input_reshape@Heter".format(name))
info = {
"input_var_reshape_dim": input_var_reshape_dim,
"input_var_reshape_name": input_var_reshape_name,
"block_input_var_name": block_input_var_name,
}
return info
def add_vars_by_var_list(var_name_list, origin_program, program, block):
for var_name in var_name_list:
if var_name not in program.global_block(
).vars and var_name not in block.vars:
var = origin_program.global_block().vars[var_name]
if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False)
else:
block._clone_variable(var, force_persistable=False)
def _get_output_map_from_op(varmap, op):
"""Returns a dict from op output name to the vars in varmap."""
iomap = collections.OrderedDict()
for key in op.output_names:
vars = []
for varname in op.output(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def block_append_op(program, origin_program, block, op):
merge_ordereddict = origin_program.global_block().vars.copy()
merge_ordereddict.update(block.vars)
inputs = _get_input_map_from_op(merge_ordereddict, op)
for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if var.name not in program.global_block(
).vars and var.name not in block.vars:
if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False)
else:
block._clone_variable(var, force_persistable=False)
outputs = _get_output_map_from_op(origin_program.global_block().vars, op)
for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if var.name not in program.global_block(
).vars and var.name not in block.vars:
if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False)
else:
block._clone_variable(var, force_persistable=False)
if "_grad" not in op.type:
# for forward op
return block.append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=op.all_attrs())
else:
# for grad op
op_desc = op.desc
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
# append grad op
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op_desc)
new_op_desc._set_attr(op_role_attr_name, backward)
# set device gard
if op.desc.has_attr(device_attr_name):
op_device = op_desc.attr(device_attr_name)
new_op_desc._set_attr(device_attr_name, op_device)
block._sync_with_cpp()
def get_next_stage_trainers(role_maker):
try:
return role_maker._get_next_trainers()
except Exception:
return role_maker.get_next_trainers()
def insert_communicate_op(orign_program,
role_maker,
heter_block,
stage_id,
first_op_index,
block_var_detail,
device,
is_forward=True):
if is_forward:
next_heter_worker_endpoints = get_next_stage_trainers(role_maker)
previous_heter_worker_endpoints = get_previous_stage_trainers(
role_maker)
entrance_var = block_var_detail[stage_id]["forward"]["entrance"]
comm_info = get_communicate_var_info(orign_program, stage_id + 1,
entrance_var)
else:
next_heter_worker_endpoints = get_next_stage_trainers(role_maker)
previous_heter_worker_endpoints = get_previous_stage_trainers(
role_maker)
entrance_var = block_var_detail[stage_id - 1]["backward"]["exit"]
comm_info = get_communicate_var_info(orign_program, stage_id - 1,
entrance_var, "backward")
heter_block._insert_op(
index=first_op_index,
type="send_and_recv",
inputs={"X": heter_block.vars[entrance_var[0]]},
outputs={"Out": []},
attrs={
"mode": "forward" if is_forward else "backward",
"send_var_name": entrance_var + ["microbatch_id"],
"recv_var_name": [],
"message_name": comm_info["block_input_var_name"],
"next_endpoints": next_heter_worker_endpoints,
"previous_endpoints": previous_heter_worker_endpoints,
"trainer_id": get_role_id(role_maker),
"op_device": device,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
return entrance_var
def get_the_one_recv_context(context,
is_dense=True,
split_dense_table=False,
use_origin_program=False):
recv_id_maps = {}
grad_name_to_param_name = {}
if is_dense:
send_ctx = get_the_one_send_context(
context,
split_dense_table=split_dense_table,
use_origin_program=use_origin_program)
for idx, (name, ctx) in enumerate(send_ctx.items()):
if ctx.is_sparse():
continue
if ctx.is_tensor_table():
continue
origin_grad_varnames = ctx.origin_varnames()
param_names = []
for grad_varname in origin_grad_varnames:
param_name = grad_name_to_param_name[grad_varname]
param_names.append(param_name)
recv_id_maps[ctx.table_id()] = param_names
else:
send_ctx = get_the_one_send_context(
context,
split_dense_table=False,
use_origin_program=False,
ep_list=None)
for idx, (name, ctx) in enumerate(send_ctx.items()):
if not ctx.is_sparse():
continue
origin_grad_varnames = ctx.origin_varnames()
param_names = []
for grad_varname in origin_grad_varnames:
param_name = grad_name_to_param_name[grad_varname]
param_names.append(param_name)
recv_id_maps[ctx.table_id()] = param_names
return recv_id_maps
def _get_varname_parts(varname):
# returns origin, blockid, trainerid
orig_var_name = ""
trainer_part = ""
block_part = ""
trainer_idx = varname.find(".trainer_")
if trainer_idx >= 0:
trainer_part = varname[trainer_idx + 1:]
else:
trainer_idx = len(varname)
block_index = varname.find(".block")
if block_index >= 0:
block_part = varname[block_index + 1:trainer_idx]
else:
block_index = len(varname)
orig_var_name = varname[0:min(block_index, trainer_idx)]
return orig_var_name, block_part, trainer_part
dtype_to_size = {
core.VarDesc.VarType.FP16: 2,
core.VarDesc.VarType.FP32: 4,
core.VarDesc.VarType.FP64: 8,
core.VarDesc.VarType.INT16: 2,
core.VarDesc.VarType.INT32: 4,
core.VarDesc.VarType.INT64: 8,
core.VarDesc.VarType.BOOL: 1,
core.VarDesc.VarType.UINT8: 1,
}
def get_var_mem_size(var):
m_size = reduce(lambda x, y: x * y, var.shape)
m_size *= dtype_to_size[var.dtype]
return m_size
class MergedVariable:
def __init__(self, merged, ordered, offsets):
self.merged_var = merged
self.ordered_vars = ordered
self.offsets = offsets
def build_var_distributed(context):
sparse_pairs, dense_pairs = get_param_grads(context['origin_main_program'])
origin_for_sparse = []
origin_for_dense = []
param_name_grad_name = {}
grad_name_to_param_name = {}
context["merged_variables_pairs"] = []
context["merged_sparse_pairs"] = []
context['merged_dense_pairs'] = []
context["merged_variable_map"] = {}
for param, grad in sparse_pairs:
origin_for_sparse.append((param, grad))
for param, grad in dense_pairs:
origin_for_dense.append((param, grad))
for dense_pair in origin_for_dense:
param, grad = dense_pair
m_param = MergedVariable(param, [param], [0])
m_grad = MergedVariable(grad, [grad], [0])
context["merged_variables_pairs"].append((m_param, m_grad))
context["merged_dense_pairs"].append((m_param, m_grad))
for sparse_pair in origin_for_sparse:
param, grad = sparse_pair
m_param = MergedVariable(param, [param], [0])
m_grad = MergedVariable(grad, [grad], [0])
context["merged_variables_pairs"].append((m_param, m_grad))
context["merged_sparse_pairs"].append((m_param, m_grad))
for merged in context["merged_variables_pairs"]:
m_param, m_grad = merged
context["merged_variable_map"][
m_param.merged_var.name] = m_param.merged_var
context["merged_variable_map"][
m_grad.merged_var.name] = m_grad.merged_var
param_merges = []
param_merges.extend(origin_for_sparse)
param_merges.extend(origin_for_dense)
for param, grad in param_merges:
param_name_grad_name[param.name] = grad.name
grad_name_to_param_name[grad.name] = param.name
context["origin_sparse_pairs"] = origin_for_sparse
context["origin_dense_pairs"] = origin_for_dense
context["param_name_to_grad_name"] = param_name_grad_name
context["grad_name_to_param_name"] = grad_name_to_param_name
def _is_opt_role_op(op):
# NOTE : depend on oprole to find out whether this op is for
# optimize
op_maker = core.op_proto_and_checker_maker
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
if op_maker.kOpRoleAttrName() in op.attr_names and \
int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role):
return True
return False
def get_param_grads(origin_program):
def _get_params_grads(sparse_varnames):
block = origin_program.global_block()
dense_param_grads = []
sparse_param_grads = []
optimize_params = set()
origin_var_dict = origin_program.global_block().vars
role_id = int(core.op_proto_and_checker_maker.OpRole.Backward)
for op in block.ops:
if _is_opt_role_op(op):
# delete clip op from opt_ops when run in Parameter Server mode
if OP_NAME_SCOPE in op.all_attrs() \
and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE):
op._set_attr("op_role", role_id)
continue
if op.attr(OP_ROLE_VAR_ATTR_NAME):
param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
if param_name not in optimize_params:
optimize_params.add(param_name)
param_grad = (origin_var_dict[param_name],
origin_var_dict[grad_name])
if param_name in sparse_varnames:
sparse_param_grads.append(param_grad)
else:
dense_param_grads.append(param_grad)
return sparse_param_grads, dense_param_grads
def _get_sparse_varnames():
varnames = []
for op in origin_program.global_block().ops:
if op.type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
varnames.append(param_name)
return list(set(varnames))
sparse_varnames = _get_sparse_varnames()
sparse_param_grads, dense_param_grads = _get_params_grads(sparse_varnames)
return sparse_param_grads, dense_param_grads
def debug_program(file, program, is_trainer):
if is_trainer:
with open(file, 'w+') as f:
f.write(str(program))
else:
with open(file, 'w+') as f:
f.write(str(program))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import sys
import pickle
import shlex
import shutil
import inspect
import unittest
import numpy as np
from collections import OrderedDict
from dist_pass_test_base import prepare_python_path_and_return_module, remove_path_if_exists
import paddle.distributed.fleet as fleet
class PsPassTestBase(unittest.TestCase):
def init(self):
raise NotImplementedError
def setUp(self):
print('Ps setUp...')
def tearDown(self):
print('Ps tearDown...')
def ps_launch(self, config):
cmd = [
sys.executable,
"-u",
] + [
"-m", "launch", "--log_dir", config['log_dir'], "--worker_num",
config['worker_num'], "--server_num", config['server_num'],
"../ps/ps_dnn_trainer.py", "-m", config['ps_mode_config'],
"--run_minimize", config['run_minimize'], "--run_single_pass",
config['run_single_pass'], "--debug_new_pass",
config['debug_new_pass'], "--debug_new_minimize",
config['debug_new_minimize'], "--applied_pass_name",
config['applied_pass_name']
]
cmd = [shlex.quote(c) for c in cmd]
prepare_python_path_and_return_module(__file__)
exitcode = os.system(' '.join(cmd))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle
from ps_pass_test_base import *
from paddle.fluid.tests.unittests.ps.ps_dnn_trainer import DnnTrainer
class TestPsServerPass(PsPassTestBase):
def init(self):
pass
def setUp(self):
print('TestPsServerPass setUp...')
def tearDown(self):
print('TestPsServerPass tearDown...')
def test_add_lr_decay_table_passs(self):
pass
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle
from ps_pass_test_base import *
from paddle.fluid.tests.unittests.ps.ps_dnn_trainer import DnnTrainer
class TestPsTrainerPass(PsPassTestBase):
def init(self):
self.config = {}
self.config['ps_mode_config'] = "../ps/cpu_async_ps_config.yaml"
self.config['worker_num'] = "1"
self.config['server_num'] = "1"
self.config['run_minimize'] = "0"
self.config['run_single_pass'] = "0"
self.config['debug_new_minimize'] = "0"
self.config['debug_new_pass'] = "0"
self.config['log_dir'] = ""
self.config['applied_pass_name'] = ""
def setUp(self):
print('TestPsTrainerPass setUp...')
def tearDown(self):
print('TestPsTrainerPass tearDown...')
def check(self):
pass
def test_ps_optimizer_minimize(self):
self.init()
self.config['run_minimize'] = '1'
self.config['debug_new_minimize'] = '0'
self.config['log_dir'] = "/log_old_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.config['debug_new_minimize'] = '1'
self.config['log_dir'] = "/log_new_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.check()
def test_append_send_ops_pass(self):
self.init()
self.config['run_single_pass'] = '1'
self.config['applied_pass_name'] = "append_send_ops_pass"
self.config['debug_new_pass'] = '0'
self.config['log_dir'] = "/log_old_" + self.config['applied_pass_name']
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.config['debug_new_pass'] = '1'
self.config['log_dir'] = "/log_new_" + self.config['applied_pass_name']
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.check()
def test_distributed_ops_pass(self):
pass
if __name__ == '__main__':
unittest.main()
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
set_tests_properties(test_the_one_ps PROPERTIES TIMEOUT 50)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.p
# Note: On Windows, import form subdirectories such as dirA()->dirB(), current directory
# will still be dirA(), But is should be dirB(). So it will ModulNotFoundError
# please refer to https://stackoverflow.com/questions/8953844/import-module-from-subfolder
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
hyper_parameters:
optimizer:
class: Adam
learning_rate: 0.0001
adam_lazy_mode: True
sparse_inputs_slots: 27
sparse_feature_number: 1000001
sparse_feature_dim: 10
dense_input_dim: 13
fc_sizes: [400, 400, 400]
runner:
geo_step: 400
sync_mode: "async" # sync / async / geo / heter
thread_num: 16
use_gpu: 0
model_path: "../ps_dnn_model.py"
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.ps.utils.ps_program_builder import *
import paddle.distributed.fleet as fleet
import argparse
import time
import sys
import yaml, six, copy
import paddle
import os
import warnings
import logging
import ast
import numpy as np
import struct
sys.path.append("..")
from ps_dnn_model import StaticModel
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
def is_distributed_env():
node_role = os.getenv("TRAINING_ROLE")
logger.info("-- Role: {} --".format(node_role))
if node_role is None:
return False
else:
return True
class YamlHelper(object):
def load_yaml(self, yaml_file, other_part=None):
part_list = ["runner", "hyper_parameters"]
if other_part:
part_list += other_part
running_config = self.get_all_inters_from_yaml(yaml_file, part_list)
running_config = self.workspace_adapter(running_config)
return running_config
def print_yaml(self, config):
print(self.pretty_print_envs(config))
def parse_yaml(self, config):
vs = [int(i) for i in yaml.__version__.split(".")]
if vs[0] < 5:
use_full_loader = False
elif vs[0] > 5:
use_full_loader = True
else:
if vs[1] >= 1:
use_full_loader = True
else:
use_full_loader = False
if os.path.isfile(config):
if six.PY2:
with open(config, 'r') as rb:
if use_full_loader:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
_config = yaml.load(rb.read())
return _config
else:
with open(config, 'r', encoding="utf-8") as rb:
if use_full_loader:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
_config = yaml.load(rb.read())
return _config
else:
raise ValueError("config {} can not be supported".format(config))
def get_all_inters_from_yaml(self, file, filters):
_envs = self.parse_yaml(file)
all_flattens = {}
def fatten_env_namespace(namespace_nests, local_envs):
for k, v in local_envs.items():
if isinstance(v, dict):
nests = copy.deepcopy(namespace_nests)
nests.append(k)
fatten_env_namespace(nests, v)
else:
global_k = ".".join(namespace_nests + [k])
all_flattens[global_k] = v
fatten_env_namespace([], _envs)
ret = {}
for k, v in all_flattens.items():
for f in filters:
if k.startswith(f):
ret[k] = v
return ret
def workspace_adapter(self, config):
workspace = config.get("workspace")
for k, v in config.items():
if isinstance(v, str) and "{workspace}" in v:
config[k] = v.replace("{workspace}", workspace)
return config
def pretty_print_envs(self, envs, header=None):
spacing = 2
max_k = 40
max_v = 45
for k, v in envs.items():
max_k = max(max_k, len(k))
h_format = " " + "|{{:>{}s}}{}{{:^{}s}}|\n".format(max_k, " " *
spacing, max_v)
l_format = " " + "|{{:>{}s}}{{}}{{:^{}s}}|\n".format(max_k, max_v)
length = max_k + max_v + spacing
border = " +" + "".join(["="] * length) + "+"
line = " +" + "".join(["-"] * length) + "+"
draws = ""
draws += border + "\n"
if header:
draws += h_format.format(header[0], header[1])
else:
draws += h_format.format("PaddleRec Benchmark Envs", "Value")
draws += line + "\n"
for k, v in sorted(envs.items()):
if isinstance(v, str) and len(v) >= max_v:
str_v = "... " + v[-41:]
else:
str_v = v
draws += l_format.format(k, " " * spacing, str(str_v))
draws += border
_str = "\n{}\n".format(draws)
return _str
def get_user_defined_strategy(config):
if not is_distributed_env():
logger.warn(
"Not Find Distributed env, Change To local train mode. If you want train with fleet, please use [fleetrun] command."
)
return None
sync_mode = config.get("runner.sync_mode")
assert sync_mode in ["async", "sync", "geo", "heter", "gpubox"]
if sync_mode == "sync":
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = False
elif sync_mode == "async":
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
elif sync_mode == "geo":
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
strategy.a_sync_configs = {"k_steps": config.get("runner.geo_step")}
elif sync_mode == "heter":
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
strategy.a_sync_configs = {"heter_worker_device_guard": "gpu"}
elif sync_mode == "gpubox":
print("sync_mode = {}".format(sync_mode))
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
strategy.a_sync_configs = {"use_ps_gpu": 1}
strategy.trainer_desc_configs = {
"dump_fields_path": config.get("runner.dump_fields_path", ""),
"dump_fields": config.get("runner.dump_fields", []),
"dump_param": config.get("runner.dump_param", []),
"stat_var_names": config.get("stat_var_names", [])
}
print("strategy:", strategy.trainer_desc_configs)
if config.get("runner.fs_client.uri") is not None:
strategy.fs_client_param = {
"uri": config.get("runner.fs_client.uri", ""),
"user": config.get("runner.fs_client.user", ""),
"passwd": config.get("runner.fs_client.passwd", ""),
"hadoop_bin": config.get("runner.fs_client.hadoop_bin", "hadoop")
}
print("strategy:", strategy.fs_client_param)
strategy.adam_d2sum = config.get("hyper_parameters.adam_d2sum", True)
table_config = {}
for x in config:
if x.startswith("table_parameters"):
table_name = x.split('.')[1]
if table_name not in table_config:
table_config[table_name] = {}
table_config[table_name][x] = config[x]
print("table_config:", table_config)
strategy.sparse_table_configs = table_config
print("strategy table config:", strategy.sparse_table_configs)
a_sync_configs = strategy.a_sync_configs
a_sync_configs["launch_barrier"] = False
strategy.a_sync_configs = a_sync_configs
print("launch_barrier: ", strategy.a_sync_configs["launch_barrier"])
return strategy
def get_distributed_strategy(user_defined_strategy):
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
k_steps = user_defined_strategy.a_sync_configs["k_steps"]
strategy = None
if not user_defined_strategy.a_sync and k_steps == 0:
strategy = StrategyFactory.create_sync_strategy()
if user_defined_strategy.a_sync and k_steps == 0:
strategy = StrategyFactory.create_async_strategy()
if user_defined_strategy.a_sync and k_steps > 0:
strategy = StrategyFactory.create_geo_strategy(k_steps)
if not strategy:
raise ValueError("k_steps must be invalid value, please check")
return strategy
def get_model(config):
abs_dir = config['config_abs_dir']
sys.path.append(abs_dir)
static_model = StaticModel(config)
return static_model
def parse_args():
parser = argparse.ArgumentParser("PsTest train script")
parser.add_argument(
'-m', '--config_yaml', type=str, required=True, help='config file path')
parser.add_argument(
'-bf16',
'--pure_bf16',
type=ast.literal_eval,
default=False,
help="whether use bf16")
parser.add_argument(
'--run_minimize', type=int, default=0, help="test single pass")
parser.add_argument(
'--run_single_pass', type=int, default=0, help="test single pass")
parser.add_argument(
'--debug_new_minimize', type=int, default=0, help="test single pass")
parser.add_argument(
'--debug_new_pass', type=int, default=0, help="test single pass")
parser.add_argument(
'--applied_pass_name', type=str, default="", help="test single pass")
args = parser.parse_args()
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
yaml_helper = YamlHelper()
config = yaml_helper.load_yaml(args.config_yaml)
config["yaml_path"] = args.config_yaml
config["config_abs_dir"] = args.abs_dir
config["pure_bf16"] = args.pure_bf16
config['run_minimize'] = args.run_minimize
config['run_single_pass'] = args.run_single_pass
config['debug_new_minimize'] = args.debug_new_minimize
config['debug_new_pass'] = args.debug_new_pass
config['applied_pass_name'] = args.applied_pass_name
yaml_helper.print_yaml(config)
return config
def bf16_to_fp32(val):
return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
class DnnTrainer(object):
def __init__(self, config):
self.metrics = {}
self.config = config
self.input_data = None
self.reader = None
self.exe = None
self.train_result_dict = {}
self.train_result_dict["speed"] = []
self.model = None
self.pure_bf16 = self.config['pure_bf16']
self.role_maker = role_maker.PaddleCloudRoleMaker()
def init_fleet_with_gloo(self, use_gloo=False):
if use_gloo:
os.environ["PADDLE_WITH_GLOO"] = "1"
fleet.init(self.role_maker)
else:
fleet.init()
if fleet.is_server():
logger.info("server: {} started".format(fleet.server_index()))
else:
logger.info("worker: {} started".format(fleet.worker_index()))
def run_minimize(self):
logger.info("entering run_minimize")
self.init_fleet_with_gloo()
self.model = get_model(self.config)
logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
self.input_data = self.model.create_feeds()
self.metrics = self.model.net(self.input_data)
loss = self.model._cost
user_defined_strategy = get_user_defined_strategy(self.config)
learning_rate = self.config.get(
"hyper_parameters.optimizer.learning_rate")
inner_optimizer = paddle.optimizer.Adam(learning_rate, lazy_mode=True)
if self.config['debug_new_minimize'] == 1:
from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer
ps_optimizer = ParameterServerOptimizer(inner_optimizer)
ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer,
user_defined_strategy)
ps_optimizer.minimize_impl(loss)
else:
fleet_obj = fleet.distributed_optimizer(
inner_optimizer, user_defined_strategy) ## Fleet 对象
fleet_obj.minimize(loss)
if fleet.is_server():
_main_file = '/' + 'run_minimize' + '_debug_minimize:_' + str(
self.config['debug_new_minimize']) + '_server_main.prototxt'
debug_program(_main_file, loss.block.program, 0)
elif fleet.is_worker():
_main_file = '/' + 'run_minimize' + '_debug_minimize:_' + str(
self.config['debug_new_minimize']) + '_worker_main.prototxt'
debug_program(_main_file, loss.block.program, 1)
'''
if fleet.is_server():
logger.info("Run Server Begin")
fleet.init_server()
fleet.run_server()
'''
def run_single_pass(self):
logger.info("entering run_single_pass")
self.init_fleet_with_gloo()
self.model = get_model(config)
input_data = self.model.create_feeds()
metrics = self.model.net(input_data)
loss = self.model._cost
user_defined_strategy = get_user_defined_strategy(config)
learning_rate = config.get("hyper_parameters.optimizer.learning_rate")
inner_optimizer = paddle.optimizer.Adam(learning_rate, lazy_mode=True)
startup_program = paddle.static.default_startup_program()
inner_optimizer.minimize(loss, startup_program)
if self.config['debug_new_pass'] == 1:
from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer
ps_optimizer = ParameterServerOptimizer(inner_optimizer)
ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer,
user_defined_strategy)
ps_optimizer._init_ps_pass_context(loss, startup_program)
_main = ps_optimizer.attrs['cloned_main']
append_send_ops_pass = new_pass(config["applied_pass_name"],
ps_optimizer.attrs)
append_send_ops_pass.apply([_main], [None], ps_optimizer.pass_ctx)
else:
from paddle.fluid.incubate.fleet.parameter_server.ir import public as public
dist_strategy = get_distributed_strategy(user_defined_strategy)
compiled_config = public.CompileTimeStrategy(
loss.block.program, startup_program, dist_strategy,
self.role_maker)
_main = compiled_config.origin_main_program.clone()
_startup = compiled_config.origin_startup_program.clone()
from paddle.fluid.incubate.fleet.parameter_server.ir import trainer_pass as worker
_main = worker.append_send_ops_pass(_main, compiled_config)
if fleet.is_server():
_main_file = '/' + str(config[
"applied_pass_name"]) + '_debug_pass:_' + str(self.config[
'debug_new_pass']) + '_server_main.prototxt'
debug_program(_main_file, _main, 0)
elif fleet.is_worker():
_main_file = '/' + str(config[
"applied_pass_name"]) + '_debug_pass:_' + str(self.config[
'debug_new_pass']) + '_worker_main.prototxt'
debug_program(_main_file, _main, 1)
if __name__ == "__main__":
paddle.enable_static()
config = parse_args()
os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
benchmark_main = DnnTrainer(config)
if config['run_single_pass'] == 1:
benchmark_main.run_single_pass()
elif config['run_minimize'] == 1:
benchmark_main.run_minimize()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
class TestTheOnePs(unittest.TestCase):
def setUp(self):
print('setUp...')
def tearDown(self):
print('tearDown...')
def test_main(self):
pass
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import math
import paddle.distributed.fleet as fleet
class DNNLayer(nn.Layer):
def __init__(self,
sparse_feature_number,
sparse_feature_dim,
dense_feature_dim,
num_field,
layer_sizes,
sync_mode=None):
super(DNNLayer, self).__init__()
self.sync_mode = sync_mode
self.sparse_feature_number = sparse_feature_number
self.sparse_feature_dim = sparse_feature_dim
self.dense_feature_dim = dense_feature_dim
self.num_field = num_field
self.layer_sizes = layer_sizes
self.embedding = paddle.nn.Embedding(
self.sparse_feature_number,
self.sparse_feature_dim,
sparse=True,
weight_attr=paddle.ParamAttr(
name="SparseFeatFactors",
initializer=paddle.nn.initializer.Uniform()))
sizes = [sparse_feature_dim * num_field + dense_feature_dim
] + self.layer_sizes + [2]
acts = ["relu" for _ in range(len(self.layer_sizes))] + [None]
self._mlp_layers = []
for i in range(len(layer_sizes) + 1):
linear = paddle.nn.Linear(
in_features=sizes[i],
out_features=sizes[i + 1],
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Normal(
std=1.0 / math.sqrt(sizes[i]))))
self.add_sublayer('linear_%d' % i, linear)
self._mlp_layers.append(linear)
if acts[i] == 'relu':
act = paddle.nn.ReLU()
self.add_sublayer('act_%d' % i, act)
self._mlp_layers.append(act)
def forward(self, sparse_inputs, dense_inputs):
sparse_embs = []
for s_input in sparse_inputs:
if self.sync_mode == "gpubox":
emb = paddle.fluid.contrib.sparse_embedding(
input=s_input,
size=[self.sparse_feature_number, self.sparse_feature_dim],
param_attr=paddle.ParamAttr(name="embedding"))
else:
emb = self.embedding(s_input)
emb = paddle.reshape(emb, shape=[-1, self.sparse_feature_dim])
sparse_embs.append(emb)
y_dnn = paddle.concat(x=sparse_embs + [dense_inputs], axis=1)
for n_layer in self._mlp_layers:
y_dnn = n_layer(y_dnn)
return y_dnn
class StaticModel():
def __init__(self, config):
self.cost = None
self.infer_target_var = None
self.config = config
self._init_hyper_parameters()
self.sync_mode = config.get("runner.sync_mode")
def _init_hyper_parameters(self):
self.is_distributed = False
self.distributed_embedding = False
if self.config.get("hyper_parameters.distributed_embedding", 0) == 1:
self.distributed_embedding = True
self.sparse_feature_number = self.config.get(
"hyper_parameters.sparse_feature_number")
self.sparse_feature_dim = self.config.get(
"hyper_parameters.sparse_feature_dim")
self.sparse_inputs_slots = self.config.get(
"hyper_parameters.sparse_inputs_slots")
self.dense_input_dim = self.config.get(
"hyper_parameters.dense_input_dim")
self.learning_rate = self.config.get(
"hyper_parameters.optimizer.learning_rate")
self.fc_sizes = self.config.get("hyper_parameters.fc_sizes")
def create_feeds(self, is_infer=False):
dense_input = paddle.static.data(
name="dense_input",
shape=[None, self.dense_input_dim],
dtype="float32")
sparse_input_ids = [
paddle.static.data(
name="C" + str(i), shape=[None, 1], dtype="int64")
for i in range(1, self.sparse_inputs_slots)
]
label = paddle.static.data(name="label", shape=[None, 1], dtype="int64")
feeds_list = [label] + sparse_input_ids + [dense_input]
return feeds_list
def net(self, input, is_infer=False):
self.label_input = input[0]
self.sparse_inputs = input[1:self.sparse_inputs_slots]
self.dense_input = input[-1]
sparse_number = self.sparse_inputs_slots - 1
dnn_model = DNNLayer(
self.sparse_feature_number,
self.sparse_feature_dim,
self.dense_input_dim,
sparse_number,
self.fc_sizes,
sync_mode=self.sync_mode)
raw_predict_2d = dnn_model.forward(self.sparse_inputs, self.dense_input)
predict_2d = paddle.nn.functional.softmax(raw_predict_2d)
self.predict = predict_2d
auc, batch_auc, [
self.batch_stat_pos, self.batch_stat_neg, self.stat_pos,
self.stat_neg
] = paddle.static.auc(input=self.predict,
label=self.label_input,
num_thresholds=2**12,
slide_steps=20)
self.inference_target_var = auc
if is_infer:
fetch_dict = {'auc': auc}
return fetch_dict
cost = paddle.nn.functional.cross_entropy(
input=raw_predict_2d, label=self.label_input)
avg_cost = paddle.mean(x=cost)
self._cost = avg_cost
fetch_dict = {'cost': avg_cost, 'auc': auc}
return fetch_dict
......@@ -270,6 +270,8 @@ packages=['paddle',
'paddle.reader',
'paddle.distributed',
'paddle.distributed.metric',
'paddle.distributed.ps',
'paddle.distributed.ps.utils',
'paddle.incubate',
'paddle.incubate.optimizer',
'paddle.incubate.checkpoint',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册