未验证 提交 d9061062 编写于 作者: W Wu Yi 提交者: GitHub

Cleanup transpiler and move weight decay and clip on pservers (#11039)

* WIP move weight decay

* weight decay ok

* wip

* clean up transpiler

* add details folder

* update

* fix split var test

* follow comments
上级 1af0b28c
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import math import math
import unittest import unittest
from paddle.fluid.transpiler.distribute_transpiler import split_dense_variable from paddle.fluid.transpiler.distribute_transpiler import split_variable
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import random import random
...@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase): ...@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
# dtype=core.VarDesc.VarType.LOD_TENSOR, # dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape) shape=shape)
var_list.append(var) var_list.append(var)
blocks = split_dense_variable(var_list, 10, min_size) blocks = split_variable(var_list, 10, min_size)
all_sizes = [] all_sizes = []
for s in expected_sizes: for s in expected_sizes:
for s2 in s: for s2 in s:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from program_utils import *
from ufind import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
def find_op_by_input_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.input_arg_names:
return index
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class UnionFind(object):
""" Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def __init__(self, elementes=None):
self._parents = [] # index -> parent index
self._index = {} # element -> index
self._curr_idx = 0
if not elementes:
elementes = []
for ele in elementes:
self._parents.append(self._curr_idx)
self._index.update({ele: self._curr_idx})
self._curr_idx += 1
def find(self, x):
# Find the root index of given element x,
# execute the path compress while findind the root index
if not x in self._index:
return -1
idx = self._index[x]
while idx != self._parents[idx]:
t = self._parents[idx]
self._parents[idx] = self._parents[t]
idx = t
return idx
def union(self, x, y):
# Union two given element
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
self._parents[x_root] = y_root
def is_connected(self, x, y):
# If two given elements have the same root index,
# then they are connected.
return self.find(x) == self.find(y)
...@@ -11,6 +11,30 @@ ...@@ -11,6 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
Use different methods to split trainable variables to different
parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
"""
from __future__ import print_function from __future__ import print_function
...@@ -21,9 +45,11 @@ from .. import core, framework ...@@ -21,9 +45,11 @@ from .. import core, framework
from ..framework import Program, default_main_program, \ from ..framework import Program, default_main_program, \
default_startup_program, \ default_startup_program, \
Variable, Parameter, grad_var_name Variable, Parameter, grad_var_name
from details import *
LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
) )
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
...@@ -40,62 +66,11 @@ class VarBlock: ...@@ -40,62 +66,11 @@ class VarBlock:
return "%s:%d:%d" % (self.varname, self.offset, self.size) return "%s:%d:%d" % (self.varname, self.offset, self.size)
class UnionFind(object):
""" Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def __init__(self, elementes=None):
self._parents = [] # index -> parent index
self._index = {} # element -> index
self._curr_idx = 0
if not elementes:
elementes = []
for ele in elementes:
self._parents.append(self._curr_idx)
self._index.update({ele: self._curr_idx})
self._curr_idx += 1
def find(self, x):
# Find the root index of given element x,
# execute the path compress while findind the root index
if not x in self._index:
return -1
idx = self._index[x]
while idx != self._parents[idx]:
t = self._parents[idx]
self._parents[idx] = self._parents[t]
idx = t
return idx
def union(self, x, y):
# Union two given element
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
self._parents[x_root] = y_root
def is_connected(self, x, y):
# If two given elements have the same root index,
# then they are connected.
return self.find(x) == self.find(y)
def same_or_split_var(p_name, var_name): def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block") return p_name == var_name or p_name.startswith(var_name + ".block")
def split_dense_variable(var_list, service_count, min_block_size=8192): def split_variable(var_list, service_count, min_block_size=8192):
""" """
We may need to split dense tensor to one or more blocks and put We may need to split dense tensor to one or more blocks and put
them equally onto parameter server. One block is a sub-tensor them equally onto parameter server. One block is a sub-tensor
...@@ -141,99 +116,15 @@ def split_dense_variable(var_list, service_count, min_block_size=8192): ...@@ -141,99 +116,15 @@ def split_dense_variable(var_list, service_count, min_block_size=8192):
return blocks return blocks
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
def find_op_by_input_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.input_arg_names:
return index
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
class DistributeTranspiler: class DistributeTranspiler:
def transpile(self, def _has_distributed_lookup_table(self):
trainer_id,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=RoundRobin,
sync_mode=True):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
Use different methods to split trainable variables to different
parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert (split_method.__bases__[0] == PSDispatcher)
if program is None:
program = default_main_program()
self.origin_program = program
self.trainer_num = trainers
self.sync_mode = sync_mode
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance.
self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
self.optimize_ops, params_grads = self._get_optimize_pass()
ps_dispatcher = split_method(pserver_endpoints)
# process lookup_table_op # process lookup_table_op
# 1. check all lookup_table_op is distributed # 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table. # 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops = [] distributed_lookup_table_ops = []
# support only one distributed_lookup_table now # support only one distributed_lookup_table now
self.table_name = None self.table_name = None
for op in program.global_block().ops: for op in self.origin_program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE: if op.type == LOOKUP_TABLE_TYPE:
if op.attrs['is_distributed'] is True: if op.attrs['is_distributed'] is True:
if self.table_name is None: if self.table_name is None:
...@@ -246,20 +137,13 @@ class DistributeTranspiler: ...@@ -246,20 +137,13 @@ class DistributeTranspiler:
if self.table_name is not None: if self.table_name is not None:
assert op.input("W")[0] != self.table_name assert op.input("W")[0] != self.table_name
self.has_distributed_lookup_table = len( return len(distributed_lookup_table_ops) > 0
distributed_lookup_table_ops) > 0
# step1: For large parameters and gradients, split them into smaller
# blocks.
param_list = []
grad_list = []
for p, g in params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
param_list.append(p)
grad_list.append(g)
def _update_dist_lookup_table_vars(self, param_list, grad_list,
params_grads):
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
# update self.table_param_grad and self.trainer_side_table_grad_list
program = self.origin_program
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
param_list = [ param_list = [
param for param in param_list if param.name != self.table_name param for param in param_list if param.name != self.table_name
...@@ -277,7 +161,7 @@ class DistributeTranspiler: ...@@ -277,7 +161,7 @@ class DistributeTranspiler:
self.trainer_side_table_grad_list = [ self.trainer_side_table_grad_list = [
program.global_block().create_var( program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" % name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, trainer_id, index), (table_grad_var.name, self.trainer_id, index),
type=table_grad_var.type, type=table_grad_var.type,
shape=table_grad_var.shape, shape=table_grad_var.shape,
dtype=table_grad_var.dtype) dtype=table_grad_var.dtype)
...@@ -293,23 +177,41 @@ class DistributeTranspiler: ...@@ -293,23 +177,41 @@ class DistributeTranspiler:
for index in range(len(self.pserver_endpoints)) for index in range(len(self.pserver_endpoints))
] ]
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) def _init_splited_vars(self, split_method):
param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) # update these mappings for further transpile:
# 1. param_var_mapping: param var name -> [splited params vars]
# 2. grad_var_mapping: grad var name -> [splited grads vars]
# 3. grad_param_mapping: grad.blockx -> param.blockx
# 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []}
param_list = []
grad_list = []
for p, g in self.params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
param_list.append(p)
grad_list.append(g)
self._update_dist_lookup_table_vars(param_list, grad_list,
self.params_grads)
grad_blocks = split_variable(grad_list, len(self.pserver_endpoints))
param_blocks = split_variable(param_list, len(self.pserver_endpoints))
assert (len(grad_blocks) == len(param_blocks)) assert (len(grad_blocks) == len(param_blocks))
# step2: Create new vars for the parameters and gradients blocks and # origin_varname -> [splited_var]
# add ops to do the split. self.param_var_mapping = self._create_vars_from_blocklist(
param_var_mapping = self._create_vars_from_blocklist(program, self.origin_program, param_blocks)
param_blocks) self.grad_var_mapping = self._create_vars_from_blocklist(
grad_var_mapping = self._create_vars_from_blocklist( self.origin_program,
program, grad_blocks, add_trainer_suffix=self.trainer_num > 1) grad_blocks,
grad_param_mapping = dict() add_trainer_suffix=self.trainer_num > 1)
self.grad_param_mapping = dict()
for g, p in zip(grad_blocks, param_blocks): for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":") g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":") p_name, p_bid, _ = p.split(":")
grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \ self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
param_var_mapping[p_name][int(p_bid)] self.param_var_mapping[p_name][int(p_bid)]
# step 3: transpile trainer side program, insert recv op and send op.
# create mapping of endpoint -> split var to create pserver side program # create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict() self.param_grad_ep_mapping = dict()
...@@ -322,10 +224,50 @@ class DistributeTranspiler: ...@@ -322,10 +224,50 @@ class DistributeTranspiler:
}) for ep in self.pserver_endpoints }) for ep in self.pserver_endpoints
] ]
def transpile(self,
trainer_id,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=RoundRobin,
sync_mode=True):
"""
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert (split_method.__bases__[0] == PSDispatcher)
if program is None:
program = default_main_program()
self.origin_program = program
self.trainer_num = trainers
self.sync_mode = sync_mode
self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
self.optimize_ops, self.params_grads = self._get_optimize_pass()
ps_dispatcher = split_method(self.pserver_endpoints)
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
# split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars(split_method)
# step 3.1: insert send op to send gradient vars to parameter servers # step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher.reset() ps_dispatcher.reset()
send_vars = [] send_vars = []
for orig_varname, splited_vars in grad_var_mapping.items(): for orig_varname, splited_vars in self.grad_var_mapping.items():
eplist = ps_dispatcher.dispatch(splited_vars) eplist = ps_dispatcher.dispatch(splited_vars)
if len(splited_vars) == 1: if len(splited_vars) == 1:
orig_varname = splited_vars[0].name orig_varname = splited_vars[0].name
...@@ -367,7 +309,7 @@ class DistributeTranspiler: ...@@ -367,7 +309,7 @@ class DistributeTranspiler:
# step 3.2: insert recv op to receive parameters from parameter server # step 3.2: insert recv op to receive parameters from parameter server
recv_vars = [] recv_vars = []
for _, var in enumerate(send_vars): for _, var in enumerate(send_vars):
recv_vars.append(grad_param_mapping[var]) recv_vars.append(self.grad_param_mapping[var])
ps_dispatcher.reset() ps_dispatcher.reset()
eplist = ps_dispatcher.dispatch(recv_vars) eplist = ps_dispatcher.dispatch(recv_vars)
...@@ -375,7 +317,7 @@ class DistributeTranspiler: ...@@ -375,7 +317,7 @@ class DistributeTranspiler:
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in self.param_var_mapping.iteritems():
eps = [] eps = []
for var in splited_var: for var in splited_var:
index = [v.name for v in recv_vars].index(var.name) index = [v.name for v in recv_vars].index(var.name)
...@@ -399,7 +341,7 @@ class DistributeTranspiler: ...@@ -399,7 +341,7 @@ class DistributeTranspiler:
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in self.param_var_mapping.iteritems():
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
orig_param = program.global_block().vars[varname] orig_param = program.global_block().vars[varname]
...@@ -440,7 +382,6 @@ class DistributeTranspiler: ...@@ -440,7 +382,6 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives. # we don't need to create them when grad arrives.
# change client side var name to origin name by # change client side var name to origin name by
# removing ".trainer_%d" suffix # removing ".trainer_%d" suffix
suff_idx = v.name.find(".trainer_") suff_idx = v.name.find(".trainer_")
if suff_idx >= 0: if suff_idx >= 0:
orig_var_name = v.name[:suff_idx] orig_var_name = v.name[:suff_idx]
...@@ -477,24 +418,14 @@ class DistributeTranspiler: ...@@ -477,24 +418,14 @@ class DistributeTranspiler:
# located on current pserver # located on current pserver
opt_op_on_pserver = [] opt_op_on_pserver = []
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op): if self._is_optimizer_op(op) and self._is_opt_op_on_pserver(
endpoint, op):
opt_op_on_pserver.append(op) opt_op_on_pserver.append(op)
# step 3.3 # step 3.3
# Iterate through the ops, and if an op and the optimize ops # Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then # which located on current pserver are in one set, then
# append it into the sub program. # append it into the sub program.
# We try to put optimization program run parallelly, assume
# optimization program always looks like:
#
# prevop -> prevop -> opt op -> following op -> following op; ->
# prevop -> prevop -> opt op -> following op -> following op; ->
# global op -> global op
#
# we put operators that can run parallelly to many program blocks.
# in above example, we seperate ops by the ";". Global ops must run
# after all the optimize ops finished.
global_ops = [] global_ops = []
# HACK: optimization global ops only used to scale beta1 and beta2 # HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine. # replace it with dependency engine.
...@@ -502,12 +433,18 @@ class DistributeTranspiler: ...@@ -502,12 +433,18 @@ class DistributeTranspiler:
if self._is_adam_connected_op(op): if self._is_adam_connected_op(op):
global_ops.append(op) global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id): def __append_optimize_op__(op, block, grad_to_block_id, merged_var):
if self._is_opt_op(op): if self._is_optimizer_op(op):
self._append_pserver_ops(block, op, endpoint, grad_to_block_id, self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
self.origin_program) self.origin_program, merged_var)
else: else:
self._append_pserver_non_opt_ops(block, op) self._append_pserver_non_opt_ops(block, op, endpoint)
def __op_have_grad_input__(op):
for varname in op.input_arg_names:
if varname.find("@GRAD") >= 0:
return varname
return ""
# append lr decay ops to the child block if exists # append lr decay ops to the child block if exists
lr_ops = self._get_lr_ops() lr_ops = self._get_lr_ops()
...@@ -515,17 +452,26 @@ class DistributeTranspiler: ...@@ -515,17 +452,26 @@ class DistributeTranspiler:
lr_decay_block = pserver_program.create_block( lr_decay_block = pserver_program.create_block(
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
for _, op in enumerate(lr_ops): for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op) self._append_pserver_non_opt_ops(lr_decay_block, op, endpoint)
# append op to the current block # append op to the current block
grad_to_block_id = [] grad_to_block_id = []
pre_block_idx = pserver_program.num_blocks - 1 pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver): for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx) per_opt_block = pserver_program.create_block(pre_block_idx)
# append grad merging ops before clip and weight decay
for _, op in enumerate(self.optimize_ops):
# find the origin @GRAD var before clipping
grad_varname_for_block = __op_have_grad_input__(op)
if ufind.is_connected(op, opt_op) and grad_varname_for_block:
merged_var = self._append_pserver_grad_merge_ops(
per_opt_block, grad_varname_for_block, endpoint,
grad_to_block_id, self.origin_program)
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself # optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops: if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block, grad_to_block_id) __append_optimize_op__(op, per_opt_block, grad_to_block_id,
merged_var)
# append global ops # append global ops
if global_ops: if global_ops:
...@@ -533,15 +479,7 @@ class DistributeTranspiler: ...@@ -533,15 +479,7 @@ class DistributeTranspiler:
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
for glb_op in global_ops: for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block, __append_optimize_op__(glb_op, opt_state_block,
grad_to_block_id) grad_to_block_id, None)
# NOT USED: single block version:
#
# for _, op in enumerate(self.optimize_ops):
# for _, opt_op in enumerate(opt_op_on_pserver):
# if ufind.is_connected(op, opt_op):
# __append_optimize_op__(glb_op, optimize_block)
# break
# process distributed lookup_table # process distributed lookup_table
prefetch_block = None prefetch_block = None
...@@ -631,6 +569,8 @@ class DistributeTranspiler: ...@@ -631,6 +569,8 @@ class DistributeTranspiler:
attrs=op.attrs) attrs=op.attrs)
return s_prog return s_prog
# ====================== private transpiler functions =====================
# transpiler function for dis lookup_table # transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints): pserver_endpoints):
...@@ -836,7 +776,6 @@ class DistributeTranspiler: ...@@ -836,7 +776,6 @@ class DistributeTranspiler:
return table_opt_block return table_opt_block
# ====================== private transpiler functions =====================
def _create_vars_from_blocklist(self, def _create_vars_from_blocklist(self,
program, program,
block_list, block_list,
...@@ -979,44 +918,57 @@ class DistributeTranspiler: ...@@ -979,44 +918,57 @@ class DistributeTranspiler:
pass pass
return orig_shape return orig_shape
def _orig_varname(self, varname): def _get_varname_parts(self, varname):
suff_idx = varname.find(".trainer_") # returns origin, blockid, trainerid
orig_var_name = "" orig_var_name = ""
if suff_idx >= 0: trainer_part = ""
orig_var_name = varname[:suff_idx] block_part = ""
trainer_idx = varname.find(".trainer_")
if trainer_idx >= 0:
trainer_part = varname[trainer_idx + 1:]
else:
trainer_idx = len(varname)
block_index = varname.find(".block")
if block_index >= 0:
block_part = varname[block_index + 1:trainer_idx]
else: else:
orig_var_name = varname block_index = len(varname)
return orig_var_name orig_var_name = varname[0:min(block_index, trainer_idx)]
return orig_var_name, block_part, trainer_part
def _append_pserver_ops(self, optimize_block, opt_op, endpoint, def _orig_varname(self, varname):
orig, _, _ = self._get_varname_parts(varname)
return orig
def _append_pserver_grad_merge_ops(self, optimize_block,
grad_varname_for_block, endpoint,
grad_to_block_id, origin_program): grad_to_block_id, origin_program):
program = optimize_block.program program = optimize_block.program
pserver_block = program.global_block() pserver_block = program.global_block()
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
for key in opt_op.input_names:
if key == "Grad":
grad_block = None grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]: for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if same_or_split_var( if self._orig_varname(g.name) == \
self._orig_varname(g.name), self._orig_varname(grad_varname_for_block):
self._orig_varname(opt_op.input(key)[0])):
grad_block = g grad_block = g
break break
if not grad_block: if not grad_block:
# do not append this op if current endpoint # do not append this op if current endpoint
# is not dealing with this grad block # is not dealing with this grad block
return return
orig_varname, block_name, trainer_name = self._get_varname_parts(
grad_block.name)
if block_name:
merged_var_name = '.'.join([orig_varname, block_name])
else:
merged_var_name = orig_varname
merged_var = \ merged_var = \
pserver_block.vars[self._orig_varname(grad_block.name)] pserver_block.vars[merged_var_name]
grad_to_block_id.append(merged_var.name + ":" + str( grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx))
optimize_block.idx))
if self.sync_mode and self.trainer_num > 1: if self.sync_mode and self.trainer_num > 1:
vars2merge = [] vars2merge = []
for i in xrange(self.trainer_num): for i in xrange(self.trainer_num):
per_trainer_name = "%s.trainer_%d" % \ per_trainer_name = "%s.trainer_%d" % \
(self._orig_varname(grad_block.name), i) (merged_var_name, i)
vars2merge.append(pserver_block.vars[per_trainer_name]) vars2merge.append(pserver_block.vars[per_trainer_name])
optimize_block.append_op( optimize_block.append_op(
...@@ -1030,7 +982,17 @@ class DistributeTranspiler: ...@@ -1030,7 +982,17 @@ class DistributeTranspiler:
inputs={"X": merged_var}, inputs={"X": merged_var},
outputs={"Out": merged_var}, outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainer_num)}) attrs={"scale": 1.0 / float(self.trainer_num)})
return merged_var
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
grad_to_block_id, origin_program, merged_var):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
for key in opt_op.input_names:
if key == "Grad":
new_inputs[key] = merged_var new_inputs[key] = merged_var
elif key == "Param": elif key == "Param":
# param is already created on global program # param is already created on global program
...@@ -1089,17 +1051,31 @@ class DistributeTranspiler: ...@@ -1089,17 +1051,31 @@ class DistributeTranspiler:
outputs=outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op): def _is_splited_grad_var(self, var, var_dict):
grad_block = None
for _, g in var_dict.iteritems():
if self._orig_varname(g.name) == self._orig_varname(var.name):
if g.name.find(".trainer_") == -1:
grad_block = g
break
return grad_block
def _append_pserver_non_opt_ops(self, optimize_block, opt_op, endpoint):
program = optimize_block.program program = optimize_block.program
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op( inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for varlist in inputs.itervalues(): for key, varlist in inputs.iteritems():
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
if not program.global_block().vars.has_key(var.name): # for ops like clipping and weight decay, get the splited var
# for inputs/outputs
grad_block = self._is_splited_grad_var(
var, program.global_block().vars)
if grad_block:
inputs[key] = grad_block
elif not program.global_block().vars.has_key(var.name):
program.global_block().create_var( program.global_block().create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
...@@ -1108,12 +1084,15 @@ class DistributeTranspiler: ...@@ -1108,12 +1084,15 @@ class DistributeTranspiler:
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for key, varlist in outputs.iteritems():
for varlist in outputs.itervalues():
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
grad_block = self._is_splited_grad_var(
var, program.global_block().vars)
if grad_block:
outputs[key] = grad_block
elif not program.global_block().vars.has_key(var.name):
program.global_block().clone_variable(var) program.global_block().clone_variable(var)
optimize_block.append_op( optimize_block.append_op(
...@@ -1160,9 +1139,17 @@ class DistributeTranspiler: ...@@ -1160,9 +1139,17 @@ class DistributeTranspiler:
ufind.union(op1, op2) ufind.union(op1, op2)
return ufind return ufind
def _is_opt_op(self, op): def _is_opt_role_op(self, op):
# NOTE: It's a HACK implement. # NOTE: depend on oprole to find out whether this op is for
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc... # optimize
op_maker = core.op_proto_and_checker_maker
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
if op_maker.kOpRoleAttrName() in op.attrs and \
int(op.attrs[op_maker.kOpRoleAttrName()]) == int(optimize_role):
return True
return False
def _is_optimizer_op(self, op):
if "Param" in op.input_names and \ if "Param" in op.input_names and \
"LearningRate" in op.input_names: "LearningRate" in op.input_names:
return True return True
...@@ -1212,7 +1199,7 @@ class DistributeTranspiler: ...@@ -1212,7 +1199,7 @@ class DistributeTranspiler:
# find learning rate variables by optimize op # find learning rate variables by optimize op
lr_vars = set() lr_vars = set()
for op in self.optimize_ops: for op in self.optimize_ops:
if self._is_opt_op(op): if self._is_optimizer_op(op):
lr_vars.add(op.input("LearningRate")[0]) lr_vars.add(op.input("LearningRate")[0])
find_ops = [] find_ops = []
...@@ -1229,7 +1216,7 @@ class DistributeTranspiler: ...@@ -1229,7 +1216,7 @@ class DistributeTranspiler:
# NOTE: we need to skip all optimize ops, since it is connected # NOTE: we need to skip all optimize ops, since it is connected
# with forward/backward ops and lr ops, we only need the lr ops. # with forward/backward ops and lr ops, we only need the lr ops.
if op1 != op2 and self._is_op_connected(op1, op2) and \ if op1 != op2 and self._is_op_connected(op1, op2) and \
not self._is_opt_op(op1) and not self._is_opt_op(op2): not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2):
ufind.union(op1, op2) ufind.union(op1, op2)
# find all ops which is related with lr var # find all ops which is related with lr var
for op1 in block.ops: for op1 in block.ops:
...@@ -1250,13 +1237,21 @@ class DistributeTranspiler: ...@@ -1250,13 +1237,21 @@ class DistributeTranspiler:
block = self.origin_program.global_block() block = self.origin_program.global_block()
opt_ops = [] opt_ops = []
params_grads = [] params_grads = []
origin_var_dict = self.origin_program.global_block().vars
for op in block.ops: for op in block.ops:
if self._is_opt_op(op): if self._is_opt_role_op(op):
opt_ops.append(op) opt_ops.append(op)
params_grads.append((self.origin_program.global_block().var( # HACK(wuyi): if we find grad vars from input of optimize
op.input("Param")[0]), # ops, we may get the output of clip op. Use syntax "@GRAD"
self.origin_program.global_block().var( # and op_role_var to get the pair.
op.input("Grad")[0]))) for input_name in op.input_arg_names:
if input_name.find("@GRAD") != -1 and \
op.attrs[RPC_OP_ROLE_ATTR_NAME]:
param_name = op.attrs[OP_ROLE_VAR_ATTR_NAME][0]
params_grads.append([
origin_var_dict[param_name],
origin_var_dict[input_name]
])
elif self._is_adam_connected_op(op): elif self._is_adam_connected_op(op):
opt_ops.append(op) opt_ops.append(op)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册