未验证 提交 e550fc02 编写于 作者: W WangXi 提交者: GitHub

fleet2.0 add fp16 grad compression (#27480)

上级 c5c13473
......@@ -127,6 +127,7 @@ message DistributedStrategy {
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ];
optional bool fp16_allreduce = 25 [ default = false ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
......
......@@ -845,6 +845,29 @@ class DistributedStrategy(object):
check_configs_key(self.strategy.dgc_configs, configs, "dgc_configs")
assign_configs_value(self.strategy.dgc_configs, configs)
@property
def fp16_allreduce(self):
"""
Indicating whether we are using fp16 gradient allreduce training
Default Value: False
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.fp16_allreduce = True # by default this is false
"""
return self.strategy.fp16_allreduce
@fp16_allreduce.setter
@is_strict_auto
def fp16_allreduce(self, flag):
if not isinstance(flag, bool):
raise TypeError('fp16_allreduce must be value of bool type')
self.strategy.fp16_allreduce = flag
@property
def gradient_merge(self):
"""
......
......@@ -23,3 +23,4 @@ from .lars_optimizer import LarsOptimizer
from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer
from .dgc_optimizer import DGCOptimizer
from .lamb_optimizer import LambOptimizer
from .fp16_allreduce_optimizer import FP16AllReduceOptimizer
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from paddle.fluid import core, framework, unique_name
from .meta_optimizer_base import MetaOptimizerBase
class FP16AllReduceOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(FP16AllReduceOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = [
"LarsOptimizer",
"LambOptimizer",
"RecomputeOptimizer",
"LocalSGDOptimizer",
"GradientMergeOptimizer",
"GraphExecutionOptimizer",
"AdaptiveLocalSGDOptimizer",
]
self.meta_optimizers_black_list = ["DGCOptimizer"]
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
user_defined_strategy):
super(FP16AllReduceOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)
def _can_apply(self):
if not self.role_maker._is_collective:
return False
if self.user_defined_strategy.fp16_allreduce:
return True
return False
def _disable_strategy(self, dist_strategy):
dist_strategy.fp16_allreduce = False
def _enable_strategy(self, dist_strategy, context=None):
dist_strategy.fp16_allreduce = True
@staticmethod
def fp16_compression(param_and_grads):
"""
Compress fp32 gradients to fp16 during allreduce.
"""
op_maker = core.op_proto_and_checker_maker
new_param_and_grads = [] # param, grad, is_cast
# cast grad from fp32->fp16 before allreduce,
for param, grad in param_and_grads:
if grad is None or grad.dtype != core.VarDesc.VarType.FP32:
new_param_and_grads.append((param, grad, False))
continue
op = grad.op
block = grad.block
var_attr = op.all_attrs()[op_maker.kOpRoleVarAttrName()]
if param.name not in var_attr:
new_param_and_grads.append((param, grad, False))
continue
# remove (param, grad) from op_role_var
var_attr.remove(param.name)
var_attr.remove(grad.name)
if len(var_attr) > 1:
op._set_attr(op_maker.kOpRoleVarAttrName(), var_attr)
else:
op._remove_attr(op_maker.kOpRoleVarAttrName())
new_grad = block.create_var(
name=unique_name.generate(grad.name + ".cast_fp16"),
dtype=core.VarDesc.VarType.FP16,
persistable=False,
stop_gradient=True)
with block.program._backward_role_guard():
cast_op = block.append_op(
type="cast",
inputs={"X": grad},
outputs={"Out": new_grad},
attrs={
"in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": core.VarDesc.VarType.FP16
},
stop_gradient=True)
backward = op_maker.OpRole.Backward
cast_op._set_attr(op_maker.kOpRoleAttrName(), backward)
cast_op._set_attr(op_maker.kOpRoleVarAttrName(),
[param.name, new_grad.name])
new_grad.op = cast_op
new_param_and_grads.append((param, new_grad, True))
ret_param_and_grads = []
# cast grad from fp16->fp32 after allreduce.
# NOTE. Now we split fp16 compression into two for loops,
# if we do not separate them, fuse allreduce will wrong.
# This must be the problem of fuse allreduce pass, need
# fixed in future.
for param, grad, cast in new_param_and_grads:
if not cast:
ret_param_and_grads.append((param, grad))
continue
block = grad.block
new_grad = block.create_var(
name=unique_name.generate(grad.name + ".cast_fp32"),
dtype=core.VarDesc.VarType.FP32,
persistable=False,
stop_gradient=True)
with block.program._optimized_guard(
[param, grad]), framework.name_scope('fp16_allreduce'):
cast_op = block.append_op(
type="cast",
inputs={"X": grad},
outputs={"Out": new_grad},
attrs={
"in_dtype": core.VarDesc.VarType.FP16,
"out_dtype": core.VarDesc.VarType.FP32
},
stop_gradient=True)
ret_param_and_grads.append((param, new_grad))
return ret_param_and_grads
def apply_optimize(self, loss, startup_program, params_grads):
new_params_grads = self.fp16_compression(params_grads)
return self.inner_opt.apply_optimize(
loss,
startup_program=startup_program,
params_grads=new_params_grads)
......@@ -45,6 +45,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_lars_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_lamb_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_dgc_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_fp16_allreduce_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_private_function)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base)
......@@ -458,6 +459,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_graph_executor MODULES test_fleet_graph_executor ENVS ${dist_ENVS})
py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_fp16_allreduce_meta_optimizer MODULES test_fleet_fp16_allreduce_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS})
py_test_modules(test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS ${dist_ENVS})
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers import FP16AllReduceOptimizer as FP16AllReduce
from test_dist_base import TestDistRunnerBase, runtime_main
from dist_mnist import cnn_model
DTYPE = "float32"
paddle.dataset.mnist.fetch()
# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
class TestDistMnist2x2(TestDistRunnerBase):
def get_model(self, batch_size=2):
# Input data
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Train program
predict = cnn_model(images)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
# Evaluator
batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(
input=predict, label=label, total=batch_size_tensor)
inference_program = fluid.default_main_program().clone()
# Optimization
opt = fluid.optimizer.MomentumOptimizer(
learning_rate=0.001, momentum=0.9)
opt = FP16AllReduce(opt)
# Reader
train_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
opt.minimize(avg_cost)
return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict
if __name__ == "__main__":
runtime_main(TestDistMnist2x2)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_dist_base import TestDistBase
class TestDistMnist2x2FP16AllReduce(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._nccl2_mode = True
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place("dist_mnist_fp16_allreduce.py", delta=1e-5)
if __name__ == "__main__":
unittest.main()
......@@ -102,6 +102,16 @@ class TestStrategyConfig(unittest.TestCase):
strategy.dgc = "True"
self.assertEqual(strategy.dgc, False)
def test_fp16_allreduce(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.fp16_allreduce = True
self.assertEqual(strategy.fp16_allreduce, True)
strategy.fp16_allreduce = False
self.assertEqual(strategy.fp16_allreduce, False)
with self.assertRaises(TypeError):
strategy.fp16_allreduce = "True"
self.assertEqual(strategy.fp16_allreduce, False)
def test_sync_nccl_allreduce(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sync_nccl_allreduce = True
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import unittest
import paddle
import paddle.fluid as fluid
import os
paddle.enable_static()
class TestFleetFP16CompressOptimizer(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "0"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
def net(self, main_prog, startup_prog, dtype='float32'):
with fluid.program_guard(main_prog, startup_prog):
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype=dtype)
input_y = paddle.fluid.layers.data(
name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2],
size=2,
act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.fp16_allreduce = True
return avg_cost, strategy
def test_fp16_allreduce_optimizer(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
train_prog, startup_prog = fluid.Program(), fluid.Program()
avg_cost, strategy = self.net(train_prog, startup_prog)
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
ops = [op.type for op in avg_cost.block.ops]
cast_out = [
op.output('Out')[0] for op in avg_cost.block.ops
if op.type == 'cast'
]
cast_op_count = 0
for name in ops:
if name == 'cast':
cast_op_count += 1
self.assertIn('cast', ops)
self.assertEqual(cast_op_count, 12) # 6 + 6, cast_fp16 + cast_fp32
for name in cast_out:
self.assertIn('cast_fp16', name)
def test_fp16_allreduce_not_apply_fp16_net(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
train_prog, startup_prog = fluid.Program(), fluid.Program()
avg_cost, strategy = self.net(train_prog, startup_prog, dtype='float16')
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
ops = [op.type for op in avg_cost.block.ops]
self.assertNotIn('cast', ops)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册