提交 78b4479a 编写于 作者: F fary86

Fix bug of ApplyRMSProp's check

上级 1df8ea29
...@@ -1586,9 +1586,11 @@ class ApplyRMSProp(PrimitiveWithInfer): ...@@ -1586,9 +1586,11 @@ class ApplyRMSProp(PrimitiveWithInfer):
args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensor_type_same(args, mstype.number_type, self.name)
args = {"learning_rate": learning_rate_dtype, "decay": decay_dtype, valid_types = [mstype.float16, mstype.float32]
'momentum': momentum_dtype, "epsilon": epsilon_dtype} args_decay = {"decay": decay_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_type_same(args_decay, valid_types, self.name)
args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype}
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
return var_dtype return var_dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册