From efd7a2293af9540bbf6ba75e64b2017476f2da22 Mon Sep 17 00:00:00 2001 From: TTerror Date: Tue, 7 Dec 2021 13:21:03 +0800 Subject: [PATCH] add some op to xpu2 op list && format xpu op list (#37832) * format xpu op list * format xpu op list * update xpu1 op list --- .../fluid/platform/device/xpu/xpu1_op_list.h | 284 ++++++------ .../fluid/platform/device/xpu/xpu2_op_list.h | 403 ++++++++++-------- 2 files changed, 367 insertions(+), 320 deletions(-) diff --git a/paddle/fluid/platform/device/xpu/xpu1_op_list.h b/paddle/fluid/platform/device/xpu/xpu1_op_list.h index 1cc7bba132..d6b466ff92 100644 --- a/paddle/fluid/platform/device/xpu/xpu1_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu1_op_list.h @@ -29,40 +29,35 @@ using XPUOpMap = std::unordered_map; XPUOpMap& get_kl1_ops() { // KL1支持的op,通过op_name, data_type, place来索引 static XPUOpMap s_xpu1_kernels{ - {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"sigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"sigmoid_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"gelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"gelu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"square_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"hard_switch", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"hard_switch_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"leaky_relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"leaky_relu_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"pow", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"abs", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"affine_channel", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"accuracy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"affine_channel_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"affine_channel", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"assign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace())})}, - {"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"batch_norm_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"bilinear_interp", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"bilinear_interp_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"bilinear_interp_v2", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"bilinear_interp_v2_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, {"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, @@ -72,188 +67,197 @@ XPUOpMap& get_kl1_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, - {"c_reduce_sum", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"c_allreduce_sum", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP64, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace())})}, {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"logicalor", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::INT8, XPUPlace()), - pOpKernelType(vartype::INT16, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"logicaland", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::INT8, XPUPlace()), - pOpKernelType(vartype::INT16, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"logicalnot", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::INT8, XPUPlace()), - pOpKernelType(vartype::INT16, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"depthwise_conv2d", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"depthwise_conv2d_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"deformable_conv", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"deformable_conv_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"depthwise_conv2d", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"depthwise_conv2d_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"dropout_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_sub", + {"c_allreduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_sub_grad", + {"c_reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"elementwise_add", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_add_grad", + {"elementwise_div_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"elementwise_div", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_div_grad", + {"elementwise_floordiv", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_pow", + {"elementwise_max_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_floordiv", + {"elementwise_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_mul", + {"elementwise_min_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"elementwise_min", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"elementwise_mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_max", + {"elementwise_mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_max_grad", + {"elementwise_pow", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_min", + {"elementwise_sub_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_min_grad", + {"elementwise_sub", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"expand_as_v2", + XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"expand_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"fill_constant", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"gaussian_random", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"bilinear_interp", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"bilinear_interp_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"nearest_interp", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"nearest_interp_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"bilinear_interp_v2", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"bilinear_interp_v2_grad", + {"gelu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"gelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"hard_switch_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"nearest_interp_v2", + {"hard_switch", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"iou_similarity", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"nearest_interp_v2_grad", + {"lamb", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"layer_norm_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"layer_norm_grad", + {"leaky_relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"leaky_relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"load", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"log_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"logicaland", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::INT16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"logicalnot", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::INT16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"logicalor", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::INT16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"log_loss_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"lookup_table_v2", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"log_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"logsumexp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"lookup_table_v2_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"lookup_table_v2", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_v2_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mean_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"accuracy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"one_hot", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace())})}, + {"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"nearest_interp_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"nearest_interp_v2_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"nearest_interp_v2", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"nearest_interp", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace())})}, - {"sgd", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"lamb", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"pool2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"one_hot", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, {"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"pool2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"pow", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"range", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reduce_sum_grad", + {"reduce_max_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"logsumexp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reduce_max_grad", + {"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_sum_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reshape2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reshape2_grad", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reshape2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"sgd", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"shape", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"sigmoid_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"sigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, - {"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"softmax_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"softmax_with_cross_entropy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"squeeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::INT8, XPUPlace()), - pOpKernelType(vartype::UINT8, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"squeeze_grad", + {"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"square_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"squeeze2_grad", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), @@ -268,7 +272,7 @@ XPUOpMap& get_kl1_ops() { pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::UINT8, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"squeeze2_grad", + {"squeeze_grad", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), @@ -276,27 +280,29 @@ XPUOpMap& get_kl1_ops() { pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::UINT8, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"squeeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"top_k", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"transpose", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"transpose_grad", + {"transpose2_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"transpose2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"transpose2_grad", + {"transpose_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"transpose", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"truncated_gaussian_random", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"uniform_random", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"unsqueeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::INT8, XPUPlace()), - pOpKernelType(vartype::UINT8, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"unsqueeze_grad", + {"unsqueeze2_grad", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), @@ -311,7 +317,7 @@ XPUOpMap& get_kl1_ops() { pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::UINT8, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"unsqueeze2_grad", + {"unsqueeze_grad", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), @@ -319,21 +325,13 @@ XPUOpMap& get_kl1_ops() { pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::UINT8, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"iou_similarity", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"expand_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + {"unsqueeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"expand_as_v2", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, // AddMore }; diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 636b27e051..74f519c7a8 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -29,141 +29,109 @@ using XPUOpMap = std::unordered_map; XPUOpMap& get_kl2_ops() { // KL1支持的op,通过op_name, data_type, place来索引 static XPUOpMap s_xpu2_kernels{ - {"label_smooth", + {"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"assign_value", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"elementwise_sub", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"elementwise_sub_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"elementwise_add", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, + {"batch_norm_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, + {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"depthwise_conv2d_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"depthwise_conv2d", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"dropout_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"elementwise_add_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"elementwise_div", + {"elementwise_add", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"elementwise_div_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"elementwise_div_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"elementwise_pow", + {"elementwise_div", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"elementwise_div", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"elementwise_floordiv", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"elementwise_mul", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"elementwise_mul_grad", + {"elementwise_max_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"elementwise_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"elementwise_max_grad", + {"elementwise_min_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"elementwise_min", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"elementwise_min_grad", + {"elementwise_mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"batch_norm_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"layer_norm_grad", + {"elementwise_mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"mean_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reduce_sum_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"softmax_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"softmax_with_cross_entropy", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"softmax_with_cross_entropy_grad", + {"elementwise_pow", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"transpose", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"transpose_grad", + {"elementwise_sub_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"transpose2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"transpose2_grad", + {"elementwise_sub", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"iou_similarity", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reduce_mean_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace())})}, - {"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace())})}, {"equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"not_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"less_than", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"less_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"greater_than", - XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), + {"expand_as_v2", + XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"greater_equal", + {"expand_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"fill_any_like", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace())})}, - {"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace())})}, - {"fill_any_like", + {"fill_constant", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT16, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"flatten", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT8, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"flatten_grad", + pOpKernelType(vartype::BF16, XPUPlace()), + pOpKernelType(vartype::COMPLEX64, XPUPlace()), + pOpKernelType(vartype::COMPLEX128, XPUPlace())})}, + {"flatten2_grad", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()), @@ -172,124 +140,205 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"flatten2_grad", + {"flatten_contiguous_range_grad", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"matmul_v2_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"assign_value", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"dropout_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_div", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_div_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"range", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace())})}, - {"reshape2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reshape2_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace())})}, - {"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace())})}, - {"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"layer_norm_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"lookup_table_v2", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"lookup_table_v2_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"flatten_contiguous_range", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"flatten_contiguous_range_grad", + {"flatten_grad", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace())})}, - {"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + {"flatten", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"gather_nd", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"gaussian_random", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"gelu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"gelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"gelu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"fill_constant", + {"greater_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT16, XPUPlace()), - pOpKernelType(vartype::INT8, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::FP64, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace()), - pOpKernelType(vartype::BF16, XPUPlace()), - pOpKernelType(vartype::COMPLEX64, XPUPlace()), - pOpKernelType(vartype::COMPLEX128, XPUPlace())})}, - {"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, - {"softmax_grad", + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"greater_than", + XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"iou_similarity", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"label_smooth", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"layer_norm_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"layer_norm_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - {"gather_nd", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), + {"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"less_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"less_than", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"tile", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"where", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, - {"where_index", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, + {"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"lookup_table_v2_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"lookup_table_v2", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"masked_select", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"expand_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::INT64, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace()), + {"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"matmul_v2_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"matmul_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"matmul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"mean_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"not_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"expand_as_v2", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + {"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, + {"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"pool2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"prior_box", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"range", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, + {"reduce_max_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_mean_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_sum_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reshape2_grad", + XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"depthwise_conv2d", + {"reshape2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, + {"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, + {"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, + {"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, + {"softmax_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"depthwise_conv2d_grad", + {"softmax_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"softmax_with_cross_entropy_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"softmax_with_cross_entropy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"prior_box", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"squeeze2_grad", + XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"squeeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, + {"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"tile", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"transpose2_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"transpose2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"transpose_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"transpose", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"unsqueeze2_grad", + XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"unsqueeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"where_index", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"where", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, // AddMore }; -- GitLab