未验证 提交 0f741880 编写于 作者: Z Zeng Jinle 提交者: GitHub

Fix RawProgramOptimizer bug (#35704)

* fix raw optimizer gm

* update

* update ut
上级 83932715
......@@ -164,13 +164,13 @@ class RawProgramOptimizer(MetaOptimizerBase):
def _insert_allreduce_ops_for_gm(self, gm_block):
block = self.main_program.global_block()
last_backward_op_idx = None
for i, op in enumerate(reversed(gm_block.ops)):
if is_backward_op(op) and last_backward_op_idx is None:
last_backward_idx = i
first_optimize_op_idx = None
for i, op in reversed(list(enumerate(gm_block.ops))):
if is_backward_op(op) and first_optimize_op_idx is None:
first_optimize_op_idx = i + 1
break
if last_backward_op_idx is None:
last_backward_op_idx = 0
if first_optimize_op_idx is None:
first_optimize_op_idx = 0
param_vars = []
grad_vars = []
......@@ -191,7 +191,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
return
gm_block._insert_op(
last_backward_op_idx,
first_optimize_op_idx,
type="c_sync_calc_stream",
inputs={'X': grad_vars[0]},
outputs={'Out': grad_vars[0]},
......@@ -203,7 +203,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
# NOTE: can perform fuse allreduce inside the loop in the future
for i, (p, g) in enumerate(zip(param_vars, grad_vars)):
gm_block._insert_op(
last_backward_op_idx + insert_op_num,
first_optimize_op_idx + insert_op_num,
type="c_allreduce_sum",
inputs={'X': g},
outputs={'Out': g},
......@@ -214,7 +214,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
insert_op_num += 1
gm_block._insert_op(
last_backward_op_idx + insert_op_num,
first_optimize_op_idx + insert_op_num,
type="c_sync_comm_stream",
inputs={'X': grad_vars[-1]},
outputs={'Out': grad_vars[-1]},
......
......@@ -789,7 +789,7 @@ endif()
if (WITH_DISTRIBUTE AND NOT APPLE)
if(WITH_GPU OR WITH_ROCM)
set_tests_properties(test_c_comm_init_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_dist_mnist_gradient_merge PROPERTIES TIMEOUT 120)
set_tests_properties(test_dist_mnist_gradient_merge PROPERTIES TIMEOUT 160)
endif()
endif()
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
......@@ -44,9 +45,10 @@ class TestDistMnistGradientMergeRawOptimizer(TestDistRunnerBase):
strategy.build_strategy = build_strategy
strategy.gradient_merge = True
avg = os.environ['enable_gm_avg'] == "True"
strategy.gradient_merge_configs = {
"k_steps": 2,
"avg": False,
"avg": avg,
}
strategy.without_graph_optimization = True
......@@ -65,9 +67,25 @@ class TestDistMnistGradientMergeRawOptimizer(TestDistRunnerBase):
optimizer,
k_steps=strategy.gradient_merge_configs["k_steps"],
avg=strategy.gradient_merge_configs["avg"])
world_size = 1
else:
optimizer = fleet.distributed_optimizer(optimizer)
world_size = fleet.world_size()
optimizer.minimize(cost)
if world_size > 1:
assert paddle.static.default_main_program().num_blocks == 2
gm_block = paddle.static.default_main_program().block(1)
start_allreduce_idx = None
for i, op in enumerate(gm_block.ops):
if op.type == "c_allreduce_sum":
start_allreduce_idx = i
break
# the magic number 1 below means skip the c_sync_calc_stream op
if avg:
assert start_allreduce_idx > 1
else:
assert start_allreduce_idx == 1
train_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
test_reader = paddle.batch(
......
......@@ -52,21 +52,35 @@ class TestDistMnistGradMergeNoFuse(TestDistBase):
log_name=flag_name + "_no_fuse")
class TestDistMnistGradMergeRawOptimizer(TestDistBase):
class TestDistMnistGradMergeRawOptimizerBase(TestDistBase):
def _setup_config(self):
self._use_reader_alloc = False
self._nccl2_mode = True
self._use_fleet_api = True
self._use_fleet_api_20 = True
def enable_avg(self):
return False
def test_dist_train(self):
if fluid.core.is_compiled_with_cuda():
avg = str(self.enable_avg())
log_name = flag_name + "_raw_optimizer_gm_avg_" + avg
self.check_with_place(
"dist_mnist_gradient_merge_raw_optimizer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name + "_raw_optimizer",
need_envs={'FLAGS_apply_pass_to_program': '1'})
log_name=log_name,
need_envs={
'FLAGS_apply_pass_to_program': '1',
'enable_gm_avg': avg,
})
class TestDistMnistGradMergeRawOptimizerAvg(
TestDistMnistGradMergeRawOptimizerBase):
def enable_avg(self):
return True
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册