未验证 提交 d9061062 编写于 作者: W Wu Yi 提交者: GitHub

Cleanup transpiler and move weight decay and clip on pservers (#11039)

* WIP move weight decay

* weight decay ok

* wip

* clean up transpiler

* add details folder

* update

* fix split var test

* follow comments
上级 1af0b28c
......@@ -14,7 +14,7 @@
import math
import unittest
from paddle.fluid.transpiler.distribute_transpiler import split_dense_variable
from paddle.fluid.transpiler.distribute_transpiler import split_variable
import paddle.fluid as fluid
import paddle.fluid.core as core
import random
......@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
# dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape)
var_list.append(var)
blocks = split_dense_variable(var_list, 10, min_size)
blocks = split_variable(var_list, 10, min_size)
all_sizes = []
for s in expected_sizes:
for s2 in s:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from program_utils import *
from ufind import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
def find_op_by_input_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.input_arg_names:
return index
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class UnionFind(object):
""" Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def __init__(self, elementes=None):
self._parents = [] # index -> parent index
self._index = {} # element -> index
self._curr_idx = 0
if not elementes:
elementes = []
for ele in elementes:
self._parents.append(self._curr_idx)
self._index.update({ele: self._curr_idx})
self._curr_idx += 1
def find(self, x):
# Find the root index of given element x,
# execute the path compress while findind the root index
if not x in self._index:
return -1
idx = self._index[x]
while idx != self._parents[idx]:
t = self._parents[idx]
self._parents[idx] = self._parents[t]
idx = t
return idx
def union(self, x, y):
# Union two given element
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
self._parents[x_root] = y_root
def is_connected(self, x, y):
# If two given elements have the same root index,
# then they are connected.
return self.find(x) == self.find(y)
......@@ -11,6 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
Use different methods to split trainable variables to different
parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
"""
from __future__ import print_function
......@@ -21,9 +45,11 @@ from .. import core, framework
from ..framework import Program, default_main_program, \
default_startup_program, \
Variable, Parameter, grad_var_name
from details import *
LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
......@@ -40,62 +66,11 @@ class VarBlock:
return "%s:%d:%d" % (self.varname, self.offset, self.size)
class UnionFind(object):
""" Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def __init__(self, elementes=None):
self._parents = [] # index -> parent index
self._index = {} # element -> index
self._curr_idx = 0
if not elementes:
elementes = []
for ele in elementes:
self._parents.append(self._curr_idx)
self._index.update({ele: self._curr_idx})
self._curr_idx += 1
def find(self, x):
# Find the root index of given element x,
# execute the path compress while findind the root index
if not x in self._index:
return -1
idx = self._index[x]
while idx != self._parents[idx]:
t = self._parents[idx]
self._parents[idx] = self._parents[t]
idx = t
return idx
def union(self, x, y):
# Union two given element
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
self._parents[x_root] = y_root
def is_connected(self, x, y):
# If two given elements have the same root index,
# then they are connected.
return self.find(x) == self.find(y)
def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block")
def split_dense_variable(var_list, service_count, min_block_size=8192):
def split_variable(var_list, service_count, min_block_size=8192):
"""
We may need to split dense tensor to one or more blocks and put
them equally onto parameter server. One block is a sub-tensor
......@@ -141,99 +116,15 @@ def split_dense_variable(var_list, service_count, min_block_size=8192):
return blocks
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
def find_op_by_input_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.input_arg_names:
return index
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
class DistributeTranspiler:
def transpile(self,
trainer_id,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=RoundRobin,
sync_mode=True):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
Use different methods to split trainable variables to different
parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert (split_method.__bases__[0] == PSDispatcher)
if program is None:
program = default_main_program()
self.origin_program = program
self.trainer_num = trainers
self.sync_mode = sync_mode
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance.
self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
self.optimize_ops, params_grads = self._get_optimize_pass()
ps_dispatcher = split_method(pserver_endpoints)
def _has_distributed_lookup_table(self):
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops = []
# support only one distributed_lookup_table now
self.table_name = None
for op in program.global_block().ops:
for op in self.origin_program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attrs['is_distributed'] is True:
if self.table_name is None:
......@@ -246,20 +137,13 @@ class DistributeTranspiler:
if self.table_name is not None:
assert op.input("W")[0] != self.table_name
self.has_distributed_lookup_table = len(
distributed_lookup_table_ops) > 0
# step1: For large parameters and gradients, split them into smaller
# blocks.
param_list = []
grad_list = []
for p, g in params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
param_list.append(p)
grad_list.append(g)
return len(distributed_lookup_table_ops) > 0
def _update_dist_lookup_table_vars(self, param_list, grad_list,
params_grads):
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
# update self.table_param_grad and self.trainer_side_table_grad_list
program = self.origin_program
if self.has_distributed_lookup_table:
param_list = [
param for param in param_list if param.name != self.table_name
......@@ -277,7 +161,7 @@ class DistributeTranspiler:
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, trainer_id, index),
(table_grad_var.name, self.trainer_id, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
......@@ -293,23 +177,41 @@ class DistributeTranspiler:
for index in range(len(self.pserver_endpoints))
]
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
def _init_splited_vars(self, split_method):
# update these mappings for further transpile:
# 1. param_var_mapping: param var name -> [splited params vars]
# 2. grad_var_mapping: grad var name -> [splited grads vars]
# 3. grad_param_mapping: grad.blockx -> param.blockx
# 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []}
param_list = []
grad_list = []
for p, g in self.params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
param_list.append(p)
grad_list.append(g)
self._update_dist_lookup_table_vars(param_list, grad_list,
self.params_grads)
grad_blocks = split_variable(grad_list, len(self.pserver_endpoints))
param_blocks = split_variable(param_list, len(self.pserver_endpoints))
assert (len(grad_blocks) == len(param_blocks))
# step2: Create new vars for the parameters and gradients blocks and
# add ops to do the split.
param_var_mapping = self._create_vars_from_blocklist(program,
param_blocks)
grad_var_mapping = self._create_vars_from_blocklist(
program, grad_blocks, add_trainer_suffix=self.trainer_num > 1)
grad_param_mapping = dict()
# origin_varname -> [splited_var]
self.param_var_mapping = self._create_vars_from_blocklist(
self.origin_program, param_blocks)
self.grad_var_mapping = self._create_vars_from_blocklist(
self.origin_program,
grad_blocks,
add_trainer_suffix=self.trainer_num > 1)
self.grad_param_mapping = dict()
for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":")
grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \
param_var_mapping[p_name][int(p_bid)]
# step 3: transpile trainer side program, insert recv op and send op.
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
self.param_var_mapping[p_name][int(p_bid)]
# create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict()
......@@ -322,10 +224,50 @@ class DistributeTranspiler:
}) for ep in self.pserver_endpoints
]
def transpile(self,
trainer_id,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=RoundRobin,
sync_mode=True):
"""
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert (split_method.__bases__[0] == PSDispatcher)
if program is None:
program = default_main_program()
self.origin_program = program
self.trainer_num = trainers
self.sync_mode = sync_mode
self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
self.optimize_ops, self.params_grads = self._get_optimize_pass()
ps_dispatcher = split_method(self.pserver_endpoints)
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
# split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars(split_method)
# step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher.reset()
send_vars = []
for orig_varname, splited_vars in grad_var_mapping.items():
for orig_varname, splited_vars in self.grad_var_mapping.items():
eplist = ps_dispatcher.dispatch(splited_vars)
if len(splited_vars) == 1:
orig_varname = splited_vars[0].name
......@@ -367,7 +309,7 @@ class DistributeTranspiler:
# step 3.2: insert recv op to receive parameters from parameter server
recv_vars = []
for _, var in enumerate(send_vars):
recv_vars.append(grad_param_mapping[var])
recv_vars.append(self.grad_param_mapping[var])
ps_dispatcher.reset()
eplist = ps_dispatcher.dispatch(recv_vars)
......@@ -375,7 +317,7 @@ class DistributeTranspiler:
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
for varname, splited_var in self.param_var_mapping.iteritems():
eps = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
......@@ -399,7 +341,7 @@ class DistributeTranspiler:
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in param_var_mapping.iteritems():
for varname, splited_var in self.param_var_mapping.iteritems():
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[varname]
......@@ -440,7 +382,6 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives.
# change client side var name to origin name by
# removing ".trainer_%d" suffix
suff_idx = v.name.find(".trainer_")
if suff_idx >= 0:
orig_var_name = v.name[:suff_idx]
......@@ -477,24 +418,14 @@ class DistributeTranspiler:
# located on current pserver
opt_op_on_pserver = []
for _, op in enumerate(self.optimize_ops):
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op):
if self._is_optimizer_op(op) and self._is_opt_op_on_pserver(
endpoint, op):
opt_op_on_pserver.append(op)
# step 3.3
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
# We try to put optimization program run parallelly, assume
# optimization program always looks like:
#
# prevop -> prevop -> opt op -> following op -> following op; ->
# prevop -> prevop -> opt op -> following op -> following op; ->
# global op -> global op
#
# we put operators that can run parallelly to many program blocks.
# in above example, we seperate ops by the ";". Global ops must run
# after all the optimize ops finished.
global_ops = []
# HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine.
......@@ -502,12 +433,18 @@ class DistributeTranspiler:
if self._is_adam_connected_op(op):
global_ops.append(op)
def __append_optimize_op__(op, block, grad_to_block_id):
if self._is_opt_op(op):
def __append_optimize_op__(op, block, grad_to_block_id, merged_var):
if self._is_optimizer_op(op):
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
self.origin_program)
self.origin_program, merged_var)
else:
self._append_pserver_non_opt_ops(block, op)
self._append_pserver_non_opt_ops(block, op, endpoint)
def __op_have_grad_input__(op):
for varname in op.input_arg_names:
if varname.find("@GRAD") >= 0:
return varname
return ""
# append lr decay ops to the child block if exists
lr_ops = self._get_lr_ops()
......@@ -515,17 +452,26 @@ class DistributeTranspiler:
lr_decay_block = pserver_program.create_block(
pserver_program.num_blocks - 1)
for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op)
self._append_pserver_non_opt_ops(lr_decay_block, op, endpoint)
# append op to the current block
grad_to_block_id = []
pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx)
# append grad merging ops before clip and weight decay
for _, op in enumerate(self.optimize_ops):
# find the origin @GRAD var before clipping
grad_varname_for_block = __op_have_grad_input__(op)
if ufind.is_connected(op, opt_op) and grad_varname_for_block:
merged_var = self._append_pserver_grad_merge_ops(
per_opt_block, grad_varname_for_block, endpoint,
grad_to_block_id, self.origin_program)
for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block, grad_to_block_id)
__append_optimize_op__(op, per_opt_block, grad_to_block_id,
merged_var)
# append global ops
if global_ops:
......@@ -533,15 +479,7 @@ class DistributeTranspiler:
pserver_program.num_blocks - 1)
for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block,
grad_to_block_id)
# NOT USED: single block version:
#
# for _, op in enumerate(self.optimize_ops):
# for _, opt_op in enumerate(opt_op_on_pserver):
# if ufind.is_connected(op, opt_op):
# __append_optimize_op__(glb_op, optimize_block)
# break
grad_to_block_id, None)
# process distributed lookup_table
prefetch_block = None
......@@ -631,6 +569,8 @@ class DistributeTranspiler:
attrs=op.attrs)
return s_prog
# ====================== private transpiler functions =====================
# transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints):
......@@ -836,7 +776,6 @@ class DistributeTranspiler:
return table_opt_block
# ====================== private transpiler functions =====================
def _create_vars_from_blocklist(self,
program,
block_list,
......@@ -979,44 +918,57 @@ class DistributeTranspiler:
pass
return orig_shape
def _orig_varname(self, varname):
suff_idx = varname.find(".trainer_")
def _get_varname_parts(self, varname):
# returns origin, blockid, trainerid
orig_var_name = ""
if suff_idx >= 0:
orig_var_name = varname[:suff_idx]
trainer_part = ""
block_part = ""
trainer_idx = varname.find(".trainer_")
if trainer_idx >= 0:
trainer_part = varname[trainer_idx + 1:]
else:
trainer_idx = len(varname)
block_index = varname.find(".block")
if block_index >= 0:
block_part = varname[block_index + 1:trainer_idx]
else:
orig_var_name = varname
return orig_var_name
block_index = len(varname)
orig_var_name = varname[0:min(block_index, trainer_idx)]
return orig_var_name, block_part, trainer_part
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
def _orig_varname(self, varname):
orig, _, _ = self._get_varname_parts(varname)
return orig
def _append_pserver_grad_merge_ops(self, optimize_block,
grad_varname_for_block, endpoint,
grad_to_block_id, origin_program):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
for key in opt_op.input_names:
if key == "Grad":
grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if same_or_split_var(
self._orig_varname(g.name),
self._orig_varname(opt_op.input(key)[0])):
if self._orig_varname(g.name) == \
self._orig_varname(grad_varname_for_block):
grad_block = g
break
if not grad_block:
# do not append this op if current endpoint
# is not dealing with this grad block
return
orig_varname, block_name, trainer_name = self._get_varname_parts(
grad_block.name)
if block_name:
merged_var_name = '.'.join([orig_varname, block_name])
else:
merged_var_name = orig_varname
merged_var = \
pserver_block.vars[self._orig_varname(grad_block.name)]
grad_to_block_id.append(merged_var.name + ":" + str(
optimize_block.idx))
pserver_block.vars[merged_var_name]
grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx))
if self.sync_mode and self.trainer_num > 1:
vars2merge = []
for i in xrange(self.trainer_num):
per_trainer_name = "%s.trainer_%d" % \
(self._orig_varname(grad_block.name), i)
(merged_var_name, i)
vars2merge.append(pserver_block.vars[per_trainer_name])
optimize_block.append_op(
......@@ -1030,7 +982,17 @@ class DistributeTranspiler:
inputs={"X": merged_var},
outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainer_num)})
return merged_var
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
grad_to_block_id, origin_program, merged_var):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
for key in opt_op.input_names:
if key == "Grad":
new_inputs[key] = merged_var
elif key == "Param":
# param is already created on global program
......@@ -1089,17 +1051,31 @@ class DistributeTranspiler:
outputs=outputs,
attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
def _is_splited_grad_var(self, var, var_dict):
grad_block = None
for _, g in var_dict.iteritems():
if self._orig_varname(g.name) == self._orig_varname(var.name):
if g.name.find(".trainer_") == -1:
grad_block = g
break
return grad_block
def _append_pserver_non_opt_ops(self, optimize_block, opt_op, endpoint):
program = optimize_block.program
# Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, opt_op)
for varlist in inputs.itervalues():
for key, varlist in inputs.iteritems():
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if not program.global_block().vars.has_key(var.name):
# for ops like clipping and weight decay, get the splited var
# for inputs/outputs
grad_block = self._is_splited_grad_var(
var, program.global_block().vars)
if grad_block:
inputs[key] = grad_block
elif not program.global_block().vars.has_key(var.name):
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
......@@ -1108,12 +1084,15 @@ class DistributeTranspiler:
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op)
for varlist in outputs.itervalues():
for key, varlist in outputs.iteritems():
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
grad_block = self._is_splited_grad_var(
var, program.global_block().vars)
if grad_block:
outputs[key] = grad_block
elif not program.global_block().vars.has_key(var.name):
program.global_block().clone_variable(var)
optimize_block.append_op(
......@@ -1160,9 +1139,17 @@ class DistributeTranspiler:
ufind.union(op1, op2)
return ufind
def _is_opt_op(self, op):
# NOTE: It's a HACK implement.
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
def _is_opt_role_op(self, op):
# NOTE: depend on oprole to find out whether this op is for
# optimize
op_maker = core.op_proto_and_checker_maker
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
if op_maker.kOpRoleAttrName() in op.attrs and \
int(op.attrs[op_maker.kOpRoleAttrName()]) == int(optimize_role):
return True
return False
def _is_optimizer_op(self, op):
if "Param" in op.input_names and \
"LearningRate" in op.input_names:
return True
......@@ -1212,7 +1199,7 @@ class DistributeTranspiler:
# find learning rate variables by optimize op
lr_vars = set()
for op in self.optimize_ops:
if self._is_opt_op(op):
if self._is_optimizer_op(op):
lr_vars.add(op.input("LearningRate")[0])
find_ops = []
......@@ -1229,7 +1216,7 @@ class DistributeTranspiler:
# NOTE: we need to skip all optimize ops, since it is connected
# with forward/backward ops and lr ops, we only need the lr ops.
if op1 != op2 and self._is_op_connected(op1, op2) and \
not self._is_opt_op(op1) and not self._is_opt_op(op2):
not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2):
ufind.union(op1, op2)
# find all ops which is related with lr var
for op1 in block.ops:
......@@ -1250,13 +1237,21 @@ class DistributeTranspiler:
block = self.origin_program.global_block()
opt_ops = []
params_grads = []
origin_var_dict = self.origin_program.global_block().vars
for op in block.ops:
if self._is_opt_op(op):
if self._is_opt_role_op(op):
opt_ops.append(op)
params_grads.append((self.origin_program.global_block().var(
op.input("Param")[0]),
self.origin_program.global_block().var(
op.input("Grad")[0])))
# HACK(wuyi): if we find grad vars from input of optimize
# ops, we may get the output of clip op. Use syntax "@GRAD"
# and op_role_var to get the pair.
for input_name in op.input_arg_names:
if input_name.find("@GRAD") != -1 and \
op.attrs[RPC_OP_ROLE_ATTR_NAME]:
param_name = op.attrs[OP_ROLE_VAR_ATTR_NAME][0]
params_grads.append([
origin_var_dict[param_name],
origin_var_dict[input_name]
])
elif self._is_adam_connected_op(op):
opt_ops.append(op)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册