未验证 提交 6c51e493 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel]fp16 pass support assign op (#47649)

* fp16 pass support assign op

* choose assign op exec mode

* add unittest

* add cmakelist
上级 c65f0565
...@@ -156,6 +156,7 @@ class FP16State(object): ...@@ -156,6 +156,7 @@ class FP16State(object):
list list
) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]} ) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
self.is_train = False self.is_train = False
self.out_var_op_deps = {}
def _is_fp16_op(self, op_id): def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None) return self._op_fp16_dict.get(op_id, None)
...@@ -169,6 +170,14 @@ class FP16State(object): ...@@ -169,6 +170,14 @@ class FP16State(object):
# assume all backward block are behind forward blocks # assume all backward block are behind forward blocks
for block in self.program.blocks: for block in self.program.blocks:
for op in block.ops: for op in block.ops:
for name in op.output_arg_names:
if name not in self.out_var_op_deps:
self.out_var_op_deps[name] = [op.desc.original_id()]
else:
self.out_var_op_deps[name].extend(
[op.desc.original_id()]
)
self._mark_op(op) self._mark_op(op)
# set forward tensor dtype # set forward tensor dtype
...@@ -192,6 +201,18 @@ class FP16State(object): ...@@ -192,6 +201,18 @@ class FP16State(object):
if op.type == "assign" and "array_" in op.input_arg_names[0]: if op.type == "assign" and "array_" in op.input_arg_names[0]:
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[op.desc.original_id()] = False
return return
# If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
if op.type == "assign":
out_name = op.output_arg_names[0]
if len(self.out_var_op_deps[out_name]) > 1:
if not self._op_fp16_dict[
self.out_var_op_deps[out_name][0]
]:
self._op_fp16_dict[op.desc.original_id()] = False
else:
self._op_fp16_dict[op.desc.original_id()] = True
return
if _need_keep_fp32( if _need_keep_fp32(
op, self.amp_list.unsupported_list, self.use_fp16_guard op, self.amp_list.unsupported_list, self.use_fp16_guard
): ):
......
...@@ -115,5 +115,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -115,5 +115,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_conditional_block_reshard MODULES py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard) test_conditional_block_reshard)
py_test_modules(test_engine_api_error MODULES test_engine_api_error) py_test_modules(test_engine_api_error MODULES test_engine_api_error)
py_test_modules(test_fp16_assign MODULES test_fp16_assign)
endif() endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import copy
import paddle
from paddle.distributed.fleet import auto
from paddle.distributed.passes import new_pass
paddle.enable_static()
def make_program():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 6, 8], dtype='float32')
y = paddle.static.data(name='y', shape=[4, 6, 6], dtype='float32')
z = paddle.static.data(name='y', shape=[4, 6, 6], dtype='float32')
auto.shard_tensor(x, auto.ProcessMesh([0], ['d0']), [None, None, None])
out0 = paddle.static.nn.fc(
x,
size=6,
num_flatten_dims=2,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.5)
),
bias_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1.0)
),
)
where_0 = paddle.where(y > 1, y, out0)
out1 = paddle.static.nn.fc(
out0,
size=6,
num_flatten_dims=2,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.5)
),
bias_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1.0)
),
)
where_1 = paddle.where(y > 1, y, out1)
paddle.fluid.layers.assign(where_1, where_0)
return main_program, start_program
def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.dist_context import DistributedContext
main_program, start_program = program_func()
dist_context = DistributedContext()
completer = Completer(dist_context)
completer.complete_forward_annotation(main_program)
dist_context.block_state.parse_forward_blocks(main_program)
strategy = auto.Strategy()
amp = strategy.amp
amp.enable = True
amp.use_pure_fp16 = True
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.custom_black_list = ['where']
config = copy.deepcopy(strategy.amp.to_dict())
config["dist_context"] = dist_context
config["params_grads"] = []
config["loss"] = None
config["base_opt"] = None
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply([main_program], [start_program], None)
partitioner = Partitioner(dist_context, rank)
dist_main_prog, _, _ = partitioner.partition(
main_program, start_program, []
)
return dist_main_prog, dist_context
class TestFp16Assign(unittest.TestCase):
def assert_fp32_dtype(self, block, op):
for slot in op.input_names:
for name in op.input(slot):
if block.vars[name].dtype == paddle.bool:
continue
assert block.vars[name].dtype == paddle.float32
for slot in op.output_names:
for name in op.output(slot):
if block.vars[name].dtype == paddle.bool:
continue
assert block.vars[name].dtype == paddle.float32
def assert_fp16_dtype(self, block, op):
for slot in op.input_names:
if slot == "Condition":
continue
for name in op.input(slot):
if block.vars[name].dtype == paddle.bool:
continue
assert block.vars[name].dtype == paddle.float16
for slot in op.output_names:
for name in op.output(slot):
if block.vars[name].dtype == paddle.bool:
continue
assert block.vars[name].dtype == paddle.float16
def test_fp16_assign(self):
dist_main_prog, dist_context = parallelizer(make_program, 0)
block = dist_main_prog.global_block()
for op in block.ops:
if op.type == "cast":
continue
if op.type == "where":
self.assert_fp32_dtype(block, op)
elif op.type == "assign":
self.assert_fp32_dtype(block, op)
else:
self.assert_fp16_dtype(block, op)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册