diff --git a/oneflow/user/ops/model_update_ops.cpp b/oneflow/user/ops/model_update_ops.cpp index ea9efe1acd53c0902507308497a19bf73c2ea3dd..5f040fa4b07277739e3c8560db1876bb0ec79b9e 100644 --- a/oneflow/user/ops/model_update_ops.cpp +++ b/oneflow/user/ops/model_update_ops.cpp @@ -238,8 +238,9 @@ REGISTER_USER_OP("sgd_update") const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { ctx->NewBuilder() - .Split(ctx->inputs(), axis) - .Broadcast(user_op::OpArg("learning_rate", 0)) + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) .Build(); } return Maybe::Ok(); @@ -299,8 +300,10 @@ REGISTER_USER_OP("momentum_update") const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { ctx->NewBuilder() - .Split(ctx->inputs(), axis) - .Broadcast(user_op::OpArg("learning_rate", 0)) + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("momentum", 0), axis) .Build(); } return Maybe::Ok();