未验证 提交 d085f792 编写于 作者: W wangchaochaohu 提交者: GitHub

fix untime fail for output var stop_gradient=True for fusion group (#23317)

上级 b76f3b27
......@@ -99,6 +99,7 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
input_ids.push_back(-1);
}
}
// Output ids should be set in fixed order, like:
// - dx, dy in backward operations
std::vector<int> output_ids;
......@@ -106,10 +107,6 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) {
PADDLE_ENFORCE_EQ(
op->Output(name).size(), 1U,
platform::errors::InvalidArgument(
"Output(%s) of operation %s is not set.", name, op->Type()));
PADDLE_ENFORCE_NE(
var_ids.find(op->Output(name)[0]), var_ids.end(),
platform::errors::InvalidArgument(
......
......@@ -111,6 +111,15 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
}
}
}
auto op = n->Op();
std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) {
if (op->Output(name).size() != 1) return false;
}
return true;
}
return false;
......
......@@ -109,11 +109,9 @@ class FusionGroupPassTest2(FusionGroupPassTest):
tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3]))
tmp_3 = layers.mul(tmp_1, tmp_2)
# TODO(wangchaochaohu): support the case when some vars are set
# stop_gradient = True.
self.append_gradients(tmp_3)
self.num_fused_ops = 2
self.fetch_list = [tmp_3]
self.fetch_list = [tmp_3, self.grad(tmp_1)]
class FusionGroupPassTestFP64(FusionGroupPassTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册