未验证 提交 29c63d18 编写于 作者: W Wu Yi 提交者: GitHub

[Feature] dist op role and lr op role, to support memory optimize with dist training (#13220)

* wip

* clean up

* should fix running with memopt

* add ut

* mark lr schedule op role

* hide lr_schedule_guard

* use op_role_var instead of ufind

* unify dist test name

* wip for py3 support

* fix var deref

* fix python3 mem_opt order

* remove comments
上级 2d97903a
...@@ -210,43 +210,6 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars( ...@@ -210,43 +210,6 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
return recv_vars; return recv_vars;
} }
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const {
if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false;
}
/**
* Check any of opvars contains `.block` and in sendvars
*/
auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) {
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if (var.find(".block") != std::string::npos &&
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true;
}
}
return false;
};
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Name());
}
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Name());
}
return checker(output_var_names, send_vars) ||
checker(input_var_names, recv_vars);
}
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const { const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0; int64_t numel_sum = 0;
...@@ -370,7 +333,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -370,7 +333,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} }
} }
is_dist_train = true; is_dist_train = true;
} else if (IsDistTrainOp(node, send_vars, recv_vars)) { } else if (boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(&result, node); int op_dev_id = CreateDistTrainOp(&result, node);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0]; auto origin_param_name = node->Op()->OutputArgumentNames()[0];
...@@ -736,6 +701,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -736,6 +701,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
} }
} else { } else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
PADDLE_THROW( PADDLE_THROW(
"the distribute training related op should be in [split_byref, " "the distribute training related op should be in [split_byref, "
"concat]."); "concat].");
......
...@@ -51,12 +51,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -51,12 +51,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
int CreateRPCOp(ir::Graph *result, ir::Node *node) const; int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
/**
* Is this operator as the end-point operator before/after send operator.
*/
bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars( std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const; const std::vector<ir::Node *> &nodes) const;
......
...@@ -120,6 +120,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -120,6 +120,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
{static_cast<int>(OpRole::kForward), {static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kBackward), static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC), static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
static_cast<int>(OpRole::kDist), static_cast<int>(OpRole::kLRSched),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward), static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward), static_cast<int>(OpRole::kBackward),
......
...@@ -26,7 +26,13 @@ enum class OpRole { ...@@ -26,7 +26,13 @@ enum class OpRole {
kForward = 0x0000, kForward = 0x0000,
kBackward = 0x0001, kBackward = 0x0001,
kOptimize = 0x0002, kOptimize = 0x0002,
// RPC role is for send/recv releated op
kRPC = 0x0003, kRPC = 0x0003,
// Dist role is for split_byref/split_selected_rows/concat
// used for distributed training.
kDist = 0x0004,
// Tag all learning rate scheduler operators.
kLRSched = 0x0005,
kLoss = 0x0100, kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and // The default value of op's role. This should be only used for unittests and
......
...@@ -92,9 +92,14 @@ bool VariableResponse::CopyLodTensorData( ...@@ -92,9 +92,14 @@ bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input, ::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const framework::DDim& dims, const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) { int length) {
auto server_var = GetVar();
if (!server_var) {
LOG(ERROR) << "recved var should not on current server: "
<< meta_.varname();
return false;
}
auto* tensor = GetVar()->GetMutable<framework::LoDTensor>(); auto* tensor = GetVar()->GetMutable<framework::LoDTensor>();
tensor->Resize(dims); tensor->Resize(dims);
framework::LoD lod; framework::LoD lod;
for (int i = 0; i < meta_.lod_level(); ++i) { for (int i = 0; i < meta_.lod_level(); ++i) {
framework::Vector<size_t> v; framework::Vector<size_t> v;
...@@ -107,7 +112,6 @@ bool VariableResponse::CopyLodTensorData( ...@@ -107,7 +112,6 @@ bool VariableResponse::CopyLodTensorData(
void* tensor_data = void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type())); tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false; return false;
} }
......
...@@ -36,7 +36,9 @@ void BindConstValue(pybind11::module* m) { ...@@ -36,7 +36,9 @@ void BindConstValue(pybind11::module* m) {
.value("Backward", framework::OpRole::kBackward) .value("Backward", framework::OpRole::kBackward)
.value("Optimize", framework::OpRole::kOptimize) .value("Optimize", framework::OpRole::kOptimize)
.value("Loss", framework::OpRole::kLoss) .value("Loss", framework::OpRole::kLoss)
.value("RPC", framework::OpRole::kRPC); .value("RPC", framework::OpRole::kRPC)
.value("Dist", framework::OpRole::kDist)
.value("LRSched", framework::OpRole::kLRSched);
op_proto_and_checker_maker.def( op_proto_and_checker_maker.def(
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName); "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
......
...@@ -1509,6 +1509,30 @@ class Program(object): ...@@ -1509,6 +1509,30 @@ class Program(object):
self._op_role_var = [] self._op_role_var = []
self._current_role = OpRole.Forward self._current_role = OpRole.Forward
@contextlib.contextmanager
def _lr_schedule_guard(self):
"""
A with guard to set :code:`LRSched` :code:`OpRole` and
:code:`OpRoleVar` automatically. The :code:`OpRoleVar` is
set to the target learning rate.
Notes: This is a very low level API. Users should not use it directly.
Examples:
>>> p, g = backward(...)
>>> with program.lr_schedule_guard():
>>> lr = lr * decay
"""
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.LRSched
# TODO(typhoonzero): how to set target learning rate var
self._op_role_var = []
yield
self._op_role_var = []
self._current_role = OpRole.Forward
def __str__(self): def __str__(self):
""" """
Get the protobuf debug string of this Program. Get the protobuf debug string of this Program.
......
...@@ -27,7 +27,7 @@ from . import nn ...@@ -27,7 +27,7 @@ from . import nn
from . import ops from . import ops
from . import tensor from . import tensor
from ..initializer import init_on_cpu from ..initializer import init_on_cpu
from ..framework import default_main_program, Parameter from ..framework import default_main_program, Parameter, unique_name
__all__ = [ __all__ = [
'exponential_decay', 'natural_exp_decay', 'inverse_time_decay', 'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
...@@ -63,11 +63,12 @@ def noam_decay(d_model, warmup_steps): ...@@ -63,11 +63,12 @@ def noam_decay(d_model, warmup_steps):
Returns: Returns:
The decayed learning rate. The decayed learning rate.
""" """
global_step = _decay_step_counter(1) with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter(1)
a = global_step**-0.5 a = global_step**-0.5
b = (warmup_steps**-1.5) * global_step b = (warmup_steps**-1.5) * global_step
lr_value = (d_model**-0.5) * ops.elementwise_min(a, b) lr_value = (d_model**-0.5) * ops.elementwise_min(a, b)
return lr_value return lr_value
...@@ -108,14 +109,15 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -108,14 +109,15 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
""" """
global_step = _decay_step_counter() with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
div_res = global_step / decay_steps div_res = global_step / decay_steps
if staircase: if staircase:
div_res = ops.floor(div_res) div_res = ops.floor(div_res)
decayed_lr = learning_rate * (decay_rate**div_res) decayed_lr = learning_rate * (decay_rate**div_res)
return decayed_lr return decayed_lr
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
...@@ -136,14 +138,15 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -136,14 +138,15 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
Returns: Returns:
The decayed learning rate The decayed learning rate
""" """
global_step = _decay_step_counter() with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
div_res = global_step / decay_steps div_res = global_step / decay_steps
if staircase: if staircase:
div_res = ops.floor(div_res) div_res = ops.floor(div_res)
decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res) decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
return decayed_lr return decayed_lr
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
...@@ -181,15 +184,16 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -181,15 +184,16 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
staircase=True)) staircase=True))
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
""" """
global_step = _decay_step_counter() with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
div_res = global_step / decay_steps div_res = global_step / decay_steps
if staircase: if staircase:
div_res = ops.floor(div_res) div_res = ops.floor(div_res)
decayed_lr = learning_rate / (1 + decay_rate * div_res) decayed_lr = learning_rate / (1 + decay_rate * div_res)
return decayed_lr return decayed_lr
def polynomial_decay(learning_rate, def polynomial_decay(learning_rate,
...@@ -220,25 +224,28 @@ def polynomial_decay(learning_rate, ...@@ -220,25 +224,28 @@ def polynomial_decay(learning_rate,
Returns: Returns:
Variable: The decayed learning rate Variable: The decayed learning rate
""" """
global_step = _decay_step_counter() with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
if cycle:
div_res = ops.ceil(global_step / decay_steps) if cycle:
zero_var = tensor.fill_constant(shape=[1], dtype='float32', value=0.0) div_res = ops.ceil(global_step / decay_steps)
one_var = tensor.fill_constant(shape=[1], dtype='float32', value=1.0) zero_var = tensor.fill_constant(
shape=[1], dtype='float32', value=0.0)
with control_flow.Switch() as switch: one_var = tensor.fill_constant(
with switch.case(global_step == zero_var): shape=[1], dtype='float32', value=1.0)
tensor.assign(input=one_var, output=div_res)
decay_steps = decay_steps * div_res with control_flow.Switch() as switch:
else: with switch.case(global_step == zero_var):
decay_steps_var = tensor.fill_constant( tensor.assign(input=one_var, output=div_res)
shape=[1], dtype='float32', value=float(decay_steps)) decay_steps = decay_steps * div_res
global_step = ops.elementwise_min(x=global_step, y=decay_steps_var) else:
decay_steps_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(decay_steps))
global_step = ops.elementwise_min(x=global_step, y=decay_steps_var)
decayed_lr = (learning_rate - end_learning_rate) * \ decayed_lr = (learning_rate - end_learning_rate) * \
((1 - global_step / decay_steps) ** power) + end_learning_rate ((1 - global_step / decay_steps) ** power) + end_learning_rate
return decayed_lr return decayed_lr
def piecewise_decay(boundaries, values): def piecewise_decay(boundaries, values):
...@@ -266,34 +273,36 @@ def piecewise_decay(boundaries, values): ...@@ -266,34 +273,36 @@ def piecewise_decay(boundaries, values):
""" """
with default_main_program()._lr_schedule_guard():
if len(values) - len(boundaries) != 1:
raise ValueError("len(values) - len(boundaries) should be 1")
if len(values) - len(boundaries) != 1: global_step = _decay_step_counter()
raise ValueError("len(values) - len(boundaries) should be 1")
global_step = _decay_step_counter()
lr = tensor.create_global_var( lr = tensor.create_global_var(
shape=[1], shape=[1],
value=0.0, value=0.0,
dtype='float32', dtype='float32',
persistable=True, persistable=True,
name="learning_rate") name="learning_rate")
with control_flow.Switch() as switch: with control_flow.Switch() as switch:
for i in range(len(boundaries)): for i in range(len(boundaries)):
boundary_val = tensor.fill_constant( boundary_val = tensor.fill_constant(
shape=[1],
dtype='float32',
value=float(boundaries[i]),
force_cpu=True)
value_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(global_step < boundary_val):
tensor.assign(value_var, lr)
last_value_var = tensor.fill_constant(
shape=[1], shape=[1],
dtype='float32', dtype='float32',
value=float(boundaries[i]), value=float(values[len(values) - 1]))
force_cpu=True) with switch.default():
value_var = tensor.fill_constant( tensor.assign(last_value_var, lr)
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(global_step < boundary_val):
tensor.assign(value_var, lr)
last_value_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(values[len(values) - 1]))
with switch.default():
tensor.assign(last_value_var, lr)
return lr return lr
......
...@@ -22,7 +22,7 @@ class TestDistMnist2x2(TestDistBase): ...@@ -22,7 +22,7 @@ class TestDistMnist2x2(TestDistBase):
self._sync_mode = True self._sync_mode = True
self._use_reduce = False self._use_reduce = False
def test_se_resnext(self): def test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=1e-7) self.check_with_place("dist_mnist.py", delta=1e-7)
...@@ -31,7 +31,7 @@ class TestDistMnist2x2WithMemopt(TestDistBase): ...@@ -31,7 +31,7 @@ class TestDistMnist2x2WithMemopt(TestDistBase):
self._sync_mode = True self._sync_mode = True
self._mem_opt = True self._mem_opt = True
def test_se_resnext(self): def test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=1e-7) self.check_with_place("dist_mnist.py", delta=1e-7)
...@@ -40,7 +40,7 @@ class TestDistMnistAsync(TestDistBase): ...@@ -40,7 +40,7 @@ class TestDistMnistAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._use_reduce = False self._use_reduce = False
def test_se_resnext(self): def test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=200) self.check_with_place("dist_mnist.py", delta=200)
......
...@@ -21,7 +21,16 @@ class TestDistSeResneXt2x2(TestDistBase): ...@@ -21,7 +21,16 @@ class TestDistSeResneXt2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
def test_se_resnext(self): def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=1e-7)
class TestDistseResnXt2x2WithMemopt(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._mem_opt = True
def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=1e-7) self.check_with_place("dist_se_resnext.py", delta=1e-7)
...@@ -29,7 +38,7 @@ class TestDistSeResneXt2x2Async(TestDistBase): ...@@ -29,7 +38,7 @@ class TestDistSeResneXt2x2Async(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
def test_se_resnext(self): def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100) self.check_with_place("dist_se_resnext.py", delta=100)
......
...@@ -59,7 +59,7 @@ class TestDistTransformer2x2Sync(TestDistBase): ...@@ -59,7 +59,7 @@ class TestDistTransformer2x2Sync(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
def test_transformer(self): def test_dist_train(self):
download_files() download_files()
self.check_with_place("dist_transformer.py", delta=1e-5) self.check_with_place("dist_transformer.py", delta=1e-5)
...@@ -68,7 +68,7 @@ class TestDistTransformer2x2Async(TestDistBase): ...@@ -68,7 +68,7 @@ class TestDistTransformer2x2Async(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
def test_transformer(self): def test_dist_train(self):
download_files() download_files()
self.check_with_place("dist_transformer.py", delta=1.0) self.check_with_place("dist_transformer.py", delta=1.0)
......
...@@ -17,19 +17,28 @@ import unittest ...@@ -17,19 +17,28 @@ import unittest
from test_dist_base import TestDistBase from test_dist_base import TestDistBase
class TestDistSeResneXt2x2(TestDistBase): class TestDistW2V2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
def test_se_resnext(self): def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=1e-4) self.check_with_place("dist_word2vec.py", delta=1e-4)
class TestDistSeResneXt2x2Async(TestDistBase): class TestDistW2V2x2WithMemOpt(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._mem_opt = True
def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=1e-4)
class TestDistW2V2x2Async(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
def test_se_resnext(self): def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=1) self.check_with_place("dist_word2vec.py", delta=1)
......
...@@ -21,13 +21,12 @@ import paddle ...@@ -21,13 +21,12 @@ import paddle
def delete_ops(block, ops): def delete_ops(block, ops):
try: for op in ops:
start = list(block.ops).index(ops[0]) try:
end = list(block.ops).index(ops[-1]) idx = list(block.ops).index(op)
[block._remove_op(start) for _ in six.moves.range(end - start + 1)] block._remove_op(idx)
except Exception as e: except Exception as e:
raise e print(e)
block.program._sync_with_cpp()
def find_op_by_input_arg(block, arg_name): def find_op_by_input_arg(block, arg_name):
...@@ -37,10 +36,18 @@ def find_op_by_input_arg(block, arg_name): ...@@ -37,10 +36,18 @@ def find_op_by_input_arg(block, arg_name):
return -1 return -1
def find_op_by_output_arg(block, arg_name): def find_op_by_output_arg(block, arg_name, reverse=False):
for index, op in enumerate(block.ops): if reverse:
if arg_name in op.output_arg_names: pos = len(block.ops) - 1
return index while pos >= 0:
op = block.ops[pos]
if arg_name in op.output_arg_names:
return pos
pos -= 1
else:
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1 return -1
......
...@@ -50,6 +50,15 @@ OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() ...@@ -50,6 +50,15 @@ OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
) )
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
PRINT_LOG = False
def log(*args):
if PRINT_LOG:
print(args)
class VarBlock: class VarBlock:
...@@ -127,6 +136,7 @@ class DistributeTranspilerConfig(object): ...@@ -127,6 +136,7 @@ class DistributeTranspilerConfig(object):
slice_var_up = True slice_var_up = True
split_method = None split_method = None
min_block_size = 8192 min_block_size = 8192
print_log = False
class DistributeTranspiler(object): class DistributeTranspiler(object):
...@@ -174,6 +184,9 @@ class DistributeTranspiler(object): ...@@ -174,6 +184,9 @@ class DistributeTranspiler(object):
if self.config.split_method is None: if self.config.split_method is None:
self.config.split_method = RoundRobin self.config.split_method = RoundRobin
global PRINT_LOG
if self.config.print_log:
PRINT_LOG = True
assert (self.config.min_block_size >= 8192) assert (self.config.min_block_size >= 8192)
assert (self.config.split_method.__bases__[0] == PSDispatcher) assert (self.config.split_method.__bases__[0] == PSDispatcher)
...@@ -257,12 +270,12 @@ class DistributeTranspiler(object): ...@@ -257,12 +270,12 @@ class DistributeTranspiler(object):
splited_grad_varname = grad_varname splited_grad_varname = grad_varname
if len(splited_vars) == 1: if len(splited_vars) == 1:
splited_grad_varname = splited_vars[0].name splited_grad_varname = splited_vars[0].name
index = find_op_by_output_arg(program.global_block(), index = find_op_by_output_arg(
splited_grad_varname) program.global_block(), splited_grad_varname, reverse=True)
elif len(splited_vars) > 1: elif len(splited_vars) > 1:
orig_var = program.global_block().vars[splited_grad_varname] orig_var = program.global_block().vars[splited_grad_varname]
index = find_op_by_output_arg(program.global_block(), index = find_op_by_output_arg(
splited_grad_varname) program.global_block(), splited_grad_varname, reverse=True)
self._insert_split_op(program, orig_var, index, splited_vars) self._insert_split_op(program, orig_var, index, splited_vars)
index += 1 index += 1
else: else:
...@@ -301,7 +314,7 @@ class DistributeTranspiler(object): ...@@ -301,7 +314,7 @@ class DistributeTranspiler(object):
self.grad_name_to_send_dummy_out[ self.grad_name_to_send_dummy_out[
self.table_name] = program.global_block().create_var( self.table_name] = program.global_block().create_var(
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
input_deps = self.grad_name_to_send_dummy_out.values() input_deps = list(self.grad_name_to_send_dummy_out.values())
program.global_block().append_op( program.global_block().append_op(
type="send_barrier", type="send_barrier",
...@@ -377,7 +390,10 @@ class DistributeTranspiler(object): ...@@ -377,7 +390,10 @@ class DistributeTranspiler(object):
type="concat", type="concat",
inputs={"X": splited_var}, inputs={"X": splited_var},
outputs={"Out": [orig_param]}, outputs={"Out": [orig_param]},
attrs={"axis": 0}) attrs={
"axis": 0,
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
...@@ -496,9 +512,9 @@ class DistributeTranspiler(object): ...@@ -496,9 +512,9 @@ class DistributeTranspiler(object):
# NOTE: assume blocks of the same variable is not distributed # NOTE: assume blocks of the same variable is not distributed
# on the same pserver, only change param/grad varnames for # on the same pserver, only change param/grad varnames for
# trainers to fetch. # trainers to fetch.
sys.stderr.write("get_pserver_program() is deprecated, call\ sys.stderr.write("get_pserver_program() is deprecated, call \
get_pserver_programs() to get pserver main and startup\ get_pserver_programs() to get pserver main and startup \
in a single call.") in a single call.")
# step1 # step1
pserver_program = Program() pserver_program = Program()
pserver_program.random_seed = self.origin_program.random_seed pserver_program.random_seed = self.origin_program.random_seed
...@@ -615,22 +631,31 @@ class DistributeTranspiler(object): ...@@ -615,22 +631,31 @@ class DistributeTranspiler(object):
for idx, opt_op in enumerate(opt_op_on_pserver): for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program._create_block(pre_block_idx) per_opt_block = pserver_program._create_block(pre_block_idx)
optimize_blocks.append(per_opt_block) optimize_blocks.append(per_opt_block)
optimize_target_param_name = opt_op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
# append grad merging ops before clip and weight decay # append grad merging ops before clip and weight decay
# cases may like: # e.g. merge grad -> L2Decay op -> clip op -> optimize
# L2Decay op -> clip op -> optimize merged_var = None
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
# find the origin @GRAD var before clipping # find the origin grad var before clipping/L2Decay,
grad_varname_for_block = __op_have_grad_input__(op) # merged_var should be the input var name of L2Decaybuil
if ufind.is_connected(op, opt_op) and grad_varname_for_block: grad_varname_for_block = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
if op.attr(OP_ROLE_VAR_ATTR_NAME)[
0] == optimize_target_param_name:
merged_var = self._append_pserver_grad_merge_ops( merged_var = self._append_pserver_grad_merge_ops(
per_opt_block, grad_varname_for_block, endpoint, per_opt_block, grad_varname_for_block, endpoint,
grad_to_block_id, self.origin_program) grad_to_block_id, self.origin_program)
break # append optimize op once then append other ops. if merged_var:
for _, op in enumerate(self.optimize_ops): break # append optimize op once then append other ops.
# optimizer is connected to itself if merged_var:
if ufind.is_connected(op, opt_op) and op not in global_ops: for _, op in enumerate(self.optimize_ops):
__append_optimize_op__(op, per_opt_block, grad_to_block_id, # optimizer is connected to itself
merged_var, lr_ops) if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name and \
op not in global_ops:
log("append opt op: ", op.type, op.input_arg_names,
merged_var)
__append_optimize_op__(op, per_opt_block,
grad_to_block_id, merged_var,
lr_ops)
# dedup grad to ids list # dedup grad to ids list
grad_to_block_id = list(set(grad_to_block_id)) grad_to_block_id = list(set(grad_to_block_id))
...@@ -726,17 +751,17 @@ class DistributeTranspiler(object): ...@@ -726,17 +751,17 @@ class DistributeTranspiler(object):
Returns: Returns:
Program: parameter server side startup program. Program: parameter server side startup program.
""" """
sys.stderr.write("get_startup_program() is deprecated, call\ sys.stderr.write("get_startup_program() is deprecated, call \
get_pserver_programs() to get pserver main and startup\ get_pserver_programs() to get pserver main and startup \
in a single call.") in a single call.")
if pserver_program != None: if pserver_program != None:
sys.stderr.write("passing pserver_program to get_startup_program()\ sys.stderr.write("passing pserver_program to get_startup_program() \
is deprecated, you can use new API get_pserver_programs() to\ is deprecated, you can use new API get_pserver_programs() to \
get both pserver main program and startup program.") get both pserver main program and startup program.")
if startup_program != None: if startup_program != None:
sys.stderr.write("passing startup_program to get_startup_program()\ sys.stderr.write("passing startup_program to get_startup_program() \
is deprecated, use fluid.program_guard() or pass this argument\ is deprecated, use fluid.program_guard() or pass this argument \
to transpile() call.") to transpile() call.")
s_prog = Program() s_prog = Program()
orig_s_prog = self.startup_program orig_s_prog = self.startup_program
...@@ -1302,7 +1327,10 @@ class DistributeTranspiler(object): ...@@ -1302,7 +1327,10 @@ class DistributeTranspiler(object):
type="split_selected_rows", type="split_selected_rows",
inputs={"X": orig_var}, inputs={"X": orig_var},
outputs={"Out": splited_vars}, outputs={"Out": splited_vars},
attrs={"height_sections": height_sections}) attrs={
"height_sections": height_sections,
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
sections = [] sections = []
for v in splited_vars: for v in splited_vars:
...@@ -1312,8 +1340,10 @@ class DistributeTranspiler(object): ...@@ -1312,8 +1340,10 @@ class DistributeTranspiler(object):
type="split_byref", type="split_byref",
inputs={"X": orig_var}, inputs={"X": orig_var},
outputs={"Out": splited_vars}, outputs={"Out": splited_vars},
attrs={"sections": sections} # assume split evenly attrs={
) "sections": sections,
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
else: else:
AssertionError("Variable type should be in set " AssertionError("Variable type should be in set "
"[LOD_TENSOR, SELECTED_ROWS]") "[LOD_TENSOR, SELECTED_ROWS]")
...@@ -1381,15 +1411,15 @@ class DistributeTranspiler(object): ...@@ -1381,15 +1411,15 @@ class DistributeTranspiler(object):
if not grad_block: if not grad_block:
# do not append this op if current endpoint # do not append this op if current endpoint
# is not dealing with this grad block # is not dealing with this grad block
return return None
orig_varname, block_name, trainer_name = self._get_varname_parts( orig_varname, block_name, trainer_name = self._get_varname_parts(
grad_block.name) grad_block.name)
if block_name: if block_name:
merged_var_name = '.'.join([orig_varname, block_name]) merged_var_name = '.'.join([orig_varname, block_name])
else: else:
merged_var_name = orig_varname merged_var_name = orig_varname
merged_var = \
pserver_block.vars[merged_var_name] merged_var = pserver_block.vars[merged_var_name]
grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx)) grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx))
if self.sync_mode and self.trainer_num > 1: if self.sync_mode and self.trainer_num > 1:
vars2merge = [] vars2merge = []
...@@ -1473,7 +1503,6 @@ class DistributeTranspiler(object): ...@@ -1473,7 +1503,6 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
outputs["ParamOut"] = new_inputs["Param"] outputs["ParamOut"] = new_inputs["Param"]
optimize_block.append_op( optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
inputs=new_inputs, inputs=new_inputs,
...@@ -1618,6 +1647,16 @@ class DistributeTranspiler(object): ...@@ -1618,6 +1647,16 @@ class DistributeTranspiler(object):
return iomap return iomap
def _get_lr_ops(self): def _get_lr_ops(self):
lr_ops = []
block = self.origin_program.global_block()
for op in block.ops:
if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) == int(
LR_SCHED_OP_ROLE_ATTR_VALUE):
lr_ops.append(op)
log("append lr op: ", op.type)
return lr_ops
def _get_lr_ops_deprecated(self):
lr_ops = [] lr_ops = []
# find learning rate variables by optimize op # find learning rate variables by optimize op
lr_vars = set() lr_vars = set()
...@@ -1670,20 +1709,21 @@ class DistributeTranspiler(object): ...@@ -1670,20 +1709,21 @@ class DistributeTranspiler(object):
block = self.origin_program.global_block() block = self.origin_program.global_block()
opt_ops = [] opt_ops = []
params_grads = [] params_grads = []
# tmp set to dedup
optimize_params = set()
origin_var_dict = self.origin_program.global_block().vars origin_var_dict = self.origin_program.global_block().vars
for op in block.ops: for op in block.ops:
if self._is_opt_role_op(op): if self._is_opt_role_op(op):
opt_ops.append(op) opt_ops.append(op)
# HACK(wuyi): if we find grad vars from input of optimize if op.attr(OP_ROLE_VAR_ATTR_NAME):
# ops, we may get the output of clip op. Use syntax "@GRAD" param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
# and op_role_var to get the pair. grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
for input_name in op.input_arg_names: if not param_name in optimize_params:
if input_name.find("@GRAD") != -1 and \ optimize_params.add(param_name)
op.attr(RPC_OP_ROLE_ATTR_NAME): log("adding param_grad pair: ", param_name, grad_name)
param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
params_grads.append([ params_grads.append([
origin_var_dict[param_name], origin_var_dict[param_name],
origin_var_dict[input_name] origin_var_dict[grad_name]
]) ])
else: else:
pass pass
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
from __future__ import print_function from __future__ import print_function
from collections import defaultdict from collections import defaultdict, OrderedDict, Callable
from .. import core from .. import core
from ... import compat as cpt from ... import compat as cpt
from ..framework import Program, default_main_program, Parameter from ..framework import Program, default_main_program, Parameter, Variable
from ..backward import _rename_arg_ from ..backward import _rename_arg_
from functools import reduce from functools import reduce
from six.moves import range from six.moves import range
...@@ -113,8 +113,10 @@ class ControlFlowGraph(object): ...@@ -113,8 +113,10 @@ class ControlFlowGraph(object):
def _fill_pool(self, i, is_forward): def _fill_pool(self, i, is_forward):
block_desc = self._ops[i].block() block_desc = self._ops[i].block()
in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i]) in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
# NOTE: must sort the in_diff set for cases that get different cache var.
# FIXME(typhoonzero): maybe use a "sorted set" is better than this.
can_optimize = [ can_optimize = [
x for x in in_diff x for x in sorted(list(in_diff))
if self._check_var_validity(block_desc, x, is_forward) if self._check_var_validity(block_desc, x, is_forward)
] ]
if can_optimize: if can_optimize:
...@@ -220,8 +222,9 @@ class ControlFlowGraph(object): ...@@ -220,8 +222,9 @@ class ControlFlowGraph(object):
block_desc = op.block() block_desc = op.block()
is_forward = i < self._forward_num is_forward = i < self._forward_num
if self.pool: if self.pool:
# NOTE: must sort the in_diff set for cases that get different cache var.
defs_can_optimize = [ defs_can_optimize = [
x for x in self._defs[i] x for x in sorted(list(self._defs[i]))
if self._check_var_validity(block_desc, x, is_forward) if self._check_var_validity(block_desc, x, is_forward)
] ]
out_pair = [ out_pair = [
...@@ -271,6 +274,8 @@ class ControlFlowGraph(object): ...@@ -271,6 +274,8 @@ class ControlFlowGraph(object):
self._program.block(block_desc.id).var(cpt.to_text( self._program.block(block_desc.id).var(cpt.to_text(
x)).desc = self._find_var(block_desc, cache_var, x)).desc = self._find_var(block_desc, cache_var,
is_forward) is_forward)
self._program.block(block_desc.id).vars[cpt.to_text(x)] = \
Variable(self._program.block(block_desc.id), name=cpt.to_text(x))
self._update_graph(x, cache_var, begin_idx=i) self._update_graph(x, cache_var, begin_idx=i)
break break
self._fill_pool(i, is_forward) self._fill_pool(i, is_forward)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册