未验证 提交 29c63d18 编写于 作者: W Wu Yi 提交者: GitHub

[Feature] dist op role and lr op role, to support memory optimize with dist training (#13220)

* wip

* clean up

* should fix running with memopt

* add ut

* mark lr schedule op role

* hide lr_schedule_guard

* use op_role_var instead of ufind

* unify dist test name

* wip for py3 support

* fix var deref

* fix python3 mem_opt order

* remove comments
上级 2d97903a
......@@ -210,43 +210,6 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
return recv_vars;
}
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const {
if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false;
}
/**
* Check any of opvars contains `.block` and in sendvars
*/
auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) {
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if (var.find(".block") != std::string::npos &&
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true;
}
}
return false;
};
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Name());
}
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Name());
}
return checker(output_var_names, send_vars) ||
checker(input_var_names, recv_vars);
}
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
......@@ -370,7 +333,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
}
}
is_dist_train = true;
} else if (IsDistTrainOp(node, send_vars, recv_vars)) {
} else if (boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(&result, node);
if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
......@@ -736,6 +701,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
.emplace(varname, op_dev_id);
}
} else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
PADDLE_THROW(
"the distribute training related op should be in [split_byref, "
"concat].");
......
......@@ -51,12 +51,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
/**
* Is this operator as the end-point operator before/after send operator.
*/
bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const;
......
......@@ -120,6 +120,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
{static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
static_cast<int>(OpRole::kDist), static_cast<int>(OpRole::kLRSched),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward),
......
......@@ -26,7 +26,13 @@ enum class OpRole {
kForward = 0x0000,
kBackward = 0x0001,
kOptimize = 0x0002,
// RPC role is for send/recv releated op
kRPC = 0x0003,
// Dist role is for split_byref/split_selected_rows/concat
// used for distributed training.
kDist = 0x0004,
// Tag all learning rate scheduler operators.
kLRSched = 0x0005,
kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and
......
......@@ -92,9 +92,14 @@ bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto server_var = GetVar();
if (!server_var) {
LOG(ERROR) << "recved var should not on current server: "
<< meta_.varname();
return false;
}
auto* tensor = GetVar()->GetMutable<framework::LoDTensor>();
tensor->Resize(dims);
framework::LoD lod;
for (int i = 0; i < meta_.lod_level(); ++i) {
framework::Vector<size_t> v;
......@@ -107,7 +112,6 @@ bool VariableResponse::CopyLodTensorData(
void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false;
}
......
......@@ -36,7 +36,9 @@ void BindConstValue(pybind11::module* m) {
.value("Backward", framework::OpRole::kBackward)
.value("Optimize", framework::OpRole::kOptimize)
.value("Loss", framework::OpRole::kLoss)
.value("RPC", framework::OpRole::kRPC);
.value("RPC", framework::OpRole::kRPC)
.value("Dist", framework::OpRole::kDist)
.value("LRSched", framework::OpRole::kLRSched);
op_proto_and_checker_maker.def(
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
......
......@@ -1509,6 +1509,30 @@ class Program(object):
self._op_role_var = []
self._current_role = OpRole.Forward
@contextlib.contextmanager
def _lr_schedule_guard(self):
"""
A with guard to set :code:`LRSched` :code:`OpRole` and
:code:`OpRoleVar` automatically. The :code:`OpRoleVar` is
set to the target learning rate.
Notes: This is a very low level API. Users should not use it directly.
Examples:
>>> p, g = backward(...)
>>> with program.lr_schedule_guard():
>>> lr = lr * decay
"""
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.LRSched
# TODO(typhoonzero): how to set target learning rate var
self._op_role_var = []
yield
self._op_role_var = []
self._current_role = OpRole.Forward
def __str__(self):
"""
Get the protobuf debug string of this Program.
......
......@@ -27,7 +27,7 @@ from . import nn
from . import ops
from . import tensor
from ..initializer import init_on_cpu
from ..framework import default_main_program, Parameter
from ..framework import default_main_program, Parameter, unique_name
__all__ = [
'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
......@@ -63,11 +63,12 @@ def noam_decay(d_model, warmup_steps):
Returns:
The decayed learning rate.
"""
global_step = _decay_step_counter(1)
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter(1)
a = global_step**-0.5
b = (warmup_steps**-1.5) * global_step
lr_value = (d_model**-0.5) * ops.elementwise_min(a, b)
a = global_step**-0.5
b = (warmup_steps**-1.5) * global_step
lr_value = (d_model**-0.5) * ops.elementwise_min(a, b)
return lr_value
......@@ -108,14 +109,15 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
sgd_optimizer.minimize(avg_cost)
"""
global_step = _decay_step_counter()
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
decayed_lr = learning_rate * (decay_rate**div_res)
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
decayed_lr = learning_rate * (decay_rate**div_res)
return decayed_lr
return decayed_lr
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
......@@ -136,14 +138,15 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
Returns:
The decayed learning rate
"""
global_step = _decay_step_counter()
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
return decayed_lr
return decayed_lr
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
......@@ -181,15 +184,16 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
staircase=True))
sgd_optimizer.minimize(avg_cost)
"""
global_step = _decay_step_counter()
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
decayed_lr = learning_rate / (1 + decay_rate * div_res)
decayed_lr = learning_rate / (1 + decay_rate * div_res)
return decayed_lr
return decayed_lr
def polynomial_decay(learning_rate,
......@@ -220,25 +224,28 @@ def polynomial_decay(learning_rate,
Returns:
Variable: The decayed learning rate
"""
global_step = _decay_step_counter()
if cycle:
div_res = ops.ceil(global_step / decay_steps)
zero_var = tensor.fill_constant(shape=[1], dtype='float32', value=0.0)
one_var = tensor.fill_constant(shape=[1], dtype='float32', value=1.0)
with control_flow.Switch() as switch:
with switch.case(global_step == zero_var):
tensor.assign(input=one_var, output=div_res)
decay_steps = decay_steps * div_res
else:
decay_steps_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(decay_steps))
global_step = ops.elementwise_min(x=global_step, y=decay_steps_var)
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
if cycle:
div_res = ops.ceil(global_step / decay_steps)
zero_var = tensor.fill_constant(
shape=[1], dtype='float32', value=0.0)
one_var = tensor.fill_constant(
shape=[1], dtype='float32', value=1.0)
with control_flow.Switch() as switch:
with switch.case(global_step == zero_var):
tensor.assign(input=one_var, output=div_res)
decay_steps = decay_steps * div_res
else:
decay_steps_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(decay_steps))
global_step = ops.elementwise_min(x=global_step, y=decay_steps_var)
decayed_lr = (learning_rate - end_learning_rate) * \
((1 - global_step / decay_steps) ** power) + end_learning_rate
return decayed_lr
decayed_lr = (learning_rate - end_learning_rate) * \
((1 - global_step / decay_steps) ** power) + end_learning_rate
return decayed_lr
def piecewise_decay(boundaries, values):
......@@ -266,34 +273,36 @@ def piecewise_decay(boundaries, values):
"""
with default_main_program()._lr_schedule_guard():
if len(values) - len(boundaries) != 1:
raise ValueError("len(values) - len(boundaries) should be 1")
if len(values) - len(boundaries) != 1:
raise ValueError("len(values) - len(boundaries) should be 1")
global_step = _decay_step_counter()
global_step = _decay_step_counter()
lr = tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
lr = tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
with control_flow.Switch() as switch:
for i in range(len(boundaries)):
boundary_val = tensor.fill_constant(
with control_flow.Switch() as switch:
for i in range(len(boundaries)):
boundary_val = tensor.fill_constant(
shape=[1],
dtype='float32',
value=float(boundaries[i]),
force_cpu=True)
value_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(global_step < boundary_val):
tensor.assign(value_var, lr)
last_value_var = tensor.fill_constant(
shape=[1],
dtype='float32',
value=float(boundaries[i]),
force_cpu=True)
value_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(global_step < boundary_val):
tensor.assign(value_var, lr)
last_value_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(values[len(values) - 1]))
with switch.default():
tensor.assign(last_value_var, lr)
value=float(values[len(values) - 1]))
with switch.default():
tensor.assign(last_value_var, lr)
return lr
......
......@@ -22,7 +22,7 @@ class TestDistMnist2x2(TestDistBase):
self._sync_mode = True
self._use_reduce = False
def test_se_resnext(self):
def test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=1e-7)
......@@ -31,7 +31,7 @@ class TestDistMnist2x2WithMemopt(TestDistBase):
self._sync_mode = True
self._mem_opt = True
def test_se_resnext(self):
def test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=1e-7)
......@@ -40,7 +40,7 @@ class TestDistMnistAsync(TestDistBase):
self._sync_mode = False
self._use_reduce = False
def test_se_resnext(self):
def test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=200)
......
......@@ -21,7 +21,16 @@ class TestDistSeResneXt2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
def test_se_resnext(self):
def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=1e-7)
class TestDistseResnXt2x2WithMemopt(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._mem_opt = True
def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=1e-7)
......@@ -29,7 +38,7 @@ class TestDistSeResneXt2x2Async(TestDistBase):
def _setup_config(self):
self._sync_mode = False
def test_se_resnext(self):
def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100)
......
......@@ -59,7 +59,7 @@ class TestDistTransformer2x2Sync(TestDistBase):
def _setup_config(self):
self._sync_mode = True
def test_transformer(self):
def test_dist_train(self):
download_files()
self.check_with_place("dist_transformer.py", delta=1e-5)
......@@ -68,7 +68,7 @@ class TestDistTransformer2x2Async(TestDistBase):
def _setup_config(self):
self._sync_mode = False
def test_transformer(self):
def test_dist_train(self):
download_files()
self.check_with_place("dist_transformer.py", delta=1.0)
......
......@@ -17,19 +17,28 @@ import unittest
from test_dist_base import TestDistBase
class TestDistSeResneXt2x2(TestDistBase):
class TestDistW2V2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
def test_se_resnext(self):
def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=1e-4)
class TestDistSeResneXt2x2Async(TestDistBase):
class TestDistW2V2x2WithMemOpt(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._mem_opt = True
def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=1e-4)
class TestDistW2V2x2Async(TestDistBase):
def _setup_config(self):
self._sync_mode = False
def test_se_resnext(self):
def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=1)
......
......@@ -21,13 +21,12 @@ import paddle
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block._remove_op(start) for _ in six.moves.range(end - start + 1)]
except Exception as e:
raise e
block.program._sync_with_cpp()
for op in ops:
try:
idx = list(block.ops).index(op)
block._remove_op(idx)
except Exception as e:
print(e)
def find_op_by_input_arg(block, arg_name):
......@@ -37,10 +36,18 @@ def find_op_by_input_arg(block, arg_name):
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
def find_op_by_output_arg(block, arg_name, reverse=False):
if reverse:
pos = len(block.ops) - 1
while pos >= 0:
op = block.ops[pos]
if arg_name in op.output_arg_names:
return pos
pos -= 1
else:
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
......
......@@ -50,6 +50,15 @@ OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
PRINT_LOG = False
def log(*args):
if PRINT_LOG:
print(args)
class VarBlock:
......@@ -127,6 +136,7 @@ class DistributeTranspilerConfig(object):
slice_var_up = True
split_method = None
min_block_size = 8192
print_log = False
class DistributeTranspiler(object):
......@@ -174,6 +184,9 @@ class DistributeTranspiler(object):
if self.config.split_method is None:
self.config.split_method = RoundRobin
global PRINT_LOG
if self.config.print_log:
PRINT_LOG = True
assert (self.config.min_block_size >= 8192)
assert (self.config.split_method.__bases__[0] == PSDispatcher)
......@@ -257,12 +270,12 @@ class DistributeTranspiler(object):
splited_grad_varname = grad_varname
if len(splited_vars) == 1:
splited_grad_varname = splited_vars[0].name
index = find_op_by_output_arg(program.global_block(),
splited_grad_varname)
index = find_op_by_output_arg(
program.global_block(), splited_grad_varname, reverse=True)
elif len(splited_vars) > 1:
orig_var = program.global_block().vars[splited_grad_varname]
index = find_op_by_output_arg(program.global_block(),
splited_grad_varname)
index = find_op_by_output_arg(
program.global_block(), splited_grad_varname, reverse=True)
self._insert_split_op(program, orig_var, index, splited_vars)
index += 1
else:
......@@ -301,7 +314,7 @@ class DistributeTranspiler(object):
self.grad_name_to_send_dummy_out[
self.table_name] = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
input_deps = self.grad_name_to_send_dummy_out.values()
input_deps = list(self.grad_name_to_send_dummy_out.values())
program.global_block().append_op(
type="send_barrier",
......@@ -377,7 +390,10 @@ class DistributeTranspiler(object):
type="concat",
inputs={"X": splited_var},
outputs={"Out": [orig_param]},
attrs={"axis": 0})
attrs={
"axis": 0,
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
......@@ -496,9 +512,9 @@ class DistributeTranspiler(object):
# NOTE: assume blocks of the same variable is not distributed
# on the same pserver, only change param/grad varnames for
# trainers to fetch.
sys.stderr.write("get_pserver_program() is deprecated, call\
get_pserver_programs() to get pserver main and startup\
in a single call.")
sys.stderr.write("get_pserver_program() is deprecated, call \
get_pserver_programs() to get pserver main and startup \
in a single call.")
# step1
pserver_program = Program()
pserver_program.random_seed = self.origin_program.random_seed
......@@ -615,22 +631,31 @@ class DistributeTranspiler(object):
for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program._create_block(pre_block_idx)
optimize_blocks.append(per_opt_block)
optimize_target_param_name = opt_op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
# append grad merging ops before clip and weight decay
# cases may like:
# L2Decay op -> clip op -> optimize
# e.g. merge grad -> L2Decay op -> clip op -> optimize
merged_var = None
for _, op in enumerate(self.optimize_ops):
# find the origin @GRAD var before clipping
grad_varname_for_block = __op_have_grad_input__(op)
if ufind.is_connected(op, opt_op) and grad_varname_for_block:
# find the origin grad var before clipping/L2Decay,
# merged_var should be the input var name of L2Decaybuil
grad_varname_for_block = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
if op.attr(OP_ROLE_VAR_ATTR_NAME)[
0] == optimize_target_param_name:
merged_var = self._append_pserver_grad_merge_ops(
per_opt_block, grad_varname_for_block, endpoint,
grad_to_block_id, self.origin_program)
break # append optimize op once then append other ops.
for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block, grad_to_block_id,
merged_var, lr_ops)
if merged_var:
break # append optimize op once then append other ops.
if merged_var:
for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself
if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name and \
op not in global_ops:
log("append opt op: ", op.type, op.input_arg_names,
merged_var)
__append_optimize_op__(op, per_opt_block,
grad_to_block_id, merged_var,
lr_ops)
# dedup grad to ids list
grad_to_block_id = list(set(grad_to_block_id))
......@@ -726,17 +751,17 @@ class DistributeTranspiler(object):
Returns:
Program: parameter server side startup program.
"""
sys.stderr.write("get_startup_program() is deprecated, call\
get_pserver_programs() to get pserver main and startup\
in a single call.")
sys.stderr.write("get_startup_program() is deprecated, call \
get_pserver_programs() to get pserver main and startup \
in a single call.")
if pserver_program != None:
sys.stderr.write("passing pserver_program to get_startup_program()\
is deprecated, you can use new API get_pserver_programs() to\
get both pserver main program and startup program.")
sys.stderr.write("passing pserver_program to get_startup_program() \
is deprecated, you can use new API get_pserver_programs() to \
get both pserver main program and startup program.")
if startup_program != None:
sys.stderr.write("passing startup_program to get_startup_program()\
is deprecated, use fluid.program_guard() or pass this argument\
to transpile() call.")
sys.stderr.write("passing startup_program to get_startup_program() \
is deprecated, use fluid.program_guard() or pass this argument \
to transpile() call.")
s_prog = Program()
orig_s_prog = self.startup_program
......@@ -1302,7 +1327,10 @@ class DistributeTranspiler(object):
type="split_selected_rows",
inputs={"X": orig_var},
outputs={"Out": splited_vars},
attrs={"height_sections": height_sections})
attrs={
"height_sections": height_sections,
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
sections = []
for v in splited_vars:
......@@ -1312,8 +1340,10 @@ class DistributeTranspiler(object):
type="split_byref",
inputs={"X": orig_var},
outputs={"Out": splited_vars},
attrs={"sections": sections} # assume split evenly
)
attrs={
"sections": sections,
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
else:
AssertionError("Variable type should be in set "
"[LOD_TENSOR, SELECTED_ROWS]")
......@@ -1381,15 +1411,15 @@ class DistributeTranspiler(object):
if not grad_block:
# do not append this op if current endpoint
# is not dealing with this grad block
return
return None
orig_varname, block_name, trainer_name = self._get_varname_parts(
grad_block.name)
if block_name:
merged_var_name = '.'.join([orig_varname, block_name])
else:
merged_var_name = orig_varname
merged_var = \
pserver_block.vars[merged_var_name]
merged_var = pserver_block.vars[merged_var_name]
grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx))
if self.sync_mode and self.trainer_num > 1:
vars2merge = []
......@@ -1473,7 +1503,6 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op)
outputs["ParamOut"] = new_inputs["Param"]
optimize_block.append_op(
type=opt_op.type,
inputs=new_inputs,
......@@ -1618,6 +1647,16 @@ class DistributeTranspiler(object):
return iomap
def _get_lr_ops(self):
lr_ops = []
block = self.origin_program.global_block()
for op in block.ops:
if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) == int(
LR_SCHED_OP_ROLE_ATTR_VALUE):
lr_ops.append(op)
log("append lr op: ", op.type)
return lr_ops
def _get_lr_ops_deprecated(self):
lr_ops = []
# find learning rate variables by optimize op
lr_vars = set()
......@@ -1670,20 +1709,21 @@ class DistributeTranspiler(object):
block = self.origin_program.global_block()
opt_ops = []
params_grads = []
# tmp set to dedup
optimize_params = set()
origin_var_dict = self.origin_program.global_block().vars
for op in block.ops:
if self._is_opt_role_op(op):
opt_ops.append(op)
# HACK(wuyi): if we find grad vars from input of optimize
# ops, we may get the output of clip op. Use syntax "@GRAD"
# and op_role_var to get the pair.
for input_name in op.input_arg_names:
if input_name.find("@GRAD") != -1 and \
op.attr(RPC_OP_ROLE_ATTR_NAME):
param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
if op.attr(OP_ROLE_VAR_ATTR_NAME):
param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
if not param_name in optimize_params:
optimize_params.add(param_name)
log("adding param_grad pair: ", param_name, grad_name)
params_grads.append([
origin_var_dict[param_name],
origin_var_dict[input_name]
origin_var_dict[grad_name]
])
else:
pass
......
......@@ -14,10 +14,10 @@
from __future__ import print_function
from collections import defaultdict
from collections import defaultdict, OrderedDict, Callable
from .. import core
from ... import compat as cpt
from ..framework import Program, default_main_program, Parameter
from ..framework import Program, default_main_program, Parameter, Variable
from ..backward import _rename_arg_
from functools import reduce
from six.moves import range
......@@ -113,8 +113,10 @@ class ControlFlowGraph(object):
def _fill_pool(self, i, is_forward):
block_desc = self._ops[i].block()
in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
# NOTE: must sort the in_diff set for cases that get different cache var.
# FIXME(typhoonzero): maybe use a "sorted set" is better than this.
can_optimize = [
x for x in in_diff
x for x in sorted(list(in_diff))
if self._check_var_validity(block_desc, x, is_forward)
]
if can_optimize:
......@@ -220,8 +222,9 @@ class ControlFlowGraph(object):
block_desc = op.block()
is_forward = i < self._forward_num
if self.pool:
# NOTE: must sort the in_diff set for cases that get different cache var.
defs_can_optimize = [
x for x in self._defs[i]
x for x in sorted(list(self._defs[i]))
if self._check_var_validity(block_desc, x, is_forward)
]
out_pair = [
......@@ -271,6 +274,8 @@ class ControlFlowGraph(object):
self._program.block(block_desc.id).var(cpt.to_text(
x)).desc = self._find_var(block_desc, cache_var,
is_forward)
self._program.block(block_desc.id).vars[cpt.to_text(x)] = \
Variable(self._program.block(block_desc.id), name=cpt.to_text(x))
self._update_graph(x, cache_var, begin_idx=i)
break
self._fill_pool(i, is_forward)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册