From d90610624f7693d4dc9eda15ae6d6b46b0df8398 Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Thu, 31 May 2018 16:23:11 +0800 Subject: [PATCH] Cleanup transpiler and move weight decay and clip on pservers (#11039) * WIP move weight decay * weight decay ok * wip * clean up transpiler * add details folder * update * fix split var test * follow comments --- .../fluid/tests/unittests/test_split_var.py | 4 +- .../fluid/transpiler/details/__init__.py | 16 + .../fluid/transpiler/details/program_utils.py | 37 ++ .../paddle/fluid/transpiler/details/ufind.py | 64 +++ .../fluid/transpiler/distribute_transpiler.py | 509 +++++++++--------- 5 files changed, 371 insertions(+), 259 deletions(-) create mode 100644 python/paddle/fluid/transpiler/details/__init__.py create mode 100644 python/paddle/fluid/transpiler/details/program_utils.py create mode 100644 python/paddle/fluid/transpiler/details/ufind.py diff --git a/python/paddle/fluid/tests/unittests/test_split_var.py b/python/paddle/fluid/tests/unittests/test_split_var.py index 0c5e8901b..157def9b5 100644 --- a/python/paddle/fluid/tests/unittests/test_split_var.py +++ b/python/paddle/fluid/tests/unittests/test_split_var.py @@ -14,7 +14,7 @@ import math import unittest -from paddle.fluid.transpiler.distribute_transpiler import split_dense_variable +from paddle.fluid.transpiler.distribute_transpiler import split_variable import paddle.fluid as fluid import paddle.fluid.core as core import random @@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase): # dtype=core.VarDesc.VarType.LOD_TENSOR, shape=shape) var_list.append(var) - blocks = split_dense_variable(var_list, 10, min_size) + blocks = split_variable(var_list, 10, min_size) all_sizes = [] for s in expected_sizes: for s2 in s: diff --git a/python/paddle/fluid/transpiler/details/__init__.py b/python/paddle/fluid/transpiler/details/__init__.py new file mode 100644 index 000000000..dc597c338 --- /dev/null +++ b/python/paddle/fluid/transpiler/details/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from program_utils import * +from ufind import * diff --git a/python/paddle/fluid/transpiler/details/program_utils.py b/python/paddle/fluid/transpiler/details/program_utils.py new file mode 100644 index 000000000..f10b49630 --- /dev/null +++ b/python/paddle/fluid/transpiler/details/program_utils.py @@ -0,0 +1,37 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def delete_ops(block, ops): + try: + start = list(block.ops).index(ops[0]) + end = list(block.ops).index(ops[-1]) + [block.remove_op(start) for _ in xrange(end - start + 1)] + except Exception, e: + raise e + block.program.sync_with_cpp() + + +def find_op_by_input_arg(block, arg_name): + for index, op in enumerate(block.ops): + if arg_name in op.input_arg_names: + return index + return -1 + + +def find_op_by_output_arg(block, arg_name): + for index, op in enumerate(block.ops): + if arg_name in op.output_arg_names: + return index + return -1 diff --git a/python/paddle/fluid/transpiler/details/ufind.py b/python/paddle/fluid/transpiler/details/ufind.py new file mode 100644 index 000000000..0e30d0e3f --- /dev/null +++ b/python/paddle/fluid/transpiler/details/ufind.py @@ -0,0 +1,64 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class UnionFind(object): + """ Union-find data structure. + + Union-find is a data structure that keeps track of a set of elements partitioned + into a number of disjoint (non-overlapping) subsets. + + Reference: + https://en.wikipedia.org/wiki/Disjoint-set_data_structure + + Args: + elements(list): The initialize element list. + """ + + def __init__(self, elementes=None): + self._parents = [] # index -> parent index + self._index = {} # element -> index + self._curr_idx = 0 + if not elementes: + elementes = [] + for ele in elementes: + self._parents.append(self._curr_idx) + self._index.update({ele: self._curr_idx}) + self._curr_idx += 1 + + def find(self, x): + # Find the root index of given element x, + # execute the path compress while findind the root index + if not x in self._index: + return -1 + idx = self._index[x] + while idx != self._parents[idx]: + t = self._parents[idx] + self._parents[idx] = self._parents[t] + idx = t + return idx + + def union(self, x, y): + # Union two given element + x_root = self.find(x) + y_root = self.find(y) + + if x_root == y_root: + return + self._parents[x_root] = y_root + + def is_connected(self, x, y): + # If two given elements have the same root index, + # then they are connected. + return self.find(x) == self.find(y) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index e9b7d9e9d..06b0a1375 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -11,6 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Transpile the program to distributed data-parallelism programs. +The main_program will be transformed to use a remote parameter server +to do parameter optimization. And the optimization graph will be put +into a parameter server program. + +Use different methods to split trainable variables to different +parameter servers. + +Steps to transpile trainer: +1. split variable to multiple blocks, aligned by product(dim[1:]) (width). +2. rename splited grad variables to add trainer_id suffix ".trainer_%d". +3. modify trainer program add split_op to each grad variable. +4. append send_op to send splited variables to server and fetch + params(splited blocks or origin param) from server. +5. append concat_op to merge splited blocks to update local weights. + +Steps to transpile pserver: +1. create new program for parameter server. +2. create params and grad variables that assigned to current server instance. +3. create a sub-block in the server side program +4. append ops that should run on current server instance. +5. add listen_and_serv op +""" from __future__ import print_function @@ -21,9 +45,11 @@ from .. import core, framework from ..framework import Program, default_main_program, \ default_startup_program, \ Variable, Parameter, grad_var_name +from details import * LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" +OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( ) RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC @@ -40,62 +66,11 @@ class VarBlock: return "%s:%d:%d" % (self.varname, self.offset, self.size) -class UnionFind(object): - """ Union-find data structure. - - Union-find is a data structure that keeps track of a set of elements partitioned - into a number of disjoint (non-overlapping) subsets. - - Reference: - https://en.wikipedia.org/wiki/Disjoint-set_data_structure - - Args: - elements(list): The initialize element list. - """ - - def __init__(self, elementes=None): - self._parents = [] # index -> parent index - self._index = {} # element -> index - self._curr_idx = 0 - if not elementes: - elementes = [] - for ele in elementes: - self._parents.append(self._curr_idx) - self._index.update({ele: self._curr_idx}) - self._curr_idx += 1 - - def find(self, x): - # Find the root index of given element x, - # execute the path compress while findind the root index - if not x in self._index: - return -1 - idx = self._index[x] - while idx != self._parents[idx]: - t = self._parents[idx] - self._parents[idx] = self._parents[t] - idx = t - return idx - - def union(self, x, y): - # Union two given element - x_root = self.find(x) - y_root = self.find(y) - - if x_root == y_root: - return - self._parents[x_root] = y_root - - def is_connected(self, x, y): - # If two given elements have the same root index, - # then they are connected. - return self.find(x) == self.find(y) - - def same_or_split_var(p_name, var_name): return p_name == var_name or p_name.startswith(var_name + ".block") -def split_dense_variable(var_list, service_count, min_block_size=8192): +def split_variable(var_list, service_count, min_block_size=8192): """ We may need to split dense tensor to one or more blocks and put them equally onto parameter server. One block is a sub-tensor @@ -141,99 +116,15 @@ def split_dense_variable(var_list, service_count, min_block_size=8192): return blocks -def delete_ops(block, ops): - try: - start = list(block.ops).index(ops[0]) - end = list(block.ops).index(ops[-1]) - [block.remove_op(start) for _ in xrange(end - start + 1)] - except Exception, e: - raise e - block.program.sync_with_cpp() - - -def find_op_by_input_arg(block, arg_name): - for index, op in enumerate(block.ops): - if arg_name in op.input_arg_names: - return index - return -1 - - -def find_op_by_output_arg(block, arg_name): - for index, op in enumerate(block.ops): - if arg_name in op.output_arg_names: - return index - return -1 - - class DistributeTranspiler: - def transpile(self, - trainer_id, - program=None, - pservers="127.0.0.1:6174", - trainers=1, - split_method=RoundRobin, - sync_mode=True): - """ - Transpile the program to distributed data-parallelism programs. - The main_program will be transformed to use a remote parameter server - to do parameter optimization. And the optimization graph will be put - into a parameter server program. - - Use different methods to split trainable variables to different - parameter servers. - - Steps to transpile trainer: - 1. split variable to multiple blocks, aligned by product(dim[1:]) (width). - 2. rename splited grad variables to add trainer_id suffix ".trainer_%d". - 3. modify trainer program add split_op to each grad variable. - 4. append send_op to send splited variables to server and fetch - params(splited blocks or origin param) from server. - 5. append concat_op to merge splited blocks to update local weights. - - Steps to transpile pserver: - 1. create new program for parameter server. - 2. create params and grad variables that assigned to current server instance. - 3. create a sub-block in the server side program - 4. append ops that should run on current server instance. - 5. add listen_and_serv op - - :param trainer_id: one unique id for each trainer in a job. - :type trainer_id: int - :param program: program to transpile, default is default_main_program - :type program: Program - :param pservers: parameter server endpoints like "m1:6174,m2:6174" - :type pservers: string - :param trainers: total number of workers/trainers in the job - :type trainers: int - :param split_method: A function to determin how to split variables - to different servers equally. - :type split_method: function - :type sync_mode: boolean default True - :param sync_mode: if sync_mode is set True, it means that dist transpiler - will transpile the program into sync_mode pserver and trainer program. - """ - assert (split_method.__bases__[0] == PSDispatcher) - if program is None: - program = default_main_program() - self.origin_program = program - self.trainer_num = trainers - self.sync_mode = sync_mode - # TODO(typhoonzero): currently trainer_id is fetched from cluster system - # like Kubernetes, we should port this to use etcd later when developing - # fluid distributed training with fault-tolerance. - self.trainer_id = trainer_id - pserver_endpoints = pservers.split(",") - self.pserver_endpoints = pserver_endpoints - self.optimize_ops, params_grads = self._get_optimize_pass() - ps_dispatcher = split_method(pserver_endpoints) - + def _has_distributed_lookup_table(self): # process lookup_table_op # 1. check all lookup_table_op is distributed # 2. check all lookup_table_op share the same table. distributed_lookup_table_ops = [] # support only one distributed_lookup_table now self.table_name = None - for op in program.global_block().ops: + for op in self.origin_program.global_block().ops: if op.type == LOOKUP_TABLE_TYPE: if op.attrs['is_distributed'] is True: if self.table_name is None: @@ -246,20 +137,13 @@ class DistributeTranspiler: if self.table_name is not None: assert op.input("W")[0] != self.table_name - self.has_distributed_lookup_table = len( - distributed_lookup_table_ops) > 0 - - # step1: For large parameters and gradients, split them into smaller - # blocks. - param_list = [] - grad_list = [] - for p, g in params_grads: - # skip parameter marked not trainable - if type(p) == Parameter and p.trainable == False: - continue - param_list.append(p) - grad_list.append(g) + return len(distributed_lookup_table_ops) > 0 + def _update_dist_lookup_table_vars(self, param_list, grad_list, + params_grads): + # TODO(wuyi): put find a way to put dist lookup table stuff all together. + # update self.table_param_grad and self.trainer_side_table_grad_list + program = self.origin_program if self.has_distributed_lookup_table: param_list = [ param for param in param_list if param.name != self.table_name @@ -277,7 +161,7 @@ class DistributeTranspiler: self.trainer_side_table_grad_list = [ program.global_block().create_var( name="%s.trainer_%d.pserver_%d" % - (table_grad_var.name, trainer_id, index), + (table_grad_var.name, self.trainer_id, index), type=table_grad_var.type, shape=table_grad_var.shape, dtype=table_grad_var.dtype) @@ -293,23 +177,41 @@ class DistributeTranspiler: for index in range(len(self.pserver_endpoints)) ] - grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) - param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) + def _init_splited_vars(self, split_method): + # update these mappings for further transpile: + # 1. param_var_mapping: param var name -> [splited params vars] + # 2. grad_var_mapping: grad var name -> [splited grads vars] + # 3. grad_param_mapping: grad.blockx -> param.blockx + # 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []} + + param_list = [] + grad_list = [] + for p, g in self.params_grads: + # skip parameter marked not trainable + if type(p) == Parameter and p.trainable == False: + continue + param_list.append(p) + grad_list.append(g) + + self._update_dist_lookup_table_vars(param_list, grad_list, + self.params_grads) + + grad_blocks = split_variable(grad_list, len(self.pserver_endpoints)) + param_blocks = split_variable(param_list, len(self.pserver_endpoints)) assert (len(grad_blocks) == len(param_blocks)) - # step2: Create new vars for the parameters and gradients blocks and - # add ops to do the split. - param_var_mapping = self._create_vars_from_blocklist(program, - param_blocks) - grad_var_mapping = self._create_vars_from_blocklist( - program, grad_blocks, add_trainer_suffix=self.trainer_num > 1) - grad_param_mapping = dict() + # origin_varname -> [splited_var] + self.param_var_mapping = self._create_vars_from_blocklist( + self.origin_program, param_blocks) + self.grad_var_mapping = self._create_vars_from_blocklist( + self.origin_program, + grad_blocks, + add_trainer_suffix=self.trainer_num > 1) + self.grad_param_mapping = dict() for g, p in zip(grad_blocks, param_blocks): g_name, g_bid, _ = g.split(":") p_name, p_bid, _ = p.split(":") - grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \ - param_var_mapping[p_name][int(p_bid)] - - # step 3: transpile trainer side program, insert recv op and send op. + self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \ + self.param_var_mapping[p_name][int(p_bid)] # create mapping of endpoint -> split var to create pserver side program self.param_grad_ep_mapping = dict() @@ -322,10 +224,50 @@ class DistributeTranspiler: }) for ep in self.pserver_endpoints ] + def transpile(self, + trainer_id, + program=None, + pservers="127.0.0.1:6174", + trainers=1, + split_method=RoundRobin, + sync_mode=True): + """ + :param trainer_id: one unique id for each trainer in a job. + :type trainer_id: int + :param program: program to transpile, default is default_main_program + :type program: Program + :param pservers: parameter server endpoints like "m1:6174,m2:6174" + :type pservers: string + :param trainers: total number of workers/trainers in the job + :type trainers: int + :param split_method: A function to determin how to split variables + to different servers equally. + :type split_method: function + :type sync_mode: boolean default True + :param sync_mode: if sync_mode is set True, it means that dist transpiler + will transpile the program into sync_mode pserver and trainer program. + """ + assert (split_method.__bases__[0] == PSDispatcher) + if program is None: + program = default_main_program() + self.origin_program = program + self.trainer_num = trainers + self.sync_mode = sync_mode + self.trainer_id = trainer_id + pserver_endpoints = pservers.split(",") + self.pserver_endpoints = pserver_endpoints + self.optimize_ops, self.params_grads = self._get_optimize_pass() + + ps_dispatcher = split_method(self.pserver_endpoints) + self.has_distributed_lookup_table = self._has_distributed_lookup_table() + + # split and create vars, then put splited vars in dicts for later use. + self._init_splited_vars(split_method) + # step 3.1: insert send op to send gradient vars to parameter servers ps_dispatcher.reset() send_vars = [] - for orig_varname, splited_vars in grad_var_mapping.items(): + for orig_varname, splited_vars in self.grad_var_mapping.items(): eplist = ps_dispatcher.dispatch(splited_vars) if len(splited_vars) == 1: orig_varname = splited_vars[0].name @@ -367,7 +309,7 @@ class DistributeTranspiler: # step 3.2: insert recv op to receive parameters from parameter server recv_vars = [] for _, var in enumerate(send_vars): - recv_vars.append(grad_param_mapping[var]) + recv_vars.append(self.grad_param_mapping[var]) ps_dispatcher.reset() eplist = ps_dispatcher.dispatch(recv_vars) @@ -375,7 +317,7 @@ class DistributeTranspiler: self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) # step4: Concat the parameters splits together after recv. - for varname, splited_var in param_var_mapping.iteritems(): + for varname, splited_var in self.param_var_mapping.iteritems(): eps = [] for var in splited_var: index = [v.name for v in recv_vars].index(var.name) @@ -399,7 +341,7 @@ class DistributeTranspiler: RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) - for varname, splited_var in param_var_mapping.iteritems(): + for varname, splited_var in self.param_var_mapping.iteritems(): if len(splited_var) <= 1: continue orig_param = program.global_block().vars[varname] @@ -440,7 +382,6 @@ class DistributeTranspiler: # we don't need to create them when grad arrives. # change client side var name to origin name by # removing ".trainer_%d" suffix - suff_idx = v.name.find(".trainer_") if suff_idx >= 0: orig_var_name = v.name[:suff_idx] @@ -477,24 +418,14 @@ class DistributeTranspiler: # located on current pserver opt_op_on_pserver = [] for _, op in enumerate(self.optimize_ops): - if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op): + if self._is_optimizer_op(op) and self._is_opt_op_on_pserver( + endpoint, op): opt_op_on_pserver.append(op) # step 3.3 # Iterate through the ops, and if an op and the optimize ops # which located on current pserver are in one set, then # append it into the sub program. - # We try to put optimization program run parallelly, assume - # optimization program always looks like: - # - # prevop -> prevop -> opt op -> following op -> following op; -> - # prevop -> prevop -> opt op -> following op -> following op; -> - # global op -> global op - # - # we put operators that can run parallelly to many program blocks. - # in above example, we seperate ops by the ";". Global ops must run - # after all the optimize ops finished. - global_ops = [] # HACK: optimization global ops only used to scale beta1 and beta2 # replace it with dependency engine. @@ -502,12 +433,18 @@ class DistributeTranspiler: if self._is_adam_connected_op(op): global_ops.append(op) - def __append_optimize_op__(op, block, grad_to_block_id): - if self._is_opt_op(op): + def __append_optimize_op__(op, block, grad_to_block_id, merged_var): + if self._is_optimizer_op(op): self._append_pserver_ops(block, op, endpoint, grad_to_block_id, - self.origin_program) + self.origin_program, merged_var) else: - self._append_pserver_non_opt_ops(block, op) + self._append_pserver_non_opt_ops(block, op, endpoint) + + def __op_have_grad_input__(op): + for varname in op.input_arg_names: + if varname.find("@GRAD") >= 0: + return varname + return "" # append lr decay ops to the child block if exists lr_ops = self._get_lr_ops() @@ -515,17 +452,26 @@ class DistributeTranspiler: lr_decay_block = pserver_program.create_block( pserver_program.num_blocks - 1) for _, op in enumerate(lr_ops): - self._append_pserver_non_opt_ops(lr_decay_block, op) + self._append_pserver_non_opt_ops(lr_decay_block, op, endpoint) # append op to the current block grad_to_block_id = [] pre_block_idx = pserver_program.num_blocks - 1 for idx, opt_op in enumerate(opt_op_on_pserver): per_opt_block = pserver_program.create_block(pre_block_idx) + # append grad merging ops before clip and weight decay + for _, op in enumerate(self.optimize_ops): + # find the origin @GRAD var before clipping + grad_varname_for_block = __op_have_grad_input__(op) + if ufind.is_connected(op, opt_op) and grad_varname_for_block: + merged_var = self._append_pserver_grad_merge_ops( + per_opt_block, grad_varname_for_block, endpoint, + grad_to_block_id, self.origin_program) for _, op in enumerate(self.optimize_ops): # optimizer is connected to itself if ufind.is_connected(op, opt_op) and op not in global_ops: - __append_optimize_op__(op, per_opt_block, grad_to_block_id) + __append_optimize_op__(op, per_opt_block, grad_to_block_id, + merged_var) # append global ops if global_ops: @@ -533,15 +479,7 @@ class DistributeTranspiler: pserver_program.num_blocks - 1) for glb_op in global_ops: __append_optimize_op__(glb_op, opt_state_block, - grad_to_block_id) - - # NOT USED: single block version: - # - # for _, op in enumerate(self.optimize_ops): - # for _, opt_op in enumerate(opt_op_on_pserver): - # if ufind.is_connected(op, opt_op): - # __append_optimize_op__(glb_op, optimize_block) - # break + grad_to_block_id, None) # process distributed lookup_table prefetch_block = None @@ -631,6 +569,8 @@ class DistributeTranspiler: attrs=op.attrs) return s_prog + # ====================== private transpiler functions ===================== + # transpiler function for dis lookup_table def _replace_lookup_table_op_with_prefetch(self, program, pserver_endpoints): @@ -836,7 +776,6 @@ class DistributeTranspiler: return table_opt_block - # ====================== private transpiler functions ===================== def _create_vars_from_blocklist(self, program, block_list, @@ -979,17 +918,74 @@ class DistributeTranspiler: pass return orig_shape - def _orig_varname(self, varname): - suff_idx = varname.find(".trainer_") + def _get_varname_parts(self, varname): + # returns origin, blockid, trainerid orig_var_name = "" - if suff_idx >= 0: - orig_var_name = varname[:suff_idx] + trainer_part = "" + block_part = "" + trainer_idx = varname.find(".trainer_") + if trainer_idx >= 0: + trainer_part = varname[trainer_idx + 1:] + else: + trainer_idx = len(varname) + block_index = varname.find(".block") + if block_index >= 0: + block_part = varname[block_index + 1:trainer_idx] else: - orig_var_name = varname - return orig_var_name + block_index = len(varname) + orig_var_name = varname[0:min(block_index, trainer_idx)] + return orig_var_name, block_part, trainer_part + + def _orig_varname(self, varname): + orig, _, _ = self._get_varname_parts(varname) + return orig + + def _append_pserver_grad_merge_ops(self, optimize_block, + grad_varname_for_block, endpoint, + grad_to_block_id, origin_program): + program = optimize_block.program + pserver_block = program.global_block() + grad_block = None + for g in self.param_grad_ep_mapping[endpoint]["grads"]: + if self._orig_varname(g.name) == \ + self._orig_varname(grad_varname_for_block): + grad_block = g + break + if not grad_block: + # do not append this op if current endpoint + # is not dealing with this grad block + return + orig_varname, block_name, trainer_name = self._get_varname_parts( + grad_block.name) + if block_name: + merged_var_name = '.'.join([orig_varname, block_name]) + else: + merged_var_name = orig_varname + merged_var = \ + pserver_block.vars[merged_var_name] + grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx)) + if self.sync_mode and self.trainer_num > 1: + vars2merge = [] + for i in xrange(self.trainer_num): + per_trainer_name = "%s.trainer_%d" % \ + (merged_var_name, i) + vars2merge.append(pserver_block.vars[per_trainer_name]) + + optimize_block.append_op( + type="sum", + inputs={"X": vars2merge}, + outputs={"Out": merged_var}) + # TODO(panyx0718): What if it's SELECTED_ROWS. + if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS: + optimize_block.append_op( + type="scale", + inputs={"X": merged_var}, + outputs={"Out": merged_var}, + attrs={"scale": 1.0 / float(self.trainer_num)}) + return merged_var def _append_pserver_ops(self, optimize_block, opt_op, endpoint, - grad_to_block_id, origin_program): + grad_to_block_id, origin_program, merged_var): program = optimize_block.program pserver_block = program.global_block() new_inputs = dict() @@ -997,40 +993,6 @@ class DistributeTranspiler: # moment can use the updated shape for key in opt_op.input_names: if key == "Grad": - grad_block = None - for g in self.param_grad_ep_mapping[endpoint]["grads"]: - if same_or_split_var( - self._orig_varname(g.name), - self._orig_varname(opt_op.input(key)[0])): - grad_block = g - break - if not grad_block: - # do not append this op if current endpoint - # is not dealing with this grad block - return - merged_var = \ - pserver_block.vars[self._orig_varname(grad_block.name)] - grad_to_block_id.append(merged_var.name + ":" + str( - optimize_block.idx)) - if self.sync_mode and self.trainer_num > 1: - vars2merge = [] - for i in xrange(self.trainer_num): - per_trainer_name = "%s.trainer_%d" % \ - (self._orig_varname(grad_block.name), i) - vars2merge.append(pserver_block.vars[per_trainer_name]) - - optimize_block.append_op( - type="sum", - inputs={"X": vars2merge}, - outputs={"Out": merged_var}) - # TODO(panyx0718): What if it's SELECTED_ROWS. - if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS: - optimize_block.append_op( - type="scale", - inputs={"X": merged_var}, - outputs={"Out": merged_var}, - attrs={"scale": 1.0 / float(self.trainer_num)}) - new_inputs[key] = merged_var elif key == "Param": # param is already created on global program @@ -1089,17 +1051,31 @@ class DistributeTranspiler: outputs=outputs, attrs=opt_op.attrs) - def _append_pserver_non_opt_ops(self, optimize_block, opt_op): + def _is_splited_grad_var(self, var, var_dict): + grad_block = None + for _, g in var_dict.iteritems(): + if self._orig_varname(g.name) == self._orig_varname(var.name): + if g.name.find(".trainer_") == -1: + grad_block = g + break + return grad_block + + def _append_pserver_non_opt_ops(self, optimize_block, opt_op, endpoint): program = optimize_block.program # Append the ops for parameters that do not need to be optimized/updated inputs = self._get_input_map_from_op( self.origin_program.global_block().vars, opt_op) - for varlist in inputs.itervalues(): + for key, varlist in inputs.iteritems(): if not isinstance(varlist, list): varlist = [varlist] - for var in varlist: - if not program.global_block().vars.has_key(var.name): + # for ops like clipping and weight decay, get the splited var + # for inputs/outputs + grad_block = self._is_splited_grad_var( + var, program.global_block().vars) + if grad_block: + inputs[key] = grad_block + elif not program.global_block().vars.has_key(var.name): program.global_block().create_var( name=var.name, persistable=var.persistable, @@ -1108,13 +1084,16 @@ class DistributeTranspiler: outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, opt_op) - - for varlist in outputs.itervalues(): + for key, varlist in outputs.iteritems(): if not isinstance(varlist, list): varlist = [varlist] - for var in varlist: - program.global_block().clone_variable(var) + grad_block = self._is_splited_grad_var( + var, program.global_block().vars) + if grad_block: + outputs[key] = grad_block + elif not program.global_block().vars.has_key(var.name): + program.global_block().clone_variable(var) optimize_block.append_op( type=opt_op.type, @@ -1160,9 +1139,17 @@ class DistributeTranspiler: ufind.union(op1, op2) return ufind - def _is_opt_op(self, op): - # NOTE: It's a HACK implement. - # optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc... + def _is_opt_role_op(self, op): + # NOTE: depend on oprole to find out whether this op is for + # optimize + op_maker = core.op_proto_and_checker_maker + optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize + if op_maker.kOpRoleAttrName() in op.attrs and \ + int(op.attrs[op_maker.kOpRoleAttrName()]) == int(optimize_role): + return True + return False + + def _is_optimizer_op(self, op): if "Param" in op.input_names and \ "LearningRate" in op.input_names: return True @@ -1212,7 +1199,7 @@ class DistributeTranspiler: # find learning rate variables by optimize op lr_vars = set() for op in self.optimize_ops: - if self._is_opt_op(op): + if self._is_optimizer_op(op): lr_vars.add(op.input("LearningRate")[0]) find_ops = [] @@ -1229,7 +1216,7 @@ class DistributeTranspiler: # NOTE: we need to skip all optimize ops, since it is connected # with forward/backward ops and lr ops, we only need the lr ops. if op1 != op2 and self._is_op_connected(op1, op2) and \ - not self._is_opt_op(op1) and not self._is_opt_op(op2): + not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2): ufind.union(op1, op2) # find all ops which is related with lr var for op1 in block.ops: @@ -1250,13 +1237,21 @@ class DistributeTranspiler: block = self.origin_program.global_block() opt_ops = [] params_grads = [] + origin_var_dict = self.origin_program.global_block().vars for op in block.ops: - if self._is_opt_op(op): + if self._is_opt_role_op(op): opt_ops.append(op) - params_grads.append((self.origin_program.global_block().var( - op.input("Param")[0]), - self.origin_program.global_block().var( - op.input("Grad")[0]))) + # HACK(wuyi): if we find grad vars from input of optimize + # ops, we may get the output of clip op. Use syntax "@GRAD" + # and op_role_var to get the pair. + for input_name in op.input_arg_names: + if input_name.find("@GRAD") != -1 and \ + op.attrs[RPC_OP_ROLE_ATTR_NAME]: + param_name = op.attrs[OP_ROLE_VAR_ATTR_NAME][0] + params_grads.append([ + origin_var_dict[param_name], + origin_var_dict[input_name] + ]) elif self._is_adam_connected_op(op): opt_ops.append(op) else: -- GitLab