提交 e3334f3e 编写于 作者: M mapingshuo

add zero

上级 43240a1b
...@@ -24,6 +24,14 @@ enum Mode { ...@@ -24,6 +24,14 @@ enum Mode {
message RecomputeConfig { repeated string checkpoints = 1; } message RecomputeConfig { repeated string checkpoints = 1; }
message ZeROConfig {
optional bool amp = 1 [ default = true ];
optional int32 nrings = 2 [ default = 3 ];
optional float fuse_broadcast_MB_bytes = 3 [ default = 64.0 ];
repeated string checkpoints = 4;
optional bool allreduce = 5 [ default = false ];
}
message AMPConfig { message AMPConfig {
optional float init_loss_scaling = 1 [ default = 32768.0 ]; optional float init_loss_scaling = 1 [ default = 32768.0 ];
optional int32 incr_every_n_steps = 2 [ default = 1000 ]; optional int32 incr_every_n_steps = 2 [ default = 1000 ];
...@@ -127,6 +135,7 @@ message DistributedStrategy { ...@@ -127,6 +135,7 @@ message DistributedStrategy {
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ]; optional int32 conv_workspace_size_limit = 22 [ default = 4000 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ]; optional bool adaptive_localsgd = 24 [ default = false ];
optional bool zero = 25 [ default = false ];
optional RecomputeConfig recompute_configs = 101; optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102; optional AMPConfig amp_configs = 102;
...@@ -138,6 +147,7 @@ message DistributedStrategy { ...@@ -138,6 +147,7 @@ message DistributedStrategy {
optional LarsConfig lars_configs = 108; optional LarsConfig lars_configs = 108;
optional LambConfig lamb_configs = 109; optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110; optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional ZeROConfig zero_configs = 111;
optional BuildStrategy build_strategy = 201; optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202; optional ExecutionStrategy execution_strategy = 202;
} }
......
...@@ -55,8 +55,10 @@ class CSyncCommStreamOp : public framework::OperatorBase { ...@@ -55,8 +55,10 @@ class CSyncCommStreamOp : public framework::OperatorBase {
class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker { class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "(Tensor) Dependency of the variable need to sync"); AddInput("X", "(Tensor) Dependency of the variable need to sync")
AddOutput("Out", "(Tensor) Dependency of the variable need to sync"); .AsDuplicable();
AddOutput("Out", "(Tensor) Dependency of the variable need to sync")
.AsDuplicable();
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0); AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
CSyncCommStream Operator CSyncCommStream Operator
......
...@@ -611,6 +611,39 @@ class DistributedStrategy(object): ...@@ -611,6 +611,39 @@ class DistributedStrategy(object):
"checkpoint_configs") "checkpoint_configs")
assign_configs_value(self.strategy.recompute_configs, configs) assign_configs_value(self.strategy.recompute_configs, configs)
@property
def zero(self):
"""
Indicating whether we are using Zero Redundancy Optimizer for memory
optimization
Default value: False
Examples:
.. code-block:: python
import paddle.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.zero = True
"""
return self.strategy.zero
@zero.setter
def zero(self, flag):
if isinstance(flag, bool):
self.strategy.zero = flag
else:
print("WARNING: zero should have value of bool type")
@property
def zero_configs(self):
"""
Set zero configurations.
"""
return get_msg_dict(self.strategy.zero_configs)
@zero_configs.setter
def zero_configs(self, configs):
check_configs_key(self.strategy.zero_configs, configs, "zero_configs")
assign_configs_value(self.strategy.zero_configs, configs)
@property @property
def pipeline(self): def pipeline(self):
""" """
......
...@@ -1086,6 +1086,9 @@ class Fleet(object): ...@@ -1086,6 +1086,9 @@ class Fleet(object):
context["program_optimize_ops"] = optimize_ops context["program_optimize_ops"] = optimize_ops
context["program_params_grads"] = params_grads context["program_params_grads"] = params_grads
if self.user_defined_strategy.zero:
graph_optimizer = None
if graph_optimizer: if graph_optimizer:
optimize_ops, params_grads = graph_optimizer.minimize( optimize_ops, params_grads = graph_optimizer.minimize(
loss, loss,
......
...@@ -23,3 +23,4 @@ from .lars_optimizer import LarsOptimizer ...@@ -23,3 +23,4 @@ from .lars_optimizer import LarsOptimizer
from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer
from .dgc_optimizer import DGCOptimizer from .dgc_optimizer import DGCOptimizer
from .lamb_optimizer import LambOptimizer from .lamb_optimizer import LambOptimizer
from .zero_optimizer import ZeroOptimizer
...@@ -847,7 +847,7 @@ def append_gradient_clip_ops(param_grads): ...@@ -847,7 +847,7 @@ def append_gradient_clip_ops(param_grads):
if g is None: if g is None:
continue continue
with p.block.program._optimized_guard( with p.block.program._optimized_guard(
[p, g]), framework.name_scope('graident_clip_@CLIP'): [p, g]), framework.name_scope('gradient_clip_@CLIP'):
param, new_grad = clip_attr._create_operators(param=p, grad=g) param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_name_dict[param.name] = new_grad.name param_new_grad_name_dict[param.name] = new_grad.name
res.append([param, new_grad]) res.append([param, new_grad])
......
...@@ -16,6 +16,7 @@ from ... import default_main_program ...@@ -16,6 +16,7 @@ from ... import default_main_program
from ... import default_startup_program from ... import default_startup_program
from ... import layers from ... import layers
from ... import unique_name from ... import unique_name
from ... import framework
from . import fp16_utils from . import fp16_utils
from .fp16_utils import rewrite_program from .fp16_utils import rewrite_program
from .fp16_utils import update_role_var_grad from .fp16_utils import update_role_var_grad
...@@ -132,7 +133,8 @@ class OptimizerWithMixedPrecision(object): ...@@ -132,7 +133,8 @@ class OptimizerWithMixedPrecision(object):
gradient respectively, and the scaled loss. gradient respectively, and the scaled loss.
""" """
rewrite_program(self._train_program, self._amp_lists) rewrite_program(self._train_program, self._amp_lists)
self._scaled_loss = loss * self._loss_scaling with framework.name_scope('mixed_precision'):
self._scaled_loss = loss * self._loss_scaling
self._params_grads = self._optimizer.backward( self._params_grads = self._optimizer.backward(
self._scaled_loss, startup_program, parameter_list, no_grad_set, self._scaled_loss, startup_program, parameter_list, no_grad_set,
callbacks) callbacks)
...@@ -156,22 +158,24 @@ class OptimizerWithMixedPrecision(object): ...@@ -156,22 +158,24 @@ class OptimizerWithMixedPrecision(object):
grads = [g for _, g in params_grads] grads = [g for _, g in params_grads]
with self._train_program._optimized_guard(grads): with self._train_program._optimized_guard(grads):
grads, found_inf = check_finite_and_unscale( with framework.name_scope('mixed_precision'):
grads, self._loss_scaling, name="find_infinite_scale") grads, found_inf = check_finite_and_unscale(
grads, self._loss_scaling, name="find_infinite_scale")
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
with self._train_program._optimized_guard(grads): with self._train_program._optimized_guard(grads):
grads = update_loss_scaling( with framework.name_scope('mixed_precision'):
grads, grads = update_loss_scaling(
found_inf, grads,
self._loss_scaling, found_inf,
self._num_good_steps, self._loss_scaling,
self._num_bad_steps, self._num_good_steps,
self._incr_every_n_steps, self._num_bad_steps,
self._decr_every_n_nan_or_inf, self._incr_every_n_steps,
self._incr_ratio, self._decr_every_n_nan_or_inf,
self._decr_ratio, self._incr_ratio,
name="update_loss_scaling") self._decr_ratio,
name="update_loss_scaling")
params_unscaled_grads = [] params_unscaled_grads = []
for pg, new_g in zip(params_grads, grads): for pg, new_g in zip(params_grads, grads):
......
...@@ -2063,10 +2063,16 @@ class Operator(object): ...@@ -2063,10 +2063,16 @@ class Operator(object):
% (out_proto.name, len(out_args))) % (out_proto.name, len(out_args)))
out_arg_names = [] out_arg_names = []
for arg in out_args: for arg in out_args:
out_arg_names.append(cpt.to_text(arg.name)) if isinstance(arg, six.string_types):
out_arg_names.append(arg)
else:
out_arg_names.append(cpt.to_text(arg.name))
# TODO(minqiyang): could we remove variable's op in static mode? # TODO(minqiyang): could we remove variable's op in static mode?
if not in_dygraph_mode(): if not in_dygraph_mode():
arg.op = self if isinstance(arg, six.string_types):
block.var(arg).op = self
else:
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
if op_attrs is not None: if op_attrs is not None:
...@@ -2801,7 +2807,6 @@ class Block(object): ...@@ -2801,7 +2807,6 @@ class Block(object):
return var return var
def _remove_var(self, name): def _remove_var(self, name):
self._sync_with_cpp()
self.desc._remove_var(cpt.to_bytes(name)) self.desc._remove_var(cpt.to_bytes(name))
del self.vars[name] del self.vars[name]
...@@ -2893,7 +2898,6 @@ class Block(object): ...@@ -2893,7 +2898,6 @@ class Block(object):
Returns: Returns:
Operator: the insert Operator. Operator: the insert Operator.
""" """
self._sync_with_cpp()
op_desc = self.desc._insert_op(index) op_desc = self.desc._insert_op(index)
op = Operator(block=self, desc=op_desc, *args, **kwargs) op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.insert(index, op) self.ops.insert(index, op)
...@@ -2909,7 +2913,6 @@ class Block(object): ...@@ -2909,7 +2913,6 @@ class Block(object):
Returns: Returns:
None None
""" """
self._sync_with_cpp()
self.desc._remove_op(index, index + 1) self.desc._remove_op(index, index + 1)
del self.ops[index] del self.ops[index]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册