提交 8c6dae77 编写于 作者: C chengduoZH

fix pe bug

上级 8231960f
...@@ -145,12 +145,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -145,12 +145,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} else if (IsDistTrainOp(*op, send_op)) { } else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1); CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
CreateComputationalOps(&result, *op, places_.size());
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (use_default_grad_scale_) { if (use_default_grad_scale_) {
CreateScaleLossGradOp(&result); CreateScaleLossGradOp(&result);
} }
is_forwarding = false; is_forwarding = false;
} else { } else {
if (IsScaleLossGradOp(*op)) continue;
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
if (op_dev_id == -1) { // var on all device if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, *op, places_.size());
...@@ -399,6 +401,12 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, ...@@ -399,6 +401,12 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this
return op.OutputArgumentNames().size() == 1 &&
(op.OutputArgumentNames()[0]) == loss_var_name_;
}
bool MultiDevSSAGraphBuilder::IsScaleLossGradOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this // FIXME(yy): Do not hard code like this
return op.OutputArgumentNames().size() == 1 && return op.OutputArgumentNames().size() == 1 &&
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_); op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);
......
...@@ -67,6 +67,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -67,6 +67,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
bool IsScaleLossGradOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const; void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
/** /**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册