未验证 提交 bfd1dd77 编写于 作者: K kangguangli 提交者: GitHub

[CherryPick] [BugFix] wrong match between depend and c_allreduce_sum (#53271)

* fix bug: wrong match between depend and c_allreduce_sum

(cherry picked from commit 327da8035bdfee3ec2f016e8cda29ec8ee89bc95)

* fix codestyle

(cherry picked from commit bdb1483081adc41aa47d3f7df257f63f1cff399b)

* fix bug

(cherry picked from commit 373ba5253c45ac019ffaa8d69d4ce9e02cb9ae79)

* add c_sync_calc_stream back

(cherry picked from commit 9933d7533ae1f307b76f24a33bf0c59e4c8e8f01)

* fix

(cherry picked from commit abc9a31beaa326f6a566c08749419bb33e209672)

* revert

(cherry picked from commit 07bc98dbf7c9df43910fa6e86a6a2698731dffb2)

* use flag to control

(cherry picked from commit 8e5682a4b99759cbe35a49f3f8c9db735dc8fee4)

* fix for code coverage

(cherry picked from commit fe7e61bdef24fbc43e2f4e1cb67f68963c957cf1)
上级 1e7efd81
......@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import os
from paddle import static
from paddle.fluid import core
from paddle.framework import _global_flags
......@@ -62,6 +64,9 @@ class RawProgramOptimizer(MetaOptimizerBase):
self.calc_comm_same_stream = (
user_defined_strategy._calc_comm_same_stream
)
self.sync_before_allreduce = os.environ.get(
'FLAGS_sync_before_allreduce', None
)
def _can_apply(self):
if not self.role_maker._is_collective:
......@@ -433,17 +438,28 @@ class RawProgramOptimizer(MetaOptimizerBase):
OP_ROLE_KEY: OpRole.Backward,
},
)
if not self.calc_comm_same_stream and self.sync_before_allreduce:
block._insert_op_without_sync(
after_idx + 1,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: OpRole.Backward},
)
idx = 0
if not self.calc_comm_same_stream:
if not self.calc_comm_same_stream and not self.sync_before_allreduce:
for i in range(len(grad_param_segments)):
while block.ops[idx].type != 'c_allreduce_sum':
while (
block.ops[idx].type != 'c_allreduce_sum'
or fused_vars[i].name not in block.ops[idx].input_arg_names
):
idx += 1
grad_segment, param_segment = grad_param_segments[i]
for grad in grad_segment:
block._insert_op_without_sync(
idx + 1,
type='depend',
inputs={'X': grad, 'Dep': fused_var},
inputs={'X': grad, 'Dep': fused_vars[i]},
outputs={'Out': grad},
)
idx += 1
......@@ -486,6 +502,21 @@ class RawProgramOptimizer(MetaOptimizerBase):
},
)
if self.calc_comm_same_stream or not self.sync_before_allreduce:
block._sync_with_cpp()
return
# insert the sync comm op
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op_without_sync(
idx,
type='c_sync_comm_stream',
inputs={'X': fused_vars},
outputs={'Out': fused_vars},
attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward},
)
break
block._sync_with_cpp()
def __get_ouputs_name_to_idx(self, first_backward_idx, block):
......
......@@ -45,5 +45,22 @@ class TestFleetMetaOptimizerPrecision(TestDistBase):
)
class TestFleetMetaOptimizerPrecisionWithSync(TestFleetMetaOptimizerPrecision):
def need_envs(self):
return {'FLAGS_sync_before_allreduce': '1'}
def test_dist_train(self):
from paddle import fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"dist_fleet_raw_program_optimizer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name + 'with_sync',
need_envs=self.need_envs(),
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册