提交 50d89921 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!253 change adam output numbers adapter to tbe

Merge pull request !253 from zhaoting/add-YOLOv3-infer-scipt-and-change-dataset-to-MindRecord
......@@ -88,7 +88,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta1_power",
"need_compile": false,
......@@ -101,7 +102,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta2_power",
"need_compile": false,
......@@ -114,7 +116,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "lr",
"need_compile": false,
......@@ -127,7 +130,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta1",
"need_compile": false,
......@@ -140,7 +144,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta2",
"need_compile": false,
......@@ -153,7 +158,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "epsilon",
"need_compile": false,
......@@ -161,7 +167,7 @@ from mindspore.ops.op_info_register import op_info_register
"shape": "all"
},
{
"index": 8,
"index": 9,
"dtype": [
"float16","float16","float16","float16","float","float","float", "float"
],
......@@ -187,6 +193,32 @@ from mindspore.ops.op_info_register import op_info_register
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "m",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "v",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
......
......@@ -2149,7 +2149,7 @@ class Adam(PrimitiveWithInfer):
validator.check_param_equal("var_shape", var_shape, "m_shape", m_shape)
validator.check_param_equal("var_shape", var_shape, "v_shape", v_shape)
validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape)
return var_shape
return var_shape, m_shape, v_shape
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
......@@ -2159,7 +2159,7 @@ class Adam(PrimitiveWithInfer):
args = {"beta1_power_dtype": beta1_power_dtype, "beta2_power_dtype": beta2_power_dtype, 'lr_dtype': lr_dtype,
"beta1_dtype": beta1_dtype, "beta2_dtype": beta2_dtype, "epsilon_dtype": epsilon_dtype}
validator.check_type_same(args, [mstype.float16, mstype.float32])
return var_dtype
return var_dtype, m_dtype, v_dtype
class BinaryCrossEntropy(PrimitiveWithInfer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册