“f233b936c7b26d4147a6d1cc936dd62546991437”上不存在“benchmark/git@gitcode.net:Crayonxin2000/Paddle.git”
未验证 提交 1882a74f 编写于 作者: W wangzhen38 提交者: GitHub

[RM FLUID] rm ps ir (#50691)

上级 e84fa263
...@@ -16,8 +16,7 @@ import os ...@@ -16,8 +16,7 @@ import os
import warnings import warnings
import paddle import paddle
import paddle.fluid as fluid from paddle.framework import core
from paddle.fluid import core
from paddle.static import ( from paddle.static import (
CompiledProgram, CompiledProgram,
Executor, Executor,
...@@ -73,7 +72,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -73,7 +72,7 @@ class ParameterServerRuntime(RuntimeBase):
return strategy return strategy
def build_compiled_startegy(self): def build_compiled_startegy(self):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
CompileTimeStrategy, CompileTimeStrategy,
) )
...@@ -102,7 +101,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -102,7 +101,7 @@ class ParameterServerRuntime(RuntimeBase):
if main_program is None: if main_program is None:
main_program = self.origin_main_program main_program = self.origin_main_program
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
_get_varname_parts, _get_varname_parts,
) )
...@@ -111,7 +110,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -111,7 +110,7 @@ class ParameterServerRuntime(RuntimeBase):
origin_varname, _, _ = _get_varname_parts(each_var.name) origin_varname, _, _ = _get_varname_parts(each_var.name)
new_var = fluid.io._clone_var_in_block_(load_block, each_var) new_var = paddle.static.io._clone_var_in_block(load_block, each_var)
var_path = os.path.join(dirname, origin_varname) var_path = os.path.join(dirname, origin_varname)
if not os.path.exists(var_path): if not os.path.exists(var_path):
raise ValueError( raise ValueError(
...@@ -138,7 +137,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -138,7 +137,7 @@ class ParameterServerRuntime(RuntimeBase):
def _load_distributed_params(self, dirname, varnames): def _load_distributed_params(self, dirname, varnames):
from paddle.distributed.communicator import LargeScaleKV from paddle.distributed.communicator import LargeScaleKV
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
_get_varname_parts, _get_varname_parts,
) )
...@@ -154,7 +153,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -154,7 +153,7 @@ class ParameterServerRuntime(RuntimeBase):
if var.name in exclude_var_names: if var.name in exclude_var_names:
return False return False
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
_get_varname_parts, _get_varname_parts,
) )
...@@ -185,7 +184,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -185,7 +184,7 @@ class ParameterServerRuntime(RuntimeBase):
return kwargs return kwargs
def geo_strategy_envs(): def geo_strategy_envs():
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
get_sparse_tablenames, get_sparse_tablenames,
) )
...@@ -239,14 +238,14 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -239,14 +238,14 @@ class ParameterServerRuntime(RuntimeBase):
kwargs["sparse_attrs"] = get_sparse_attrs() kwargs["sparse_attrs"] = get_sparse_attrs()
return kwargs return kwargs
from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
_get_lr_ops,
_has_global_step,
)
from paddle.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import ( from paddle.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
GeoStrategy, GeoStrategy,
SyncStrategy, SyncStrategy,
) )
from paddle.incubate.fleet.parameter_server.ir.public import (
_get_lr_ops,
_has_global_step,
)
trainer_config = self.async_strategy.get_trainer_runtime_config() trainer_config = self.async_strategy.get_trainer_runtime_config()
print(trainer_config) print(trainer_config)
...@@ -475,7 +474,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -475,7 +474,7 @@ class ParameterServerRuntime(RuntimeBase):
return reshaped_names, origin_names return reshaped_names, origin_names
def _get_optimizer_op(self, param_name): def _get_optimizer_op(self, param_name):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
_get_optimize_ops, _get_optimize_ops,
) )
......
...@@ -36,7 +36,7 @@ PSERVER_SAVE_SUFFIX = ".shard" ...@@ -36,7 +36,7 @@ PSERVER_SAVE_SUFFIX = ".shard"
def parse_table_class(varname, o_main_program): def parse_table_class(varname, o_main_program):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
is_distributed_sparse_op, is_distributed_sparse_op,
is_sparse_op, is_sparse_op,
) )
...@@ -247,7 +247,7 @@ class CommonAccessor: ...@@ -247,7 +247,7 @@ class CommonAccessor:
self.opt_init_map = opt_init_map self.opt_init_map = opt_init_map
def parse_entry(self, varname, o_main_program): def parse_entry(self, varname, o_main_program):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
is_distributed_sparse_op, is_distributed_sparse_op,
is_sparse_op, is_sparse_op,
) )
...@@ -304,7 +304,7 @@ class CommonAccessor: ...@@ -304,7 +304,7 @@ class CommonAccessor:
compiled_strategy, compiled_strategy,
adam_d2sum, adam_d2sum,
): ):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
_get_optimize_ops, _get_optimize_ops,
) )
...@@ -716,7 +716,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -716,7 +716,7 @@ class TheOnePSRuntime(RuntimeBase):
return strategy return strategy
def build_compiled_startegy(self): def build_compiled_startegy(self):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
CompileTimeStrategy, CompileTimeStrategy,
) )
...@@ -1191,7 +1191,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1191,7 +1191,7 @@ class TheOnePSRuntime(RuntimeBase):
proto_txt, string_hosts, role_id, trainers, self._server_sub_program proto_txt, string_hosts, role_id, trainers, self._server_sub_program
) )
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
get_sparse_tablenames, get_sparse_tablenames,
) )
...@@ -1252,7 +1252,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1252,7 +1252,7 @@ class TheOnePSRuntime(RuntimeBase):
if var.name in exclude_var_names: if var.name in exclude_var_names:
return False return False
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
_get_varname_parts, _get_varname_parts,
) )
...@@ -1283,7 +1283,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1283,7 +1283,7 @@ class TheOnePSRuntime(RuntimeBase):
def _save_sparse_params( def _save_sparse_params(
self, executor, dirname, context, main_program, mode self, executor, dirname, context, main_program, mode
): ):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
get_sparse_tablenames, get_sparse_tablenames,
) )
...@@ -1479,7 +1479,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1479,7 +1479,7 @@ class TheOnePSRuntime(RuntimeBase):
self._ps_inference_save_persistables(*args, **kwargs) self._ps_inference_save_persistables(*args, **kwargs)
def _load_sparse_params(self, dirname, context, main_program, mode): def _load_sparse_params(self, dirname, context, main_program, mode):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.incubate.fleet.parameter_server.ir.public import (
get_sparse_tablenames, get_sparse_tablenames,
) )
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class PSDispatcher:
"""
PSDispatcher is the base class for dispatching vars
into different pserver instance.
You need to implement the `dispatch` interface.
"""
def __init__(self, pserver_endpoints):
self._eps = pserver_endpoints
self._step = 0
@property
def eps(self):
return self._eps
def reset(self):
"""
reset the step counter, set it zero.
"""
self._step = 0
def dispatch(self, varlist):
"""
Args:
varlist(list): a list of Variables
Returns:
a map of pserver endpoint -> varname
"""
raise NotImplementedError("Interface has not been implemented.")
class HashName(PSDispatcher):
"""
Hash variable names to several endpoints using python
"hash()" function.
Args:
pserver_endpoints (list): list of endpoint(ip:port).
Examples:
.. code-block:: python
pserver_endpoints = ["127.0.0.1:6007", "127.0.0.1:6008"]
vars = ["var1","var2","var3","var4","var5"]
rr = RoundRobin(pserver_endpoints)
rr.dispatch(vars)
"""
def __init__(self, pserver_endpoints):
super().__init__(pserver_endpoints)
def _hash_block(self, block_str, total):
return hash(block_str) % total
def dispatch(self, varlist):
"""
use `HashName` method to dispatch variables with each parameter server.
Args:
varlist (list): a list of Variables
"""
eplist = []
for var in varlist:
server_id = self._hash_block(var.name(), len(self._eps))
server_for_param = self._eps[server_id]
eplist.append(server_for_param)
return eplist
class RoundRobin(PSDispatcher):
"""
Distribute variables to several endpoints using
RondRobin<https://en.wikipedia.org/wiki/Round-robin_scheduling> method.
Args:
pserver_endpoints (list): list of endpoint(ip:port).
Examples:
.. code-block:: python
pserver_endpoints = ["127.0.0.1:6007", "127.0.0.1:6008"]
vars = ["var1","var2","var3","var4","var5"]
rr = RoundRobin(pserver_endpoints)
rr.dispatch(vars)
"""
def __init__(self, pserver_endpoints):
super().__init__(pserver_endpoints)
def dispatch(self, varlist):
"""
use `RoundRobin` method to dispatch variables with each parameter server.
Args:
varlist (list): a list of Variables
"""
eplist = []
for var in varlist:
server_for_param = self._eps[self._step]
eplist.append(server_for_param)
self._step += 1
if self._step >= len(self._eps):
self._step = 0
return eplist
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from paddle.framework import core, Block
from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
_get_optimize_ops,
)
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _orig_varname
from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
_get_varname_parts,
)
from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
is_distributed_sparse_op,
)
from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
get_sparse_tablename,
)
from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
get_sparse_tablenames,
)
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_lr_ops
LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName()
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
def _is_optimizer_op(op):
if "Param" in op.input_names and "LearningRate" in op.input_names:
return True
return False
def _same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block")
def _get_optimizer_input_shape(op_type, varkey, orig_shape, param_shape):
"""
Returns the shape for optimizer inputs that need to be reshaped when
Param and Grad is split to multiple servers.
"""
# HACK(typhoonzero) : Should use functions of corresponding optimizer in
# optimizer.py to get the shape, do not bind this in the transpiler.
if op_type == "adam":
if varkey in ["Moment1", "Moment2"]:
return param_shape
elif op_type == "adagrad":
if varkey == "Moment":
return param_shape
elif op_type == "adamax":
if varkey in ["Moment", "InfNorm"]:
return param_shape
elif op_type in ["momentum", "lars_momentum"]:
if varkey == "Velocity":
return param_shape
elif op_type == "rmsprop":
if varkey in ["Moment", "MeanSquare"]:
return param_shape
elif op_type == "decayed_adagrad":
if varkey == "Moment":
return param_shape
elif op_type == "ftrl":
if varkey in ["SquaredAccumulator", "LinearAccumulator"]:
return param_shape
elif op_type == "sgd":
pass
else:
raise ValueError(
"Not supported optimizer for distributed training: %s" % op_type
)
return orig_shape
def _append_pserver_non_opt_ops(optimize_block, opt_op, origin_program, config):
def _get_pserver_grad_param_var(var, var_dict):
"""
Return pserver side grad/param variable, return None
if the variable is not grad/param, e.g.
a@GRAD -> a@GRAD.block0
a@GRAD -> a@GRAD (a is not split)
fc_0.w_0 -> fc_0.w_0.block_0
fc_0.w_0 -> fc_0.w_0 (weight is not split)
_generated_var_123 -> None
"""
grad_block = None
for _, g in var_dict.items():
if _orig_varname(g.name) == _orig_varname(var.name):
# skip per trainer vars
if g.name.find(".trainer_") == -1:
# only param or grads have split blocks
ovar_name = _orig_varname(g.name)
if ovar_name in config.param_grad_ep_mapping:
grad_block = g
break
elif ovar_name in config.grad_param_mapping:
grad_block = g
break
return grad_block
program = optimize_block.program
# Append the ops for parameters that do not need to be optimized / updated
inputs = _get_input_map_from_op(origin_program.global_block().vars, opt_op)
for key, varlist in inputs.items():
if not isinstance(varlist, list):
varlist = [varlist]
for i in range(len(varlist)):
var = varlist[i]
# for ops like clipping and weight decay, get the split var(xxx.block0)
# for inputs / outputs
grad_block = _get_pserver_grad_param_var(
var, program.global_block().vars
)
if grad_block:
varlist[i] = grad_block
elif var.name not in program.global_block().vars:
tmpvar = program.global_block()._clone_variable(var)
varlist[i] = tmpvar
else:
varlist[i] = program.global_block().vars[var.name]
inputs[key] = varlist
outputs = _get_output_map_from_op(
origin_program.global_block().vars, opt_op
)
for key, varlist in outputs.items():
if not isinstance(varlist, list):
varlist = [varlist]
for i in range(len(varlist)):
var = varlist[i]
grad_block = _get_pserver_grad_param_var(
var, program.global_block().vars
)
if grad_block:
varlist[i] = grad_block
elif var.name not in program.global_block().vars:
tmpvar = program.global_block()._clone_variable(var)
varlist[i] = tmpvar
else:
varlist[i] = program.global_block().vars[var.name]
outputs[key] = varlist
return optimize_block.append_op(
type=opt_op.type,
inputs=inputs,
outputs=outputs,
attrs=opt_op.all_attrs(),
)
def _append_pserver_ops(
optimize_block,
opt_op,
endpoint,
grad_to_block_id,
origin_program,
merged_var,
sparse_grad_to_param,
config,
):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = collections.OrderedDict()
def _get_param_block(opt_op):
# param is already created on global program
unmerged_vars = []
merged_vars = []
merged_ordervars = []
param_vars = [
p for p in config.param_grad_ep_mapping[endpoint]["params"]
]
for var in param_vars:
name = var.name
orig_varname = _orig_varname(name)
for pairs in config.merged_variables_pairs:
merged_p = pairs[0]
if merged_p.merged_var.name == orig_varname:
if (
merged_p.merged_var.name
== merged_p.ordered_vars[0].name
):
unmerged_vars.append(merged_p.ordered_vars[0])
else:
merged_vars.append(merged_p.merged_var)
merged_ordervars.append(merged_p.ordered_vars[0])
break
param_name = opt_op.input("Param")[0]
for i in range(len(unmerged_vars)):
if _same_or_split_var(param_name, unmerged_vars[i].name):
for var in param_vars:
if _same_or_split_var(var.name, unmerged_vars[i].name):
return var
for i in range(len(merged_ordervars)):
if _same_or_split_var(param_name, merged_ordervars[i].name):
for var in param_vars:
if _same_or_split_var(var.name, merged_vars[i].name):
return var
return None
for key in opt_op.input_names:
if key == "Grad":
# Note !!This is for l2decay on sparse gradient, \
# because it will create a new tensor for
# decayed gradient but not inplace modify the origin one
origin_grad_name = opt_op.input(key)[0]
if (
core.kNewGradSuffix() in origin_grad_name
and pserver_block.has_var(origin_grad_name)
):
new_grad = pserver_block.var(origin_grad_name)
new_inputs[key] = new_grad
else:
new_inputs[key] = merged_var
elif key == "Param":
param_block = _get_param_block(opt_op)
if not param_block:
return
tmpvar = pserver_block.create_var(
name=param_block.name,
persistable=True,
dtype=param_block.dtype,
shape=param_block.shape,
)
new_inputs[key] = tmpvar
elif key == "LearningRate":
# learning rate variable has already be created by non - optimize op,
# don't create it once again.
lr_varname = opt_op.input(key)[0]
if lr_varname in pserver_block.vars:
new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
else:
origin_var = origin_program.global_block().vars[lr_varname]
tmpvar = pserver_block.create_var(
name=origin_var.name,
persistable=origin_var.persistable,
dtype=origin_var.dtype,
shape=origin_var.shape,
)
new_inputs[key] = tmpvar
for key in opt_op.input_names:
new_shape = None
if key in [
"Param",
"Grad",
"LearningRate",
"MasterParam",
"Beta1Tensor",
"Beta2Tensor",
]:
continue
var = origin_program.global_block().vars[opt_op.input(key)[0]]
param_var = new_inputs["Param"]
# update accumulator variable shape
new_shape = _get_optimizer_input_shape(
opt_op.type, key, var.shape, param_var.shape
)
tmpvar = pserver_block.create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=new_shape,
)
new_inputs[key] = tmpvar
# change output's ParamOut variable
outputs = _get_output_map_from_op(
origin_program.global_block().vars, opt_op
)
outputs["ParamOut"] = new_inputs["Param"]
optimize_block.append_op(
type=opt_op.type,
inputs=new_inputs,
outputs=outputs,
attrs=opt_op.all_attrs(),
)
# record sparse grad to param name
if new_inputs["Grad"].type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_grad_to_param.append(
str(new_inputs["Grad"].name) + ":" + str(new_inputs["Param"].name)
)
def _get_input_map_from_op(varmap, op):
"""Returns a dict from op input name to the vars in varmap."""
iomap = collections.OrderedDict()
for key in op.input_names:
vars = []
for varname in op.input(key):
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def _get_output_map_from_op(varmap, op):
"""Returns a dict from op output name to the vars in varmap."""
iomap = collections.OrderedDict()
for key in op.output_names:
vars = []
for varname in op.output(key):
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def get_op_by_type(block, op_type):
for op in block.ops:
if op.type == op_type:
return op
raise ValueError("add_listen_and_serv_pass must at first")
def add_listen_and_serv_pass(program, config):
attrs = {
"grad_to_block_id": None,
"sparse_grad_to_param": None,
"lr_decay_block_id": None,
"dense_optimize_blocks": None,
"sparse_optimize_blocks": None,
# runtime attribute
"endpoint": config.get_ps_endpoint(),
"pserver_id": config.get_role_id(),
"Fanin": config.get_trainers(),
"distributed_mode": config.get_distributed_mode(),
"rpc_get_thread_num": -1,
"rpc_send_thread_num": -1,
"rpc_prefetch_thread_num": -1,
}
# step5 append the listen_and_serv op
program.global_block().append_op(
type="listen_and_serv", inputs={'X': []}, outputs={}, attrs=attrs
)
return program
def add_rpc_global_flags_pass(program, config):
server_runtime = config.get_server_runtime_config()
send_threads = server_runtime._rpc_send_thread_num
get_threads = server_runtime._rpc_get_thread_num
pull_threads = server_runtime._rpc_prefetch_thread_num
op = get_op_by_type(program.global_block(), "listen_and_serv")
if get_threads < 1 or send_threads < 1 or pull_threads < 1:
raise ValueError(
"error arguments in get_threads/send_threads/pull_threads"
)
op._set_attr("rpc_get_thread_num", get_threads)
op._set_attr("rpc_send_thread_num", send_threads)
op._set_attr("rpc_prefetch_thread_num", pull_threads)
return program
def _clone_var(block, var, persistable=True):
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=persistable,
)
def add_optimizer_pass(program, config):
def _append_pserver_grad_merge_ops(
optimize_block, grad_varname_for_block, endpoint, grad_to_block_id
):
trainers = config.get_trainers()
program = optimize_block.program
pserver_block = program.global_block()
grad_block = None
for g in config.param_grad_ep_mapping[endpoint]["grads"]:
if _orig_varname(g.name) == _orig_varname(grad_varname_for_block):
grad_block = g
break
if not grad_block:
# do not append this op if current endpoint
# is not dealing with this grad block
return None
orig_varname, block_name, trainer_name = _get_varname_parts(
grad_block.name
)
if block_name:
merged_var_name = '.'.join([orig_varname, block_name])
else:
merged_var_name = orig_varname
merged_var = pserver_block.create_var(
name=grad_block.name,
persistable=True,
type=grad_block.type,
dtype=grad_block.dtype,
shape=grad_block.shape,
)
grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx))
if config.is_sync_mode() and trainers > 1:
vars2merge = []
for i in range(trainers):
per_trainer_name = "%s.trainer_%d" % (merged_var_name, i)
per_trainer_var = pserver_block.create_var(
name=per_trainer_name,
persistable=False,
type=grad_block.type,
dtype=grad_block.dtype,
shape=grad_block.shape,
)
vars2merge.append(per_trainer_var)
optimize_block.append_op(
type="sum",
inputs={"X": vars2merge},
outputs={"Out": merged_var},
attrs={"use_mkldnn": False},
)
optimize_block.append_op(
type="scale",
inputs={"X": merged_var},
outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(trainers)},
)
return merged_var
origin_program = config.get_origin_main_program()
origin_program = origin_program.clone()
ps_endpoint = config.get_ps_endpoint()
opt_op_on_pserver = []
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
global_ops = []
# sparse grad name to param name
sparse_grad_to_param = []
def _is_opt_op_on_pserver(endpoint, op):
param_names = [
p.name for p in config.param_grad_ep_mapping[endpoint]["params"]
]
unmerged_varnames = []
merged_varnames = []
merged_ordernames = []
for name in param_names:
orig_varname = _orig_varname(name)
for pairs in config.merged_variables_pairs:
merged_p = pairs[0]
if merged_p.merged_var.name == orig_varname:
if (
merged_p.merged_var.name
== merged_p.ordered_vars[0].name
):
unmerged_varnames.append(merged_p.ordered_vars[0].name)
else:
merged_varnames.append(merged_p.merged_var.name)
merged_ordernames.append(merged_p.ordered_vars[0].name)
break
param = op.input("Param")[0]
if param in unmerged_varnames:
return True
for i in range(len(merged_ordernames)):
if param == merged_ordernames[i]:
merged_p = merged_varnames[i]
merged_g = "{}@GRAD".format(merged_varnames[i])
op._set_attr(OP_ROLE_VAR_ATTR_NAME, [merged_p, merged_g])
return True
return False
def __append_optimize_op__(op, block, grad_to_block_id, merged_var, lr_ops):
if _is_optimizer_op(op):
_append_pserver_ops(
block,
op,
ps_endpoint,
grad_to_block_id,
origin_program,
merged_var,
sparse_grad_to_param,
config,
)
elif op not in lr_ops:
_append_pserver_non_opt_ops(block, op, origin_program, config)
optimize_ops = _get_optimize_ops(origin_program)
for _, op in enumerate(optimize_ops):
if _is_optimizer_op(op) and _is_opt_op_on_pserver(ps_endpoint, op):
opt_op_on_pserver.append(op)
# append lr decay ops to the child block if exists
lr_ops = _get_lr_ops(origin_program)
has_lr_decay = True if len(lr_ops) > 0 else False
lr_decay_block_id = -1
optimize_blocks = []
if has_lr_decay > 0:
counter_increment_idx = -1
for idx, op in enumerate(lr_ops):
if op.type != 'increment':
continue
counter = op.input("X")[0]
if counter == LEARNING_RATE_DECAY_COUNTER:
counter_increment_idx = idx
break
if counter_increment_idx != -1:
lr_ops.pop(counter_increment_idx)
lr_decay_block = program._create_block(program.num_blocks - 1)
optimize_blocks.append(lr_decay_block)
for op in lr_ops:
cloned_op = _append_pserver_non_opt_ops(
lr_decay_block, op, origin_program, config
)
# append sub blocks to pserver_program in lr_decay_op
# todo(tangwei12): __clone_lr_op_sub_block__
lr_decay_block_id = lr_decay_block.idx
# append op to the current block
grad_to_block_id = []
pre_block_idx = program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = program._create_block(pre_block_idx)
optimize_blocks.append(per_opt_block)
optimize_target_param_name = opt_op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
# append grad merging ops before clip and weight decay
# e.g.merge grad->L2Decay op->clip op->optimize
merged_var = None
for _, op in enumerate(optimize_ops):
# find the origin grad var before clipping / L2Decay,
# merged_var should be the input var name of L2Decay
grad_varname_for_block = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name:
merged_var = _append_pserver_grad_merge_ops(
per_opt_block,
grad_varname_for_block,
ps_endpoint,
grad_to_block_id,
)
if merged_var:
break # append optimize op once then append other ops.
if merged_var:
for _, op in enumerate(optimize_ops):
# optimizer is connected to itself
if (
op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
== optimize_target_param_name
and op not in global_ops
):
__append_optimize_op__(
op, per_opt_block, grad_to_block_id, merged_var, lr_ops
)
# dedup grad to ids list
grad_to_block_id = list(set(grad_to_block_id))
# append global ops
if global_ops:
opt_state_block = program._create_block(program.num_blocks - 1)
optimize_blocks.append(opt_state_block)
for glb_op in global_ops:
__append_optimize_op__(
glb_op, opt_state_block, grad_to_block_id, None, lr_ops
)
if len(optimize_blocks) == 0:
pre_block_idx = program.num_blocks - 1
empty_block = program._create_block(pre_block_idx)
optimize_blocks.append(empty_block)
op = get_op_by_type(program.global_block(), "listen_and_serv")
op._set_attr("optimize_blocks", optimize_blocks)
op._set_attr("grad_to_block_id", grad_to_block_id)
op._set_attr("sparse_grad_to_param", sparse_grad_to_param)
op._set_attr("lr_decay_block_id", lr_decay_block_id)
return program
def large_scale_sparse_pass(program, main_program, config, is_startup=False):
opt_value_map = {}
opt_value_map["sgd"] = ["Param"]
opt_value_map["adam"] = ["Param", "Moment1", "Moment2"]
opt_value_map["adagrad"] = ["Param", "Moment"]
opt_value_map["adamax"] = ["Param", "Moment", "InfNorm"]
opt_value_map["momentum"] = ["Param", "Velocity"]
opt_value_map["lars_momentum"] = ["Param", "Velocity"]
opt_value_map["rmsprop"] = ["Param", "Moment", "MeanSquare"]
opt_value_map["decayed_adagrad"] = ["Param", "Moment"]
opt_value_map["ftrl"] = ["Param", "SquaredAccumulator", "LinearAccumulator"]
geo_value_map = {}
geo_value_map["sum"] = "Param"
opt_init_map = {}
opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
opt_init_map["fill_constant"] = ["value"]
opt_init_map["uniform_random"] = ["seed", "min", "max"]
opt_init_map["truncated_gaussian_random"] = ["seed", "mean", "std"]
def get_entry_attr(param_name):
origin_name = _orig_varname(param_name)
o_main_program = config.get_origin_main_program()
for op in o_main_program.global_block().ops:
if (
is_distributed_sparse_op(op)
and get_sparse_tablename(op) == origin_name
):
entry = op.attr("entry")
return entry
def get_initializer_attrs(acture_value_names):
l_sep = ","
l_in = "&"
init_attrs = []
o_startup_program = config.get_origin_startup_program()
for value_name in acture_value_names:
origin_var_name = _orig_varname(value_name)
for op in o_startup_program.global_block().ops:
if (
op.type in opt_init_map.keys()
and origin_var_name == op.output("Out")[0]
):
init_attr = [op.type]
for attr in opt_init_map[op.type]:
init_attr.append(str(op.attr(attr)))
init_attrs.append(l_in.join(init_attr))
break
return l_sep.join(init_attrs)
def get_optimizer_values(block):
value_names = []
acture_names = []
value_dims = []
grad = None
opt_idx = -1
fuse = False
for op in block.ops:
opt_idx += 1
if op.type not in opt_value_map.keys():
continue
if op.type in ["sgd", "adam"]:
fuse = True
grad = main_program.global_block().vars[op.input("Grad")[0]]
for value in opt_value_map[op.type]:
var = main_program.global_block().vars[op.input(value)[0]]
if len(var.shape) != 2:
raise ValueError("sparse param's dimension must be 2")
value_names.append(value)
value_dims.append(var.shape[1])
acture_names.append(var.name)
if value_names:
break
return grad, opt_idx, value_names, value_dims, acture_names, fuse
def add_fuse_large_scale_op(
block,
global_block,
table_name,
value_names,
acture_names,
grad,
is_entry,
opt_idx,
):
op = block.ops[opt_idx]
if op.type == "sgd":
grad = main_program.global_block().vars[op.input("Grad")[0]]
lr = main_program.global_block().vars[op.input("LearningRate")[0]]
block._insert_op(
opt_idx,
type="lookup_sparse_table_fuse_sgd",
inputs={"Grad": grad, "LearningRate": lr},
attrs={
"is_entry": is_entry,
"tablename": table_name,
"value_names": value_names,
},
)
elif op.type == "adam":
grad = main_program.global_block().vars[op.input("Grad")[0]]
lr = main_program.global_block().vars[op.input("LearningRate")[0]]
beta1_pow = main_program.global_block().vars[
op.input("Beta1Pow")[0]
]
beta2_pow = main_program.global_block().vars[
op.input("Beta2Pow")[0]
]
beta1_pow_o = main_program.global_block().vars[
op.output("Beta1PowOut")[0]
]
beta2_pow_o = main_program.global_block().vars[
op.output("Beta2PowOut")[0]
]
beta1 = op.attr('beta1')
beta2 = op.attr('beta2')
epsilon = op.attr('epsilon')
block._insert_op(
opt_idx,
type="lookup_sparse_table_fuse_adam",
inputs={
"Grad": grad,
"LearningRate": lr,
"Beta1Pow": beta1_pow,
"Beta2Pow": beta2_pow,
},
outputs={
"Beta1PowOut": beta1_pow_o,
"Beta2PowOut": beta2_pow_o,
},
attrs={
"beta1": beta1,
"beta2": beta2,
"epsilon": epsilon,
"is_entry": is_entry,
"tablename": table_name,
"value_names": value_names,
},
)
else:
raise ValueError("only support sgd/adam optimizer now")
def add_large_scale_op(
block,
global_block,
table_name,
value_names,
acture_names,
grad,
is_entry,
opt_idx,
):
ids = global_block.create_var(
name="kSparseIDs@{}".format(table_name),
persistable=False,
dtype="int64",
shape=[1, 1],
lod_level=0,
)
# insert grad split to ids and tensor op
block._insert_op(
opt_idx,
type="lookup_sparse_table_grad_split",
inputs={"Grad": grad},
outputs={"Row": ids, "Value": grad},
attrs={"tablename": table_name, "is_entry": is_entry},
)
# insert read at first
vars = [global_block.vars[acture_name] for acture_name in acture_names]
block._insert_op(
opt_idx + 1,
type="lookup_sparse_table_read",
inputs={"Ids": ids},
outputs={"Out": vars},
attrs={"tablename": table_name, "value_names": value_names},
)
# append write at last
inputs = {"Ids": ids, "In": vars}
block.append_op(
type="lookup_sparse_table_write",
inputs=inputs,
outputs={},
attrs={"tablename": table_name, "value_names": value_names},
)
op = get_op_by_type(main_program.global_block(), "listen_and_serv")
param_blockid_map = {}
grad_blockid_map = {}
grad_to_params = op.attr('sparse_grad_to_param')
grad_to_block_ids = op.attr('grad_to_block_id')
origin_program = config.get_origin_main_program()
sparse_varnames = get_sparse_tablenames(origin_program, False)
for grad_to_block_id in grad_to_block_ids:
grad, blockid = grad_to_block_id.split(":")
grad_blockid_map[grad] = int(blockid)
for grad_to_param in grad_to_params:
grad, param = grad_to_param.split(":")
if _orig_varname(param) in sparse_varnames:
continue
param_blockid_map[param] = grad_blockid_map[grad]
if not is_startup:
for param, blockid in param_blockid_map.items():
opt_block = program.block(blockid)
(
grad,
opt_idx,
value_names,
value_dims,
acture_names,
fuse,
) = get_optimizer_values(opt_block)
entry_attr = get_entry_attr(param)
is_entry = False if entry_attr == "none" else True
if fuse:
add_fuse_large_scale_op(
opt_block,
program.global_block(),
param,
value_names,
acture_names,
grad,
is_entry,
opt_idx,
)
else:
add_large_scale_op(
opt_block,
program.global_block(),
param,
value_names,
acture_names,
grad,
is_entry,
opt_idx,
)
else:
large_scale_kv_metas = []
for param, blockid in param_blockid_map.items():
opt_block = main_program.block(blockid)
(
grad,
opt_idx,
value_names,
value_dims,
acture_names,
fuse,
) = get_optimizer_values(opt_block)
entry_attr = get_entry_attr(param)
if fuse:
# remove origin optimzier op
opt_block._remove_op(opt_idx)
# training/infer
mode = "0"
names_str = ",".join(value_names)
dims_str = ",".join([str(dim) for dim in value_dims])
ids_name = "kSparseIDs@{}".format(param)
cached_str = ",".join(acture_names + [ids_name])
init_attr_str = get_initializer_attrs(acture_names)
meta_str = ":".join(
[
param,
names_str,
dims_str,
mode,
grad.name,
cached_str,
init_attr_str,
entry_attr,
]
)
print("large_scale_metas: {}".format(meta_str))
large_scale_kv_metas.append(meta_str)
program.global_block().append_op(
type="lookup_sparse_table_init",
inputs=None,
outputs=None,
attrs={"large_scale_metas": large_scale_kv_metas},
)
# todo: need delete unused var.
return program
def get_distributed_from_listen_and_serv(program, origin_program):
op = get_op_by_type(program.global_block(), "listen_and_serv")
sparse_varnames = get_sparse_tablenames(origin_program, True)
sparse_params = []
grad_to_params = op.attr('sparse_grad_to_param')
for grad_to_param in grad_to_params:
_, param = grad_to_param.split(":")
if _orig_varname(param) in sparse_varnames:
sparse_params.append(param)
return sparse_params
def delete_unused_in_main_pass(program, config):
origin_program = config.get_origin_main_program()
sparse_params = get_distributed_from_listen_and_serv(
program, origin_program
)
for var in sparse_params:
if program.global_block().has_var(var):
program.global_block()._remove_var(var)
return program
def delete_unused_in_startup_pass(program, main_program, config):
origin_program = config.get_origin_main_program()
sparse_params = get_distributed_from_listen_and_serv(
main_program, origin_program
)
remove_ops = []
for op in program.global_block().ops:
if op.type in ["recv", "fetch_barrier", "concat"]:
continue
for key in op.output_names:
if op.output(key)[0] in sparse_params:
remove_ops.append(op)
all_ops = program.global_block().ops
op_idxs = [all_ops.index(op) for op in remove_ops]
for idx in op_idxs[::-1]:
program.global_block()._remove_op(idx)
for var in sparse_params:
if program.global_block().has_var(var):
program.global_block()._remove_var(var)
return program
def build_pserver_startup_program_pass(program, p_main_program, config):
ps_endpoint = config.get_ps_endpoint()
o_startup_program = config.get_origin_startup_program()
program.random_seed = o_startup_program.random_seed
params = config.param_grad_ep_mapping[ps_endpoint]["params"]
merged_ordervars = []
for var in params:
name = var.name
orig_varname = _orig_varname(name)
for pairs in config.merged_variables_pairs:
merged_p = pairs[0]
if merged_p.merged_var.name == orig_varname:
if merged_p.merged_var.name != merged_p.ordered_vars[0].name:
merged_ordervars.append(merged_p.ordered_vars[0])
break
def _get_splited_name_and_shape(varname):
for splited_param in params:
pname = splited_param.name
if _same_or_split_var(pname, varname) and varname != pname:
return pname, splited_param.shape
for idx, ordered in enumerate(merged_ordervars):
if _same_or_split_var(varname, ordered.name):
return pname, splited_param.shape
return "", []
# 1. create vars in pserver program to startup program
pserver_vars = p_main_program.global_block().vars
created_var_map = collections.OrderedDict()
for _, var in pserver_vars.items():
tmpvar = program.global_block()._clone_variable(var)
created_var_map[var.name] = tmpvar
# 2. rename op outputs
for op in o_startup_program.global_block().ops:
new_outputs = collections.OrderedDict()
# do not append startup op if var is not on this pserver
op_on_pserver = False
# TODO(gongwb) : remove this line.
if op.type not in ["recv", "fetch_barrier", "concat"]:
for key in op.output_names:
newname, _ = _get_splited_name_and_shape(op.output(key)[0])
if newname:
op_on_pserver = True
new_outputs[key] = created_var_map[newname]
elif op.output(key)[0] in pserver_vars:
op_on_pserver = True
new_outputs[key] = pserver_vars[op.output(key)[0]]
if op_on_pserver:
# most startup program ops have no inputs
new_inputs = _get_input_map_from_op(pserver_vars, op)
if op.type in [
"gaussian_random",
"fill_constant",
"uniform_random",
"truncated_gaussian_random",
]:
op._set_attr("shape", list(new_outputs["Out"].shape))
program.global_block().append_op(
type=op.type,
inputs=new_inputs,
outputs=new_outputs,
attrs=op.all_attrs(),
)
return program
def add_geo_optimizer_pass(program, config):
endpoint = config.get_ps_endpoint()
params = [p for p in config.param_grad_ep_mapping[endpoint]["params"]]
sparse_tablenames = get_sparse_tablenames(
config.get_origin_main_program(), False
)
for param in params:
_clone_var(program.global_block(), param)
optimize_block = []
sparse_grad_to_param = []
param_to_block_id = []
pre_block_idx = program.num_blocks - 1
for param in params:
per_opt_block = program._create_block(pre_block_idx)
optimize_block.append(per_opt_block)
var_name = param.name
pserver_block = per_opt_block.program.global_block()
param = pserver_block.vars[var_name]
delta_var_name = "%s.delta" % (param.name)
origin_varname = _orig_varname(param.name)
if origin_varname in sparse_tablenames:
sparse_grad_to_param.append(":".join([delta_var_name, param.name]))
delta_var = pserver_block.create_var(
name=delta_var_name,
persistable=False,
type=param.type,
dtype=param.dtype,
shape=param.shape,
)
per_opt_block.append_op(
type="sum", inputs={"X": [param, delta_var]}, outputs={"Out": param}
)
param_to_block_id.append(delta_var_name + ":" + str(per_opt_block.idx))
op = get_op_by_type(program.global_block(), "listen_and_serv")
op._set_attr("optimize_blocks", optimize_block)
op._set_attr("grad_to_block_id", param_to_block_id)
op._set_attr("sparse_grad_to_param", sparse_grad_to_param)
return program
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import paddle
import collections
import math
import os
import warnings
import logging
from paddle.framework import core
from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode
from paddle.fluid.incubate.fleet.parameter_server.ir import vars_metatools
from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import (
RoundRobin,
PSDispatcher,
)
OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@"
LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName()
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"]
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
def _get_lr_ops(program):
lr_ops = []
for index, op in enumerate(program.global_block().ops):
role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME))
if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or role_id == int(
LR_SCHED_OP_ROLE_ATTR_VALUE
) | int(OPT_OP_ROLE_ATTR_VALUE):
lr_ops.append(op)
return lr_ops
def _has_global_step(lr_ops):
if len(lr_ops) > 0:
for idx, op in enumerate(lr_ops):
if op.type != 'increment':
continue
counter = op.input("X")[0]
if counter == LEARNING_RATE_DECAY_COUNTER:
return True
return False
def is_sparse_op(op):
if (
op.type in SPARSE_OP_LIST
and op.attr('is_sparse') is True
and op.attr('is_distributed') is False
):
return True
if (
op.type == "distributed_lookup_table"
and op.attr('is_distributed') is False
):
return True
return False
def is_distributed_sparse_op(op):
if op.type in SPARSE_OP_LIST and op.attr('is_distributed') is True:
return True
if (
op.type == "distributed_lookup_table"
and op.attr('is_distributed') is True
):
return True
return False
def get_sparse_tablename(op):
return op.input("W")[0]
def get_sparse_tablenames(program, is_distributed):
tablenames = set()
if is_distributed:
for op in program.global_block().ops:
if is_distributed_sparse_op(op):
tablenames.add(get_sparse_tablename(op))
else:
for op in program.global_block().ops:
if is_sparse_op(op):
tablenames.add(get_sparse_tablename(op))
return list(tablenames)
class MergedVariable:
def __init__(self, merged, ordered, offsets):
self.merged_var = merged
self.ordered_vars = ordered
self.offsets = offsets
def Singleton(cls):
_instance = {}
def _singleton(*args, **kargs):
if cls not in _instance:
_instance[cls] = cls(*args, **kargs)
return _instance[cls]
return _singleton
@Singleton
class CompileTimeStrategy:
def __init__(self, main_program, startup_program, strategy, role_maker):
self.min_block_size = 81920
self.origin_main_program = main_program
self.origin_startup_program = startup_program
self.origin_ps_main_program = main_program
self.origin_ps_startup_program = startup_program
self.strategy = strategy
self.role_maker = role_maker
self.use_ps_gpu = False
try:
self.is_heter_ps_mode = role_maker._is_heter_parameter_server_mode
except:
warnings.warn(
"Using paddle.distributed.fleet instead of paddle.fluid.incubate.fleet"
)
self.is_heter_ps_mode = False
self.origin_sparse_pairs = []
self.origin_dense_pairs = []
self.merged_variables_pairs = []
self.merged_dense_pairs = []
self.merged_sparse_pairs = []
self.merged_variable_map = {}
self.param_name_to_grad_name = {}
self.grad_name_to_param_name = {}
self.param_grad_ep_mapping = collections.OrderedDict()
self.grad_param_mapping = collections.OrderedDict()
self._build_var_distributed()
self.tensor_table_dict = {}
# for heter-ps save variables
self.origin_merged_variables_pairs = list(self.merged_variables_pairs)
self.origin_merged_dense_pairs = list(self.merged_dense_pairs)
self.origin_merged_sparse_pairs = list(self.merged_sparse_pairs)
def get_distributed_mode(self):
trainer = self.strategy.get_trainer_runtime_config()
return trainer.mode
def is_sync_mode(self):
trainer = self.strategy.get_trainer_runtime_config()
return trainer.mode == DistributedMode.SYNC
def is_geo_mode(self):
trainer = self.strategy.get_trainer_runtime_config()
return trainer.mode == DistributedMode.GEO
def is_async_mode(self):
trainer = self.strategy.get_trainer_runtime_config()
return trainer.mode == DistributedMode.ASYNC
def get_role_id(self):
try:
return self.role_maker._role_id()
except Exception:
return self.role_maker.role_id()
def get_trainers(self):
try:
return self.role_maker._worker_num()
except Exception:
return self.role_maker.worker_num()
def get_ps_endpoint(self):
try:
return self.role_maker._get_pserver_endpoints()[self.get_role_id()]
except Exception:
return self.role_maker.get_pserver_endpoints()[self.get_role_id()]
def get_ps_endpoints(self):
try:
return self.role_maker._get_pserver_endpoints()
except Exception:
return self.role_maker.get_pserver_endpoints()
def get_heter_worker_endpoints(self):
try:
return self.role_maker._get_heter_worker_endpoints()
except Exception:
return self.role_maker.get_heter_worker_endpoints()
def get_next_stage_trainers(self):
try:
return self.role_maker._get_next_trainers()
except Exception:
return self.role_maker.get_next_trainers()
def get_heter_worker_endpoint(self):
try:
return self.role_maker._get_heter_worker_endpoint()
except Exception:
return self.role_maker.get_heter_worker_endpoint()
def get_trainer_endpoints(self):
try:
return self.role_maker._get_trainer_endpoints()
except Exception:
return self.role_maker.get_trainer_endpoints()
def get_trainer_endpoint(self):
try:
return self.role_maker._get_trainer_endpoint()
except Exception:
return self.role_maker.get_trainer_endpoint()
def get_previous_stage_trainers(self):
try:
return self.role_maker._get_previous_trainers()
except Exception:
return self.role_maker.get_previous_trainers()
def get_origin_programs(self):
return self.origin_main_program, self.origin_startup_program
def get_origin_main_program(self):
return self.origin_main_program
def get_origin_startup_program(self):
return self.origin_startup_program
def set_origin_ps_main_program(self, program):
self.origin_ps_main_program = program
def set_origin_ps_startup_program(self, program):
self.origin_ps_startup_program = program
def get_origin_ps_main_program(self):
return self.origin_ps_main_program
def get_origin_ps_startup_program(self):
return self.origin_ps_startup_program
def add_tensor_table(
self,
feed_var_name,
fetch_var_name="",
startup_program=None,
main_program=None,
tensor_table_class="",
):
self.tensor_table_dict[feed_var_name] = {}
self.tensor_table_dict[feed_var_name]["feed_var_name"] = feed_var_name
self.tensor_table_dict[feed_var_name]["fetch_var_name"] = fetch_var_name
self.tensor_table_dict[feed_var_name][
"startup_program"
] = startup_program
self.tensor_table_dict[feed_var_name]["main_program"] = main_program
self.tensor_table_dict[feed_var_name][
"tensor_table_class"
] = tensor_table_class
def get_tensor_table_dict(self):
return self.tensor_table_dict
def get_sparse_varname_on_ps(self, is_distributed, endpoint=None):
if not endpoint:
endpoint = self.get_ps_endpoint()
varnames = get_sparse_tablenames(
self.get_origin_main_program(), is_distributed
)
ps_sparse_varnames = []
for varname in varnames:
tables = self.get_var_distributed(varname, True)
for i in range(len(tables)):
table, ep, _ = tables[i]
if ep == endpoint:
ps_sparse_varnames.append(table)
return ps_sparse_varnames
def get_optimize_varname_on_ps(self, param_name):
origin_param_name, _, _ = _get_varname_parts(param_name)
optimize_var_names = []
for op in self.get_origin_main_program().global_block().ops:
# check all optimizer op
if int(op.all_attrs()["op_role"]) == 2:
# check param name
if op.input("Param")[0] != origin_param_name:
continue
# check all input
for key in op.input_names:
if key in [
"Param",
"Grad",
"LearningRate",
"Beta1Tensor",
"Beta2Tensor",
]:
continue
# check varibale shape related param, e.g: Moment1
optimize_var_names += (
self._get_optimizer_param_related_var_name(
op, op.type, key
)
)
return optimize_var_names
def _get_optimizer_param_related_var_name(self, op, op_type, varkey):
"""
Returns the names for optimizer inputs that need to be load
"""
related_var_names = []
if op_type == "adam":
if varkey in ["Moment1", "Moment2"]:
related_var_names.append(op.input(varkey)[0])
elif op_type == "adagrad":
if varkey == "Moment":
related_var_names.append(op.input(varkey)[0])
elif op_type in ["momentum", "lars_momentum"]:
if varkey == "Velocity":
related_var_names.append(op.input(varkey)[0])
elif op_type == "rmsprop":
if varkey in ["Moment", "MeanSquare"]:
related_var_names.append(op.input(varkey)[0])
elif op_type == "ftrl":
if varkey in ["SquaredAccumulator", "LinearAccumulator"]:
related_var_names.append(op.input(varkey)[0])
elif op_type == "sgd":
pass
else:
raise ValueError(
"Not supported optimizer for distributed training: %s" % op_type
)
return related_var_names
def build_ctx(
self, vars, mapping, is_grad, is_sparse, is_send, is_distributed=False
):
def get_grad_var_ep(slices):
names = []
eps = []
sections = []
for slice in slices:
if self.is_geo_mode():
if is_send:
names.append("{}.delta".format(slice.name))
else:
names.append(slice.name)
elif (
is_grad and self.is_sync_mode() and self.get_trainers() > 1
):
names.append(
"{}.trainer_{}".format(slice.name, self.get_role_id())
)
else:
names.append(slice.name)
sections.append(slice.shape[0])
for ep, pairs in self.param_grad_ep_mapping.items():
params, grads = pairs["params"], pairs["grads"]
for var in params + grads:
if slice.name == var.name:
eps.append(ep)
break
return names, eps, sections
if isinstance(vars, MergedVariable):
name = vars.merged_var.name
slices = mapping[name]
names, eps, sections = get_grad_var_ep(slices)
origin_varnames = [var.name for var in vars.ordered_vars]
else:
name = vars.name
slices = mapping[name]
names, eps, sections = get_grad_var_ep(slices)
origin_varnames = [vars.name]
trainer_id = self.get_role_id()
aggregate = True
ctx = core.CommContext(
name,
names,
eps,
sections,
origin_varnames,
trainer_id,
aggregate,
is_sparse,
is_distributed,
[],
)
return ctx
def get_trainer_send_context(self):
send_ctx = {}
distibuted_varnames = get_sparse_tablenames(
self.origin_main_program, True
)
idx = 0
if not self.is_geo_mode():
for merged in self.merged_dense_pairs:
grad = merged[1]
ctx = self.build_ctx(
grad, self.grad_var_mapping, True, False, True
)
send_ctx[ctx.var_name()] = ctx
for merged in self.merged_sparse_pairs:
param = merged[0]
grad = merged[1]
param_name = param.merged_var.name
is_distributed = (
True if param_name in distibuted_varnames else False
)
ctx = self.build_ctx(
grad,
self.grad_var_mapping,
True,
True,
True,
is_distributed,
)
send_ctx[ctx.var_name()] = ctx
idx += 1
if self.is_async_mode():
name, ctx = self._step_ctx(idx)
send_ctx[name] = ctx
else:
for pairs in self.origin_sparse_pairs:
param, grad = pairs
param_name = param.name
is_distributed = (
True if param_name in distibuted_varnames else False
)
param_ctx = self.build_ctx(
param,
self.param_var_mapping,
False,
True,
True,
is_distributed,
)
grad_ctx = self.build_ctx(
grad,
self.grad_var_mapping,
True,
True,
True,
is_distributed,
)
ctx = core.CommContext(
param_ctx.var_name(),
param_ctx.split_varnames(),
param_ctx.split_endpoints(),
param_ctx.sections(),
grad_ctx.origin_varnames(),
param_ctx.trainer_id(),
param_ctx.aggregate(),
param_ctx.is_sparse(),
param_ctx.is_distributed(),
[],
)
send_ctx[ctx.var_name()] = ctx
idx += 1
name, ctx = self._step_ctx(idx)
send_ctx[name] = ctx
return send_ctx
def get_communicator_send_context(self):
send_ctx = {}
distibuted_varnames = get_sparse_tablenames(
self.origin_main_program, True
)
idx = 0
if self.is_geo_mode():
for pairs in self.merged_dense_pairs:
param = pairs[0]
ctx = self.build_ctx(
param, self.param_var_mapping, False, False, True
)
send_ctx[ctx.var_name()] = ctx
for pairs in self.merged_sparse_pairs:
param = pairs[0]
param_name = param.merged_var.name
is_distributed = (
True if param_name in distibuted_varnames else False
)
ctx = self.build_ctx(
param,
self.param_var_mapping,
False,
True,
True,
is_distributed,
)
send_ctx[ctx.var_name()] = ctx
idx += 1
name, ctx = self._step_ctx(idx)
send_ctx[name] = ctx
else:
for merged in self.merged_dense_pairs:
grad = merged[1]
ctx = self.build_ctx(
grad, self.grad_var_mapping, True, False, True
)
send_ctx[ctx.var_name()] = ctx
for merged in self.merged_sparse_pairs:
param, grad = merged
param_name = param.merged_var.name
is_distributed = (
True if param_name in distibuted_varnames else False
)
ctx = self.build_ctx(
grad,
self.grad_var_mapping,
True,
True,
True,
is_distributed,
)
send_ctx[ctx.var_name()] = ctx
idx += 1
name, ctx = self._step_ctx(idx)
send_ctx[name] = ctx
return send_ctx
def get_communicator_recv_context(
self, recv_type=1, use_origin_program=False
):
# recv_type
# 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL
distibuted_varnames = get_sparse_tablenames(
self.origin_main_program, True
)
sparse_varnames = []
for pairs in self.origin_sparse_pairs:
param, grad = pairs
sparse_varnames.append(param.name)
dense_recv_ctx = {}
sparse_recv_ctx = {}
distributed_recv_ctx = {}
variables_pairs = (
self.merged_variables_pairs
if not use_origin_program
else self.origin_merged_variables_pairs
)
for merged in variables_pairs:
params = merged[0]
if params.merged_var.name in sparse_varnames:
continue
ctx = self.build_ctx(
params, self.param_var_mapping, False, False, False, False
)
dense_recv_ctx[ctx.var_name()] = ctx
for pairs in self.origin_sparse_pairs:
param, grad = pairs
if param.name in distibuted_varnames:
ctx = self.build_ctx(
param, self.param_var_mapping, False, True, False, True
)
distributed_recv_ctx[ctx.var_name()] = ctx
else:
ctx = self.build_ctx(
param, self.param_var_mapping, False, True, False, False
)
sparse_recv_ctx[ctx.var_name()] = ctx
if recv_type == 1:
return dense_recv_ctx
if recv_type == 2:
return sparse_recv_ctx
if recv_type == 3:
return distributed_recv_ctx
if recv_type == 4:
dense_recv_ctx.update(sparse_recv_ctx)
dense_recv_ctx.update(distributed_recv_ctx)
return dense_recv_ctx
assert ValueError(
"recv_type can only be 1/2/3/4, 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL"
)
def get_the_one_trainer_send_context(self, split_dense_table):
if self.is_geo_mode():
send_ctx = {}
trainer_id = self.get_role_id()
idx = 0
distibuted_varnames = get_sparse_tablenames(
self.origin_main_program, True
)
for merged in self.merged_sparse_pairs:
param, grad = merged
grad_name = grad.merged_var.name
param_name = param.merged_var.name
is_distributed = (
True if param_name in distibuted_varnames else False
)
var = self.origin_main_program.global_block().vars[
grad.merged_var.name
]
var_numel = reduce(lambda x, y: x * y, var.shape[1:])
sparse_ctx = core.CommContext(
grad_name,
[grad_name],
["127.0.0.1:6071"],
[var_numel],
[grad_name],
trainer_id,
True,
True,
is_distributed,
idx,
False,
False,
-1,
[],
)
idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx
if len(send_ctx) == 0:
raise ValueError(
"GeoSGD require sparse parameters in your net."
)
if len(self.tensor_table_dict) > 0 and self.role_maker._is_worker():
name, ctx = self._step_ctx(idx)
send_ctx[name] = ctx
return send_ctx
else:
return self.get_the_one_send_context(split_dense_table)
def get_dense_send_context(
self,
send_ctx,
idx,
merged_dense_pairs,
trainer_id,
split_dense_table=False,
):
if len(merged_dense_pairs) < 1:
return idx
if not split_dense_table:
origin_varnames = []
var_numel = 0
for merged in merged_dense_pairs:
grad = merged[1]
origin_varnames.append(grad.merged_var.name)
var = self.origin_main_program.global_block().vars[
grad.merged_var.name
]
var_numel += reduce(lambda x, y: x * y, var.shape)
grad_name = "Dense@Grad"
trainer_id = self.get_role_id()
aggregate = True
dense_ctx = core.CommContext(
grad_name,
[grad_name],
["127.0.0.1:6071"],
[var_numel],
origin_varnames,
trainer_id,
aggregate,
False,
False,
idx,
False,
False,
-1,
[],
)
send_ctx[grad_name] = dense_ctx
idx += 1
else:
for merged in merged_dense_pairs:
grad = merged[1]
origin_varname = grad.merged_var.name
var = self.origin_main_program.global_block().vars[
origin_varname
]
var_numel = reduce(lambda x, y: x * y, var.shape)
grad_name = origin_varname
aggregate = True
dense_ctx = core.CommContext(
grad_name,
[grad_name],
["127.0.0.1:6071"],
[var_numel],
[origin_varname],
trainer_id,
aggregate,
False,
False,
idx,
False,
False,
-1,
[],
)
send_ctx[grad_name] = dense_ctx
idx += 1
return idx
def get_the_one_send_context(
self, split_dense_table=False, use_origin_program=False, ep_list=None
):
if ep_list is None:
ep_list = ["127.0.0.1:6071"]
send_ctx = {}
trainer_id = self.get_role_id()
idx = 0
merged_dense_pairs = (
self.origin_merged_dense_pairs
if use_origin_program
else self.merged_dense_pairs
)
merged_sparse_pairs = (
self.origin_merged_sparse_pairs
if use_origin_program
else self.merged_sparse_pairs
)
idx += self.get_dense_send_context(
send_ctx, idx, merged_dense_pairs, trainer_id, split_dense_table
)
distibuted_varnames = get_sparse_tablenames(
self.origin_main_program, True
)
for merged in merged_sparse_pairs:
param, grad = merged
grad_name = grad.merged_var.name
param_name = param.merged_var.name
splited_varname = []
for i in range(len(ep_list)):
splited_varname.append("{}.block{}".format(param_name, i))
is_distributed = (
True if param_name in distibuted_varnames else False
)
var = self.origin_main_program.global_block().vars[
grad.merged_var.name
]
shape = list(var.shape)
shape[0] = 0 if is_distributed else shape[0]
sparse_ctx = core.CommContext(
grad_name,
splited_varname,
ep_list,
shape,
[grad_name],
trainer_id,
True,
True,
is_distributed,
idx,
False,
False,
-1,
[],
)
idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx
if len(self.tensor_table_dict) > 0 and self.role_maker._is_worker():
name, ctx = self._step_ctx(idx)
send_ctx[name] = ctx
return send_ctx
def get_the_one_recv_context(
self, is_dense=True, split_dense_table=False, use_origin_program=False
):
recv_id_maps = {}
if is_dense:
send_ctx = self.get_the_one_send_context(
split_dense_table=split_dense_table,
use_origin_program=use_origin_program,
)
for idx, (name, ctx) in enumerate(send_ctx.items()):
if ctx.is_sparse():
continue
if ctx.is_tensor_table():
continue
origin_grad_varnames = ctx.origin_varnames()
param_names = []
for grad_varname in origin_grad_varnames:
param_name = self.grad_name_to_param_name[grad_varname]
param_names.append(param_name)
recv_id_maps[ctx.table_id()] = param_names
else:
send_ctx = self.get_the_one_send_context()
for idx, (name, ctx) in enumerate(send_ctx.items()):
if not ctx.is_sparse():
continue
origin_grad_varnames = ctx.origin_varnames()
param_names = []
for grad_varname in origin_grad_varnames:
param_name = self.grad_name_to_param_name[grad_varname]
param_names.append(param_name)
recv_id_maps[ctx.table_id()] = param_names
return recv_id_maps
def get_server_runtime_config(self):
return self.strategy.get_server_runtime_config()
def get_var_distributed(self, varname, is_param):
var_distributed = []
offset = 0
if is_param:
params = self.param_var_mapping[varname]
param_varnames = [var.name for var in params]
for ep, pairs in self.param_grad_ep_mapping.items():
for p in pairs["params"]:
if p.name in param_varnames:
offset += p.shape[0]
var_distributed.append((p.name, ep, p.shape[0]))
else:
grads = self.grad_var_mapping[varname]
grad_varnames = [var.name for var in grads]
for ep, pairs in self.param_grad_ep_mapping.items():
for g in pairs["grads"]:
if g.name in grad_varnames:
var_distributed.append((g.name, ep, g.shape[0]))
return var_distributed
def _step_ctx(self, idx):
name = STEP_COUNTER
trainer_id = self.get_role_id()
endpoints = self.get_ps_endpoints()
sections = [1] * len(endpoints)
names = [name] * len(endpoints)
ctx = core.CommContext(
name,
names,
endpoints,
sections,
[name],
trainer_id,
True,
False,
False,
idx,
True,
False,
-1,
[],
)
return name, ctx
def _create_vars_from_blocklist(self, block_list):
"""
Create vars for each split.
NOTE: only grads need to be named for different trainers, use
add_trainer_suffix to rename the grad vars.
Args:
block_list (list[(varname, block_id, block_size)]): List of gradient blocks.
add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True.
Returns:
var_mapping (collections.OrderedDict(varname->[new_varname_variable])):A dict mapping
from original var name to each var split.
"""
# varname->[(block_id, current_block_size)]
block_map = collections.OrderedDict()
var_mapping = collections.OrderedDict()
for block_str in block_list:
varname, offset, size = block_str.split(":")
if varname not in block_map:
block_map[varname] = []
block_map[varname].append((int(offset), int(size)))
for varname, split in block_map.items():
orig_var = self.merged_variable_map[varname]
if len(split) == 1:
var_mapping[varname] = [orig_var]
self.var_distributed.add_distributed_var(
origin_var=orig_var,
slice_var=orig_var,
block_id=0,
offset=0,
is_slice=False,
vtype="Param",
)
else:
var_mapping[varname] = []
orig_shape = orig_var.shape
orig_dim1_flatten = 1
if len(orig_shape) >= 2:
orig_dim1_flatten = reduce(
lambda x, y: x * y, orig_shape[1:]
)
for i, block in enumerate(split):
size = block[1]
rows = size // orig_dim1_flatten
splited_shape = [rows]
if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:])
new_var_name = "%s.block%d" % (varname, i)
slice_var = vars_metatools.VarStruct(
name=new_var_name,
shape=splited_shape,
dtype=orig_var.dtype,
type=orig_var.type,
lod_level=orig_var.lod_level,
persistable=False,
)
var_mapping[varname].append(slice_var)
self.var_distributed.add_distributed_var(
origin_var=orig_var,
slice_var=slice_var,
block_id=i,
offset=-1,
is_slice=False,
vtype="Param",
)
return var_mapping
def _dispatcher(self):
ps_dispatcher = RoundRobin(self.get_ps_endpoints())
ps_dispatcher.reset()
grad_var_mapping_items = list(self.grad_var_mapping.items())
sparse_gradnames = [grad.name for _, grad in self.origin_sparse_pairs]
for grad_varname, splited_vars in grad_var_mapping_items:
if grad_varname in sparse_gradnames:
continue
send_vars = []
for _, var in enumerate(splited_vars):
send_vars.append(var)
recv_vars = []
for _, var in enumerate(send_vars):
recv_vars.append(self.grad_param_mapping[var])
eps = ps_dispatcher.dispatch(recv_vars)
for i, ep in enumerate(eps):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
for grad_varname, splited_vars in grad_var_mapping_items:
if grad_varname not in sparse_gradnames:
continue
ps_dispatcher.reset()
send_vars = []
for _, var in enumerate(splited_vars):
send_vars.append(var)
recv_vars = []
for _, var in enumerate(send_vars):
recv_vars.append(self.grad_param_mapping[var])
eps = ps_dispatcher.dispatch(recv_vars)
for i, ep in enumerate(eps):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
def _slice_variable(
self, var_list, slice_count, min_block_size, uniform=False
):
"""
We may need to split dense tensor to one or more blocks and put
them equally onto parameter server. One block is a sub-tensor
aligned by dim[0] of the tensor.
We need to have a minimal block size so that the calculations in
the parameter server side can gain better performance. By default
minimum block size 8K elements (maybe 16bit or 32bit or 64bit).
Args:
var_list (list): List of variables.
slice_count (int): Numel of count that variables will be sliced, which
could be the pserver services' count.
min_block_size (int): Minimum split block size.
Returns:
blocks (list[(varname, block_id, current_block_size)]): A list
of VarBlocks. Each VarBlock specifies a shard of the var.
"""
blocks = []
for var in var_list:
if not uniform:
var_numel = reduce(lambda x, y: x * y, var.shape)
split_count = 1
if min_block_size == -1:
split_count = 1
else:
split_count = slice_count
max_pserver_count = int(
math.floor(var_numel / float(min_block_size))
)
if max_pserver_count == 0:
max_pserver_count = 1
if max_pserver_count < slice_count:
split_count = max_pserver_count
block_size = int(math.ceil(var_numel / float(split_count)))
if len(var.shape) >= 2:
# align by dim1(width)
dim1 = reduce(lambda x, y: x * y, var.shape[1:])
remains = block_size % dim1
if remains != 0:
block_size += dim1 - remains
# update split_count after aligning
split_count = int(math.ceil(var_numel / float(block_size)))
for block_id in range(split_count):
curr_block_size = min(
block_size, var_numel - ((block_id) * block_size)
)
block = vars_metatools.VarBlock(
var.name, block_id, curr_block_size
)
blocks.append(str(block))
else:
block_size = var.shape[0] / slice_count
remainder = var.shape[0] % slice_count
if block_size == 0:
dim0s = [block_size] * remainder
else:
dim0s = [block_size] * slice_count
for i in range(remainder):
dim0s[i] = dim0s[i] + 1
dim1 = reduce(lambda x, y: x * y, var.shape[1:])
for block_id in range(len(dim0s)):
numel = dim0s[block_id] * dim1
block = vars_metatools.VarBlock(var.name, block_id, numel)
blocks.append(str(block))
return blocks
def _get_param_grad_blocks(self, pairs, min_block_size, uniform=False):
param_list = []
grad_list = []
param_grad_set = set()
for p, g in pairs:
# todo(tangwei12) skip parameter marked not trainable
# if type(p) == Parameter and p.trainable == False:
# continue
p = p.merged_var
g = g.merged_var
if p.name not in param_grad_set:
param_list.append(p)
param_grad_set.add(p.name)
if g.name not in param_grad_set:
grad_list.append(g)
param_grad_set.add(g.name)
# when we slice var up into blocks, we will slice the var according to
# pserver services' count. A pserver may have two or more listening ports.
grad_blocks = self._slice_variable(
grad_list, len(self.get_ps_endpoints()), min_block_size, uniform
)
param_blocks = self._slice_variable(
param_list, len(self.get_ps_endpoints()), min_block_size, uniform
)
return param_blocks, grad_blocks
def _var_slice_and_distribute(self):
# update these mappings for further transpile:
# 1. param_var_mapping : param var name->[split params vars]
# 2. grad_var_mapping : grad var name->[split grads vars]
# 3. grad_param_mapping : grad.blockx->param.blockx
# 4. param_grad_ep_mapping : ep->{"params" : [], "grads" : [] }
dps, dgs = self._get_param_grad_blocks(
self.merged_dense_pairs, self.min_block_size, False
)
sps, sgs = self._get_param_grad_blocks(
self.merged_sparse_pairs, self.min_block_size, True
)
param_blocks = dps + sps
grad_blocks = dgs + sgs
assert len(grad_blocks) == len(param_blocks)
# origin_param_name->[splited_param_vars]
self.param_var_mapping = self._create_vars_from_blocklist(param_blocks)
self.grad_var_mapping = self._create_vars_from_blocklist(grad_blocks)
# dict(grad_splited_var->param_splited_var)
self.grad_param_mapping = collections.OrderedDict()
for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":")
self.grad_param_mapping[
self.grad_var_mapping[g_name][int(g_bid)]
] = self.param_var_mapping[p_name][int(p_bid)]
print_maps = {}
for k, v in self.grad_param_mapping.items():
print_maps[str(k)] = str(v)
# create mapping of endpoint->split var to create pserver side program
self.param_grad_ep_mapping = collections.OrderedDict()
[
self.param_grad_ep_mapping.update({ep: {"params": [], "grads": []}})
for ep in self.get_ps_endpoints()
]
def _build_var_distributed(self):
self.var_distributed = vars_metatools.VarsDistributed()
sparse_pairs, dense_pairs = self.get_param_grads()
origin_for_sparse = []
origin_for_dense = []
param_name_grad_name = dict()
grad_name_to_param_name = dict()
for param, grad in sparse_pairs:
param = vars_metatools.create_var_struct(param)
grad = vars_metatools.create_var_struct(grad)
origin_for_sparse.append((param, grad))
for param, grad in dense_pairs:
param = vars_metatools.create_var_struct(param)
grad = vars_metatools.create_var_struct(grad)
origin_for_dense.append((param, grad))
for dense_pair in origin_for_dense:
param, grad = dense_pair
m_param = MergedVariable(param, [param], [0])
m_grad = MergedVariable(grad, [grad], [0])
self.merged_variables_pairs.append((m_param, m_grad))
self.merged_dense_pairs.append((m_param, m_grad))
for sparse_pair in origin_for_sparse:
param, grad = sparse_pair
m_param = MergedVariable(param, [param], [0])
m_grad = MergedVariable(grad, [grad], [0])
self.merged_variables_pairs.append((m_param, m_grad))
self.merged_sparse_pairs.append((m_param, m_grad))
for merged in self.merged_variables_pairs:
m_param, m_grad = merged
self.merged_variable_map[
m_param.merged_var.name
] = m_param.merged_var
self.merged_variable_map[m_grad.merged_var.name] = m_grad.merged_var
param_merges = []
param_merges.extend(origin_for_sparse)
param_merges.extend(origin_for_dense)
for param, grad in param_merges:
param_name_grad_name[param.name] = grad.name
grad_name_to_param_name[grad.name] = param.name
self.origin_sparse_pairs = origin_for_sparse
self.origin_dense_pairs = origin_for_dense
self.param_name_to_grad_name = param_name_grad_name
self.grad_name_to_param_name = grad_name_to_param_name
sparse_pair_map = collections.OrderedDict()
for pair in self.origin_sparse_pairs + self.origin_dense_pairs:
param, grad = pair
sparse_pair_map[param.name] = str(param)
sparse_pair_map[grad.name] = str(grad)
self._var_slice_and_distribute()
self._dispatcher()
def get_param_grads(self):
origin_program = self.origin_main_program
def _get_params_grads(sparse_varnames):
block = origin_program.global_block()
dense_param_grads = []
sparse_param_grads = []
optimize_params = set()
origin_var_dict = origin_program.global_block().vars
role_id = int(core.op_proto_and_checker_maker.OpRole.Backward)
for op in block.ops:
if _is_opt_role_op(op):
# delete clip op from opt_ops when run in Parameter Server mode
if (
OP_NAME_SCOPE in op.all_attrs()
and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE)
):
op._set_attr("op_role", role_id)
continue
if op.attr(OP_ROLE_VAR_ATTR_NAME):
param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
if param_name not in optimize_params:
optimize_params.add(param_name)
param_grad = (
origin_var_dict[param_name],
origin_var_dict[grad_name],
)
if param_name in sparse_varnames:
sparse_param_grads.append(param_grad)
else:
dense_param_grads.append(param_grad)
return sparse_param_grads, dense_param_grads
def _get_sparse_varnames():
varnames = []
for op in origin_program.global_block().ops:
if (
op.type in SPARSE_OP_TYPE_DICT.keys()
and op.attr('remote_prefetch') is True
):
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
varnames.append(param_name)
return list(set(varnames))
sparse_varnames = _get_sparse_varnames()
sparse_param_grads, dense_param_grads = _get_params_grads(
sparse_varnames
)
return sparse_param_grads, dense_param_grads
def remove_var_pair_by_grad(self, var_name):
for index, pair in enumerate(self.merged_variables_pairs):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del self.merged_variables_pairs[index]
for index, pair in enumerate(self.merged_dense_pairs):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del self.merged_dense_pairs[index]
return
for index, pair in enumerate(self.merged_sparse_pairs):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del self.merged_sparse_pairs[index]
return
print("Not find {} in self.merge_pairs".format(var_name))
def _is_opt_role_op(op):
# NOTE : depend on oprole to find out whether this op is for
# optimize
op_maker = core.op_proto_and_checker_maker
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
if op_maker.kOpRoleAttrName() in op.attr_names and int(
op.all_attrs()[op_maker.kOpRoleAttrName()]
) == int(optimize_role):
return True
return False
def _get_optimize_ops(_program):
block = _program.global_block()
opt_ops = []
for op in block.ops:
if _is_opt_role_op(op):
# delete clip op from opt_ops when run in Parameter Server mode
if (
OP_NAME_SCOPE in op.all_attrs()
and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE)
):
op._set_attr(
"op_role",
int(core.op_proto_and_checker_maker.OpRole.Backward),
)
continue
opt_ops.append(op)
return opt_ops
def _add_lr_decay_table_pass(main_program, compiled_config, lr_decay_steps):
if hasattr(compiled_config.origin_main_program, 'lr_sheduler'):
from paddle.optimizer.lr import LRScheduler
assert isinstance(
compiled_config.origin_main_program.lr_sheduler, LRScheduler
), "must be LRScheduler"
ops = _get_optimize_ops(compiled_config.origin_main_program)
lr_param_dict = _get_lr_param_dict(ops)
(
lr_decay_main_program,
lr_decay_startup_program,
lr_name,
) = _get_lr_sheduler_program(
compiled_config.origin_main_program.lr_sheduler,
lr_param_dict,
lr_decay_steps,
)
compiled_config.add_tensor_table(
"@LR_DECAY_COUNTER@",
lr_name,
lr_decay_startup_program,
lr_decay_main_program,
"GlobalStepTable",
)
def _get_lr_param_dict(opt_ops):
lr_param_dict = {}
for op in opt_ops:
lr_name = op.input("LearningRate")[0]
param_name = op.input("Param")[0]
if lr_name not in lr_param_dict:
lr_param_dict[lr_name] = []
lr_param_dict[lr_name].append(param_name)
return lr_param_dict
def _get_lr_sheduler_program(lr_sheduler, lr_param_dict, lr_decay_steps):
schedler_decay = [
'NoamDecay',
'NaturalExpDecay',
'InverseTimeDecay',
'ExponentialDecay',
]
from paddle.optimizer.lr import (
ExponentialDecay,
NoamDecay,
PiecewiseDecay,
NaturalExpDecay,
InverseTimeDecay,
)
from paddle.static.learning_rate_scheduler import (
exponential_decay,
noam_decay,
piecewise_decay,
natural_exp_decay,
inverse_time_decay,
)
decay_main_program = paddle.static.Program()
decay_startup_program = paddle.static.Program()
lr_name = ""
if isinstance(lr_sheduler, ExponentialDecay):
with paddle.static.program_guard(
decay_main_program, decay_startup_program
):
lr = exponential_decay(1.0, lr_decay_steps, lr_sheduler.gamma, True)
lr_name = lr.name
logging.warn(
"ExponentialDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n"
"\t strategy = paddle.distributed.fleet.DistributedStrategy() \n "
"\t strategy.a_sync = True \n"
"\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n"
% lr_decay_steps
)
elif isinstance(lr_sheduler, NoamDecay):
with paddle.static.program_guard(
decay_main_program, decay_startup_program
):
lr = noam_decay(lr_sheduler.d_model, lr_sheduler.warmup_steps, 1.0)
lr_name = lr.name
logging.warn(
"NoamDecay is set, warmup steps is [ %d ]"
% lr_sheduler.warmup_steps
)
elif isinstance(lr_sheduler, NaturalExpDecay):
with paddle.static.program_guard(
decay_main_program, decay_startup_program
):
lr = natural_exp_decay(1.0, lr_decay_steps, lr_sheduler.gamma, True)
lr_name = lr.name
logging.warn(
"NaturalExpDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n"
"\t strategy = paddle.distributed.fleet.DistributedStrategy() \n "
"\t strategy.a_sync = True \n"
"\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n"
% lr_decay_steps
)
elif isinstance(lr_sheduler, InverseTimeDecay):
with paddle.static.program_guard(
decay_main_program, decay_startup_program
):
lr = inverse_time_decay(
1.0, lr_decay_steps, lr_sheduler.gamma, True
)
lr_name = lr.name
logging.warn(
"InverseTimeDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n"
"\t strategy = paddle.distributed.fleet.DistributedStrategy() \n "
"\t strategy.a_sync = True \n"
"\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n"
% lr_decay_steps
)
else:
raise ValueError(
"Not supported current LearningRate strategy, please use follow decay strategy: {}".format(
schedler_decay
)
)
return decay_main_program, decay_startup_program, lr_name
def _get_varname_parts(varname):
# returns origin, blockid, trainerid
orig_var_name = ""
trainer_part = ""
block_part = ""
trainer_idx = varname.find(".trainer_")
if trainer_idx >= 0:
trainer_part = varname[trainer_idx + 1 :]
else:
trainer_idx = len(varname)
block_index = varname.find(".block")
if block_index >= 0:
block_part = varname[block_index + 1 : trainer_idx]
else:
block_index = len(varname)
orig_var_name = varname[0 : min(block_index, trainer_idx)]
return orig_var_name, block_part, trainer_part
def _orig_varname(varname):
orig, _, _ = _get_varname_parts(varname)
return orig
# -*- coding: UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import collections
import warnings
import math
from functools import reduce
import paddle
from paddle.framework import core
import paddle.framework as framework
from paddle.distributed.transpiler.details.program_utils import delete_ops
from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
_get_optimize_ops,
)
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_lr_ops
from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
get_sparse_tablenames,
)
from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode
OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName()
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
SPARSE_GRAD_OP_TYPE_DICT = {
"lookup_table_grad": "W",
"lookup_table_v2_grad": "W",
}
DEVICE_LIST = ["cpu", "gpu", "xpu"]
COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"]
DEFAULT_DEVICE = 'cpu'
def delete_optimizer_pass(program, config):
def _delete_optimizer_op_and_vars(_program, optimize_ops):
optimize_vars = []
optimize_op_role_vars = []
optimize_need_delete_vars = []
for op in optimize_ops:
optimize_vars.extend(op.input_arg_names)
optimize_op_role_vars.extend(op.attr("op_role_var"))
optimize_vars = list(set(optimize_vars))
optimize_op_role_vars = list(set(optimize_op_role_vars))
for var in optimize_vars:
if var not in optimize_op_role_vars:
optimize_need_delete_vars.append(var)
need_delete_optimize_vars = list(set(optimize_need_delete_vars))
delete_ops(_program.global_block(), optimize_ops)
for var in need_delete_optimize_vars:
if _program.global_block().has_var(var):
_program.global_block()._remove_var(var)
def _add_lr_var(main_program, compiled_config):
# Todo: hard code for pe
lr_var = compiled_config.origin_main_program.global_block().vars[
"learning_rate_0"
]
main_program.global_block().create_var(
name=lr_var.name,
shape=lr_var.shape,
dtype=lr_var.dtype,
type=lr_var.type,
lod_level=lr_var.lod_level,
persistable=True,
)
optimizer_ops = _get_optimize_ops(program)
lr_ops = _get_lr_ops(program)
optimizer_ops.extend(lr_ops)
_delete_optimizer_op_and_vars(program, optimizer_ops)
if hasattr(config.origin_main_program, 'lr_sheduler'):
_add_lr_var(program, config)
return program
def distributed_ops_pass(program, config, use_ps_gpu=False):
trainer_id = config.get_role_id()
send_ctx = config.get_the_one_send_context(
split_dense_table=config.is_heter_ps_mode
)
w_2_table_id = {}
emb_size = {}
def _get_pull_sparse_ops(_program):
pull_sparse_ops = {}
pull_sparse_ids = {}
push_sparse_ops = {}
ops = {}
for op in _program.global_block().ops:
if (
op.type in SPARSE_OP_TYPE_DICT.keys()
and op.attr('remote_prefetch') is True
):
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
if config.is_heter_ps_mode:
# trick for matchnet, need to modify
param_name += op.input("Ids")[0][0]
ops = pull_sparse_ops.get(param_name, [])
ops.append(op)
pull_sparse_ops[param_name] = ops
ids = pull_sparse_ids.get(param_name, [])
ids.append(op.input("Ids")[0])
pull_sparse_ids[param_name] = ids
for op in _program.global_block().ops:
if op.type in SPARSE_GRAD_OP_TYPE_DICT.keys():
param_name = op.input(SPARSE_GRAD_OP_TYPE_DICT[op.type])[0]
if (
param_name in pull_sparse_ids
and op.input("Ids")[0] in pull_sparse_ids[param_name]
):
ops = push_sparse_ops.get(param_name, [])
ops.append(op)
push_sparse_ops[param_name] = ops
return pull_sparse_ops, push_sparse_ops
def _pull_sparse_fuse(_program, pull_sparse_ops, use_ps_gpu):
def dag_check_up_and_reorder(program, inputs, outputs):
global_block = program.global_block()
min_output_index = len(global_block.ops)
max_input_index = -1
input_indexes = [0] * len(global_block.ops)
output_indexes = [0] * len(global_block.ops)
for idx, op in enumerate(global_block.ops):
for i in range(0, len(op.output_names)):
if input_indexes[idx] == 1:
break
outs = op.output(op.output_names[i])
for in_id, in_var in enumerate(inputs):
if in_var.name in outs:
input_indexes[idx] = 1
max_input_index = max(max_input_index, idx)
break
for i in range(0, len(op.input_names)):
if output_indexes[idx] == 1:
break
ins = op.input(op.input_names[i])
for out_id, out_var in enumerate(outputs):
if out_var.name in ins:
output_indexes[idx] = 1
min_output_index = min(min_output_index, idx)
for i in range(len(global_block.ops)):
if input_indexes[i] == 1 and output_indexes[i] == 1:
warnings.warn(
"unable to re-arrange dags order to combine distributed embedding ops because a op both needs embedding table's output as input and produces ids as the same embedding table's input"
)
return
if min_output_index < max_input_index:
move_ops = []
for i in range(min_output_index + 1, len(input_indexes)):
if input_indexes[i] == 1:
move_ops.append((global_block.ops[i], i))
for i, op in enumerate(move_ops):
queue = list()
visited = set()
queue.append(op[1])
visited.add(op[0])
start = 0
while start < len(queue):
pos = queue[start]
op = global_block.ops[pos]
op_inputs = []
for k in range(0, len(op.input_names)):
ins = op.input(op.input_names[k])
op_inputs.append(ins)
for j in range(pos - 1, min_output_index - 1, -1):
op1 = global_block.ops[j]
if op1 in visited:
continue
found = False
for k in range(0, len(op1.output_names)):
outs = op1.output(op1.output_names[k])
for t in range(len(op_inputs)):
for y in op_inputs[t]:
if y in outs:
found = True
break
if found:
break
if found:
break
if found:
if output_indexes[j] == True:
warnings.warn(
"unable to re-arrange dags order to combine distributed embedding ops"
)
return
queue.append(j)
visited.add(global_block.ops[j])
start = start + 1
queue.sort()
for index in queue:
desc = global_block.desc._insert_op(min_output_index)
desc.copy_from(global_block.ops[index].desc)
global_block.desc._remove_op(index + 1, index + 2)
global_block.ops[index].desc = desc
insert_op = global_block.ops.pop(index)
input_state = input_indexes.pop(index)
output_state = output_indexes.pop(index)
global_block.ops.insert(min_output_index, insert_op)
input_indexes.insert(min_output_index, input_state)
output_indexes.insert(min_output_index, output_state)
min_output_index = min_output_index + 1
assert global_block.desc.op_size() == len(global_block.ops)
for i in range(len(global_block.ops)):
assert global_block.desc.op(i) == global_block.ops[i].desc
for param, ops in pull_sparse_ops.items():
all_ops = program.global_block().ops
op_device = ""
if config.is_heter_ps_mode:
op_device = ops[0].attr("op_device")
inputs = [
program.global_block().vars[op.input("Ids")[0]] for op in ops
]
w = program.global_block().vars[ops[0].input("W")[0]]
emb_size[param] = w.shape[1]
grad_name = config.param_name_to_grad_name[w.name]
table_id = -1
for name, ctx in send_ctx.items():
if grad_name in ctx.origin_varnames():
table_id = ctx.table_id()
if table_id == -1:
raise ValueError(
"can not find suitable sparse table, please check"
)
w_2_table_id[param] = table_id
padding_idx = ops[0].attr("padding_idx")
is_distributed = ops[0].attr("is_distributed")
op_type = ops[0].type
outputs = [
program.global_block().vars[op.output("Out")[0]] for op in ops
]
dag_check_up_and_reorder(program, inputs, outputs)
op_idxs = [all_ops.index(op) for op in ops]
for idx in op_idxs[::-1]:
program.global_block()._remove_op(idx)
inputs_idxs = [-1] * len(inputs)
outputs_idxs = [len(program.global_block().ops) + 1] * len(outputs)
for idx, op in enumerate(program.global_block().ops):
for i in range(0, len(op.output_names)):
outs = op.output(op.output_names[i])
for in_id, in_var in enumerate(inputs):
if in_var.name in outs:
inputs_idxs[in_id] = max(idx, inputs_idxs[in_id])
for i in range(0, len(op.input_names)):
ins = op.input(op.input_names[i])
for out_id, out_var in enumerate(outputs):
if out_var.name in ins:
outputs_idxs[out_id] = min(
idx, outputs_idxs[out_id]
)
if min(outputs_idxs) - max(inputs_idxs) >= 1:
if max(inputs_idxs) == -1:
distributed_idx = min(op_idxs)
else:
distributed_idx = max(inputs_idxs) + 1
if use_ps_gpu:
program.global_block()._insert_op(
index=distributed_idx,
type="pull_gpups_sparse",
inputs={"Ids": inputs, 'W': w},
outputs={"Out": outputs},
attrs={
"size": [w.shape[1] for i in inputs],
"is_distributed": True,
"is_sparse": True,
},
)
else:
program.global_block()._insert_op(
index=distributed_idx,
type="distributed_lookup_table",
inputs={"Ids": inputs, 'W': w},
outputs={"Outputs": outputs},
attrs={
"is_distributed": is_distributed,
"padding_idx": padding_idx,
"table_id": table_id,
"lookup_table_version": op_type,
"op_device": op_device,
},
)
else:
for i in range(len(inputs_idxs)):
distributed_idx = op_idxs[i]
program.global_block()._insert_op(
index=distributed_idx,
type="distributed_lookup_table",
inputs={"Ids": [inputs[i]], 'W': w},
outputs={"Outputs": [outputs[i]]},
attrs={
"is_distributed": is_distributed,
"padding_idx": padding_idx,
"table_id": table_id,
"lookup_table_version": op_type,
"op_device": op_device,
},
)
def _push_sparse_fuse(_program, push_sparse_ops, use_ps_gpu):
if use_ps_gpu:
# in ps_gpu_pass
return
if len(push_sparse_ops) == 0:
return
show = None
clk = None
use_entry = False
for param, ops in push_sparse_ops.items():
op_first = ops[0]
break
print(op_first)
if op_first.has_attr("entry"):
entry = op_first.attr("entry")
entry = entry.split(':')
if len(entry) == 3 and entry[0] == 'show_click_entry':
show_var_name = entry[1]
click_var_name = entry[2]
if (
show_var_name in program.global_block().vars
and click_var_name in program.global_block().vars
):
show = program.global_block().vars[show_var_name]
clk = program.global_block().vars[click_var_name]
use_entry = True
else:
warnings.warn(
'ShowClickEntry configured, but cannot find show/click var, will not use'
)
if not use_entry:
print('ShowClickEntry not configured, will not use')
show = program.global_block().create_var(
name="show",
dtype=core.VarDesc.VarType.INT64,
persistable=False,
stop_gradient=True,
)
program.global_block()._insert_op(
index=0,
type='fill_constant',
inputs={},
outputs={'Out': show},
attrs={
'shape': [1],
'dtype': show.dtype,
'value': 1,
# OP_ROLE_KEY: OpRole.Forward
},
)
clk = program.global_block().create_var(
name="clk",
dtype=core.VarDesc.VarType.INT64,
persistable=False,
stop_gradient=True,
)
program.global_block()._insert_op(
index=0,
type='fill_constant',
inputs={},
outputs={'Out': clk},
attrs={
'shape': [1],
'dtype': clk.dtype,
'value': 0,
# OP_ROLE_KEY: OpRole.Forward
},
)
for param, ops in push_sparse_ops.items():
all_ops = program.global_block().ops
op_idxs = [all_ops.index(op) for op in ops]
inputs = [
program.global_block().vars[op.input("Ids")[0]] for op in ops
]
w = program.global_block().vars[ops[0].output("W@GRAD")[0]]
table_id = w_2_table_id[param]
padding_idx = ops[0].attr("padding_idx")
is_distributed = ops[0].attr("is_distributed")
op_type = ops[0].type
outputs = [
program.global_block().vars[op.input("Out@GRAD")[0]]
for op in ops
]
for idx in op_idxs[::-1]:
program.global_block()._remove_op(idx)
# if use_ps_gpu:
# program.global_block().append_op(
# type="push_box_sparse",
# inputs={"Ids": inputs,
# 'Out': outputs},
# outputs={"Out": outputs},
# attrs={
# "size": w.shape[1],
# "is_distributed": True,
# "is_sparse": True
# })
# else:
program.global_block().append_op(
type="distributed_push_sparse",
inputs={
"Ids": inputs,
'W': w,
"Outputs": outputs,
"Shows": show,
"Clicks": clk,
},
outputs={"Outputs": outputs},
attrs={
"is_distributed": is_distributed,
"padding_idx": padding_idx,
"table_id": table_id,
"size": emb_size[param],
},
)
pull_sparse_ops, push_sparse_ops = _get_pull_sparse_ops(program)
_pull_sparse_fuse(program, pull_sparse_ops, use_ps_gpu)
_push_sparse_fuse(program, push_sparse_ops, use_ps_gpu)
return program
def append_send_ops_pass(program, config):
mode = config.get_distributed_mode()
trainer_id = config.get_role_id()
def _append_send_op(union_vars, queue, is_sparse, table_id):
if queue == STEP_COUNTER:
send_input_vars = []
else:
send_input_vars = [
program.global_block().vars[union_var]
for union_var in union_vars
]
dummy_output = []
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name()
)
program.global_block().append_op(
type="send",
inputs={"X": send_input_vars},
outputs={"Out": dummy_output},
attrs={
"send_varnames": [queue],
"is_sparse": is_sparse,
"table_id": table_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
},
)
return dummy_output
def _append_barrier_op(dummys):
program.global_block().append_op(
type="send_barrier",
inputs={"X": dummys},
outputs={"Out": []},
attrs={
"trainer_id": trainer_id,
"half_async": True,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
},
)
dummys = []
sends = config.get_the_one_trainer_send_context(
split_dense_table=config.is_heter_ps_mode
)
for merged_name, send in sends.items():
if send.is_sparse() and not config.is_geo_mode():
continue
is_sparse = 1 if send.is_sparse() else 0
is_sparse = 2 if send.is_distributed() else is_sparse
dummys.append(
_append_send_op(
send.origin_varnames(), merged_name, is_sparse, send.table_id()
)
)
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
_append_barrier_op(dummys)
return program
def init_from_server_pass(program, config):
# 0' trainer do not need barrier, it will call barrier at the end init_worker
if config.role_maker._is_first_worker():
return program
fetch_barrier_out = program.global_block().create_var(
name=framework.generate_control_dev_var_name()
)
program.global_block().append_op(
type="fetch_barrier",
inputs={},
outputs={"Out": fetch_barrier_out},
attrs={
"endpoints": config.get_ps_endpoints(),
"trainer_id": config.get_role_id(),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
},
)
return program
def fake_init_ops_pass(program, config):
origin_program = config.get_origin_main_program()
def _get_sparse_table_names():
dist_varnames = get_sparse_tablenames(origin_program, True)
sparse_varnames = get_sparse_tablenames(origin_program, False)
return list(set(dist_varnames + sparse_varnames))
def _fake_init_sparsetable(sparse_table_names):
# delete table init op
for table_name in sparse_table_names:
table_var = program.global_block().vars[table_name]
table_param_init_op = []
for op in program.global_block().ops:
if table_name in op.output_arg_names:
table_param_init_op.append(op)
init_op_num = len(table_param_init_op)
if init_op_num != 1:
raise ValueError(
"table init op num should be 1, now is " + str(init_op_num)
)
table_init_op = table_param_init_op[0]
program.global_block().append_op(
type="fake_init",
inputs={},
outputs={"Out": table_var},
attrs={"shape": table_init_op.attr('shape')},
)
delete_ops(program.global_block(), table_param_init_op)
sparse_tables = _get_sparse_table_names()
_fake_init_sparsetable(sparse_tables)
return program
def ps_gpu_pass(program):
def _add_push_box_sparse_op(program):
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for op in program.global_block().ops:
if op.type != "pull_box_sparse" and op.type != "pull_gpups_sparse":
continue
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, set(), []
)
for op_desc in grad_op_desc:
new_op_desc = program.global_block().desc.append_op()
new_op_desc.copy_from(op_desc)
new_op_desc._set_attr(op_role_attr_name, backward)
def _remove_lookup_table_grad_op_and_var(program):
lookup_table_grad_var = {}
remove_op_index = []
remove_var = []
for idx, op in list(enumerate(program.global_block().ops)):
if op.type == "lookup_table_grad":
for name in op.output("W@GRAD"):
lookup_table_grad_var[name] = 1
remove_op_index.append(idx)
remove_var.append(name)
for name in op.input("W"):
lookup_table_grad_var[name] = 1
for idx, op in list(enumerate(program.global_block().ops)):
if op.type == "pull_box_sparse" or op.type == "pull_gpups_sparse":
continue
for key_name in op.input_names:
for var in op.input(key_name):
if var in lookup_table_grad_var:
remove_op_index.append(idx)
break
remove_op_index = list(set(remove_op_index))
remove_op_index.sort(reverse=True)
for idx in remove_op_index:
program.global_block()._remove_op(idx)
for name in remove_var:
program.global_block()._remove_var(name)
def _remove_optimizer_var(program):
embedding_w = {}
for idx, op in list(enumerate(program.global_block().ops)):
if op.type == "lookup_table_grad":
for name in op.input("W"):
embedding_w[name] = 1
optimize_vars = []
optimize_op_role_vars = []
optimize_need_delete_vars = []
for op in _get_optimize_ops(program):
for name in op.input("Param"):
if name in embedding_w:
optimize_op_role_vars.extend(op.attr("op_role_var"))
for key_name in op.input_names:
if key_name == "LearningRate":
continue
for var in op.input(key_name):
optimize_vars.append(var)
optimize_vars = list(set(optimize_vars))
optimize_op_role_vars = list(set(optimize_op_role_vars))
for var in optimize_vars:
if var not in optimize_op_role_vars:
optimize_need_delete_vars.append(var)
need_delete_optimize_vars = list(set(optimize_need_delete_vars))
for name in need_delete_optimize_vars:
if program.global_block().has_var(name):
program.global_block()._remove_var(name)
_add_push_box_sparse_op(program)
_remove_optimizer_var(program)
_remove_lookup_table_grad_op_and_var(program)
return program
def delete_extra_optimizes_pass(program, config):
optimize_vars = []
optimize_op_role_vars = []
optimize_need_delete_vars = []
origin_program = config.get_origin_main_program()
for op in _get_optimize_ops(origin_program):
optimize_vars.extend(op.input_arg_names)
optimize_op_role_vars.extend(op.attr("op_role_var"))
optimize_vars = list(set(optimize_vars))
optimize_op_role_vars = list(set(optimize_op_role_vars))
for var in optimize_vars:
if var not in optimize_op_role_vars:
optimize_need_delete_vars.append(var)
need_delete_optimize_vars = list(set(optimize_need_delete_vars))
init_ops = []
for var in need_delete_optimize_vars:
param_init_op = []
for op in program.global_block().ops:
if var in op.output_arg_names:
param_init_op.append(op)
init_ops.extend(param_init_op)
delete_ops(program.global_block(), init_ops)
for var in need_delete_optimize_vars:
if program.global_block().has_var(var):
program.global_block()._remove_var(var)
return program
def find_heter_ops(program, default_device="cpu"):
if default_device not in DEVICE_LIST:
raise ValueError(
"Given device {} is not in device list {}".format(
default_device, DEVICE_LIST
)
)
def _is_heter_op(op, current_heter_device, default_device="cpu"):
heter_devices = list(DEVICE_LIST)
heter_devices.remove(default_device)
op_device = op.attr("op_device")
op_type = op.type
if op_device in heter_devices:
return True
elif (
op_type in COMMUNICATE_OPS_TYPE
and current_heter_device != default_device
):
# for distributed communciate ops: send & recv & barrier etc.
# Todo: need update this method
# op._set_attr('op_device', current_heter_device)
return True
elif op_device is None or op_device == default_device:
op._set_attr('op_device', default_device)
return False
return False
def _is_same_device(op, pre_device, default_device="cpu"):
op_device = op.attr("op_device")
if op_device == pre_device:
return True
if pre_device == default_device:
return True
return False
def _append_heter_op(op, current_heter_block_ops, heter_ops):
op_device = op.attr("op_device")
if op_device not in heter_ops:
heter_ops[op_device] = {}
current_heter_block_ops.append(op)
origin_porgram = program.clone()
block = program.global_block()
'''
re-place sum op to fix bug for union forward backward op
'''
var2idx = {}
op_list = list(block.ops)
op_size = len(op_list)
for i in range(op_size - 1, -1, -1):
op_list = list(block.ops)
op = op_list[i]
if "_grad" in op.type:
forward_op_type = op.type.split("_grad")[0]
if (
forward_op_type in SPARSE_OP_TYPE_DICT.keys()
and op.attr('remote_prefetch') is True
):
param_name = op.input(SPARSE_OP_TYPE_DICT[forward_op_type])[0]
if param_name in var2idx:
## insert sum op & remove sum op from var2idx and origin place
op_list = list(block.ops)
sum_op = op_list[var2idx[param_name]]
sum_op_inputs = {
sum_op.input_names[0]: [
block.vars[input]
for input in sum_op.input_arg_names
]
}
sum_op_outputs = {
sum_op.output_names[0]: [
block.vars[output]
for output in sum_op.output_arg_names
]
}
block._insert_op(
index=i + 1,
type=sum_op.type,
inputs=sum_op_inputs,
outputs=sum_op_outputs,
attrs=sum_op.all_attrs(),
)
block._remove_op(var2idx[param_name] + 1)
var2idx.pop(param_name)
for var_ in var2idx:
var2idx[var_] += 1
elif forward_op_type == "elementwise_mul":
"""
get output varname of pre op
"""
output_vars_no_grad = []
for key in op.output_names:
for varname in op.output(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
output_vars_no_grad.append(varname.split("@GRAD")[0])
for no_grad_var in output_vars_no_grad:
if no_grad_var in var2idx:
"""
insert sum op & remove sum op from var2idx and origin place
"""
op_list = list(block.ops)
sum_op = op_list[var2idx[no_grad_var]]
sum_op_inputs = {
sum_op.input_names[0]: [
block.vars[input]
for input in sum_op.input_arg_names
]
}
sum_op_outputs = {
sum_op.output_names[0]: [
block.vars[output]
for output in sum_op.output_arg_names
]
}
block._insert_op(
index=i + 1,
type=sum_op.type,
inputs=sum_op_inputs,
outputs=sum_op_outputs,
attrs=sum_op.all_attrs(),
)
block._remove_op(var2idx[no_grad_var] + 1)
var2idx.pop(no_grad_var)
for var_ in var2idx:
var2idx[var_] += 1
else:
if op.type == "sum":
var = op.output("Out")[0]
if "@GRAD" in var:
origin_var = var.split("@GRAD")[0]
pre_op = op_list[i - 1]
if "_grad" in pre_op.type:
forward_op_type = pre_op.type.split("_grad")[0]
if (
forward_op_type in SPARSE_OP_TYPE_DICT.keys()
and pre_op.attr('remote_prefetch') is True
):
param_name = pre_op.input(
SPARSE_OP_TYPE_DICT[forward_op_type]
)[0]
if param_name == origin_var and op.attr(
"op_device"
) == pre_op.attr("op_device"):
continue
else:
var2idx[origin_var] = i
elif forward_op_type == "elementwise_mul":
output_vars = []
for key in pre_op.output_names:
for varname in pre_op.output(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
output_vars.append(varname)
input_vars = []
for key in op.input_names:
for varname in op.input(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
input_vars.append(varname)
is_match = False
for varname in output_vars:
if varname in input_vars:
is_match = True
break
if is_match:
continue
else:
var2idx[origin_var] = i
else:
var2idx[origin_var] = i
origin_porgram = program.clone()
block = program.global_block()
program_block_ops = []
default_ops = {default_device: {}}
heter_ops = {}
block_index = 0
current_heter_block_ops = []
current_default_block_ops = []
current_heter_device = default_device
is_heter = False
for op in block.ops:
if _is_heter_op(op, current_heter_device, default_device):
# for gpu/xpu-op
is_heter = True
# for cpu-op block append
if len(current_default_block_ops) > 1:
default_ops[default_device][
block_index
] = current_default_block_ops
program_block_ops.append(current_default_block_ops)
current_default_block_ops = []
block_index += 1
if _is_same_device(op, current_heter_device, default_device):
# for gpu-op, gpu-op -> gpu-op,...
current_heter_device = op.attr("op_device")
_append_heter_op(op, current_heter_block_ops, heter_ops)
else:
# for gpu-op -> xpu-op, ...
op_device = current_heter_block_ops[0].attr("op_device")
heter_ops[op_device][block_index] = current_heter_block_ops
program_block_ops.append(current_heter_block_ops)
block_index += 1
current_heter_block_ops = []
current_heter_device = op.attr("op_device")
_append_heter_op(op, current_heter_block_ops, heter_ops)
elif is_heter:
# for gpu/xpu-op -> cpu-op
op_device = current_heter_block_ops[0].attr("op_device")
heter_ops[op_device][block_index] = current_heter_block_ops
program_block_ops.append(current_heter_block_ops)
block_index += 1
current_heter_block_ops = []
current_heter_device = default_device
is_heter = False
current_default_block_ops.append(op)
else:
# for cpu-op
current_default_block_ops.append(op)
if current_default_block_ops != []:
default_ops[default_device][block_index] = current_default_block_ops
program_block_ops.append(current_default_block_ops)
if current_heter_block_ops != []:
op_device = current_heter_block_ops[0].attr("op_device")
heter_ops[op_device][block_index] = current_heter_block_ops
program_block_ops.append(current_heter_block_ops)
if len(heter_ops) == 0:
warnings.warn(
"No heterogeneous OP was found in your program , "
" please using paddle.static.device_guard() to run OPs on different device."
)
total_heter_ops = 0
heter_blocks = 0
for device in heter_ops.keys():
heter_block_dict = heter_ops[device]
heter_blocks += len(heter_block_dict)
for _, heter_block in heter_block_dict.items():
total_heter_ops += len(heter_block)
print(
"There are {} OPs in your main_program, and contains {} heter-OPs which is made up of {} heter-blocks.".format(
len(block.ops), total_heter_ops, heter_blocks
)
)
return origin_porgram, heter_ops, default_ops, program_block_ops
def create_heter_program(
program,
config,
heter_program,
program_block_ops_list,
heter_ops,
block_var_detail,
current_device,
stage_id,
):
# This function mainly includes the following contents:
# 1. For every heter block:
# a) copy heter device op from origin program
# b) create variables which belong to heter op:
# -> if variable is persistable, clone it in global_scope
# -> if variable is temp, create it in heter block
# c) create communicate related op as follow:
# joint_var.0_1 -> slice -> reshape -> origin_var
# origin_var -> origin_program
# reshape -> concat -> joint_var.1_2
# d) copy send op from origin program for var@grad which loacted in current heter block
# e) re-check every op in current blcok if its device is not current heter devie
# 2. Create send op for step counter in last heter-block
# 3. Create Listen&Serv OP and Send&Recv OP for distributed training
# 4. update CompileTimeStrategy for heter_program
optimizer_block = []
grad_to_block_id = []
send_grad_var_list = []
pre_block_idx = heter_program.num_blocks - 1
stage_id = int(stage_id)
print("stage id", stage_id)
heter_block_ops_forward = program_block_ops_list[stage_id - 1]["forward"]
heter_block_ops_backward = program_block_ops_list[stage_id - 1]["backward"]
heter_block = heter_program._create_block(pre_block_idx)
optimizer_block.append(heter_block)
for _, op in enumerate(heter_block_ops_forward):
block_append_op(heter_program, program, heter_block, op)
entrance_vars = block_var_detail[stage_id - 1]["forward"]["entrance"]
add_vars_by_var_list(entrance_vars, program, heter_program, heter_block)
exit_vars = block_var_detail[stage_id - 1]["forward"]["exit"]
add_vars_by_var_list(exit_vars, program, heter_program, heter_block)
first_op_index_fp = len(heter_block.ops)
if stage_id < len(program_block_ops_list):
heter_block_bp = heter_program._create_block(pre_block_idx)
optimizer_block.append(heter_block_bp)
for _, op in enumerate(heter_block_ops_backward):
block_append_op(heter_program, program, heter_block_bp, op)
bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][
"entrance"
]
add_vars_by_var_list(
bp_entrance_vars, program, heter_program, heter_block_bp
)
bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"]
add_vars_by_var_list(
bp_exit_vars, program, heter_program, heter_block_bp
)
backward_comm_info = get_communicate_var_info(
program, stage_id, bp_entrance_vars, type="backward"
)
grad_to_block_id.append(
backward_comm_info["block_input_var_name"]
+ ":"
+ str(heter_block_bp.idx)
)
else:
for _, op in enumerate(heter_block_ops_backward):
block_append_op(heter_program, program, heter_block, op)
bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][
"entrance"
]
add_vars_by_var_list(
bp_entrance_vars, program, heter_program, heter_block
)
bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"]
add_vars_by_var_list(bp_exit_vars, program, heter_program, heter_block)
heter_block_bp = heter_block
forward_comm_info = get_communicate_var_info(
program, stage_id, entrance_vars, type="forward"
)
grad_to_block_id.append(
forward_comm_info["block_input_var_name"] + ":" + str(heter_block.idx)
)
first_op_index_bp = len(heter_block_bp.ops)
if stage_id <= len(block_var_detail) - 1:
static_var = insert_communicate_op(
program,
config,
heter_block,
stage_id,
first_op_index_fp,
block_var_detail,
current_device,
)
static_var_bp = insert_communicate_op(
program,
config,
heter_block_bp,
stage_id,
first_op_index_bp,
block_var_detail,
current_device,
False,
)
# add send op
send_grad_var_list = add_heter_send_op(
program, heter_program, heter_block_bp, block_var_detail[stage_id - 1]
)
# ---------------
# add step conter
send_input_vars = []
dummy_output = []
pserver_endpoints = config.get_ps_endpoints()
# optimizer_block[-1].append_op(
# type="send",
# inputs={"X": send_input_vars},
# outputs={"Out": dummy_output},
# attrs={
# "send_varnames": [STEP_COUNTER],
# "merge_add": True,
# "use_send_handler": False,
# "endpoints": pserver_endpoints
# })
# add info in listen&serv
attrs = {
# "mode": "sync",
# "trainers": config.get_trainers(),
# "trainer_id": config.get_role_id() + config.get_trainers(),
"message_to_block_id": grad_to_block_id,
"optimize_blocks": optimizer_block,
# runtime attribute
"endpoint": config.get_heter_worker_endpoint(),
"fanin": len(config.get_previous_stage_trainers()),
"pserver_id": config.get_role_id(),
"distributed_mode": config.get_distributed_mode(),
"rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
}
# append the listen_and_serv op
heter_program.global_block().append_op(
type="heter_listen_and_serv", inputs={'X': []}, outputs={}, attrs=attrs
)
check_heter_compile_time_strategy(program, config, send_grad_var_list)
def check_heter_compile_time_strategy(program, config, send_grad_var_list):
origin_grad_var_list = []
for _, var_grad in config.merged_variables_pairs:
origin_grad_var_list.append(var_grad.merged_var.name)
origin_grad_var_list = list(set(origin_grad_var_list))
send_grad_var_list = list(set(send_grad_var_list))
useless_grad_var_list = list(
set(origin_grad_var_list) - set(send_grad_var_list)
)
for useless_grad_var in useless_grad_var_list:
config.remove_var_pair_by_grad(useless_grad_var)
def create_trainer_program(
program, origin_program, config, program_block_ops_list, block_var_detail
):
# This function mainly includes the following contents:
# 1. For every heter block in origin program
# a) delete heter op and related variables
# b) add send&recv op
# c) add communicate ops as follows:
# origin_var -> reshape -> concat -> joint_var.0_1
# send&recv op(send joint_var.0_1; recv joint_var.1_2)
# joint_var.1_2 -> slice -> reshape -> origin_var
# d) remove send op which related var@grad is not in trainer program
# 2. check every op's device
static_var = []
for heter_block_index in range(1, len(program_block_ops_list)):
ops_list = (
program_block_ops_list[heter_block_index]["forward"]
+ program_block_ops_list[heter_block_index]["backward"]
)
static_var += replace_ops_by_communicate_op(
program, config, heter_block_index, ops_list, block_var_detail
)
remove_trainer_send_op(
program, config, heter_block_index, block_var_detail
)
optimizer_block = []
grad_to_block_id = []
bp_ops_list = program_block_ops_list[0]["backward"]
delete_same_ops(program.global_block(), bp_ops_list)
delete_trainer_useless_var(config, program, static_var)
backward_block = create_backward_block(
program, origin_program, config, bp_ops_list, block_var_detail
)
bp_entrance_vars = block_var_detail[0]["backward"]["entrance"]
backward_comm_info = get_communicate_var_info(
origin_program, 1, bp_entrance_vars, type="backward"
)
grad_to_block_id.append(
backward_comm_info["block_input_var_name"]
+ ":"
+ str(backward_block.idx)
)
optimizer_block.append(backward_block)
attrs = {
# "mode": "sync",
# "trainers": config.get_trainers(),
# "trainer_id": config.get_role_id(),
"message_to_block_id": grad_to_block_id,
"optimize_blocks": optimizer_block,
# runtime attribute
"endpoint": config.get_trainer_endpoint(), ## get trainer endpoint
"fanin": 0, ## get heter worker
"pserver_id": config.get_role_id(),
"distributed_mode": config.get_distributed_mode(),
"rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
}
# append the listen_and_serv op
program.global_block()._insert_op(
index=0,
type="heter_listen_and_serv",
inputs={'X': []},
outputs={},
attrs=attrs,
)
## TODO add check for bp block
check_op_device(program.global_block(), DEFAULT_DEVICE)
def insert_communicate_op(
orign_program,
config,
heter_block,
stage_id,
first_op_index,
block_var_detail,
device,
is_forward=True,
):
if is_forward:
next_heter_worker_endpoints = config.get_next_stage_trainers()
previous_heter_worker_endpoints = config.get_previous_stage_trainers()
entrance_var = block_var_detail[stage_id]["forward"]["entrance"]
comm_info = get_communicate_var_info(
orign_program, stage_id + 1, entrance_var
)
else:
next_heter_worker_endpoints = config.get_next_stage_trainers()
# if next_heter_worker_endpoints == "":
# next_heter_worker_endpoints = []
previous_heter_worker_endpoints = config.get_previous_stage_trainers()
entrance_var = block_var_detail[stage_id - 1]["backward"]["exit"]
comm_info = get_communicate_var_info(
orign_program, stage_id - 1, entrance_var, "backward"
)
heter_block._insert_op(
index=first_op_index,
type="send_and_recv",
inputs={"X": heter_block.vars[entrance_var[0]]},
outputs={"Out": []},
attrs={
"mode": "forward" if is_forward else "backward",
"send_var_name": entrance_var + ["microbatch_id"],
"recv_var_name": [],
"message_name": comm_info["block_input_var_name"],
"next_endpoints": next_heter_worker_endpoints,
"previous_endpoints": previous_heter_worker_endpoints,
"trainer_id": config.get_role_id(),
"op_device": device,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
},
)
return entrance_var
def create_backward_block(
program, origin_program, config, bp_ops_list, block_var_detail
):
pre_block_idx = program.num_blocks - 1
heter_block = program._create_block(pre_block_idx)
for _, op in enumerate(bp_ops_list):
if op.type == "send":
send_varnames = op.attr('send_varnames')
is_skip = False
for varname in send_varnames:
if (
varname not in program.global_block().vars
and varname not in heter_block.vars
):
is_skip = True
break
if is_skip == True:
continue
block_append_op(program, origin_program, heter_block, op)
entrance_vars = block_var_detail[0]["backward"]["entrance"]
add_vars_by_var_list(entrance_vars, origin_program, program, heter_block)
exit_vars = block_var_detail[0]["backward"]["exit"]
add_vars_by_var_list(exit_vars, origin_program, program, heter_block)
return heter_block
def replace_ops_by_communicate_op(
program, config, heter_block_index, ops_list, block_var_detail
):
all_op = program.global_block().ops
start_op = ops_list[0]
first_op_idx = -1
for op in all_op:
if is_same_op(op, start_op):
first_op_idx = all_op.index(op)
break
assert first_op_idx != -1
delete_same_ops(program.global_block(), ops_list)
entrance_var = []
if heter_block_index == 1:
mode = config.get_distributed_mode()
next_heter_worker_endpoints = config.get_next_stage_trainers()
entrance_var = block_var_detail[heter_block_index]["forward"][
"entrance"
]
comm_info = get_communicate_var_info(
program, heter_block_index + 1, entrance_var
)
program.global_block()._insert_op(
index=first_op_idx,
type="send_and_recv",
inputs={"X": program.global_block().vars[entrance_var[0]]},
outputs={"Out": []},
attrs={
"mode": "forward",
"send_var_name": entrance_var + ["microbatch_id"],
"recv_var_name": [],
"message_name": comm_info["block_input_var_name"],
"next_endpoints": next_heter_worker_endpoints,
"previous_endpoints": [],
"trainer_id": config.get_role_id(),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
},
)
return entrance_var
def remove_trainer_send_op(
program, config, heter_block_index, block_var_detail
):
# if trainer do FF->BP->SEND, it has follow vars: var, var@GRAD
# if trainer only do SEND, it has one var: var@GRAD
# Delete Send op ,if trainer doesn't has pair var (var<->var@GRAD)
persistables = (
block_var_detail[heter_block_index]["forward"]["persistables"]
+ block_var_detail[heter_block_index]["backward"]["persistables"]
)
need_remove_send_op = []
need_remove_grad_var = []
for op in find_send_op(program):
input_list, _ = find_op_input_output(
program, program.global_block(), op
)
for var_name in input_list:
origin_var_name = var_name.split("@GRAD")[0]
if origin_var_name in persistables:
need_remove_send_op.append(op)
need_remove_grad_var.append(var_name)
need_remove_send_op = list(set(need_remove_send_op))
delete_ops(program.global_block(), need_remove_send_op)
for grad_var_name in need_remove_grad_var:
config.remove_var_pair_by_grad(grad_var_name)
def add_heter_send_op(program, heter_program, block, block_var_detail):
def _get_send_op_dict():
send_op_dict = {}
send_op_list = find_send_op(program)
for op in send_op_list:
input_list, _ = find_op_input_output(
program, program.global_block(), op
)
for var in input_list:
send_op_dict[var] = op
return send_op_dict
# send_Op = { inputs{'X':[]},
# outputs{'Out':dummy_output},
# attrs{'send_varnames'"[]",
# 'is_sparse':int,
# 'table_id':int } }
send_grad_var_list = []
send_op_dict = _get_send_op_dict()
table_dict = {}
for persistable_var in block_var_detail["backward"]["persistables"]:
# check var_name == var@GRAD
if "@GRAD" not in persistable_var:
continue
if "GRAD" != persistable_var.split("@")[-1]:
continue
if persistable_var not in send_op_dict:
continue
send_op = send_op_dict[persistable_var]
is_sparse = send_op.attr('is_sparse')
table_id = send_op.attr('table_id')
send_varnames = send_op.attr('send_varnames')
send_grad_var_list.append(persistable_var)
if table_id not in table_dict:
table_dict[table_id] = {}
table_dict[table_id]['var_list'] = []
table_dict[table_id]['is_sparse'] = is_sparse
table_dict[table_id]['send_varnames'] = send_varnames
table_dict[table_id]['var_list'].append(persistable_var)
for table_id in table_dict:
dummy_output = block.create_var(
name=framework.generate_control_dev_var_name()
)
send_input_vars = [
block.vars[union_var]
for union_var in table_dict[table_id]['var_list']
]
block.append_op(
type="send",
inputs={"X": send_input_vars},
outputs={"Out": dummy_output},
attrs={
"send_varnames": table_dict[table_id]['send_varnames'],
"is_sparse": is_sparse,
"table_id": table_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
},
)
return send_grad_var_list
def find_send_op(program):
send_op_list = []
for op in program.global_block().ops:
if op.type == "send":
send_op_list.append(op)
return send_op_list
def get_communicate_var_info(
program, block_index, entrance_var_list, type="forward"
):
input_var_reshape_dim = []
input_var_reshape_name = []
if type == "forward":
block_input_var_name = "forward_joint_{}_{}@Heter".format(
block_index - 1, block_index
)
else:
block_input_var_name = "backward_joint_{}_{}@Heter".format(
block_index + 1, block_index
)
entrance_var_list.sort()
# input
# Heter_SERVER_BLOCK_index@JOINT_VAR -> slice -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> reshape -> var
for name in entrance_var_list:
var = program.global_block().vars[name]
shape = var.shape
# if len(shape) < 2 or shape[0] != -1:
# raise ValueError(
# "Variable {} not support heter training. its shape is {}".
# format(name, shape))
recv_var_dim = -1 * reduce(lambda x, y: x * y, shape)
input_var_reshape_dim.append(recv_var_dim)
input_var_reshape_name.append("{}.input_reshape@Heter".format(name))
# output
# var -> reshape -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> concat -> Heter_SERVER_BLOCK_index@JOINT_VAR
# for var_name in exit_var_list:
# var = program.global_block().vars[var_name]
# shape = var.shape
# # if len(shape) < 2 or shape[0] != -1:
# # raise ValueError(
# # "Variable {} not support heter training. its shape is {}".
# # format(var_name, shape))
# send_reshape_dim = -1 * reduce(lambda x, y: x * y, shape)
# output_var_reshape_dim.append(send_reshape_dim)
# output_var_reshape_name.append("{}.output_reshape@Heter".format(
# var_name))
info = {
"input_var_reshape_dim": input_var_reshape_dim,
"input_var_reshape_name": input_var_reshape_name,
"block_input_var_name": block_input_var_name,
# "output_var_reshape_dim": output_var_reshape_dim,
# "output_var_reshape_name": output_var_reshape_name,
# "block_output_var_name": block_output_var_name
}
return info
def union_forward_gradient_op(program_block_ops_list):
"""
before analyzing the input & output of each block in program_block_list, we should
union the forward op and corresponding gradient op to elimincate the unnecessary variable
transmit
"""
"""
fix for 2emb model, re-place sum op
"""
block_length = len(program_block_ops_list)
'''
## get the final part
final_part_idx = -1
for i in range(block_length):
op_list = program_block_ops_list[i]
for op in op_list:
if "_grad" in op.type:
final_part_idx = i
break
if final_part_idx != -1:
break
## eliminate wrong partition because of sum op
## lookup_table_v2_grad
## every looup_table_v2_grad op block should follow a sum op
var2idx = {}
for i in range(final_part_idx, block_length):
op_list = program_block_ops_list[i]
for j in range(len(op_list) - 1, -1, -1):
op = op_list[j]
#if op.type == "lookup_table_v2_grad":
# if j < len(op_list) - 1):
# else:
# ## get var and record place
if _grad in op.type:
forward_op_type = op.type.split("_grad")[0]
if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
param_name = op.input(SPARSE_OP_TYPE_DICT[forward_op_type])[0]
var2idx[] = [i,j] ##
'''
union_program_block_ops_list = []
assert (
block_length % 2 != 0
), "the length of program_block_ops_list should be odd"
for i in range(0, block_length // 2):
block_op_list = {"forward": program_block_ops_list[i]}
block_op_list.update(
{"backward": program_block_ops_list[block_length - 1 - i]}
)
union_program_block_ops_list.append(block_op_list)
block_op_list = {"forward": [], "backward": []}
for op in program_block_ops_list[block_length // 2]:
if not "_grad" in op.type and not (op.type == "sum"):
block_op_list["forward"].append(op)
else:
block_op_list["backward"].append(op)
union_program_block_ops_list.append(block_op_list)
return union_program_block_ops_list
def find_block_joints(program, program_block_ops_list, heter_ops):
block_var_detail = find_entrance_exit_private(
program, program_block_ops_list
)
block_var_detail = entrance_exit_check(
program, program_block_ops_list, block_var_detail, heter_ops
)
block_var_detail = delete_block_useless_exit(
program, program_block_ops_list, block_var_detail
)
return block_var_detail
def find_entrance_exit_private(program, program_block_ops_list):
block_var_detail = []
persistables = []
for index, block_op_list in enumerate(program_block_ops_list):
## forward
block_input, block_output = find_ops_list_input_output(
program, block_op_list["forward"]
)
persistables = screen_persistables(
program, block_input
) + screen_persistables(program, block_output)
# find entrance & exit
block_private_vars = list(set(block_input) & set(block_output))
block_entrance = list(set(block_input) - set(block_private_vars))
block_exit = list(set(block_output) - set(block_private_vars))
detail = {
"forward": {
"entrance": block_entrance,
"exit": block_exit,
"private": block_private_vars,
"persistables": persistables,
}
}
## backward
bp_block_input, bp_block_output = find_ops_list_input_output(
program, block_op_list["backward"]
)
bp_persistables = screen_persistables(
program, bp_block_input
) + screen_persistables(program, bp_block_output)
# find entrance & exit
bp_block_private_vars = list(set(bp_block_input) & set(bp_block_output))
bp_block_entrance = list(
set(bp_block_input) - set(bp_block_private_vars)
)
bp_block_exit = list(set(bp_block_output) - set(bp_block_private_vars))
detail.update(
{
"backward": {
"entrance": bp_block_entrance,
"exit": bp_block_exit,
"private": bp_block_private_vars,
"persistables": bp_persistables,
}
}
)
block_var_detail.append(detail)
return block_var_detail
def entrance_exit_check(
program, program_block_ops_list, block_var_detail, heter_ops
):
for index in range(len(block_var_detail) - 1, -1, -1):
if index - 1 < 0:
break
previous_block_exit = block_var_detail[index - 1]["forward"]["exit"]
previous_block_exit.sort()
current_block_entrance = block_var_detail[index]["forward"]["entrance"]
backward_entrance = block_var_detail[index]["backward"]["entrance"]
forward_all = (
block_var_detail[index]["forward"]["entrance"]
+ block_var_detail[index]["forward"]["private"]
+ block_var_detail[index]["forward"]["exit"]
)
for var in backward_entrance:
if not ("@GRAD" in var) and not (var in forward_all):
current_block_entrance.append(var)
current_block_entrance.sort()
if previous_block_exit == current_block_entrance:
continue
exist_vars = list(
set(previous_block_exit) & set(current_block_entrance)
)
need_add_vars = list(set(current_block_entrance) - set(exist_vars))
# var in different stage should not be ignored, since they are not placed in the same program & device
# need_add_vars = find_need_var_from_previous_block(
# need_add_vars, block_var_detail, index, heter_ops)
previous_block_private = block_var_detail[index - 1]["forward"][
"private"
]
previous_block_entrance = block_var_detail[index - 1]["forward"][
"entrance"
]
for var in need_add_vars:
if (
var not in previous_block_private
and var not in previous_block_entrance
):
previous_block_entrance.append(var)
previous_block_exit.append(var)
if not var in current_block_entrance:
current_block_entrance.append(var)
for index in range(0, len(block_var_detail) - 1, 1):
previous_block_exit = block_var_detail[index + 1]["backward"]["exit"]
previous_block_exit.sort()
current_block_entrance = block_var_detail[index]["backward"]["entrance"]
current_block_entrance.sort()
if previous_block_exit == current_block_entrance:
continue
exist_vars = list(
set(previous_block_exit) & set(current_block_entrance)
)
need_add_vars = list(set(current_block_entrance) - set(exist_vars))
need_ignore_vars = []
for var in need_add_vars:
if not "@GRAD" in var:
need_ignore_vars.append(var)
need_add_vars = list(
set(need_add_vars).difference(set(need_ignore_vars))
)
previous_block_private = block_var_detail[index + 1]["backward"][
"private"
]
previous_block_entrance = block_var_detail[index + 1]["backward"][
"entrance"
]
for var in need_add_vars:
if (
var not in previous_block_private
and var not in previous_block_entrance
):
previous_block_entrance.append(var)
previous_block_exit.append(var)
return block_var_detail
def find_need_var_from_previous_block(
need_add_vars, block_var_detail, current_index, heter_ops
):
# create index_device_map
index_device_map = {}
for index in range(len(block_var_detail)):
index_device_map[index] = DEFAULT_DEVICE
for device in heter_ops:
for index in heter_ops[device].keys():
if index < len(block_var_detail):
index_device_map[index] = device
pre_index = current_index - 1
need_ignore_var = []
# if need_add_var in current device, no need communicate
for var in need_add_vars:
while pre_index >= 0:
previous_block_private = block_var_detail[pre_index]["private"]
previous_block_exit = block_var_detail[pre_index]["exit"]
previous_block_entrance = block_var_detail[pre_index]["entrance"]
total_var = (
previous_block_private
+ previous_block_exit
+ previous_block_entrance
)
if var in total_var:
if (
index_device_map[current_index]
== index_device_map[pre_index]
and index_device_map[current_index] == DEFAULT_DEVICE
):
need_ignore_var.append(var)
break
pre_index -= 1
need_add_vars = list(set(need_add_vars).difference(set(need_ignore_var)))
return need_add_vars
def delete_block_useless_exit(
program, program_block_ops_list, block_var_detail
):
## forward
for index in range(len(block_var_detail)):
if index == len(block_var_detail) - 1:
break
current_block_exit = block_var_detail[index]["forward"]["exit"]
next_block_entrance = block_var_detail[index + 1]["forward"]["entrance"]
need_delete_var = []
for var in current_block_exit:
if var not in next_block_entrance:
need_delete_var.append(var)
for var in need_delete_var:
current_block_exit.remove(var)
## backward
for index in range(len(block_var_detail) - 1, -1, -1):
if index - 1 < 0:
break
current_block_exit = block_var_detail[index]["backward"]["exit"]
next_block_entrance = block_var_detail[index - 1]["backward"][
"entrance"
]
need_delete_var = []
for var in current_block_exit:
if var not in next_block_entrance:
need_delete_var.append(var)
for var in need_delete_var:
current_block_exit.remove(var)
return block_var_detail
def check_op_device(block, device):
for op in block.ops:
op._set_attr('op_device', device)
def screen_persistables(program, var_list):
need_remove = []
for var_name in var_list:
if "@GRAD" in var_name:
if "GRAD" != var_name.split("@")[-1]:
continue
origin_var_name = var_name.split("@GRAD")[0]
var = program.global_block().vars[origin_var_name]
else:
var = program.global_block().vars[var_name]
if paddle.static.is_persistable(var):
need_remove.append(var_name)
for var_name in need_remove:
var_list.remove(var_name)
return need_remove
def insert_reshape_op(
program, block, index, var_name, new_var_name, new_var_shape=None
):
input_var = block.vars[var_name]
if new_var_name not in block.vars:
out = block.create_var(
name=new_var_name,
shape=new_var_shape,
dtype=input_var.dtype,
type=input_var.type,
)
else:
out = block.vars[new_var_name]
new_var_shape = out.shape
x_shape = block.create_var(
name="{}.xshape@Heter".format(var_name), dtype=input_var.dtype
)
block._insert_op(
index=index,
type="reshape2",
inputs={"X": input_var},
attrs={'shape': new_var_shape},
outputs={"Out": out, "XShape": x_shape},
)
def insert_send_concat_op(
program, block, index, var_name_list, new_var_name, new_var_shape
):
input_var_list = [block.vars[var_name] for var_name in var_name_list]
out = program.global_block().create_var(
name=new_var_name,
shape=new_var_shape,
dtype=input_var_list[0].dtype,
type=input_var_list[0].type,
)
block._insert_op(
index=index,
type='concat',
inputs={"X": input_var_list},
outputs={'Out': [out]},
attrs={'axis': -1, 'use_stack': False},
)
def insert_recv_slice_op(
program,
block,
index,
var_name,
var_shape,
dtype,
type,
new_var_name_list,
new_var_shape_list,
):
if var_name not in program.global_block().vars:
input_var = program.global_block().create_var(
name=var_name, shape=var_shape, dtype=dtype, type=type
)
else:
input_var = program.global_block().vars[var_name]
out_list = []
for i in range(len(new_var_name_list)):
if new_var_name_list[i] not in block.vars:
out = block.create_var(
name=new_var_name_list[i],
shape=new_var_shape_list[i],
dtype=input_var.dtype,
type=input_var.type,
)
else:
out = block.vars[new_var_name_list[i]]
out_list.append(out)
start_index = 0
end_index = 0
for i in range(len(new_var_name_list)):
starts = []
ends = []
attrs = {'axes': [1]}
end_index += new_var_shape_list[i][1]
starts.append(start_index)
ends.append(end_index)
attrs['starts'] = starts
attrs['ends'] = ends
block._insert_op(
index=index,
type='slice',
inputs={'Input': input_var},
attrs=attrs,
outputs={'Out': out_list[i]},
)
start_index = end_index
index += 1
def add_heter_trainer_useful_vars(
config, program, heter_program, heter_block, static_var
):
static_var = list(set(static_var))
for var_name in static_var:
if (
var_name not in heter_program.global_block().vars
and var_name not in heter_block.vars
):
var = program.global_block().vars[var_name]
if var.persistable:
heter_program.global_block()._clone_variable(
var, force_persistable=False
)
else:
heter_block._clone_variable(var, force_persistable=False)
def delete_trainer_useless_var(config, program, static_var):
static_var = list(set(static_var))
program_useful_var_list = []
for op in program.global_block().ops:
input_var_list, output_var_list = find_op_input_output(
program, program.global_block(), op
)
op_var_list = list(set(input_var_list).union(set(output_var_list)))
program_useful_var_list = list(
set(program_useful_var_list).union(set(op_var_list))
)
program_useful_var_list += static_var
program_useless_var_list = list(
set(get_vars_name_in_block(program.global_block())).difference(
set(program_useful_var_list)
)
)
for var in program_useless_var_list:
program.global_block()._remove_var(var)
return program_useless_var_list
def block_append_op(program, origin_program, block, op):
merge_ordereddict = origin_program.global_block().vars.copy()
merge_ordereddict.update(block.vars)
inputs = _get_input_map_from_op(merge_ordereddict, op)
for key, varlist in inputs.items():
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if (
var.name not in program.global_block().vars
and var.name not in block.vars
):
if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False
)
else:
block._clone_variable(var, force_persistable=False)
outputs = _get_output_map_from_op(origin_program.global_block().vars, op)
for key, varlist in outputs.items():
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if (
var.name not in program.global_block().vars
and var.name not in block.vars
):
if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False
)
else:
block._clone_variable(var, force_persistable=False)
if "_grad" not in op.type:
# for forward op
return block.append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=op.all_attrs()
)
else:
# for grad op
op_desc = op.desc
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
# append grad op
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op_desc)
new_op_desc._set_attr(op_role_attr_name, backward)
# set device gard
if op.desc.has_attr(device_attr_name):
op_device = op_desc.attr(device_attr_name)
new_op_desc._set_attr(device_attr_name, op_device)
block._sync_with_cpp()
def add_vars_by_var_list(var_name_list, origin_program, program, block):
for var_name in var_name_list:
if (
var_name not in program.global_block().vars
and var_name not in block.vars
):
var = origin_program.global_block().vars[var_name]
if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False
)
else:
block._clone_variable(var, force_persistable=False)
def get_varlist_from_op_map(var_map):
var_list = []
for key, varlist in var_map.items():
if not isinstance(varlist, list):
varlist = [varlist]
for i in range(len(varlist)):
var = varlist[i]
var_list.append(var.name)
return var_list
def find_ops_list_input_output(program, ops_list):
input_var_list = []
output_var_list = []
for op in ops_list:
inputs = _get_input_map_from_op(program.global_block().vars, op)
input_var_list += get_varlist_from_op_map(inputs)
outputs = _get_output_map_from_op(program.global_block().vars, op)
output_var_list += get_varlist_from_op_map(outputs)
input_var_list = list(set(input_var_list))
output_var_list = list(set(output_var_list))
return input_var_list, output_var_list
def find_op_input_output(program, block, op):
input_var_list = []
output_var_list = []
inputs = _get_input_map_from_op(block.vars, op)
input_var_list += get_varlist_from_op_map(inputs)
outputs = _get_output_map_from_op(block.vars, op)
output_var_list += get_varlist_from_op_map(outputs)
input_var_list = list(set(input_var_list))
output_var_list = list(set(output_var_list))
return input_var_list, output_var_list
def get_vars_name_in_block(block):
vars_list = block.vars.keys()
vars_name_list = [var_name for var_name in vars_list]
return vars_name_list
def is_same_op(op1, op2):
if str(op1) != str(op2):
return False
return True
def _get_input_map_from_op(varmap, op):
"""Returns a dict from op input name to the vars in varmap."""
iomap = collections.OrderedDict()
for key in op.input_names:
vars = []
for varname in op.input(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def _get_output_map_from_op(varmap, op):
"""Returns a dict from op output name to the vars in varmap."""
iomap = collections.OrderedDict()
for key in op.output_names:
vars = []
for varname in op.output(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def delete_same_ops(block, ops):
for op in ops:
try:
for origin_op in block.ops:
if is_same_op(origin_op, op):
idx = list(block.ops).index(origin_op)
block._remove_op(idx)
break
except Exception as e:
print(e)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class UnionFind:
"""Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def __init__(self, elementes=None):
self._parents = [] # index -> parent index
self._index = {} # element -> index
self._curr_idx = 0
if not elementes:
elementes = []
for ele in elementes:
self._parents.append(self._curr_idx)
self._index.update({ele: self._curr_idx})
self._curr_idx += 1
def find(self, x):
# Find the root index of given element x,
# execute the path compress while findind the root index
if not x in self._index:
return -1
idx = self._index[x]
while idx != self._parents[idx]:
t = self._parents[idx]
self._parents[idx] = self._parents[t]
idx = t
return idx
def union(self, x, y):
# Union two given element
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
self._parents[x_root] = y_root
def is_connected(self, x, y):
# If two given elements have the same root index,
# then they are connected.
return self.find(x) == self.find(y)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
from paddle.framework.io import Variable
from paddle.framework import core
dtype_to_size = {
core.VarDesc.VarType.FP16: 2,
core.VarDesc.VarType.FP32: 4,
core.VarDesc.VarType.FP64: 8,
core.VarDesc.VarType.INT16: 2,
core.VarDesc.VarType.INT32: 4,
core.VarDesc.VarType.INT64: 8,
core.VarDesc.VarType.BOOL: 1,
core.VarDesc.VarType.UINT8: 1,
}
class VarBlock:
def __init__(self, varname, offset, size):
self.varname = varname
# NOTE: real offset is offset * size
self.offset = offset
self.size = size
def __str__(self):
return "%s:%d:%d" % (self.varname, self.offset, self.size)
def create_var_struct(var):
if var.type == core.VarDesc.VarType.SELECTED_ROWS:
lod_level = None
elif var.type == core.VarDesc.VarType.LOD_TENSOR:
lod_level = var.lod_level
else:
raise ValueError("can only support SELECTED_ROWS/LOD_TENSOR now")
return VarStruct(
var.name, var.shape, var.dtype, var.type, lod_level, var.persistable
)
class VarStruct:
"""
record part properties of a Variable in python.
"""
def __init__(self, name, shape, dtype, type, lod_level, persistable):
self.name = name
self.shape = shape
self.dtype = dtype
self.type = type
self.lod_level = lod_level
self.persistable = persistable
self.m_size = 1
self.m_size = reduce(lambda x, y: x * y, shape)
self.m_size *= dtype_to_size[dtype]
def __str__(self):
return "N: {}, S: {}, D: {}, T: {}, LL: {}, P: {}, M: {}".format(
self.name,
self.shape,
self.dtype,
self.type,
self.lod_level,
self.persistable,
self.m_size,
)
class VarDistributed:
"""
a class to record the var distributed on parameter servers.
the class will record the relationship between origin var and slice var.
the slice var's properties, such as type/shape/offset/endpoint.
"""
def __init__(
self,
origin_var,
slice_var,
is_slice=None,
block_id=None,
offset=None,
vtype=None,
endpoint=None,
):
"""
Args:
origin_var(Variable|VarStruct): origin var properties
slice_var(Variable|VarStruct): slice var properties
is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard.
block_id(int|None): the number about the slice var.
offset(int|None): if the slice var is sliced, offset is the numel before the var.
vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch.
endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001"
"""
if isinstance(origin_var, Variable):
self.origin = create_var_struct(origin_var)
else:
self.origin = origin_var
if isinstance(slice_var, Variable):
self.slice = create_var_struct(slice_var)
else:
self.slice = slice_var
if self.equal(self.origin, self.slice):
self.is_slice = False
self.block_id = 0
self.offset = 0
else:
self.is_slice = True
self.block_id = 0
self.offset = 0
if is_slice is not None:
self.is_slice = is_slice
if block_id is not None:
self.block_id = block_id
if offset is not None:
self.offset = offset
self.vtype = vtype
self.endpoint = endpoint
@staticmethod
def equal(var1, var2):
"""
the two var is equal or not.
Returns:
bool: equal will return True else False
"""
assert isinstance(var1, VarStruct) and isinstance(var2, VarStruct)
return (
var1.name == var2.name
and var1.type == var2.type
and var1.shape == var2.shape
and var1.dtype == var2.dtype
and var1.lod_level == var2.lod_level
and var1.persistable == var2.persistable
)
def __str__(self):
origin_var_str = (
"{name} : fluid.{type}.shape{shape}.astype({dtype})".format(
i="{",
e="}",
name=self.origin.name,
type=self.origin.type,
shape=self.origin.shape,
dtype=self.origin.dtype,
)
)
slice_var_str = (
"{name} : fluid.{type}.shape{shape}.astype({dtype})"
".slice({is_slice}).block({block_id}).offset({offset})".format(
i="{",
e="}",
name=self.slice.name,
type=self.slice.type,
shape=self.slice.shape,
dtype=self.slice.dtype,
is_slice=self.is_slice,
block_id=self.block_id,
offset=self.offset,
)
)
return "var owned: {}, origin var: ( {} ), slice var: ( {} ), endpoint: {} ".format(
self.vtype, origin_var_str, slice_var_str, self.endpoint
)
class VarsDistributed:
"""
a gather about VarDistributed with many methods to find distributed vars.
through the class, we can get overview about the distributed parameters on parameter servers.
this class may centralized and convenient for developer to manage and get variable's distribute.
other module can also use this to find variables such io.py.
"""
def __init__(self):
self.distributed_vars = []
def add_distributed_var(
self,
origin_var,
slice_var,
is_slice=None,
block_id=None,
offset=None,
vtype=None,
endpoint=None,
):
"""
add distributed var in this.
Args:
origin_var(Variable|VarStruct): origin var properties
slice_var(Variable|VarStruct): slice var properties
is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard.
block_id(int|None): the number about the slice var.
offset(int|None): if the slice var is sliced, offset is the numel before the var.
vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch.
endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001"
Returns:
None
"""
self.distributed_vars.append(
VarDistributed(
origin_var,
slice_var,
is_slice,
block_id,
offset,
vtype,
endpoint,
)
)
...@@ -444,7 +444,7 @@ class DnnTrainer: ...@@ -444,7 +444,7 @@ class DnnTrainer:
print( print(
"entering run {} - old".format(str(config["applied_pass_name"])) "entering run {} - old".format(str(config["applied_pass_name"]))
) )
from paddle.fluid.incubate.fleet.parameter_server.ir import ( from paddle.incubate.fleet.parameter_server.ir import (
public as public, public as public,
) )
...@@ -458,7 +458,7 @@ class DnnTrainer: ...@@ -458,7 +458,7 @@ class DnnTrainer:
_main = compiled_config.origin_main_program.clone() _main = compiled_config.origin_main_program.clone()
_startup = compiled_config.origin_startup_program.clone() _startup = compiled_config.origin_startup_program.clone()
from paddle.fluid.incubate.fleet.parameter_server.ir import ( from paddle.incubate.fleet.parameter_server.ir import (
trainer_pass as worker, trainer_pass as worker,
) )
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
from paddle.fluid.incubate.fleet.parameter_server.ir.pserver_pass import ( from paddle.incubate.fleet.parameter_server.ir.pserver_pass import (
_get_optimizer_input_shape, _get_optimizer_input_shape,
) )
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import ( from paddle.incubate.fleet.parameter_server.ir.ps_dispatcher import (
HashName, HashName,
PSDispatcher, PSDispatcher,
RoundRobin, RoundRobin,
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
import collections import collections
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.framework import core
from paddle.incubate.fleet.parameter_server.ir.public import (
_get_lr_ops, _get_lr_ops,
_get_optimize_ops, _get_optimize_ops,
_get_varname_parts, _get_varname_parts,
...@@ -23,7 +24,6 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( ...@@ -23,7 +24,6 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
get_sparse_tablenames, get_sparse_tablenames,
is_distributed_sparse_op, is_distributed_sparse_op,
) )
from paddle.framework import core
LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@" LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
......
...@@ -19,12 +19,10 @@ import warnings ...@@ -19,12 +19,10 @@ import warnings
from functools import reduce from functools import reduce
import paddle import paddle
from paddle.fluid.incubate.fleet.parameter_server.ir import vars_metatools
from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import (
RoundRobin,
)
from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode
from paddle.framework import core from paddle.framework import core
from paddle.incubate.fleet.parameter_server.ir import vars_metatools
from paddle.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundRobin
OP_NAME_SCOPE = "op_namescope" OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "gradient_clip" CLIP_OP_NAME_SCOPE = "gradient_clip"
......
...@@ -21,12 +21,12 @@ from functools import reduce ...@@ -21,12 +21,12 @@ from functools import reduce
import paddle import paddle
import paddle.framework as framework import paddle.framework as framework
from paddle.distributed.transpiler.details.program_utils import delete_ops from paddle.distributed.transpiler.details.program_utils import delete_ops
from paddle.fluid.incubate.fleet.parameter_server.ir.public import ( from paddle.framework import core
from paddle.incubate.fleet.parameter_server.ir.public import (
_get_lr_ops, _get_lr_ops,
_get_optimize_ops, _get_optimize_ops,
get_sparse_tablenames, get_sparse_tablenames,
) )
from paddle.framework import core
from paddle.incubate.fleet.parameter_server.mode import DistributedMode from paddle.incubate.fleet.parameter_server.mode import DistributedMode
OP_NAME_SCOPE = "op_namescope" OP_NAME_SCOPE = "op_namescope"
......
...@@ -407,7 +407,6 @@ packages=['paddle', ...@@ -407,7 +407,6 @@ packages=['paddle',
'paddle.fluid.incubate.fleet.base', 'paddle.fluid.incubate.fleet.base',
'paddle.fluid.incubate.fleet.collective', 'paddle.fluid.incubate.fleet.collective',
'paddle.fluid.incubate.fleet.utils', 'paddle.fluid.incubate.fleet.utils',
'paddle.fluid.incubate.fleet.parameter_server.ir',
'paddle.fluid.incubate.fleet.parameter_server', 'paddle.fluid.incubate.fleet.parameter_server',
'paddle.amp', 'paddle.amp',
'paddle.cost_model', 'paddle.cost_model',
......
...@@ -1301,7 +1301,6 @@ def get_setup_parameters(): ...@@ -1301,7 +1301,6 @@ def get_setup_parameters():
'paddle.fluid.incubate.fleet.collective', 'paddle.fluid.incubate.fleet.collective',
'paddle.fluid.incubate.fleet.utils', 'paddle.fluid.incubate.fleet.utils',
'paddle.fluid.incubate.fleet.parameter_server', 'paddle.fluid.incubate.fleet.parameter_server',
'paddle.fluid.incubate.fleet.parameter_server.ir',
'paddle.amp', 'paddle.amp',
'paddle.cost_model', 'paddle.cost_model',
'paddle.hapi', 'paddle.hapi',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册