提交 86c2c362 编写于 作者: L Liufang Sang 提交者: Bai Yifan

fix fuse_reduce_op quantization bug (#20306)

* fix fuse_reduce_op quantization bug test=develop

* close fuse_all_reduce_ops in PaddleSlim, test=develop
上级 b1218d05
...@@ -480,9 +480,12 @@ class Compressor(object): ...@@ -480,9 +480,12 @@ class Compressor(object):
executor = SlimGraphExecutor(self.place) executor = SlimGraphExecutor(self.place)
if context.optimize_graph.compiled_graph is None: if context.optimize_graph.compiled_graph is None:
build_strategy = compiler.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
context.optimize_graph.compiled_graph = compiler.CompiledProgram( context.optimize_graph.compiled_graph = compiler.CompiledProgram(
context.optimize_graph.program).with_data_parallel( context.optimize_graph.program).with_data_parallel(
loss_name=context.optimize_graph.out_nodes['loss']) loss_name=context.optimize_graph.out_nodes['loss'],
build_strategy=build_strategy)
if isinstance(context.train_reader, Variable) or ( if isinstance(context.train_reader, Variable) or (
isinstance(context.train_reader, isinstance(context.train_reader,
......
...@@ -263,6 +263,7 @@ class GraphWrapper(object): ...@@ -263,6 +263,7 @@ class GraphWrapper(object):
build_strategy = compiler.BuildStrategy() build_strategy = compiler.BuildStrategy()
build_strategy.enable_inplace = mem_opt build_strategy.enable_inplace = mem_opt
build_strategy.memory_optimize = mem_opt build_strategy.memory_optimize = mem_opt
build_strategy.fuse_all_reduce_ops = False
# build_strategy.async_mode = False # build_strategy.async_mode = False
self.compiled_graph = compiler.CompiledProgram( self.compiled_graph = compiler.CompiledProgram(
target).with_data_parallel( target).with_data_parallel(
......
...@@ -138,6 +138,7 @@ class QuantizationStrategy(Strategy): ...@@ -138,6 +138,7 @@ class QuantizationStrategy(Strategy):
build_strategy = BuildStrategy() build_strategy = BuildStrategy()
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy.memory_optimize = False build_strategy.memory_optimize = False
build_strategy.fuse_all_reduce_ops = False
# for quantization training # for quantization training
context.optimize_graph.compiled_graph = CompiledProgram( context.optimize_graph.compiled_graph = CompiledProgram(
train_ir_graph.graph).with_data_parallel( train_ir_graph.graph).with_data_parallel(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册