提交 a3ca4c99 编写于 作者: C chengduoZH

fix loss.gradVar

上级 8c6dae77
......@@ -145,14 +145,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) {
CreateComputationalOps(&result, *op, places_.size());
// user can customize loss@grad if not use_default_grad_scale_
if (use_default_grad_scale_) {
CreateScaleLossGradOp(&result);
}
is_forwarding = false;
} else {
if (IsScaleLossGradOp(*op)) continue;
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
......@@ -401,12 +399,6 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
}
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this
return op.OutputArgumentNames().size() == 1 &&
(op.OutputArgumentNames()[0]) == loss_var_name_;
}
bool MultiDevSSAGraphBuilder::IsScaleLossGradOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this
return op.OutputArgumentNames().size() == 1 &&
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);
......
......@@ -67,8 +67,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const;
bool IsScaleLossGradOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
/**
......
......@@ -480,6 +480,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
program.current_block_idx = current_block_idx
program.sync_with_cpp()
# FIXME(zcd): prevent loss.grad optimized by mem_opt.
loss.block.var(_append_grad_suffix_(loss.name)).persistable = True
if parameter_list is not None:
parameters = parameter_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册