未验证 提交 bfd1dd77 编写于 作者: K kangguangli 提交者: GitHub

[CherryPick] [BugFix] wrong match between depend and c_allreduce_sum (#53271)

* fix bug: wrong match between depend and c_allreduce_sum

(cherry picked from commit 327da8035bdfee3ec2f016e8cda29ec8ee89bc95)

* fix codestyle

(cherry picked from commit bdb1483081adc41aa47d3f7df257f63f1cff399b)

* fix bug

(cherry picked from commit 373ba5253c45ac019ffaa8d69d4ce9e02cb9ae79)

* add c_sync_calc_stream back

(cherry picked from commit 9933d7533ae1f307b76f24a33bf0c59e4c8e8f01)

* fix

(cherry picked from commit abc9a31beaa326f6a566c08749419bb33e209672)

* revert

(cherry picked from commit 07bc98dbf7c9df43910fa6e86a6a2698731dffb2)

* use flag to control

(cherry picked from commit 8e5682a4b99759cbe35a49f3f8c9db735dc8fee4)

* fix for code coverage

(cherry picked from commit fe7e61bdef24fbc43e2f4e1cb67f68963c957cf1)
上级 1e7efd81
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import os
from paddle import static from paddle import static
from paddle.fluid import core from paddle.fluid import core
from paddle.framework import _global_flags from paddle.framework import _global_flags
...@@ -62,6 +64,9 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -62,6 +64,9 @@ class RawProgramOptimizer(MetaOptimizerBase):
self.calc_comm_same_stream = ( self.calc_comm_same_stream = (
user_defined_strategy._calc_comm_same_stream user_defined_strategy._calc_comm_same_stream
) )
self.sync_before_allreduce = os.environ.get(
'FLAGS_sync_before_allreduce', None
)
def _can_apply(self): def _can_apply(self):
if not self.role_maker._is_collective: if not self.role_maker._is_collective:
...@@ -433,17 +438,28 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -433,17 +438,28 @@ class RawProgramOptimizer(MetaOptimizerBase):
OP_ROLE_KEY: OpRole.Backward, OP_ROLE_KEY: OpRole.Backward,
}, },
) )
if not self.calc_comm_same_stream and self.sync_before_allreduce:
block._insert_op_without_sync(
after_idx + 1,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: OpRole.Backward},
)
idx = 0 idx = 0
if not self.calc_comm_same_stream: if not self.calc_comm_same_stream and not self.sync_before_allreduce:
for i in range(len(grad_param_segments)): for i in range(len(grad_param_segments)):
while block.ops[idx].type != 'c_allreduce_sum': while (
block.ops[idx].type != 'c_allreduce_sum'
or fused_vars[i].name not in block.ops[idx].input_arg_names
):
idx += 1 idx += 1
grad_segment, param_segment = grad_param_segments[i] grad_segment, param_segment = grad_param_segments[i]
for grad in grad_segment: for grad in grad_segment:
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 1, idx + 1,
type='depend', type='depend',
inputs={'X': grad, 'Dep': fused_var}, inputs={'X': grad, 'Dep': fused_vars[i]},
outputs={'Out': grad}, outputs={'Out': grad},
) )
idx += 1 idx += 1
...@@ -486,6 +502,21 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -486,6 +502,21 @@ class RawProgramOptimizer(MetaOptimizerBase):
}, },
) )
if self.calc_comm_same_stream or not self.sync_before_allreduce:
block._sync_with_cpp()
return
# insert the sync comm op
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op_without_sync(
idx,
type='c_sync_comm_stream',
inputs={'X': fused_vars},
outputs={'Out': fused_vars},
attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward},
)
break
block._sync_with_cpp() block._sync_with_cpp()
def __get_ouputs_name_to_idx(self, first_backward_idx, block): def __get_ouputs_name_to_idx(self, first_backward_idx, block):
......
...@@ -45,5 +45,22 @@ class TestFleetMetaOptimizerPrecision(TestDistBase): ...@@ -45,5 +45,22 @@ class TestFleetMetaOptimizerPrecision(TestDistBase):
) )
class TestFleetMetaOptimizerPrecisionWithSync(TestFleetMetaOptimizerPrecision):
def need_envs(self):
return {'FLAGS_sync_before_allreduce': '1'}
def test_dist_train(self):
from paddle import fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"dist_fleet_raw_program_optimizer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name + 'with_sync',
need_envs=self.need_envs(),
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册