提交 e3334f3e 编写于 作者: M mapingshuo

add zero

上级 43240a1b
...@@ -24,6 +24,14 @@ enum Mode { ...@@ -24,6 +24,14 @@ enum Mode {
message RecomputeConfig { repeated string checkpoints = 1; } message RecomputeConfig { repeated string checkpoints = 1; }
message ZeROConfig {
optional bool amp = 1 [ default = true ];
optional int32 nrings = 2 [ default = 3 ];
optional float fuse_broadcast_MB_bytes = 3 [ default = 64.0 ];
repeated string checkpoints = 4;
optional bool allreduce = 5 [ default = false ];
}
message AMPConfig { message AMPConfig {
optional float init_loss_scaling = 1 [ default = 32768.0 ]; optional float init_loss_scaling = 1 [ default = 32768.0 ];
optional int32 incr_every_n_steps = 2 [ default = 1000 ]; optional int32 incr_every_n_steps = 2 [ default = 1000 ];
...@@ -127,6 +135,7 @@ message DistributedStrategy { ...@@ -127,6 +135,7 @@ message DistributedStrategy {
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ]; optional int32 conv_workspace_size_limit = 22 [ default = 4000 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ]; optional bool adaptive_localsgd = 24 [ default = false ];
optional bool zero = 25 [ default = false ];
optional RecomputeConfig recompute_configs = 101; optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102; optional AMPConfig amp_configs = 102;
...@@ -138,6 +147,7 @@ message DistributedStrategy { ...@@ -138,6 +147,7 @@ message DistributedStrategy {
optional LarsConfig lars_configs = 108; optional LarsConfig lars_configs = 108;
optional LambConfig lamb_configs = 109; optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110; optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional ZeROConfig zero_configs = 111;
optional BuildStrategy build_strategy = 201; optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202; optional ExecutionStrategy execution_strategy = 202;
} }
......
...@@ -55,8 +55,10 @@ class CSyncCommStreamOp : public framework::OperatorBase { ...@@ -55,8 +55,10 @@ class CSyncCommStreamOp : public framework::OperatorBase {
class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker { class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "(Tensor) Dependency of the variable need to sync"); AddInput("X", "(Tensor) Dependency of the variable need to sync")
AddOutput("Out", "(Tensor) Dependency of the variable need to sync"); .AsDuplicable();
AddOutput("Out", "(Tensor) Dependency of the variable need to sync")
.AsDuplicable();
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0); AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
CSyncCommStream Operator CSyncCommStream Operator
......
...@@ -611,6 +611,39 @@ class DistributedStrategy(object): ...@@ -611,6 +611,39 @@ class DistributedStrategy(object):
"checkpoint_configs") "checkpoint_configs")
assign_configs_value(self.strategy.recompute_configs, configs) assign_configs_value(self.strategy.recompute_configs, configs)
@property
def zero(self):
"""
Indicating whether we are using Zero Redundancy Optimizer for memory
optimization
Default value: False
Examples:
.. code-block:: python
import paddle.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.zero = True
"""
return self.strategy.zero
@zero.setter
def zero(self, flag):
if isinstance(flag, bool):
self.strategy.zero = flag
else:
print("WARNING: zero should have value of bool type")
@property
def zero_configs(self):
"""
Set zero configurations.
"""
return get_msg_dict(self.strategy.zero_configs)
@zero_configs.setter
def zero_configs(self, configs):
check_configs_key(self.strategy.zero_configs, configs, "zero_configs")
assign_configs_value(self.strategy.zero_configs, configs)
@property @property
def pipeline(self): def pipeline(self):
""" """
......
...@@ -1086,6 +1086,9 @@ class Fleet(object): ...@@ -1086,6 +1086,9 @@ class Fleet(object):
context["program_optimize_ops"] = optimize_ops context["program_optimize_ops"] = optimize_ops
context["program_params_grads"] = params_grads context["program_params_grads"] = params_grads
if self.user_defined_strategy.zero:
graph_optimizer = None
if graph_optimizer: if graph_optimizer:
optimize_ops, params_grads = graph_optimizer.minimize( optimize_ops, params_grads = graph_optimizer.minimize(
loss, loss,
......
...@@ -23,3 +23,4 @@ from .lars_optimizer import LarsOptimizer ...@@ -23,3 +23,4 @@ from .lars_optimizer import LarsOptimizer
from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer
from .dgc_optimizer import DGCOptimizer from .dgc_optimizer import DGCOptimizer
from .lamb_optimizer import LambOptimizer from .lamb_optimizer import LambOptimizer
from .zero_optimizer import ZeroOptimizer
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper
from .common import is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op
from .meta_optimizer_base import MetaOptimizerBase
from paddle.fluid import unique_name, core
from paddle.fluid.contrib.mixed_precision.decorator import OptimizerWithMixedPrecision
import paddle.fluid as fluid
import math
import re
__all__ = ["ZeroOptimizer"]
def _pretty_op_desc_(op_desc, prefix):
out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % \
(prefix + "_op", str(op_desc.type()), prefix + "_input", " ".join(op_desc.input_arg_names()),
prefix + "_output", " ".join(op_desc.output_arg_names()))
return out_s
class SubProgram(object):
def __init__(self, block):
self._block = block
self._allreduce_vars = []
# sub program start idx
self._start_idx = -1
# sub program end idx
self._end_idx = -1
# param name to broadcast name
self._param2broadcast = {}
self._broadcast_vars = []
# cast op pairs, fp16 name (str) -> fp32 name (str)
self._cast_ops = {}
# fill constant vars
self._fill_constant_vars = []
# parameter mems
self._param_mem = 0.0
class ProgramDeps(object):
def __init__(self, block, start_vars, end_vars):
self._block = block
# vars where to start to build the deps
self._start_vars = start_vars
# vars where to stop to build the deps
self._end_vars = end_vars
# var name -> op idxs which depends on this var
self._var_deps = {}
# sub block deps which is a subset of this topo
self._sub_block_deps = {}
# var name -> op idxs which generate var
self._var_to_generate_op = {}
self._should_removed_var = set()
self._father_block_deps = None
self._build_deps()
def get_sub_block_deps(self, idx):
if idx in self._sub_block_deps:
return self._sub_block_deps[idx]
else:
return None
def get_var_deps(self, var_name):
if var_name in self._var_deps:
return self._var_deps[var_name]
else:
return None
def _build_deps(self, ):
for var_name in self._start_vars:
self._var_deps[var_name] = [-1]
self._var_to_generate_op[var_name] = [-1]
for idx, op in enumerate(self._block.ops):
if op.type in [
"c_allreduce_sum", "c_sync_comm_stream",
"c_calc_comm_stream"
]:
continue
input_vars = op.desc.input_arg_names()
output_vars = op.desc.output_arg_names()
deps_reduce = False
for input_name in input_vars:
if input_name in self._var_deps:
deps_reduce = True
if deps_reduce:
for input_name in input_vars:
if input_name in self._var_deps:
self._var_deps[input_name].append(idx)
for output_name in output_vars:
self._var_deps[output_name] = []
if output_name not in self._var_to_generate_op:
self._var_to_generate_op[output_name] = [idx]
else:
self._var_to_generate_op[output_name].append(idx)
if op.type == "conditional_block":
# subblock
assert (op.desc.has_attr("sub_block"))
subblock_idx = op.desc.attr("sub_block").id
subblock_deps = ProgramDeps(
self._block.program.block(subblock_idx),
op.desc.input_arg_names(), op.desc.output_arg_names())
self._sub_block_deps[subblock_idx] = subblock_deps
subblock_deps._father_block_deps = self
def crop_input_var_from_op(self, op_idx, var_name):
if var_name in self._var_deps:
# update var -> dep_var_op
if self._var_deps[var_name] != []:
assert (op_idx in self._var_deps[var_name])
self._var_deps[var_name].remove(op_idx)
# update _should_removed_var
if var_name in self._start_vars:
self._should_removed_var.discard(var_name)
elif self._var_deps[var_name] == []: # no more deps of this var
self._should_removed_var.add(var_name)
elif self._var_to_generate_op[var_name][-1] >= self._var_deps[
var_name][-1]:
# there are circle in the graph
self._should_removed_var.add(var_name)
else: # input_name should not be deleted
self._should_removed_var.discard(var_name)
def crop_output_var_from_op(self, op_idx, var_name):
if var_name in self._var_to_generate_op:
assert (op_idx in self._var_to_generate_op[var_name])
self._var_to_generate_op[var_name].remove(op_idx)
if self._block.has_var(var_name) and self._var_to_generate_op[
var_name] == []:
print("main_block remove var {}".format(var_name))
self._block._remove_var(var_name)
def remove_op(self, op_idx):
# update deps
op = self._block.ops[op_idx]
print("main_block remove op {}".format(op.type))
for input_name in op.desc.input_arg_names():
self.crop_input_var_from_op(op_idx, input_name)
for output_name in op.desc.output_arg_names():
self.crop_output_var_from_op(op_idx, output_name)
self._block._remove_op(op_idx)
def should_remove_op(self, op_idx):
op = self._block.ops[op_idx]
for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var:
return False
return True
class ZeroOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(ZeroOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
self._main_program = None
self._startup_program = None
# we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = []
# params and fp16 params is for broadcast
self._params = set([])
self._fp16_params = set([])
# fp16 to fp32
self._fp16_to_params = {}
self._broadcast_vars = set([])
# _param(str) -> device_id(int)
self._param2device = {}
# varname(str) -> param(Variable)
# reduced grads to param name
self._reduced_grads_to_param = {}
# self._nrings(int) is for nccl communicate
self._nrings = 3
# self._sub_progs
self._sub_progs = []
self._fuse_broadcast_MB_bytes = 64
self._dtype_to_size = {
core.VarDesc.VarType.FP16: 2,
core.VarDesc.VarType.FP32: 4,
core.VarDesc.VarType.FP64: 8,
core.VarDesc.VarType.INT16: 2,
core.VarDesc.VarType.INT32: 4,
core.VarDesc.VarType.INT64: 8,
core.VarDesc.VarType.BOOL: 1,
core.VarDesc.VarType.UINT8: 1,
}
def _get_var_size(self, param):
"""
input:
- param: var
return:
var size in Bytes
"""
assert -1 not in param.shape
return reduce(
lambda x, y: x * y,
param.shape) * self._dtype_to_size[param.dtype] / 1024.0 / 1024.0
def _can_apply(self):
return self.user_defined_strategy.zero
def _disable_strategy(self, dist_strategy):
dist_strategy.zero = False
def _is_fp16_cast_op(self, block, op):
if op.type != "cast":
return False
if is_optimizer_op(op):
return False
assert (len(op.desc.input_arg_names()) == 1)
assert (len(op.desc.output_arg_names()) == 1)
input_name, output_name = op.desc.input_arg_names()[
0], op.desc.output_arg_names()[0]
if input_name not in self._params:
return False
input_var = block.var(input_name)
output_var = block.var(output_name)
if input_var.dtype != core.VarDesc.VarType.FP32 or \
output_var.dtype != core.VarDesc.VarType.FP16:
return False
return True
def _split_params(self, params):
param2device = {}
total_param_mem = 0.0
param2mem = []
for param in params:
mem = self._get_var_size(param)
total_param_mem += mem
param2mem.append((param.name, mem))
# print(param.name, mem)
# print("total_param_mem: ", total_param_mem)
device_num = self.role_maker.worker_num()
# print("device_num: ", device_num)
device2params = {x: [] for x in range(device_num)}
device_idx = 0
mem_accu = 0.0
for param_name, mem in param2mem:
if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / device_num:
device_idx += 1
device2params[device_idx].append(param_name)
param2device[param_name] = device_idx
mem_accu += mem
# for debug
print(device2params)
return param2device
def _is_opti_var(self, var_name):
if var_name in self._params:
return True
for suffix in [
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
"_beta2_pow_acc_0"
]:
base_name = re.sub(suffix, '', var_name)
if base_name in self._params:
return True
return False
def _var_device_id(self, var_name):
if not self._is_opti_var(var_name):
return -1
if var_name in self._param2device:
return self._param2device[var_name]
for suffix in [
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
"_beta2_pow_acc_0"
]:
base_name = re.sub(suffix, '', var_name)
if base_name in self._param2device:
return self._param2device[base_name]
return -1
def _insert_scale_loss_grad_ops(self, block, scale=1.0):
'''
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
'''
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
block._insert_op(
idx + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={'scale': scale,
OP_ROLE_KEY: OpRole.Backward})
def _split_program(self, block):
for op_idx, op in reversed(list(enumerate(block.ops))):
if int(op.attr('op_role')) != int(OpRole.Optimize):
last_backward_op_idx = op_idx + 1
break
sub_prog = SubProgram(block)
sub_prog._end_idx = last_backward_op_idx
for op_idx in reversed(range(last_backward_op_idx)):
op = block.ops[op_idx]
assert (int(op.attr('op_role')) != int(OpRole.Optimize))
if sub_prog._param_mem >= self._fuse_broadcast_MB_bytes:
sub_prog._start_idx = op_idx + 1
self._sub_progs.insert(0, sub_prog)
sub_prog = SubProgram(block)
sub_prog._end_idx = op_idx + 1
# find broadcast vars
for input_name in op.desc.input_arg_names():
if input_name not in self._broadcast_vars:
continue
root_device = self._param2device[input_name]
if input_name in sub_prog._param2broadcast:
# skip broadcast because it reuse the old broadcast var
broadcast_name = sub_prog._param2broadcast[input_name]
if input_name != broadcast_name:
op._rename_input(input_name, broadcast_name)
continue
if root_device == self.role_maker.worker_index():
broadcast_var_name = input_name
else:
broadcast_var_name = unique_name.generate(input_name +
"@BroadCast")
sub_prog._fill_constant_vars.append(broadcast_var_name)
sub_prog._param2broadcast[input_name] = broadcast_var_name
sub_prog._broadcast_vars.append(
(broadcast_var_name, self._param2device[input_name]))
sub_prog._param_mem += self._get_var_size(
self._main_program.global_block().var(input_name))
# find reduce vars
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i + 1]
sub_prog._allreduce_vars.append(reduced_grad)
assert (
reduced_grad not in self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
# find cast op
if self._is_fp16_cast_op(block, op):
fp32_param = op.desc.input_arg_names()[0]
fp16_param = op.desc.output_arg_names()[0]
if self._param2device[
fp32_param] == self.role_maker.worker_index():
sub_prog._cast_ops[fp16_param] = fp32_param
if sub_prog._param_mem > 0:
sub_prog._start_idx = 0
self._sub_progs.insert(0, sub_prog)
return
def _is_gradient_clip_sum_op(self, op):
return op.type == "sum" and op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip_@CLIP")
def _is_amp_sum_op(self, op):
return op.type == "sum" and op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/mixed_precision")
def _is_amp_subblock(self, op):
return op.type == "conditional_block" and op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/mixed_precision")
def _prune_main_program(self, block):
"""
calculate deps from allredce op to optimize op,
remove ops and vars not needed in this worker
"""
# build prog deps
reduced_grads = []
var_to_reduce_var = {}
for idx, op in enumerate(block.ops):
input_names = op.desc.input_arg_names()
output_names = op.desc.output_arg_names()
if op.type == "c_allreduce_sum":
assert (len(output_names) == 1)
output_name = output_names[0]
reduced_grads.append(output_name)
var_to_reduce_var[output_name] = output_name
else:
non_persistable_input = [
x for x in input_names if not block.var(x).persistable
]
if len(non_persistable_input) == 1 and len(
output_names) == 1 and non_persistable_input[
0] in var_to_reduce_var:
var_to_reduce_var[output_names[0]] = var_to_reduce_var[
non_persistable_input[0]]
params = []
for var_name, _ in block.vars.items():
if self._is_opti_var(var_name) and \
self._var_device_id(var_name) != self.role_maker.worker_index():
params.append(var_name)
program_deps = ProgramDeps(block, reduced_grads, params)
# Init
for var_name in program_deps._end_vars:
program_deps._should_removed_var.add(var_name)
# Prune
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [
"c_allreduce_sum", "c_sync_comm_stream",
"c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init"
]:
pass
elif self._is_gradient_clip_sum_op(op) or self._is_amp_sum_op(op):
reversed_input_vars = []
for input_name in op.desc.input_arg_names():
assert (input_name in var_to_reduce_var)
reduce_var = var_to_reduce_var[input_name]
param_name = self._reduced_grads_to_param[reduce_var]
if self._param2device[
param_name] != self.role_maker.worker_index():
program_deps.crop_input_var_from_op(idx, input_name)
else:
reversed_input_vars.append(input_name)
op.desc.set_input("X", reversed_input_vars)
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]
block._insert_op(
idx + 1,
type='c_sync_comm_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
block._insert_op(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
block._insert_op(
idx + 1,
type='c_sync_calc_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={OP_ROLE_KEY: OpRole.Optimize})
elif op.type == "conditional_block":
assert (op.desc.has_attr("sub_block"))
subblock_idx = op.desc.attr("sub_block").id
subblock_deps = program_deps.get_sub_block_deps(subblock_idx)
# only prune amp subblock
if subblock_deps is None or not self._is_amp_subblock(op):
continue
# init
reversed_output_vars = []
for output_name in op.desc.output("Out"):
if output_name in program_deps._should_removed_var:
subblock_deps._should_removed_var.add(output_name)
program_deps.crop_output_var_from_op(idx, output_name)
else:
reversed_output_vars.append(output_name)
# prune
for sub_op_idx, _ in reversed(
list(enumerate(subblock_deps._block.ops))):
if subblock_deps.should_remove_op(sub_op_idx):
subblock_deps.remove_op(sub_op_idx)
reversed_input_vars = []
for input_name in op.desc.input('Input'):
if input_name not in subblock_deps._should_removed_var:
reversed_input_vars.append(input_name)
else:
program_deps.crop_input_var_from_op(idx, input_name)
op.desc.set_input('Input', reversed_input_vars)
op.desc.set_output('Out', reversed_output_vars)
else:
if program_deps.should_remove_op(idx):
program_deps.remove_op(idx)
block._sync_with_cpp()
return
def _remove_cast_op(self, block, sub_prog, offset):
inserted_op_num = 0
for op_idx in reversed(
range(offset + sub_prog._start_idx, offset +
sub_prog._end_idx)):
op = block.ops[op_idx]
if self._is_fp16_cast_op(block, op):
block._remove_op(op_idx)
inserted_op_num -= 1
block._sync_with_cpp()
return inserted_op_num
def _insert_broadcast_ops(self, block, insert_idx, broadcast2root):
"""
_add_broadcast_ops
"""
ring_id = -1
# TODO(mapingshuo): correct OP_ROLE_KEY
for broadcast_name, root_device in broadcast2root:
ring_id = (ring_id + 1) % self._nrings
block._insert_op(
insert_idx,
type='c_broadcast',
inputs={'X': broadcast_name},
outputs={'Out': broadcast_name},
attrs={
'ring_id': ring_id,
'root': root_device,
OP_ROLE_KEY: OpRole.Forward
})
return
def _insert_allreduce_ops(self, block, insert_idx, allreduce_vars):
"""
_add_allreduce_ops
"""
ring_id = -1
for var in allreduce_vars:
ring_id = (ring_id + 1) % self._nrings
block._insert_op(
insert_idx,
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
return
def _insert_cast_ops(self, block, insert_idx, cast_ops):
"""
_add_cast_ops
"""
for fp16_name, fp32_name in cast_ops.items():
block._insert_op(
insert_idx,
type="cast",
inputs={"X": fp32_name},
outputs={"Out": fp16_name},
attrs={
"in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": core.VarDesc.VarType.FP16
})
return
def _insert_fill_constant_ops(self, block, insert_idx, fill_constant_vars):
"""
_add_fill_constant_ops
"""
for broadcast_name in fill_constant_vars:
broadcast_var = block.var(broadcast_name)
block._insert_op(
insert_idx,
type="fill_constant",
outputs={"Out": broadcast_var.name},
attrs={
"shape": broadcast_var.shape,
"dtype": broadcast_var.dtype,
"value": 0.0,
})
return
def _insert_sync_comm_ops(self, block, insert_idx, comm_dep_vars):
"""
_insert_sync_comm_ops
"""
# TODO(mapingshuo) fix OP_ROLE_KEY
for i in range(self._nrings):
block._insert_op(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': i,
OP_ROLE_KEY: OpRole.Forward})
return
def _insert_sync_calc_op(self, block, insert_idx, calc_dep_vars):
"""
_insert_sync_calc_op
"""
# TODO(mapingshuo) fix OP_ROLE_KEY
block._insert_op(
insert_idx,
type='c_sync_calc_stream',
inputs={'X': calc_dep_vars},
outputs={'Out': calc_dep_vars},
attrs={OP_ROLE_KEY: OpRole.Forward})
return
def _add_broadcast_allreduce_v2(self, block):
"""
_add_broadcast_allreduce_v2
"""
ring_id = -1
if len(self._sub_progs) < 1:
return
if self._sub_progs[-1]._allreduce_vars:
self._insert_sync_comm_ops(block, self._sub_progs[-1]._end_idx,
self._sub_progs[-1]._allreduce_vars)
self._insert_allreduce_ops(block, self._sub_progs[-1]._end_idx,
self._sub_progs[-1]._allreduce_vars)
for idx, subprog in reversed(list(enumerate(self._sub_progs))):
print("subprog_{}: ({}-{})".format(idx, subprog._start_idx,
subprog._end_idx))
allreduce_vars = self._sub_progs[
idx - 1]._allreduce_vars if idx > 0 else []
broadcast_vars = self._sub_progs[idx +
1]._broadcast_vars if idx < len(
self._sub_progs) - 1 else []
fill_constant_vars = self._sub_progs[
idx + 2]._fill_constant_vars if idx < len(
self._sub_progs) - 2 else []
cast_ops = self._sub_progs[idx + 2]._cast_ops if idx < len(
self._sub_progs) - 2 else {}
# for x in fill_constant_vars:
# print("fill_constant_vars: ", x)
# step1: modify calculate ops
# for op_idx in reversed(range(subprog._start_idx, subprog._end_idx)):
# op = block.ops[op_idx]
# print(_pretty_op_desc_(op.desc, "subprog_op"))
for op_idx in reversed(range(subprog._start_idx, subprog._end_idx)):
op = block.ops[op_idx]
for input_name in op.desc.input_arg_names():
if input_name in subprog._param2broadcast and \
input_name != subprog._param2broadcast[input_name]:
op._rename_input(input_name,
subprog._param2broadcast[input_name])
for param_name, broadcast_name in subprog._param2broadcast.items():
if param_name != broadcast_name:
block.create_var(
name=broadcast_name,
shape=self._main_program.global_block().var(
param_name).shape,
dtype=self._main_program.global_block().var(param_name)
.dtype,
persistable=False)
# step2: remove cast ops
block._sync_with_cpp()
subprog._end_idx += self._remove_cast_op(block, subprog, 0)
# step3: add Sync ops
comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars]
if len(comm_dep_vars) > 0:
self._insert_sync_comm_ops(
block,
subprog._end_idx,
comm_dep_vars, )
calc_dep_vars = fill_constant_vars + [
k for k, v in cast_ops.items()
]
if len(calc_dep_vars) > 0:
self._insert_sync_calc_op(block, subprog._end_idx,
[calc_dep_vars[-1]])
# step4: insert `fill_constant` ops
self._insert_fill_constant_ops(block, subprog._end_idx,
fill_constant_vars)
# step5: add `cast` ops
self._insert_cast_ops(block, subprog._end_idx, cast_ops)
# step6: add broadcast ops
self._insert_broadcast_ops(block, subprog._start_idx,
broadcast_vars)
# step7: add all_reduce ops
self._insert_allreduce_ops(block, subprog._start_idx,
allreduce_vars)
block._sync_with_cpp()
if self._sub_progs[0]._broadcast_vars:
self._insert_sync_comm_ops(
block, self._sub_progs[0]._start_idx,
[x[0] for x in self._sub_progs[0]._broadcast_vars])
self._insert_broadcast_ops(block, self._sub_progs[0]._start_idx,
self._sub_progs[0]._broadcast_vars)
fill_constant_vars = reduce(
lambda x, y: x._fill_constant_vars + y._fill_constant_vars,
self._sub_progs[:2])
# Join
cast_ops = {}
for x in self._sub_progs[:2]:
for k, v in x._cast_ops.items():
cast_ops[k] = v
calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()]
if fill_constant_vars or cast_ops:
self._insert_sync_calc_op(block, self._sub_progs[0]._start_idx,
[calc_deps_vars[-1]])
if fill_constant_vars:
self._insert_fill_constant_ops(block, self._sub_progs[0]._start_idx,
fill_constant_vars)
if cast_ops:
self._insert_cast_ops(block, self._sub_progs[0]._start_idx,
cast_ops)
return
def _prune_startup_program(self, block):
for idx, op in reversed(list(enumerate(block.ops))):
for output_name in op.desc.output_arg_names():
var_device_id = self._var_device_id(output_name)
if var_device_id == -1 or var_device_id == self.role_maker.worker_index(
):
continue
print("%d: startup_block remove op %s" %
(self.role_maker.worker_index(), op.type))
block._remove_op(idx)
break
for var_name, _ in block.vars.items():
var_device_id = self._var_device_id(var_name)
if var_device_id == -1 or var_device_id == self.role_maker.worker_index(
):
continue
print("%d: startup_block remove var %s" %
(self.role_maker.worker_index(), var_name))
block._remove_var(var_name)
block._sync_with_cpp()
def _find_broadcast_params(self, params, param2device):
broadcast_vars = set([])
fp16_params = set([])
fp16_to_fp32 = {}
main_block = self._main_program.global_block()
param_usage = {x: 0 for x in params}
for op in main_block.ops:
if is_optimizer_op(op):
continue
for input_name in op.desc.input_arg_names():
if input_name in params:
param_usage[input_name] += 1
for op in main_block.ops:
if not self._is_fp16_cast_op(main_block, op):
continue
input_name = op.input_arg_names[0]
output_name = op.output_arg_names[0]
broadcast_vars.add(output_name)
fp16_params.add(output_name)
fp16_to_fp32[output_name] = input_name
param_usage[input_name] -= 1
param2device[output_name] = param2device[input_name]
for param, usage in param_usage.items():
if usage > 0:
broadcast_vars.add(param)
return fp16_params, broadcast_vars, fp16_to_fp32
def _set_up(self, params_grads):
# step 1: initialize nccl
# TODO(mapingshuo) fix get_trainer_endpoints
print("work idx: ", self.role_maker.worker_index())
endpoints = self.role_maker.get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker.worker_index()]
collective_helper = CollectiveHelper(self.role_maker, self._nrings)
for ring_id in range(self._nrings):
collective_helper._init_communicator(
self._startup_program, current_endpoint, endpoints,
self.role_maker.worker_index(), ring_id, '6174')
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
# step 2: split params
self._params = set([x[0].name for x in params_grads])
self._param2device = self._split_params([x[0] for x in params_grads])
# step 3: get broadcast vars
self._fp16_params, self._broadcast_vars, self._fp16_to_params = self._find_broadcast_params(
self._params, self._param2device)
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
if self.user_defined_strategy.zero_configs["allreduce"]:
return self.minimize_impl_allreduce(loss, startup_program,
parameter_list, no_grad_set)
ckpts = list(self.user_defined_strategy.zero_configs["checkpoints"])
optimizer = self.inner_opt
if len(ckpts) > 0:
print("add recompute")
print(ckpts)
optimizer = fluid.optimizer.RecomputeOptimizer(optimizer)
optimizer._set_checkpoints(ckpts)
if self.user_defined_strategy.zero_configs["amp"]:
optimizer = fluid.contrib.mixed_precision.decorate(
optimizer, use_dynamic_loss_scaling=True)
self._nrings = self.user_defined_strategy.zero_configs["nrings"]
self._fuse_broadcast_MB_bytes = self.user_defined_strategy.zero_configs[
"fuse_broadcast_MB_bytes"]
print("doing zero optimize...")
optimize_ops, params_grads = optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
if startup_program is None:
startup_program = default_startup_program()
main_block = loss.block
startup_block = startup_program.global_block()
self._main_program = main_block.program
self._startup_program = startup_program
# step1: set_up
self._set_up(params_grads)
# step2: split_program
self._split_program(main_block)
# step3: add broadcast and reduce ops
print("insert broadcast and allreduce")
self._add_broadcast_allreduce_v2(main_block)
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
# step4: insert reduce_sum for grad
self._insert_scale_loss_grad_ops(
main_block, scale=1.0 / self.role_maker.worker_num())
main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block
print("main_block remove ops and vars")
self._prune_main_program(main_block)
print("startup_block remove ops and vars")
self._prune_startup_program(startup_block)
# check op dependecy for broadcast
self._check_broadcast(main_block)
return optimize_ops, params_grads
def _check_broadcast(self, block):
"""
if a var is broadcasted, it should have a sync_comm before
this var is used, if not, raise error.
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
"""
broadcast_vars = {}
for idx, op in enumerate(block.ops):
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if var_name in broadcast_vars:
print("error: var_name areadly exist: ", var_name)
print("the old pos is ",
broadcast_vars[var_name]["broadcast_pos"])
print("the new pos is ", idx)
assert (var_name not in broadcast_vars)
broadcast_vars[var_name] = {
"fill_constant_pos": -1,
"broadcast_pos": idx,
}
for idx, op in enumerate(block.ops):
if op.type == "fill_constant":
var_name = op.desc.output_arg_names()[0]
if var_name in broadcast_vars:
broadcast_vars[var_name]["fill_constant_pos"] = idx
continue
last_sync_comm_op_idx = -1
last_sync_calc_op_idx = -1
for idx, op in enumerate(block.ops):
if op.type == "c_sync_comm_stream":
last_sync_comm_op_idx = idx
continue
if op.type == "c_sync_calc_stream":
last_sync_calc_op_idx = idx
continue
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if broadcast_vars[var_name]["fill_constant_pos"] != -1:
assert (last_sync_calc_op_idx != -1)
assert (broadcast_vars[var_name]["fill_constant_pos"] <
last_sync_calc_op_idx)
assert (last_sync_calc_op_idx < idx)
continue
for input_name in op.desc.input_arg_names():
if input_name in broadcast_vars:
assert (broadcast_vars[input_name]["broadcast_pos"] != -1)
assert (broadcast_vars[input_name]["broadcast_pos"] <
last_sync_comm_op_idx)
assert (last_sync_comm_op_idx < idx)
print("check done")
return
def _add_broadcast_allreduce(self, block, sub_prog, offset):
"""
add broadcast and allreduce
"""
# insert reduce ops
inserted_op_num = 0
ring_id = -1
if len(sub_prog._allreduce_vars) > 0:
for i in range(self._nrings):
block._insert_op(
offset + sub_prog._end_idx,
type='c_sync_comm_stream',
inputs={'X': sub_prog._allreduce_vars},
outputs={'Out': sub_prog._allreduce_vars},
attrs={'ring_id': i,
OP_ROLE_KEY: OpRole.Forward})
inserted_op_num += self._nrings
for var in sub_prog._allreduce_vars:
ring_id = (ring_id + 1) % self._nrings
block._insert_op(
offset + sub_prog._end_idx,
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
inserted_op_num += 1
block._insert_op(
offset + sub_prog._end_idx,
type='c_sync_calc_stream',
inputs={'X': sub_prog._allreduce_vars[-1]},
outputs={'Out': sub_prog._allreduce_vars[-1]},
attrs={OP_ROLE_KEY: OpRole.Forward})
inserted_op_num += 1
block._sync_with_cpp()
# insert broadcast ops
for op_idx in reversed(
range(offset + sub_prog._start_idx, offset +
sub_prog._end_idx)):
op = block.ops[op_idx]
for input_name in op.desc.input_arg_names():
if input_name in sub_prog._param2broadcast and \
input_name != sub_prog._param2broadcast[input_name]:
op._rename_input(input_name,
sub_prog._param2broadcast[input_name])
for param_name, broadcast_name in sub_prog._param2broadcast.items():
if param_name != broadcast_name:
block.create_var(
name=broadcast_name,
shape=self._main_program.global_block().var(
param_name).shape,
dtype=self._main_program.global_block().var(param_name)
.dtype,
persistable=False)
comm_dep_vars = [v for k, v in sub_prog._param2broadcast.items()]
for i in range(self._nrings):
block._insert_op(
offset + sub_prog._start_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': i,
OP_ROLE_KEY: OpRole.Forward})
inserted_op_num += self._nrings
for param_name, broadcast_name in sub_prog._param2broadcast.items():
broadcast_var = block.var(broadcast_name)
root_device = self._param2device[param_name]
ring_id = (ring_id + 1) % self._nrings
block._insert_op(
offset + sub_prog._start_idx,
type='c_broadcast',
inputs={'X': broadcast_var.name},
outputs={'Out': broadcast_var.name},
attrs={
'ring_id': ring_id,
'root': root_device,
OP_ROLE_KEY: OpRole.Forward
})
inserted_op_num += 1
comm_dep_vars = [
v for k, v in sub_prog._param2broadcast.items() if k != v
]
if comm_dep_vars != []:
block._insert_op(
offset + sub_prog._start_idx,
type='c_sync_calc_stream',
inputs={'X': comm_dep_vars[-1]},
outputs={'Out': comm_dep_vars[-1]},
attrs={OP_ROLE_KEY: OpRole.Forward})
inserted_op_num += 1
for param_name, broadcast_name in sub_prog._param2broadcast.items():
if param_name != broadcast_name:
broadcast_var = block.var(broadcast_name)
block._insert_op(
offset + sub_prog._start_idx,
type="fill_constant",
outputs={"Out": broadcast_var.name},
attrs={
"shape": broadcast_var.shape,
"dtype": broadcast_var.dtype,
"value": 0.0,
})
inserted_op_num += 1
for fp16_name, fp32_name in sub_prog._cast_ops.items():
block._insert_op(
offset + sub_prog._start_idx,
type="cast",
inputs={"X": fp32_name},
outputs={"Out": fp16_name},
attrs={
"in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": core.VarDesc.VarType.FP16
})
inserted_op_num += 1
block._sync_with_cpp()
return inserted_op_num
def _broadcast_params(self, block):
ring_id = -1
for param in block.iter_parameters():
if param.is_distributed:
continue
ring_id = (ring_id + 1) % self._nrings
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
for ring_id in range(self._nrings):
block.append_op(
type='c_sync_comm_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward})
# def _insert_broadcast_ops(self, block, fuse_broadcast=False):
# def _insert_cache(cache,
# prepend_comm_sync=False,
# append_comm_sync=False):
# insert_idx = cache["insert_idx"]
# dummy_var_name = cache["dummy_var_name"]
# assert (len(cache["broadcast_ops"]) > 0)
# if prepend_comm_sync:
# insert_idx += self._insert_comm_sync(block, insert_idx,
# [dummy_var_name])
# if len(cache["fill_constant_ops"]) > 0:
# insert_idx += self._insert_fill_constant(
# block, insert_idx, cache["fill_constant_ops"],
# [dummy_var_name])
# insert_idx += self._insert_broadcast_inner(block, insert_idx,
# cache["broadcast_ops"])
# if append_comm_sync:
# insert_idx += self._insert_comm_sync(block, insert_idx,
# [dummy_var_name])
# return insert_idx - cache["insert_idx"]
# print("insert_idx: ", [x["insert_idx"] for x in self._sub_progs])
# move_ahead = 1
# for idx, cache in reversed(list(enumerate(self._sub_progs))):
# if idx < move_ahead:
# cache["insert_idx"] = 0
# else:
# cache["insert_idx"] = self._sub_progs[idx - move_ahead][
# "insert_idx"]
# print("insert_idx: ", [x["insert_idx"] for x in self._sub_progs])
# inserted_op_num = 0
# for idx, cache in enumerate(self._sub_progs):
# prepend_comm_sync = True
# append_comm_sync = True
# cache["insert_idx"] += inserted_op_num
# inserted_op_num += _insert_cache(
# cache,
# prepend_comm_sync=prepend_comm_sync,
# append_comm_sync=append_comm_sync)
# return
def _insert_allreduce_ops_tmp(self, block):
ring_id = -1
grad = None
for idx, op in reversed(list(enumerate(block.ops))):
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
offset = idx
for i in range(0, len(op_role_var), 2):
# param = block.vars[op_role_var[i]]
grad = block.vars[op_role_var[i + 1]]
# TODO(mapingshuo): what is is_distributed
# if param.is_distributed:
# continue
if offset == idx:
offset += 1
block._insert_op(
offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={OP_ROLE_KEY: OpRole.Backward})
offset += 1
# As we search ops reversedly, we should insert c_allreduce_sum
# op in the same way to keep the ring_id alternate
print("add allreduce op for {}".format(grad.name))
ring_id = (ring_id + 1) % self._nrings
block._insert_op(
offset,
type='c_allreduce_sum',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward
})
if grad is None:
return
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
for ring_id in range(self._nrings):
block._insert_op(
idx + ring_id,
type='c_sync_comm_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward
})
break
def minimize_impl_allreduce(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
self._nrings = self.user_defined_strategy.zero_configs["nrings"]
optimizer = self.inner_opt
if self.user_defined_strategy.zero_configs["amp"]:
optimizer = fluid.contrib.mixed_precision.decorate(
optimizer, use_dynamic_loss_scaling=True)
optimize_ops, params_grads = optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
if startup_program is None:
startup_program = default_startup_program()
print("work idx: ", self.role_maker.worker_index())
endpoints = self.role_maker.get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker.worker_index()]
collective_helper = CollectiveHelper(self.role_maker, self._nrings)
for ring_id in range(self._nrings):
collective_helper._init_communicator(
startup_program, current_endpoint, endpoints,
self.role_maker.worker_index(), ring_id, '6174')
main_block = loss.block
startup_block = startup_program.global_block()
self._broadcast_params(startup_block)
self._insert_scale_loss_grad_ops(
main_block, scale=1.0 / self.role_maker.worker_num())
self._insert_allreduce_ops_tmp(main_block)
print("insert allreduce done")
return optimize_ops, params_grads
# def _insert_comm_sync(self, block, insert_idx, var_names):
# for r in range(self._nrings):
# block._insert_op(
# insert_idx,
# type='c_sync_comm_stream',
# inputs={'X': var_names},
# outputs={'Out': var_names},
# attrs={'ring_id': r,
# OP_ROLE_KEY: OpRole.Backward})
# insert_idx += 1
# return self._nrings
# def _insert_broadcast_inner(self, block, insert_idx, broadcast_attrs):
# for attr in broadcast_attrs:
# block._insert_op(insert_idx, **attr)
# insert_idx += 1
# return len(broadcast_attrs)
# def _insert_fill_constant(self, block, insert_idx, fill_constant_attrs,
# var_names):
# for attr in fill_constant_attrs:
# block._insert_op(insert_idx, **attr)
# insert_idx += 1
# block._insert_op(
# insert_idx,
# type='c_sync_calc_stream',
# inputs={'X': var_names},
# outputs={'Out': var_names},
# attrs={OP_ROLE_KEY: OpRole.Backward})
# return len(fill_constant_attrs) + 1
...@@ -847,7 +847,7 @@ def append_gradient_clip_ops(param_grads): ...@@ -847,7 +847,7 @@ def append_gradient_clip_ops(param_grads):
if g is None: if g is None:
continue continue
with p.block.program._optimized_guard( with p.block.program._optimized_guard(
[p, g]), framework.name_scope('graident_clip_@CLIP'): [p, g]), framework.name_scope('gradient_clip_@CLIP'):
param, new_grad = clip_attr._create_operators(param=p, grad=g) param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_name_dict[param.name] = new_grad.name param_new_grad_name_dict[param.name] = new_grad.name
res.append([param, new_grad]) res.append([param, new_grad])
......
...@@ -16,6 +16,7 @@ from ... import default_main_program ...@@ -16,6 +16,7 @@ from ... import default_main_program
from ... import default_startup_program from ... import default_startup_program
from ... import layers from ... import layers
from ... import unique_name from ... import unique_name
from ... import framework
from . import fp16_utils from . import fp16_utils
from .fp16_utils import rewrite_program from .fp16_utils import rewrite_program
from .fp16_utils import update_role_var_grad from .fp16_utils import update_role_var_grad
...@@ -132,6 +133,7 @@ class OptimizerWithMixedPrecision(object): ...@@ -132,6 +133,7 @@ class OptimizerWithMixedPrecision(object):
gradient respectively, and the scaled loss. gradient respectively, and the scaled loss.
""" """
rewrite_program(self._train_program, self._amp_lists) rewrite_program(self._train_program, self._amp_lists)
with framework.name_scope('mixed_precision'):
self._scaled_loss = loss * self._loss_scaling self._scaled_loss = loss * self._loss_scaling
self._params_grads = self._optimizer.backward( self._params_grads = self._optimizer.backward(
self._scaled_loss, startup_program, parameter_list, no_grad_set, self._scaled_loss, startup_program, parameter_list, no_grad_set,
...@@ -156,11 +158,13 @@ class OptimizerWithMixedPrecision(object): ...@@ -156,11 +158,13 @@ class OptimizerWithMixedPrecision(object):
grads = [g for _, g in params_grads] grads = [g for _, g in params_grads]
with self._train_program._optimized_guard(grads): with self._train_program._optimized_guard(grads):
with framework.name_scope('mixed_precision'):
grads, found_inf = check_finite_and_unscale( grads, found_inf = check_finite_and_unscale(
grads, self._loss_scaling, name="find_infinite_scale") grads, self._loss_scaling, name="find_infinite_scale")
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
with self._train_program._optimized_guard(grads): with self._train_program._optimized_guard(grads):
with framework.name_scope('mixed_precision'):
grads = update_loss_scaling( grads = update_loss_scaling(
grads, grads,
found_inf, found_inf,
......
...@@ -2063,9 +2063,15 @@ class Operator(object): ...@@ -2063,9 +2063,15 @@ class Operator(object):
% (out_proto.name, len(out_args))) % (out_proto.name, len(out_args)))
out_arg_names = [] out_arg_names = []
for arg in out_args: for arg in out_args:
if isinstance(arg, six.string_types):
out_arg_names.append(arg)
else:
out_arg_names.append(cpt.to_text(arg.name)) out_arg_names.append(cpt.to_text(arg.name))
# TODO(minqiyang): could we remove variable's op in static mode? # TODO(minqiyang): could we remove variable's op in static mode?
if not in_dygraph_mode(): if not in_dygraph_mode():
if isinstance(arg, six.string_types):
block.var(arg).op = self
else:
arg.op = self arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
...@@ -2801,7 +2807,6 @@ class Block(object): ...@@ -2801,7 +2807,6 @@ class Block(object):
return var return var
def _remove_var(self, name): def _remove_var(self, name):
self._sync_with_cpp()
self.desc._remove_var(cpt.to_bytes(name)) self.desc._remove_var(cpt.to_bytes(name))
del self.vars[name] del self.vars[name]
...@@ -2893,7 +2898,6 @@ class Block(object): ...@@ -2893,7 +2898,6 @@ class Block(object):
Returns: Returns:
Operator: the insert Operator. Operator: the insert Operator.
""" """
self._sync_with_cpp()
op_desc = self.desc._insert_op(index) op_desc = self.desc._insert_op(index)
op = Operator(block=self, desc=op_desc, *args, **kwargs) op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.insert(index, op) self.ops.insert(index, op)
...@@ -2909,7 +2913,6 @@ class Block(object): ...@@ -2909,7 +2913,6 @@ class Block(object):
Returns: Returns:
None None
""" """
self._sync_with_cpp()
self.desc._remove_op(index, index + 1) self.desc._remove_op(index, index + 1)
del self.ops[index] del self.ops[index]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册