未验证 提交 9e107c6b 编写于 作者: L Luyang 提交者: GitHub
上级 6f38134c
......@@ -180,13 +180,13 @@ Maybe<void> BroadcastMatmul::Apply(const MatmulInterpState* ctx, const TensorTup
JUST(attrs_b.SetAttr<double>("alpha", ctx->alpha));
in_grads->resize(2);
if (ctx->requires_grad_b) {
if (ctx->requires_grad_a) {
const auto& input_b = ctx->SavedTensors().at(ctx->b_index);
in_grads->at(0) =
JUST(OpInterpUtil::Dispatch<Tensor>(*grad_a_op_, {out_grads.at(0), input_b}, attrs_a));
}
if (ctx->requires_grad_a) {
if (ctx->requires_grad_b) {
const auto& input_a = ctx->SavedTensors().at(ctx->a_index);
if (!ctx->transpose_b) {
in_grads->at(1) =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册