未验证 提交 3a0e0b6f 编写于 作者: H houj04 提交者: GitHub

update xpu1 op list, for train ResNet50 using PaddleClas. (#38201)

上级 9075a0fd
...@@ -88,6 +88,8 @@ XPUOpMap& get_kl1_ops() { ...@@ -88,6 +88,8 @@ XPUOpMap& get_kl1_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_add", {"elementwise_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_add_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div_grad", {"elementwise_div_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div", {"elementwise_div",
...@@ -112,6 +114,7 @@ XPUOpMap& get_kl1_ops() { ...@@ -112,6 +114,7 @@ XPUOpMap& get_kl1_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_sub", {"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace())})},
{"expand_as_v2", {"expand_as_v2",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
...@@ -123,6 +126,8 @@ XPUOpMap& get_kl1_ops() { ...@@ -123,6 +126,8 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"fill_any_like",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace())})},
{"fill_constant", {"fill_constant",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
...@@ -186,7 +191,7 @@ XPUOpMap& get_kl1_ops() { ...@@ -186,7 +191,7 @@ XPUOpMap& get_kl1_ops() {
{"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mean_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mean_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"nearest_interp_grad", {"nearest_interp_grad",
...@@ -212,6 +217,8 @@ XPUOpMap& get_kl1_ops() { ...@@ -212,6 +217,8 @@ XPUOpMap& get_kl1_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_mean_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum_grad", {"reduce_sum_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
...@@ -252,6 +259,8 @@ XPUOpMap& get_kl1_ops() { ...@@ -252,6 +259,8 @@ XPUOpMap& get_kl1_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_with_cross_entropy", {"softmax_with_cross_entropy",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_with_cross_entropy_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
...@@ -332,6 +341,7 @@ XPUOpMap& get_kl1_ops() { ...@@ -332,6 +341,7 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()), pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"where_index", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})},
// AddMore // AddMore
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册