未验证 提交 134f9c3e 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Bugfix allreduce fuse for MP (#46086)

* bugfix

* bugfix

* typos fixed
上级 4fba3d5e
......@@ -636,12 +636,12 @@ class Engine:
Evaluate the loss and metrics of the model on evaluation data.
Args:
eval_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
eval_sample_split (int, optional): Each sample of the eval dataset is assumed
valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
valid_sample_split (int, optional): Each sample of the eval dataset is assumed
to be a (input, label) pair by default and has two items. If each sample has
more than two items, eval_sample_split specifies how to split these items into
more than two items, valid_sample_split specifies how to split these items into
input and label. The items before it are input and the left are label. Default: None.
batch_size (int, optional): The batch size of eval_data. The user's data will
batch_size (int, optional): The batch size of valid_data. The user's data will
be used directly without batching if set to None. Default: 1.
steps (int, optional): It is the total number of steps (batches of samples) to draw before
stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid import core
from .process_mesh import ProcessMesh
from .process_mesh import get_current_process_mesh
......
......@@ -111,14 +111,9 @@ class DataParallelOptimizationPass(PassBase):
if not self._could_be_fuse():
return []
with open('./before_program.txt.' + str(paddle.distributed.get_rank()),
'w') as f:
f.write(str(default_main_program()))
grad_group = self._group_grads()
self._update_program(grad_group)
with open('./after_program.txt.' + str(paddle.distributed.get_rank()),
'w') as f:
f.write(str(default_main_program()))
return grad_group
def _analyze_program(self):
......@@ -569,6 +564,11 @@ class GradientsGroup(object):
self.remove_scale_op_indices.append(i + 1)
if len(self.gradients) == 1:
# TODO Remove this is a temporary hack for Tensor Parallel. the logic
# for find grad_op should be more general.
if self.ops[grad_op_idx].type == "c_allreduce_sum":
grad_op_idx -= 1
grad_op = self.ops[grad_op_idx]
assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册