未验证 提交 93b7ccf5 编写于 作者: Q QingshuChen 提交者: GitHub

update kl1 op list and optimize matmul unitest for kunlun (#48775)

*test=kunlun
上级 87fbc5e4
......@@ -21,10 +21,284 @@ namespace xpu {
XPUOpMap& get_kl1_ops() {
// KL1支持的op,通过op_name, data_type
static XPUOpMap s_xpu1_kernels{
{"abs", XPUKernelSet({phi::DataType::FLOAT32})},
{"accuracy", XPUKernelSet({phi::DataType::FLOAT32})},
{"adam", XPUKernelSet({phi::DataType::FLOAT32})},
{"adamw", XPUKernelSet({phi::DataType::FLOAT32})},
{"affine_channel_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"affine_channel", XPUKernelSet({phi::DataType::FLOAT32})},
{"arg_max", XPUKernelSet({phi::DataType::FLOAT32})},
{"assign",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT64,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL})},
{"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"batch_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"bilinear_interp", XPUKernelSet({phi::DataType::FLOAT32})},
{"bilinear_interp_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"bilinear_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"broadcast",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT64,
phi::DataType::INT32,
phi::DataType::INT64})},
{"cast",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::INT32})},
{"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"coalesce_tensor",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT64,
phi::DataType::INT32})},
{"concat", XPUKernelSet({phi::DataType::FLOAT32})},
{"concat_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"conv2d", XPUKernelSet({phi::DataType::FLOAT32})},
{"conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"deformable_conv", XPUKernelSet({phi::DataType::FLOAT32})},
{"deformable_conv_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d", XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"dropout", XPUKernelSet({phi::DataType::FLOAT32})},
{"dropout_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_allreduce_sum", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_reduce_sum", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_add", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_add_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_div_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_div", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_floordiv", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_max_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_max", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_min_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_min", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_mul_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_mul", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_pow", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_sub_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})},
{"equal", XPUKernelSet({phi::DataType::INT64})},
{"expand_as_v2",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"expand_v2",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"fill_any_like", XPUKernelSet({phi::DataType::INT64})},
{"fill_constant",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT64,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"gather_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"gather", XPUKernelSet({phi::DataType::FLOAT32})},
{"gaussian_random", XPUKernelSet({phi::DataType::FLOAT32})},
{"gelu_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"gelu", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_switch_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_switch", XPUKernelSet({phi::DataType::FLOAT32})},
{"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})},
{"lamb", XPUKernelSet({phi::DataType::FLOAT32})},
{"layer_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"layer_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"leaky_relu_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"leaky_relu", XPUKernelSet({phi::DataType::FLOAT32})},
{"load",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT8,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32})},
{"logicaland",
XPUKernelSet({phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::INT16,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32})},
{"logicalnot",
XPUKernelSet({phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::INT16,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32})},
{"logicalor",
XPUKernelSet({phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::INT16,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32})},
{"log_loss_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"log_loss", XPUKernelSet({phi::DataType::FLOAT32})},
{"logsumexp", XPUKernelSet({phi::DataType::FLOAT32})},
{"log", XPUKernelSet({phi::DataType::FLOAT32})},
{"lookup_table_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"lookup_table_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"matmul_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"matmul_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"matmul_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"matmul", XPUKernelSet({phi::DataType::FLOAT32})},
{"mean_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"mean", XPUKernelSet({phi::DataType::FLOAT32})},
{"momentum", XPUKernelSet({phi::DataType::FLOAT32})},
{"mul_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"mul", XPUKernelSet({phi::DataType::FLOAT32})},
{"nearest_interp_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"nearest_interp", XPUKernelSet({phi::DataType::FLOAT32})},
{"one_hot_v2",
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"one_hot", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"pool2d_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"pool2d", XPUKernelSet({phi::DataType::FLOAT32})},
{"pow", XPUKernelSet({phi::DataType::FLOAT32})},
{"range",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT32})},
{"reduce_max_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_max", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu", XPUKernelSet({phi::DataType::FLOAT32})},
{"reshape2_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"reshape2",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"rmsprop", XPUKernelSet({phi::DataType::FLOAT32})},
{"rnn_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"rnn", XPUKernelSet({phi::DataType::FLOAT32})},
{"roi_align_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"roi_align", XPUKernelSet({phi::DataType::FLOAT32})},
{"scale", XPUKernelSet({phi::DataType::FLOAT32})},
{"sgd", XPUKernelSet({phi::DataType::FLOAT32})},
{"shape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
{"sign", XPUKernelSet({phi::DataType::FLOAT32})},
{"slice_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"softmax_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"softmax_with_cross_entropy", XPUKernelSet({phi::DataType::FLOAT32})},
{"softmax_with_cross_entropy_grad",
XPUKernelSet({phi::DataType::FLOAT32})},
{"softmax", XPUKernelSet({phi::DataType::FLOAT32})},
{"split", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"sqrt", XPUKernelSet({phi::DataType::FLOAT32})},
{"square_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"square", XPUKernelSet({phi::DataType::FLOAT32})},
{"squeeze2_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"squeeze2",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"squeeze_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"squeeze",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"stack", XPUKernelSet({phi::DataType::FLOAT32})},
{"stack_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"sum", XPUKernelSet({phi::DataType::FLOAT32})},
{"tanh_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"tanh", XPUKernelSet({phi::DataType::FLOAT32})},
{"top_k", XPUKernelSet({phi::DataType::FLOAT32})},
{"transpose2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"transpose2", XPUKernelSet({phi::DataType::FLOAT32})},
{"transpose_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"transpose", XPUKernelSet({phi::DataType::FLOAT32})},
{"truncated_gaussian_random", XPUKernelSet({phi::DataType::FLOAT32})},
{"uniform_random", XPUKernelSet({phi::DataType::FLOAT32})},
{"unsqueeze2_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"unsqueeze2",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"unsqueeze_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"unsqueeze",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"where_index", XPUKernelSet({phi::DataType::BOOL})},
// AddMore
};
PD_THROW("get_kl1_ops unsupported");
return s_xpu1_kernels;
}
......
......@@ -2,10 +2,11 @@
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or
agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
or implied. See the License for the specific language governing permissions and
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
......@@ -94,7 +95,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"clip", XPUKernelSet({phi::DataType::FLOAT32})},
{"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"coalesce_tensor", XPUKernelSet({phi::DataType::FLOAT32})},
{"coalesce_tensor",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"concat_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"concat",
......@@ -525,6 +527,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"tile_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"transpose2_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"transpose2",
......@@ -557,15 +560,6 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::UINT8,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"unsqueeze_with_xshape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"unsqueeze_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
......
......@@ -190,8 +190,8 @@ class XPUTestMatmulV2Op(XPUOpTestWrapper):
"""
def config(self):
self.x_shape = (100, 20, 100)
self.y_shape = (100, 100, 100)
self.x_shape = (5, 20, 7)
self.y_shape = (5, 7, 7)
self.trans_x = False
self.trans_y = True
......@@ -201,8 +201,8 @@ class XPUTestMatmulV2Op(XPUOpTestWrapper):
"""
def config(self):
self.x_shape = (100, 20, 100)
self.y_shape = (100, 20, 100)
self.x_shape = (3, 20, 8)
self.y_shape = (3, 20, 8)
self.trans_x = True
self.trans_y = False
......@@ -212,8 +212,8 @@ class XPUTestMatmulV2Op(XPUOpTestWrapper):
"""
def config(self):
self.x_shape = (2, 20, 100)
self.y_shape = (100, 30)
self.x_shape = (2, 20, 11)
self.y_shape = (11, 30)
self.trans_x = False
self.trans_y = False
......@@ -245,8 +245,8 @@ class XPUTestMatmulV2Op(XPUOpTestWrapper):
"""
def config(self):
self.x_shape = (100, 2, 100, 10)
self.y_shape = (100, 2, 10, 90)
self.x_shape = (7, 2, 100, 10)
self.y_shape = (7, 2, 10, 90)
self.trans_x = False
self.trans_y = False
......@@ -256,22 +256,11 @@ class XPUTestMatmulV2Op(XPUOpTestWrapper):
"""
def config(self):
self.x_shape = (100, 2, 100, 10)
self.y_shape = (100, 2, 100, 10)
self.x_shape = (3, 2, 4, 10)
self.y_shape = (3, 2, 4, 10)
self.trans_x = False
self.trans_y = True
class TestMatMulOp16(TestMatMulV2Op):
"""
case 16 : to check the big data
"""
def config(self):
self.x_shape = (1000, 2, 100, 100)
self.y_shape = (1000, 2, 100, 900)
self.trans_x = False
self.trans_y = False
class TestMatMulOp17(TestMatMulV2Op):
"""
case 17 : to check the gradient for special case
......@@ -289,7 +278,7 @@ class XPUTestMatmulV2Op(XPUOpTestWrapper):
"""
def config(self):
self.x_shape = (8, 111, 4, 17)
self.x_shape = (8, 11, 4, 17)
self.y_shape = 17
self.trans_x = False
self.trans_y = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册