提交 e3334f3e 编写于 作者: M mapingshuo

add zero

上级 43240a1b
......@@ -24,6 +24,14 @@ enum Mode {
message RecomputeConfig { repeated string checkpoints = 1; }
message ZeROConfig {
optional bool amp = 1 [ default = true ];
optional int32 nrings = 2 [ default = 3 ];
optional float fuse_broadcast_MB_bytes = 3 [ default = 64.0 ];
repeated string checkpoints = 4;
optional bool allreduce = 5 [ default = false ];
}
message AMPConfig {
optional float init_loss_scaling = 1 [ default = 32768.0 ];
optional int32 incr_every_n_steps = 2 [ default = 1000 ];
......@@ -127,6 +135,7 @@ message DistributedStrategy {
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ];
optional bool zero = 25 [ default = false ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
......@@ -138,6 +147,7 @@ message DistributedStrategy {
optional LarsConfig lars_configs = 108;
optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional ZeROConfig zero_configs = 111;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
}
......
......@@ -55,8 +55,10 @@ class CSyncCommStreamOp : public framework::OperatorBase {
class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Dependency of the variable need to sync");
AddOutput("Out", "(Tensor) Dependency of the variable need to sync");
AddInput("X", "(Tensor) Dependency of the variable need to sync")
.AsDuplicable();
AddOutput("Out", "(Tensor) Dependency of the variable need to sync")
.AsDuplicable();
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
AddComment(R"DOC(
CSyncCommStream Operator
......
......@@ -611,6 +611,39 @@ class DistributedStrategy(object):
"checkpoint_configs")
assign_configs_value(self.strategy.recompute_configs, configs)
@property
def zero(self):
"""
Indicating whether we are using Zero Redundancy Optimizer for memory
optimization
Default value: False
Examples:
.. code-block:: python
import paddle.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.zero = True
"""
return self.strategy.zero
@zero.setter
def zero(self, flag):
if isinstance(flag, bool):
self.strategy.zero = flag
else:
print("WARNING: zero should have value of bool type")
@property
def zero_configs(self):
"""
Set zero configurations.
"""
return get_msg_dict(self.strategy.zero_configs)
@zero_configs.setter
def zero_configs(self, configs):
check_configs_key(self.strategy.zero_configs, configs, "zero_configs")
assign_configs_value(self.strategy.zero_configs, configs)
@property
def pipeline(self):
"""
......
......@@ -1086,6 +1086,9 @@ class Fleet(object):
context["program_optimize_ops"] = optimize_ops
context["program_params_grads"] = params_grads
if self.user_defined_strategy.zero:
graph_optimizer = None
if graph_optimizer:
optimize_ops, params_grads = graph_optimizer.minimize(
loss,
......
......@@ -23,3 +23,4 @@ from .lars_optimizer import LarsOptimizer
from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer
from .dgc_optimizer import DGCOptimizer
from .lamb_optimizer import LambOptimizer
from .zero_optimizer import ZeroOptimizer
......@@ -847,7 +847,7 @@ def append_gradient_clip_ops(param_grads):
if g is None:
continue
with p.block.program._optimized_guard(
[p, g]), framework.name_scope('graident_clip_@CLIP'):
[p, g]), framework.name_scope('gradient_clip_@CLIP'):
param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_name_dict[param.name] = new_grad.name
res.append([param, new_grad])
......
......@@ -16,6 +16,7 @@ from ... import default_main_program
from ... import default_startup_program
from ... import layers
from ... import unique_name
from ... import framework
from . import fp16_utils
from .fp16_utils import rewrite_program
from .fp16_utils import update_role_var_grad
......@@ -132,7 +133,8 @@ class OptimizerWithMixedPrecision(object):
gradient respectively, and the scaled loss.
"""
rewrite_program(self._train_program, self._amp_lists)
self._scaled_loss = loss * self._loss_scaling
with framework.name_scope('mixed_precision'):
self._scaled_loss = loss * self._loss_scaling
self._params_grads = self._optimizer.backward(
self._scaled_loss, startup_program, parameter_list, no_grad_set,
callbacks)
......@@ -156,22 +158,24 @@ class OptimizerWithMixedPrecision(object):
grads = [g for _, g in params_grads]
with self._train_program._optimized_guard(grads):
grads, found_inf = check_finite_and_unscale(
grads, self._loss_scaling, name="find_infinite_scale")
with framework.name_scope('mixed_precision'):
grads, found_inf = check_finite_and_unscale(
grads, self._loss_scaling, name="find_infinite_scale")
if self._use_dynamic_loss_scaling:
with self._train_program._optimized_guard(grads):
grads = update_loss_scaling(
grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
name="update_loss_scaling")
with framework.name_scope('mixed_precision'):
grads = update_loss_scaling(
grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
name="update_loss_scaling")
params_unscaled_grads = []
for pg, new_g in zip(params_grads, grads):
......
......@@ -2063,10 +2063,16 @@ class Operator(object):
% (out_proto.name, len(out_args)))
out_arg_names = []
for arg in out_args:
out_arg_names.append(cpt.to_text(arg.name))
if isinstance(arg, six.string_types):
out_arg_names.append(arg)
else:
out_arg_names.append(cpt.to_text(arg.name))
# TODO(minqiyang): could we remove variable's op in static mode?
if not in_dygraph_mode():
arg.op = self
if isinstance(arg, six.string_types):
block.var(arg).op = self
else:
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
if op_attrs is not None:
......@@ -2801,7 +2807,6 @@ class Block(object):
return var
def _remove_var(self, name):
self._sync_with_cpp()
self.desc._remove_var(cpt.to_bytes(name))
del self.vars[name]
......@@ -2893,7 +2898,6 @@ class Block(object):
Returns:
Operator: the insert Operator.
"""
self._sync_with_cpp()
op_desc = self.desc._insert_op(index)
op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.insert(index, op)
......@@ -2909,7 +2913,6 @@ class Block(object):
Returns:
None
"""
self._sync_with_cpp()
self.desc._remove_op(index, index + 1)
del self.ops[index]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册