提交 a3ca4c99 编写于 作者: C chengduoZH

fix loss.gradVar

上级 8c6dae77
...@@ -145,14 +145,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -145,14 +145,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} else if (IsDistTrainOp(*op, send_op)) { } else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1); CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
CreateComputationalOps(&result, *op, places_.size());
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (use_default_grad_scale_) { if (use_default_grad_scale_) {
CreateScaleLossGradOp(&result); CreateScaleLossGradOp(&result);
} }
is_forwarding = false; is_forwarding = false;
} else { } else {
if (IsScaleLossGradOp(*op)) continue;
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
if (op_dev_id == -1) { // var on all device if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, *op, places_.size());
...@@ -401,12 +399,6 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, ...@@ -401,12 +399,6 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this
return op.OutputArgumentNames().size() == 1 &&
(op.OutputArgumentNames()[0]) == loss_var_name_;
}
bool MultiDevSSAGraphBuilder::IsScaleLossGradOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this // FIXME(yy): Do not hard code like this
return op.OutputArgumentNames().size() == 1 && return op.OutputArgumentNames().size() == 1 &&
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_); op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);
......
...@@ -67,8 +67,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -67,8 +67,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
bool IsScaleLossGradOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const; void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
/** /**
......
...@@ -480,6 +480,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -480,6 +480,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
program.current_block_idx = current_block_idx program.current_block_idx = current_block_idx
program.sync_with_cpp() program.sync_with_cpp()
# FIXME(zcd): prevent loss.grad optimized by mem_opt.
loss.block.var(_append_grad_suffix_(loss.name)).persistable = True
if parameter_list is not None: if parameter_list is not None:
parameters = parameter_list parameters = parameter_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册