未验证 提交 efd7a229 编写于 作者: T TTerror 提交者: GitHub

add some op to xpu2 op list && format xpu op list (#37832)

* format xpu op list

* format xpu op list

* update xpu1 op list
上级 79c25979
......@@ -29,40 +29,35 @@ using XPUOpMap = std::unordered_map<std::string, XPUKernelSet>;
XPUOpMap& get_kl1_ops() {
// KL1支持的op,通过op_name, data_type, place来索引
static XPUOpMap s_xpu1_kernels{
{"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sigmoid_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"gelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"gelu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"square_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"hard_switch", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"hard_switch_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"leaky_relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"leaky_relu_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"pow", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"abs", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"affine_channel",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"accuracy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"affine_channel_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"affine_channel",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"assign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"bilinear_interp",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"bilinear_interp_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"bilinear_interp_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"bilinear_interp_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
......@@ -72,204 +67,220 @@ XPUOpMap& get_kl1_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"c_reduce_sum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"c_allreduce_sum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicalor", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicaland", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicalnot", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"deformable_conv",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"deformable_conv_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"dropout_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_sub",
{"c_allreduce_sum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_sub_grad",
{"c_reduce_sum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_add_grad",
{"elementwise_div_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div_grad",
{"elementwise_floordiv",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_pow",
{"elementwise_max_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_floordiv",
{"elementwise_max",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_mul",
{"elementwise_min_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_min",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_mul_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_max",
{"elementwise_mul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_max_grad",
{"elementwise_pow",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_min",
{"elementwise_sub_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_min_grad",
{"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"expand_as_v2",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"expand_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"fill_constant",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"gaussian_random",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"bilinear_interp",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"bilinear_interp_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"nearest_interp",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"nearest_interp_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"bilinear_interp_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"bilinear_interp_v2_grad",
{"gelu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"gelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"hard_switch_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"nearest_interp_v2",
{"hard_switch", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"iou_similarity",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"nearest_interp_v2_grad",
{"lamb", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm_grad",
{"leaky_relu_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"leaky_relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"load", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicaland", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicalnot", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicalor", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_loss_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lookup_table_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"logsumexp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lookup_table_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lookup_table_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mean_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"accuracy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"one_hot", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"nearest_interp_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"nearest_interp_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"nearest_interp_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"nearest_interp",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"sgd", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lamb", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"pool2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"one_hot", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"pool2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"pow", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"range", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum_grad",
{"reduce_max_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"logsumexp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_max_grad",
{"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reshape2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
{"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reshape2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"reshape2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
{"reshape2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sgd", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"shape", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"sigmoid_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_with_cross_entropy",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
{"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"square_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
{"squeeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
{"squeeze_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
{"squeeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
......@@ -278,25 +289,20 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"top_k", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"transpose", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"transpose_grad",
{"transpose2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"transpose2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"transpose2_grad",
{"transpose_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"transpose", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"truncated_gaussian_random",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"uniform_random",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze_grad",
{"unsqueeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
......@@ -311,7 +317,7 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze2_grad",
{"unsqueeze_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
......@@ -319,20 +325,12 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"iou_similarity",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"expand_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"expand_as_v2",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
{"unsqueeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
// AddMore
};
......
......@@ -29,267 +29,316 @@ using XPUOpMap = std::unordered_map<std::string, XPUKernelSet>;
XPUOpMap& get_kl2_ops() {
// KL1支持的op,通过op_name, data_type, place来索引
static XPUOpMap s_xpu2_kernels{
{"label_smooth",
{"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"assign_value",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_sub_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"batch_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"dropout_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_add_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_div",
{"elementwise_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_div_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_pow",
{"elementwise_div",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_floordiv",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_mul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_mul_grad",
{"elementwise_max_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_max",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_max_grad",
{"elementwise_min_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_min",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_min_grad",
{"elementwise_mul_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"layer_norm_grad",
{"elementwise_mul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"mean_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_with_cross_entropy",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_with_cross_entropy_grad",
{"elementwise_pow",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"transpose", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"transpose_grad",
{"elementwise_sub_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"transpose2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"transpose2_grad",
{"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"iou_similarity",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_mean_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"not_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
{"expand_as_v2",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"less_than", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
{"expand_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"less_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
{"fill_any_like",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"greater_than",
{"fill_constant",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"greater_equal",
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::BF16, XPUPlace()),
pOpKernelType(vartype::COMPLEX64, XPUPlace()),
pOpKernelType(vartype::COMPLEX128, XPUPlace())})},
{"flatten2_grad",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten2", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"fill_any_like",
{"flatten_contiguous_range_grad",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
{"flatten_contiguous_range",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten_grad",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten2", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
{"flatten", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten2_grad",
{"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gather_nd", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gaussian_random",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"gelu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"greater_equal",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_v2_grad",
{"greater_than",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"iou_similarity",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"assign_value",
{"label_smooth",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"dropout_grad",
{"layer_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div",
{"layer_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"less_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"less_than", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lookup_table_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div_grad",
{"lookup_table_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"range", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"reshape2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reshape2_grad",
{"masked_select",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mean_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"not_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm_grad",
{"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"pool2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"prior_box", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"range", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"reduce_max_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lookup_table_v2",
{"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_mean_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lookup_table_v2_grad",
{"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten_contiguous_range",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
{"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reshape2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten_contiguous_range_grad",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
{"reshape2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gelu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"softmax_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"softmax_with_cross_entropy_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"softmax_with_cross_entropy",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"fill_constant",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
{"squeeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::BF16, XPUPlace()),
pOpKernelType(vartype::COMPLEX64, XPUPlace()),
pOpKernelType(vartype::COMPLEX128, XPUPlace())})},
{"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"softmax_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gather_nd", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"tile", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"where", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
{"transpose2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"transpose2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"transpose_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"transpose", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"unsqueeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"where_index", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"masked_select",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
{"unsqueeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"expand_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
{"where_index", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"expand_as_v2",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
{"where", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"prior_box", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
// AddMore
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册