未验证 提交 9942565f 编写于 作者: 武毅 提交者: GitHub

Merge pull request #8386 from typhoonzero/fix_dist_transpiler_develop

Fix dist transpiler develop
...@@ -38,7 +38,7 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -38,7 +38,7 @@ class ConcatKernel : public framework::OpKernel<T> {
auto in_stride = framework::stride_numel(in->dims()); auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride, out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride); in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis]; output_offset += in_stride[axis];
} }
} }
...@@ -59,7 +59,7 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -59,7 +59,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto out_stride = framework::stride_numel(out->dims()); auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset, out_stride, in->data<T>() + input_offset,
in_stride); in_stride, out_stride[axis]);
input_offset += out_stride[axis]; input_offset += out_stride[axis];
} }
} }
......
...@@ -38,7 +38,7 @@ class SplitOpKernel : public framework::OpKernel<T> { ...@@ -38,7 +38,7 @@ class SplitOpKernel : public framework::OpKernel<T> {
auto out_stride = framework::stride_numel(out->dims()); auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset, out_stride, in->data<T>() + input_offset,
in_stride); in_stride, out_stride[axis]);
input_offset += out_stride[axis]; input_offset += out_stride[axis];
} }
} }
......
...@@ -54,7 +54,8 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, ...@@ -54,7 +54,8 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
int64_t axis, T* dst, int64_t axis, T* dst,
const framework::DDim& dst_stride_numel, const framework::DDim& dst_stride_numel,
const T* src, const T* src,
const framework::DDim& src_stride_numel) { const framework::DDim& src_stride_numel,
int64_t size) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis]; int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis]; int64_t dst_after = dst_stride_numel[axis];
...@@ -82,15 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, ...@@ -82,15 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place); auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst + i * dst_after, cpu_place, memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
src + i * src_after, sizeof(T) * src_after); src + i * src_after, sizeof(T) * size);
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place); auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx = auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx); reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(gpu_place, dst + i * dst_after, gpu_place, memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
src + i * src_after, sizeof(T) * src_after, src + i * src_after, sizeof(T) * size, cuda_ctx.stream());
cuda_ctx.stream());
#else #else
PADDLE_THROW("Paddle is not compiled with GPU"); PADDLE_THROW("Paddle is not compiled with GPU");
#endif #endif
......
...@@ -121,6 +121,7 @@ def split_dense_variable(var_list, ...@@ -121,6 +121,7 @@ def split_dense_variable(var_list,
block_size += dim1 - remains block_size += dim1 - remains
# update split_count after aligning # update split_count after aligning
split_count = int(math.ceil(var_numel / float(block_size))) split_count = int(math.ceil(var_numel / float(block_size)))
print("###split var ", var.name, var.shape, block_size, split_count)
for block_id in xrange(split_count): for block_id in xrange(split_count):
curr_block_size = min(block_size, var_numel - ( curr_block_size = min(block_size, var_numel - (
(block_id) * block_size)) (block_id) * block_size))
...@@ -191,7 +192,6 @@ class DistributeTranspiler: ...@@ -191,7 +192,6 @@ class DistributeTranspiler:
for b in param_blocks: for b in param_blocks:
varname, block_id, _ = b.split(":") varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)]) send_outputs.append(param_var_mapping[varname][int(block_id)])
# let send_op know which endpoint to send which var to, eplist has the same # let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs. # order as send_inputs.
eplist = split_method(send_inputs, pserver_endpoints) eplist = split_method(send_inputs, pserver_endpoints)
...@@ -230,21 +230,6 @@ class DistributeTranspiler: ...@@ -230,21 +230,6 @@ class DistributeTranspiler:
outputs={"Out": [orig_param]}, outputs={"Out": [orig_param]},
attrs={"axis": 0}) attrs={"axis": 0})
self.lr_param_mapping = self._create_lr_param_mapping()
def _create_lr_param_mapping(self):
lr_mapping = dict()
for _, opt_op in enumerate(self.optimize_ops):
if not opt_op.inputs or not opt_op.inputs.has_key("LearningRate") \
or not opt_op.inputs.has_key("Param"):
continue
lr = opt_op.inputs["LearningRate"].name
param = opt_op.inputs["Param"].name
if not lr_mapping.has_key(lr):
lr_mapping.update({lr: list()})
lr_mapping[lr].append(param)
return lr_mapping
def _create_vars_from_blocklist(self, program, block_list): def _create_vars_from_blocklist(self, program, block_list):
# Create respective variables using the block_list # Create respective variables using the block_list
block_map = dict() block_map = dict()
...@@ -271,6 +256,7 @@ class DistributeTranspiler: ...@@ -271,6 +256,7 @@ class DistributeTranspiler:
splited_shape = [rows] splited_shape = [rows]
if len(orig_shape) >= 2: if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:]) splited_shape.extend(orig_shape[1:])
print("###splited: ", size, rows, splited_shape)
var = program.global_block().create_var( var = program.global_block().create_var(
name="%s.block%d" % (varname, i), name="%s.block%d" % (varname, i),
psersistable=False, psersistable=False,
...@@ -278,6 +264,7 @@ class DistributeTranspiler: ...@@ -278,6 +264,7 @@ class DistributeTranspiler:
type=orig_var.type, type=orig_var.type,
shape=splited_shape) # flattend splited var shape=splited_shape) # flattend splited var
var_mapping[varname].append(var) var_mapping[varname].append(var)
print("###created split var ", var)
return var_mapping return var_mapping
def _clone_var(self, block, var): def _clone_var(self, block, var):
...@@ -369,18 +356,9 @@ class DistributeTranspiler: ...@@ -369,18 +356,9 @@ class DistributeTranspiler:
pass pass
return orig_shape return orig_shape
def _fetch_var_names(self, param_dict):
res = []
if not param_dict:
return res
for _, values in param_dict.iteritems():
if not isinstance(values, list):
values = [values]
res += [v.name for v in values]
return res
def _append_pserver_ops(self, optimize_block, opt_op, endpoint): def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
program = optimize_block.program program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict() new_inputs = dict()
# update param/grad shape first, then other inputs like # update param/grad shape first, then other inputs like
# moment can use the updated shape # moment can use the updated shape
...@@ -395,11 +373,11 @@ class DistributeTranspiler: ...@@ -395,11 +373,11 @@ class DistributeTranspiler:
# do not append this op if current endpoint # do not append this op if current endpoint
# is not dealing with this grad block # is not dealing with this grad block
return return
merged_var = program.global_block().vars[grad_block.name] merged_var = pserver_block.vars[grad_block.name]
# append merging ops if trainers > 1 # append merging ops if trainers > 1
if self.trainers > 1: if self.trainers > 1:
vars2merge = self._create_var_for_trainers( vars2merge = self._create_var_for_trainers(
program.global_block(), grad_block, self.trainers) pserver_block, grad_block, self.trainers)
optimize_block.append_op( optimize_block.append_op(
type="sum", type="sum",
inputs={"X": vars2merge}, inputs={"X": vars2merge},
...@@ -419,29 +397,27 @@ class DistributeTranspiler: ...@@ -419,29 +397,27 @@ class DistributeTranspiler:
break break
if not param_block: if not param_block:
return return
tmpvar = program.global_block().create_var( tmpvar = pserver_block.create_var(
name=param_block.name, name=param_block.name,
persistable=True, persistable=True,
dtype=param_block.dtype, dtype=param_block.dtype,
shape=param_block.shape) shape=param_block.shape)
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
elif key == "LearningRate": elif key == "LearningRate":
# leraning rate variable has already be created by non-optimize op, # leraning rate variable has already be created by non-optimize op,
# don't create it once again. # don't create it once again.
new_inputs[key] = program.global_block().vars[opt_op.input(key)[ new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
0]]
for key in opt_op.input_names: for key in opt_op.input_names:
new_shape = None new_shape = None
if key in ["Param", "Grad", "LearningRate"]: if key in ["Param", "Grad", "LearningRate"]:
continue continue
var = program.global_block().vars[opt_op.input(key)[0]] var = self.program.global_block().vars[opt_op.input(key)[0]]
# update accumulator variable shape # update accumulator variable shape
param_shape = new_inputs["Param"].shape param_shape = new_inputs["Param"].shape
new_shape = self._get_optimizer_input_shape(opt_op.type, key, new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape) var.shape, param_shape)
tmpvar = program.global_block().create_var( tmpvar = pserver_block.create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
...@@ -449,11 +425,14 @@ class DistributeTranspiler: ...@@ -449,11 +425,14 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
# change output's ParamOut variable # change output's ParamOut variable
opt_op.outputs["ParamOut"] = new_inputs["Param"] outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)
outputs["ParamOut"] = new_inputs["Param"]
optimize_block.append_op( optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
inputs=new_inputs, inputs=new_inputs,
outputs=opt_op.outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op): def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
...@@ -497,11 +476,12 @@ class DistributeTranspiler: ...@@ -497,11 +476,12 @@ class DistributeTranspiler:
# If one op's input is another op's output or # If one op's input is another op's output or
# one op's output is another op's input, we say # one op's output is another op's input, we say
# the two operator is connected. # the two operator is connected.
op1_input_names = self._fetch_var_names(op1.inputs) op1_input_names = op1.desc.input_arg_names()
op1_output_names = self._fetch_var_names(op1.outputs) op1_output_names = op1.desc.output_arg_names()
op2_input_names = op2.desc.input_arg_names()
op2_output_names = op2.desc.output_arg_names()
op2_input_names = self._fetch_var_names(op2.inputs)
op2_output_names = self._fetch_var_names(op2.outputs)
if set(op1_output_names) & set(op2_input_names) or \ if set(op1_output_names) & set(op2_input_names) or \
set(op1_input_names) & set(op2_output_names): set(op1_input_names) & set(op2_output_names):
return True return True
...@@ -521,8 +501,8 @@ class DistributeTranspiler: ...@@ -521,8 +501,8 @@ class DistributeTranspiler:
def _is_opt_op(self, op): def _is_opt_op(self, op):
# NOTE: It's a HACK implement. # NOTE: It's a HACK implement.
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc... # optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
if op.inputs and op.inputs.has_key("Param") \ if "Param" in op.input_names and \
and op.inputs.has_key("LearningRate"): "LearningRate" in op.input_names:
return True return True
return False return False
...@@ -530,12 +510,12 @@ class DistributeTranspiler: ...@@ -530,12 +510,12 @@ class DistributeTranspiler:
param_names = [ param_names = [
p.name for p in self.param_grad_ep_mapping[endpoint]["params"] p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
] ]
if op.inputs["Param"].name in param_names: if op.input("Param") in param_names:
return True return True
else: else:
for n in param_names: for n in param_names:
param = op.inputs["Param"].name param = op.input("Param")[0]
if same_or_split_var(n, param) and n != op.inputs["Param"].name: if same_or_split_var(n, param) and n != param:
return True return True
return False return False
return False return False
...@@ -551,6 +531,8 @@ class DistributeTranspiler: ...@@ -551,6 +531,8 @@ class DistributeTranspiler:
""" """
# step5 # step5
pserver_program = Program() pserver_program = Program()
print("param mapping on pserver: #### ",
self.param_grad_ep_mapping[endpoint]["params"])
for v in self.param_grad_ep_mapping[endpoint]["params"]: for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v) self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]: for v in self.param_grad_ep_mapping[endpoint]["grads"]:
...@@ -564,7 +546,6 @@ class DistributeTranspiler: ...@@ -564,7 +546,6 @@ class DistributeTranspiler:
persistable=True, persistable=True,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
# step6 # step6
optimize_block = pserver_program.create_block(0) optimize_block = pserver_program.create_block(0)
# step 6.1 # step 6.1
......
...@@ -400,9 +400,6 @@ class Operator(object): ...@@ -400,9 +400,6 @@ class Operator(object):
""" """
self.block = block self.block = block
self.desc = desc self.desc = desc
# for clone a new operator
self.inputs = inputs
self.outputs = outputs
self.attrs = attrs self.attrs = attrs
if len(self.desc.type()) != 0: if len(self.desc.type()) != 0:
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册