diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 8f319116ab80b75c624f35b0e1315e7362e88d9a..134fcee826715672a6e021e9bf694bb771ebb830 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -210,43 +210,6 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( return recv_vars; } -bool MultiDevSSAGraphBuilder::IsDistTrainOp( - ir::Node *node, const std::vector &send_vars, - const std::vector &recv_vars) const { - if (send_vars.size() == 0 || recv_vars.size() == 0) { - return false; - } - - /** - * Check any of opvars contains `.block` and in sendvars - */ - auto checker = [](const std::vector &opvars, - const std::vector &rpc_vars) -> bool { - for (auto &var : opvars) { - // a variable name with the suffix `.block` means it's a splited - // variable by (DistributeTranspiler) - // [python/paddle/fluid/transpiler/distribute_transpiler.py] - if (var.find(".block") != std::string::npos && - std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { - return true; - } - } - return false; - }; - - std::vector input_var_names; - std::vector output_var_names; - for (ir::Node *input : node->inputs) { - input_var_names.push_back(input->Name()); - } - for (ir::Node *output : node->outputs) { - output_var_names.push_back(output->Name()); - } - - return checker(output_var_names, send_vars) || - checker(input_var_names, recv_vars); -} - size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( const std::vector &var_names) const { int64_t numel_sum = 0; @@ -370,7 +333,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( } } is_dist_train = true; - } else if (IsDistTrainOp(node, send_vars, recv_vars)) { + } else if (boost::get(node->Op()->GetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kDist)) { int op_dev_id = CreateDistTrainOp(&result, node); if (node->Op()->Type() == "concat") { auto origin_param_name = node->Op()->OutputArgumentNames()[0]; @@ -736,6 +701,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, .emplace(varname, op_dev_id); } } else { + LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type(); PADDLE_THROW( "the distribute training related op should be in [split_byref, " "concat]."); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 47aaa80f4d66a48b729d0638badcab885a50585c..cdf9f13cde608b546d17a1e53e0f6acea9e12566 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -51,12 +51,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { int CreateRPCOp(ir::Graph *result, ir::Node *node) const; int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; - /** - * Is this operator as the end-point operator before/after send operator. - */ - bool IsDistTrainOp(ir::Node *node, const std::vector &send_vars, - const std::vector &recv_vars) const; - std::vector FindDistTrainSendVars( const std::vector &nodes) const; diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 4fa047bf3ee3d06ac4aec5d2cc6a355965836d42..df2a7a27ca4a6011b214202ac9bf4f30dc482ece 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -120,6 +120,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, {static_cast(OpRole::kForward), static_cast(OpRole::kBackward), static_cast(OpRole::kOptimize), static_cast(OpRole::kRPC), + static_cast(OpRole::kDist), static_cast(OpRole::kLRSched), static_cast(OpRole::kLoss) | static_cast(OpRole::kForward), static_cast(OpRole::kLoss) | static_cast(OpRole::kBackward), diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 18827385ad659922230ff68709a2926a8c9013ac..4ed3cc45d66849267ef4945a03da1db76b53e4ea 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -26,7 +26,13 @@ enum class OpRole { kForward = 0x0000, kBackward = 0x0001, kOptimize = 0x0002, + // RPC role is for send/recv releated op kRPC = 0x0003, + // Dist role is for split_byref/split_selected_rows/concat + // used for distributed training. + kDist = 0x0004, + // Tag all learning rate scheduler operators. + kLRSched = 0x0005, kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index 1617cc1b95216b118cf2c2122dbe8b6c106554c3..c4854d50b6371064003a10e18efc9e5f160d9a42 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -92,9 +92,14 @@ bool VariableResponse::CopyLodTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, const framework::DDim& dims, int length) { + auto server_var = GetVar(); + if (!server_var) { + LOG(ERROR) << "recved var should not on current server: " + << meta_.varname(); + return false; + } auto* tensor = GetVar()->GetMutable(); tensor->Resize(dims); - framework::LoD lod; for (int i = 0; i < meta_.lod_level(); ++i) { framework::Vector v; @@ -107,7 +112,6 @@ bool VariableResponse::CopyLodTensorData( void* tensor_data = tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type())); - if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { return false; } diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index f577068d1f39a3083a54f106d006f9982304411e..1f61a0e289f32196ead04d71d07b513cbe4655b1 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -36,7 +36,9 @@ void BindConstValue(pybind11::module* m) { .value("Backward", framework::OpRole::kBackward) .value("Optimize", framework::OpRole::kOptimize) .value("Loss", framework::OpRole::kLoss) - .value("RPC", framework::OpRole::kRPC); + .value("RPC", framework::OpRole::kRPC) + .value("Dist", framework::OpRole::kDist) + .value("LRSched", framework::OpRole::kLRSched); op_proto_and_checker_maker.def( "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 0abbb6815123f8ba65b637b3f3accef91fe66ef8..d7e5e4704858c08a21a2e7505facb63a93a4c010 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1509,6 +1509,30 @@ class Program(object): self._op_role_var = [] self._current_role = OpRole.Forward + @contextlib.contextmanager + def _lr_schedule_guard(self): + """ + A with guard to set :code:`LRSched` :code:`OpRole` and + :code:`OpRoleVar` automatically. The :code:`OpRoleVar` is + set to the target learning rate. + + Notes: This is a very low level API. Users should not use it directly. + + + Examples: + + >>> p, g = backward(...) + >>> with program.lr_schedule_guard(): + >>> lr = lr * decay + """ + OpRole = core.op_proto_and_checker_maker.OpRole + self._current_role = OpRole.LRSched + # TODO(typhoonzero): how to set target learning rate var + self._op_role_var = [] + yield + self._op_role_var = [] + self._current_role = OpRole.Forward + def __str__(self): """ Get the protobuf debug string of this Program. diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index be368007dd7061ba7fc97414dbadfce00d158776..2b947ca9e87af2a0a7b224cb55f3409e17118bed 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -27,7 +27,7 @@ from . import nn from . import ops from . import tensor from ..initializer import init_on_cpu -from ..framework import default_main_program, Parameter +from ..framework import default_main_program, Parameter, unique_name __all__ = [ 'exponential_decay', 'natural_exp_decay', 'inverse_time_decay', @@ -63,11 +63,12 @@ def noam_decay(d_model, warmup_steps): Returns: The decayed learning rate. """ - global_step = _decay_step_counter(1) + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter(1) - a = global_step**-0.5 - b = (warmup_steps**-1.5) * global_step - lr_value = (d_model**-0.5) * ops.elementwise_min(a, b) + a = global_step**-0.5 + b = (warmup_steps**-1.5) * global_step + lr_value = (d_model**-0.5) * ops.elementwise_min(a, b) return lr_value @@ -108,14 +109,15 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): sgd_optimizer.minimize(avg_cost) """ - global_step = _decay_step_counter() + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() - div_res = global_step / decay_steps - if staircase: - div_res = ops.floor(div_res) - decayed_lr = learning_rate * (decay_rate**div_res) + div_res = global_step / decay_steps + if staircase: + div_res = ops.floor(div_res) + decayed_lr = learning_rate * (decay_rate**div_res) - return decayed_lr + return decayed_lr def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): @@ -136,14 +138,15 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): Returns: The decayed learning rate """ - global_step = _decay_step_counter() + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() - div_res = global_step / decay_steps - if staircase: - div_res = ops.floor(div_res) - decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res) + div_res = global_step / decay_steps + if staircase: + div_res = ops.floor(div_res) + decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res) - return decayed_lr + return decayed_lr def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): @@ -181,15 +184,16 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): staircase=True)) sgd_optimizer.minimize(avg_cost) """ - global_step = _decay_step_counter() + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() - div_res = global_step / decay_steps - if staircase: - div_res = ops.floor(div_res) + div_res = global_step / decay_steps + if staircase: + div_res = ops.floor(div_res) - decayed_lr = learning_rate / (1 + decay_rate * div_res) + decayed_lr = learning_rate / (1 + decay_rate * div_res) - return decayed_lr + return decayed_lr def polynomial_decay(learning_rate, @@ -220,25 +224,28 @@ def polynomial_decay(learning_rate, Returns: Variable: The decayed learning rate """ - global_step = _decay_step_counter() - - if cycle: - div_res = ops.ceil(global_step / decay_steps) - zero_var = tensor.fill_constant(shape=[1], dtype='float32', value=0.0) - one_var = tensor.fill_constant(shape=[1], dtype='float32', value=1.0) - - with control_flow.Switch() as switch: - with switch.case(global_step == zero_var): - tensor.assign(input=one_var, output=div_res) - decay_steps = decay_steps * div_res - else: - decay_steps_var = tensor.fill_constant( - shape=[1], dtype='float32', value=float(decay_steps)) - global_step = ops.elementwise_min(x=global_step, y=decay_steps_var) + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() + + if cycle: + div_res = ops.ceil(global_step / decay_steps) + zero_var = tensor.fill_constant( + shape=[1], dtype='float32', value=0.0) + one_var = tensor.fill_constant( + shape=[1], dtype='float32', value=1.0) + + with control_flow.Switch() as switch: + with switch.case(global_step == zero_var): + tensor.assign(input=one_var, output=div_res) + decay_steps = decay_steps * div_res + else: + decay_steps_var = tensor.fill_constant( + shape=[1], dtype='float32', value=float(decay_steps)) + global_step = ops.elementwise_min(x=global_step, y=decay_steps_var) - decayed_lr = (learning_rate - end_learning_rate) * \ - ((1 - global_step / decay_steps) ** power) + end_learning_rate - return decayed_lr + decayed_lr = (learning_rate - end_learning_rate) * \ + ((1 - global_step / decay_steps) ** power) + end_learning_rate + return decayed_lr def piecewise_decay(boundaries, values): @@ -266,34 +273,36 @@ def piecewise_decay(boundaries, values): """ + with default_main_program()._lr_schedule_guard(): + if len(values) - len(boundaries) != 1: + raise ValueError("len(values) - len(boundaries) should be 1") - if len(values) - len(boundaries) != 1: - raise ValueError("len(values) - len(boundaries) should be 1") - - global_step = _decay_step_counter() + global_step = _decay_step_counter() - lr = tensor.create_global_var( - shape=[1], - value=0.0, - dtype='float32', - persistable=True, - name="learning_rate") + lr = tensor.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name="learning_rate") - with control_flow.Switch() as switch: - for i in range(len(boundaries)): - boundary_val = tensor.fill_constant( + with control_flow.Switch() as switch: + for i in range(len(boundaries)): + boundary_val = tensor.fill_constant( + shape=[1], + dtype='float32', + value=float(boundaries[i]), + force_cpu=True) + value_var = tensor.fill_constant( + shape=[1], dtype='float32', value=float(values[i])) + with switch.case(global_step < boundary_val): + tensor.assign(value_var, lr) + last_value_var = tensor.fill_constant( shape=[1], dtype='float32', - value=float(boundaries[i]), - force_cpu=True) - value_var = tensor.fill_constant( - shape=[1], dtype='float32', value=float(values[i])) - with switch.case(global_step < boundary_val): - tensor.assign(value_var, lr) - last_value_var = tensor.fill_constant( - shape=[1], dtype='float32', value=float(values[len(values) - 1])) - with switch.default(): - tensor.assign(last_value_var, lr) + value=float(values[len(values) - 1])) + with switch.default(): + tensor.assign(last_value_var, lr) return lr diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index 59a137c18c9435ef5c5772d0cc08f197c1d86603..09b1c546e49bd02bf336f31885bf4c7339cc5a2c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -22,7 +22,7 @@ class TestDistMnist2x2(TestDistBase): self._sync_mode = True self._use_reduce = False - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_mnist.py", delta=1e-7) @@ -31,7 +31,7 @@ class TestDistMnist2x2WithMemopt(TestDistBase): self._sync_mode = True self._mem_opt = True - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_mnist.py", delta=1e-7) @@ -40,7 +40,7 @@ class TestDistMnistAsync(TestDistBase): self._sync_mode = False self._use_reduce = False - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_mnist.py", delta=200) diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py index c0e9fa38e7d1eadd89eff9a8ba4442f888b8120e..7c3ed0916845d0a0cc0c462ff00927b91f546b21 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py @@ -21,7 +21,16 @@ class TestDistSeResneXt2x2(TestDistBase): def _setup_config(self): self._sync_mode = True - def test_se_resnext(self): + def test_dist_train(self): + self.check_with_place("dist_se_resnext.py", delta=1e-7) + + +class TestDistseResnXt2x2WithMemopt(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._mem_opt = True + + def test_dist_train(self): self.check_with_place("dist_se_resnext.py", delta=1e-7) @@ -29,7 +38,7 @@ class TestDistSeResneXt2x2Async(TestDistBase): def _setup_config(self): self._sync_mode = False - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_se_resnext.py", delta=100) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transformer.py b/python/paddle/fluid/tests/unittests/test_dist_transformer.py index 47083ca7e954c85bb42fcc88639f3e757283cbea..47e8dfaf03ceb27a74f5e48d662d2b534d2d152b 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transformer.py @@ -59,7 +59,7 @@ class TestDistTransformer2x2Sync(TestDistBase): def _setup_config(self): self._sync_mode = True - def test_transformer(self): + def test_dist_train(self): download_files() self.check_with_place("dist_transformer.py", delta=1e-5) @@ -68,7 +68,7 @@ class TestDistTransformer2x2Async(TestDistBase): def _setup_config(self): self._sync_mode = False - def test_transformer(self): + def test_dist_train(self): download_files() self.check_with_place("dist_transformer.py", delta=1.0) diff --git a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py index 9a3e92e8d775a37e0c24ee1bcc5435628d61bb91..33b39b262b95b0013e3696c3f15a288a2e801ce1 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py +++ b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py @@ -17,19 +17,28 @@ import unittest from test_dist_base import TestDistBase -class TestDistSeResneXt2x2(TestDistBase): +class TestDistW2V2x2(TestDistBase): def _setup_config(self): self._sync_mode = True - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_word2vec.py", delta=1e-4) -class TestDistSeResneXt2x2Async(TestDistBase): +class TestDistW2V2x2WithMemOpt(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._mem_opt = True + + def test_dist_train(self): + self.check_with_place("dist_word2vec.py", delta=1e-4) + + +class TestDistW2V2x2Async(TestDistBase): def _setup_config(self): self._sync_mode = False - def test_se_resnext(self): + def test_dist_train(self): self.check_with_place("dist_word2vec.py", delta=1) diff --git a/python/paddle/fluid/transpiler/details/program_utils.py b/python/paddle/fluid/transpiler/details/program_utils.py index 200175cfe87e24a53e1e229e41d1ff2a25fd66ec..59899e7e9ab98f661699d5ac0645c92bd23a1512 100644 --- a/python/paddle/fluid/transpiler/details/program_utils.py +++ b/python/paddle/fluid/transpiler/details/program_utils.py @@ -21,13 +21,12 @@ import paddle def delete_ops(block, ops): - try: - start = list(block.ops).index(ops[0]) - end = list(block.ops).index(ops[-1]) - [block._remove_op(start) for _ in six.moves.range(end - start + 1)] - except Exception as e: - raise e - block.program._sync_with_cpp() + for op in ops: + try: + idx = list(block.ops).index(op) + block._remove_op(idx) + except Exception as e: + print(e) def find_op_by_input_arg(block, arg_name): @@ -37,10 +36,18 @@ def find_op_by_input_arg(block, arg_name): return -1 -def find_op_by_output_arg(block, arg_name): - for index, op in enumerate(block.ops): - if arg_name in op.output_arg_names: - return index +def find_op_by_output_arg(block, arg_name, reverse=False): + if reverse: + pos = len(block.ops) - 1 + while pos >= 0: + op = block.ops[pos] + if arg_name in op.output_arg_names: + return pos + pos -= 1 + else: + for index, op in enumerate(block.ops): + if arg_name in op.output_arg_names: + return index return -1 diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index f58f1883a407a3123856e19b5ec8fc01862466a7..3f8c7b844a9fdc8404560ba4c78f9d328af2852a 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -50,6 +50,15 @@ OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( ) RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC +DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist +LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched + +PRINT_LOG = False + + +def log(*args): + if PRINT_LOG: + print(args) class VarBlock: @@ -127,6 +136,7 @@ class DistributeTranspilerConfig(object): slice_var_up = True split_method = None min_block_size = 8192 + print_log = False class DistributeTranspiler(object): @@ -174,6 +184,9 @@ class DistributeTranspiler(object): if self.config.split_method is None: self.config.split_method = RoundRobin + global PRINT_LOG + if self.config.print_log: + PRINT_LOG = True assert (self.config.min_block_size >= 8192) assert (self.config.split_method.__bases__[0] == PSDispatcher) @@ -257,12 +270,12 @@ class DistributeTranspiler(object): splited_grad_varname = grad_varname if len(splited_vars) == 1: splited_grad_varname = splited_vars[0].name - index = find_op_by_output_arg(program.global_block(), - splited_grad_varname) + index = find_op_by_output_arg( + program.global_block(), splited_grad_varname, reverse=True) elif len(splited_vars) > 1: orig_var = program.global_block().vars[splited_grad_varname] - index = find_op_by_output_arg(program.global_block(), - splited_grad_varname) + index = find_op_by_output_arg( + program.global_block(), splited_grad_varname, reverse=True) self._insert_split_op(program, orig_var, index, splited_vars) index += 1 else: @@ -301,7 +314,7 @@ class DistributeTranspiler(object): self.grad_name_to_send_dummy_out[ self.table_name] = program.global_block().create_var( name=framework.generate_control_dev_var_name()) - input_deps = self.grad_name_to_send_dummy_out.values() + input_deps = list(self.grad_name_to_send_dummy_out.values()) program.global_block().append_op( type="send_barrier", @@ -377,7 +390,10 @@ class DistributeTranspiler(object): type="concat", inputs={"X": splited_var}, outputs={"Out": [orig_param]}, - attrs={"axis": 0}) + attrs={ + "axis": 0, + RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE + }) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) @@ -496,9 +512,9 @@ class DistributeTranspiler(object): # NOTE: assume blocks of the same variable is not distributed # on the same pserver, only change param/grad varnames for # trainers to fetch. - sys.stderr.write("get_pserver_program() is deprecated, call\ - get_pserver_programs() to get pserver main and startup\ - in a single call.") + sys.stderr.write("get_pserver_program() is deprecated, call \ +get_pserver_programs() to get pserver main and startup \ +in a single call.") # step1 pserver_program = Program() pserver_program.random_seed = self.origin_program.random_seed @@ -615,22 +631,31 @@ class DistributeTranspiler(object): for idx, opt_op in enumerate(opt_op_on_pserver): per_opt_block = pserver_program._create_block(pre_block_idx) optimize_blocks.append(per_opt_block) + optimize_target_param_name = opt_op.attr(OP_ROLE_VAR_ATTR_NAME)[0] # append grad merging ops before clip and weight decay - # cases may like: - # L2Decay op -> clip op -> optimize + # e.g. merge grad -> L2Decay op -> clip op -> optimize + merged_var = None for _, op in enumerate(self.optimize_ops): - # find the origin @GRAD var before clipping - grad_varname_for_block = __op_have_grad_input__(op) - if ufind.is_connected(op, opt_op) and grad_varname_for_block: + # find the origin grad var before clipping/L2Decay, + # merged_var should be the input var name of L2Decaybuil + grad_varname_for_block = op.attr(OP_ROLE_VAR_ATTR_NAME)[1] + if op.attr(OP_ROLE_VAR_ATTR_NAME)[ + 0] == optimize_target_param_name: merged_var = self._append_pserver_grad_merge_ops( per_opt_block, grad_varname_for_block, endpoint, grad_to_block_id, self.origin_program) - break # append optimize op once then append other ops. - for _, op in enumerate(self.optimize_ops): - # optimizer is connected to itself - if ufind.is_connected(op, opt_op) and op not in global_ops: - __append_optimize_op__(op, per_opt_block, grad_to_block_id, - merged_var, lr_ops) + if merged_var: + break # append optimize op once then append other ops. + if merged_var: + for _, op in enumerate(self.optimize_ops): + # optimizer is connected to itself + if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name and \ + op not in global_ops: + log("append opt op: ", op.type, op.input_arg_names, + merged_var) + __append_optimize_op__(op, per_opt_block, + grad_to_block_id, merged_var, + lr_ops) # dedup grad to ids list grad_to_block_id = list(set(grad_to_block_id)) @@ -726,17 +751,17 @@ class DistributeTranspiler(object): Returns: Program: parameter server side startup program. """ - sys.stderr.write("get_startup_program() is deprecated, call\ - get_pserver_programs() to get pserver main and startup\ - in a single call.") + sys.stderr.write("get_startup_program() is deprecated, call \ +get_pserver_programs() to get pserver main and startup \ +in a single call.") if pserver_program != None: - sys.stderr.write("passing pserver_program to get_startup_program()\ - is deprecated, you can use new API get_pserver_programs() to\ - get both pserver main program and startup program.") + sys.stderr.write("passing pserver_program to get_startup_program() \ +is deprecated, you can use new API get_pserver_programs() to \ +get both pserver main program and startup program.") if startup_program != None: - sys.stderr.write("passing startup_program to get_startup_program()\ - is deprecated, use fluid.program_guard() or pass this argument\ - to transpile() call.") + sys.stderr.write("passing startup_program to get_startup_program() \ +is deprecated, use fluid.program_guard() or pass this argument \ +to transpile() call.") s_prog = Program() orig_s_prog = self.startup_program @@ -1302,7 +1327,10 @@ class DistributeTranspiler(object): type="split_selected_rows", inputs={"X": orig_var}, outputs={"Out": splited_vars}, - attrs={"height_sections": height_sections}) + attrs={ + "height_sections": height_sections, + RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE + }) elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: sections = [] for v in splited_vars: @@ -1312,8 +1340,10 @@ class DistributeTranspiler(object): type="split_byref", inputs={"X": orig_var}, outputs={"Out": splited_vars}, - attrs={"sections": sections} # assume split evenly - ) + attrs={ + "sections": sections, + RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE + }) else: AssertionError("Variable type should be in set " "[LOD_TENSOR, SELECTED_ROWS]") @@ -1381,15 +1411,15 @@ class DistributeTranspiler(object): if not grad_block: # do not append this op if current endpoint # is not dealing with this grad block - return + return None orig_varname, block_name, trainer_name = self._get_varname_parts( grad_block.name) if block_name: merged_var_name = '.'.join([orig_varname, block_name]) else: merged_var_name = orig_varname - merged_var = \ - pserver_block.vars[merged_var_name] + + merged_var = pserver_block.vars[merged_var_name] grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx)) if self.sync_mode and self.trainer_num > 1: vars2merge = [] @@ -1473,7 +1503,6 @@ class DistributeTranspiler(object): outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, opt_op) outputs["ParamOut"] = new_inputs["Param"] - optimize_block.append_op( type=opt_op.type, inputs=new_inputs, @@ -1618,6 +1647,16 @@ class DistributeTranspiler(object): return iomap def _get_lr_ops(self): + lr_ops = [] + block = self.origin_program.global_block() + for op in block.ops: + if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) == int( + LR_SCHED_OP_ROLE_ATTR_VALUE): + lr_ops.append(op) + log("append lr op: ", op.type) + return lr_ops + + def _get_lr_ops_deprecated(self): lr_ops = [] # find learning rate variables by optimize op lr_vars = set() @@ -1670,20 +1709,21 @@ class DistributeTranspiler(object): block = self.origin_program.global_block() opt_ops = [] params_grads = [] + # tmp set to dedup + optimize_params = set() origin_var_dict = self.origin_program.global_block().vars for op in block.ops: if self._is_opt_role_op(op): opt_ops.append(op) - # HACK(wuyi): if we find grad vars from input of optimize - # ops, we may get the output of clip op. Use syntax "@GRAD" - # and op_role_var to get the pair. - for input_name in op.input_arg_names: - if input_name.find("@GRAD") != -1 and \ - op.attr(RPC_OP_ROLE_ATTR_NAME): - param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0] + if op.attr(OP_ROLE_VAR_ATTR_NAME): + param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0] + grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1] + if not param_name in optimize_params: + optimize_params.add(param_name) + log("adding param_grad pair: ", param_name, grad_name) params_grads.append([ origin_var_dict[param_name], - origin_var_dict[input_name] + origin_var_dict[grad_name] ]) else: pass diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index d4517059a4b033eec20ef6903894426ccbd597d7..d5aa54d752305b188d292f95f05cd70d27702c35 100755 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -14,10 +14,10 @@ from __future__ import print_function -from collections import defaultdict +from collections import defaultdict, OrderedDict, Callable from .. import core from ... import compat as cpt -from ..framework import Program, default_main_program, Parameter +from ..framework import Program, default_main_program, Parameter, Variable from ..backward import _rename_arg_ from functools import reduce from six.moves import range @@ -113,8 +113,10 @@ class ControlFlowGraph(object): def _fill_pool(self, i, is_forward): block_desc = self._ops[i].block() in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i]) + # NOTE: must sort the in_diff set for cases that get different cache var. + # FIXME(typhoonzero): maybe use a "sorted set" is better than this. can_optimize = [ - x for x in in_diff + x for x in sorted(list(in_diff)) if self._check_var_validity(block_desc, x, is_forward) ] if can_optimize: @@ -220,8 +222,9 @@ class ControlFlowGraph(object): block_desc = op.block() is_forward = i < self._forward_num if self.pool: + # NOTE: must sort the in_diff set for cases that get different cache var. defs_can_optimize = [ - x for x in self._defs[i] + x for x in sorted(list(self._defs[i])) if self._check_var_validity(block_desc, x, is_forward) ] out_pair = [ @@ -271,6 +274,8 @@ class ControlFlowGraph(object): self._program.block(block_desc.id).var(cpt.to_text( x)).desc = self._find_var(block_desc, cache_var, is_forward) + self._program.block(block_desc.id).vars[cpt.to_text(x)] = \ + Variable(self._program.block(block_desc.id), name=cpt.to_text(x)) self._update_graph(x, cache_var, begin_idx=i) break self._fill_pool(i, is_forward)