From 5920d69df509f5feddf6d7be8190bfb21c874310 Mon Sep 17 00:00:00 2001 From: ShenLiang <2282912238@qq.com> Date: Wed, 25 Sep 2019 13:07:02 +0800 Subject: [PATCH] Avoid treating broadcast as initialization operation (#19857) * treat broadcast as non-initial, test=develop * rename the class name * rename the class name, test=develop --- python/paddle/fluid/framework.py | 5 ++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../test_avoid_twice_initialization.py | 50 +++++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_avoid_twice_initialization.py diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 650c78a9583..c9fb957656c 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1833,6 +1833,11 @@ class Block(object): init_ops = [] for op in block.ops: if var.name in op.output_arg_names: + # In startup_program, "c_broadcast" and "c_sync_comm_stream" + # are treated as initialization ops that cause error. + # Think of "c_broadcast" and "c_sync_comm_stream" as a special case here. + if op.type in ["c_broadcast", "c_sync_comm_stream"]: + continue init_ops.append(op) return init_ops diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 6b1971b1561..6809f88da03 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -39,6 +39,7 @@ if(WIN32) LIST(REMOVE_ITEM TEST_OPS test_boxps) LIST(REMOVE_ITEM TEST_OPS test_trainer_desc) LIST(REMOVE_ITEM TEST_OPS test_multiprocess_reader_exception) + LIST(REMOVE_ITEM TEST_OPS test_avoid_twice_initialization) endif() LIST(REMOVE_ITEM TEST_OPS test_launch) diff --git a/python/paddle/fluid/tests/unittests/test_avoid_twice_initialization.py b/python/paddle/fluid/tests/unittests/test_avoid_twice_initialization.py new file mode 100644 index 00000000000..8572572f146 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_avoid_twice_initialization.py @@ -0,0 +1,50 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid + + +class TestAvoidTwiceInitialization(unittest.TestCase): + def test_avoid_twice_initialization(self): + cur_program = fluid.Program() + cur_block = cur_program.current_block() + var = cur_block.create_parameter( + initializer=fluid.initializer.Constant(value=0.01), + shape=[2, 2], + dtype='float32', + name='var_a') + cur_block.append_op( + type="c_broadcast", + inputs={"X": [var]}, + outputs={"Out": [var]}, + attrs={'root': 0, + 'ring_id': 0, + 'use_calc_stream': False}) + cur_block.append_op( + type="c_sync_comm_stream", + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={'ring_id': 0}) + var2 = cur_block.create_parameter( + initializer=fluid.initializer.Constant(value=0.01), + shape=[2, 2], + dtype='float32', + name='var_a') + + +if __name__ == '__main__': + unittest.main() -- GitLab