提交 0cf99823 编写于 作者: G guo ran 提交者: GitHub

Fix fuse scalar mul by tensor sbp (#3692)

* fix fuse_scalar_mul_by_tensor sbp

* refine
Co-authored-by: Noneflow-bot <69100618+oneflow-bot@users.noreply.github.com>
Former-commit-id: 1d32adf4
上级 642b9d57
......@@ -238,8 +238,9 @@ REGISTER_USER_OP("sgd_update")
const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0);
FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) {
ctx->NewBuilder()
.Split(ctx->inputs(), axis)
.Broadcast(user_op::OpArg("learning_rate", 0))
.Broadcast(ctx->inputs())
.Split(user_op::OpArg("model", 0), axis)
.Split(user_op::OpArg("model_diff", 0), axis)
.Build();
}
return Maybe<void>::Ok();
......@@ -299,8 +300,10 @@ REGISTER_USER_OP("momentum_update")
const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0);
FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) {
ctx->NewBuilder()
.Split(ctx->inputs(), axis)
.Broadcast(user_op::OpArg("learning_rate", 0))
.Broadcast(ctx->inputs())
.Split(user_op::OpArg("model", 0), axis)
.Split(user_op::OpArg("model_diff", 0), axis)
.Split(user_op::OpArg("momentum", 0), axis)
.Build();
}
return Maybe<void>::Ok();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册