diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 679e2a27ac26e69eaa03173bbf517eff038b26f0..cdb5f4221237f720809fe200e38fd1d7d568109a 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -156,6 +156,7 @@ class FP16State(object): list ) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]} self.is_train = False + self.out_var_op_deps = {} def _is_fp16_op(self, op_id): return self._op_fp16_dict.get(op_id, None) @@ -169,6 +170,14 @@ class FP16State(object): # assume all backward block are behind forward blocks for block in self.program.blocks: for op in block.ops: + for name in op.output_arg_names: + if name not in self.out_var_op_deps: + self.out_var_op_deps[name] = [op.desc.original_id()] + else: + self.out_var_op_deps[name].extend( + [op.desc.original_id()] + ) + self._mark_op(op) # set forward tensor dtype @@ -192,6 +201,18 @@ class FP16State(object): if op.type == "assign" and "array_" in op.input_arg_names[0]: self._op_fp16_dict[op.desc.original_id()] = False return + # If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var. + if op.type == "assign": + out_name = op.output_arg_names[0] + if len(self.out_var_op_deps[out_name]) > 1: + if not self._op_fp16_dict[ + self.out_var_op_deps[out_name][0] + ]: + self._op_fp16_dict[op.desc.original_id()] = False + else: + self._op_fp16_dict[op.desc.original_id()] = True + return + if _need_keep_fp32( op, self.amp_list.unsupported_list, self.use_fp16_guard ): diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index c538ae126f016ba980e0872872c6f5ebb1690bdd..b2935a0b175b3d0f2b55b5aee74711efec9f9b6c 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -115,5 +115,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_conditional_block_reshard MODULES test_conditional_block_reshard) py_test_modules(test_engine_api_error MODULES test_engine_api_error) + py_test_modules(test_fp16_assign MODULES test_fp16_assign) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_fp16_assign.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_fp16_assign.py new file mode 100644 index 0000000000000000000000000000000000000000..da385173dca3292d466614f785c556d14a4f2727 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_fp16_assign.py @@ -0,0 +1,145 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import copy + +import paddle +from paddle.distributed.fleet import auto +from paddle.distributed.passes import new_pass + +paddle.enable_static() + + +def make_program(): + main_program = paddle.fluid.Program() + start_program = paddle.fluid.Program() + with paddle.static.program_guard(main_program, start_program): + x = paddle.static.data(name='x', shape=[4, 6, 8], dtype='float32') + y = paddle.static.data(name='y', shape=[4, 6, 6], dtype='float32') + z = paddle.static.data(name='y', shape=[4, 6, 6], dtype='float32') + + auto.shard_tensor(x, auto.ProcessMesh([0], ['d0']), [None, None, None]) + + out0 = paddle.static.nn.fc( + x, + size=6, + num_flatten_dims=2, + weight_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.5) + ), + bias_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=1.0) + ), + ) + where_0 = paddle.where(y > 1, y, out0) + + out1 = paddle.static.nn.fc( + out0, + size=6, + num_flatten_dims=2, + weight_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.5) + ), + bias_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=1.0) + ), + ) + where_1 = paddle.where(y > 1, y, out1) + + paddle.fluid.layers.assign(where_1, where_0) + + return main_program, start_program + + +def parallelizer(program_func, rank): + from paddle.distributed.auto_parallel.completion import Completer + from paddle.distributed.auto_parallel.partitioner import Partitioner + from paddle.distributed.auto_parallel.dist_context import DistributedContext + + main_program, start_program = program_func() + + dist_context = DistributedContext() + completer = Completer(dist_context) + completer.complete_forward_annotation(main_program) + dist_context.block_state.parse_forward_blocks(main_program) + + strategy = auto.Strategy() + amp = strategy.amp + amp.enable = True + amp.use_pure_fp16 = True + amp.init_loss_scaling = 32768 + amp.use_fp16_guard = False + amp.custom_black_list = ['where'] + + config = copy.deepcopy(strategy.amp.to_dict()) + config["dist_context"] = dist_context + config["params_grads"] = [] + config["loss"] = None + config["base_opt"] = None + auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) + auto_parallel_fp16_pass.apply([main_program], [start_program], None) + + partitioner = Partitioner(dist_context, rank) + dist_main_prog, _, _ = partitioner.partition( + main_program, start_program, [] + ) + + return dist_main_prog, dist_context + + +class TestFp16Assign(unittest.TestCase): + def assert_fp32_dtype(self, block, op): + for slot in op.input_names: + for name in op.input(slot): + if block.vars[name].dtype == paddle.bool: + continue + assert block.vars[name].dtype == paddle.float32 + for slot in op.output_names: + for name in op.output(slot): + if block.vars[name].dtype == paddle.bool: + continue + assert block.vars[name].dtype == paddle.float32 + + def assert_fp16_dtype(self, block, op): + for slot in op.input_names: + if slot == "Condition": + continue + for name in op.input(slot): + if block.vars[name].dtype == paddle.bool: + continue + assert block.vars[name].dtype == paddle.float16 + for slot in op.output_names: + for name in op.output(slot): + if block.vars[name].dtype == paddle.bool: + continue + assert block.vars[name].dtype == paddle.float16 + + def test_fp16_assign(self): + + dist_main_prog, dist_context = parallelizer(make_program, 0) + block = dist_main_prog.global_block() + for op in block.ops: + if op.type == "cast": + continue + if op.type == "where": + self.assert_fp32_dtype(block, op) + elif op.type == "assign": + self.assert_fp32_dtype(block, op) + else: + self.assert_fp16_dtype(block, op) + + +if __name__ == "__main__": + unittest.main()