未验证 提交 7634a18a 编写于 作者: H houj04 提交者: GitHub

[XPU] fc use int_with_ll_t (#53183)

* [XPU] fc use int_with_ll_t

* fix test_unbind_op_xpu
上级 a3cd9cb9
......@@ -33,6 +33,8 @@ void MatMul(const Context& dev_ctx,
MatMulXPUFunction<T, int32_t>(a, b, out, trans_a, trans_b, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
MatMulXPUFunction<T, float>(a, b, out, trans_a, trans_b, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_INT_WITH_LL) {
MatMulXPUFunction<T, int_with_ll_t>(a, b, out, trans_a, trans_b, xpu_ctx);
} else {
MatMulXPUFunction<T, int16_t>(a, b, out, trans_a, trans_b, xpu_ctx);
}
......
......@@ -68,6 +68,8 @@ void BmmKernel(const Context& dev_ctx,
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_INT_WITH_LL) {
MatMulXPUFunction<T, int_with_ll_t>(x, y, out, trans_x, trans_y, xpu_ctx);
} else {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, xpu_ctx);
}
......
......@@ -30,6 +30,7 @@ enum XPUFCCalcType {
FC_INT16 = 0,
FC_INT32,
FC_FLOAT,
FC_INT_WITH_LL,
};
template <typename T>
......@@ -41,6 +42,8 @@ XPUFCCalcType FCCalcType() {
return XPUFCCalcType::FC_INT32;
} else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) {
return XPUFCCalcType::FC_FLOAT;
} else if (std::getenv("XPU_PADDLE_FC_INT_WITH_LL") != nullptr) {
return XPUFCCalcType::FC_INT_WITH_LL;
}
return XPUFCCalcType::FC_INT16;
}
......@@ -387,15 +390,17 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
using XPUType = typename XPUTypeTrait<T>::Type;
int fccal_type = FCCalcType<XPUType>();
decltype(&xpu_fc_wrapper<XPUType, int16_t>) fc_api_list[3] = {
decltype(&xpu_fc_wrapper<XPUType, int16_t>) fc_api_list[4] = {
&xpu_fc_wrapper<XPUType, int16_t>,
&xpu_fc_wrapper<XPUType, int32_t>,
&xpu_fc_wrapper<XPUType, float>,
&xpu_fc_wrapper<XPUType, int_with_ll_t>,
};
decltype(&xpu_fc_batch_wrapper<XPUType, int16_t>) fc_batch_api_list[3] = {
decltype(&xpu_fc_batch_wrapper<XPUType, int16_t>) fc_batch_api_list[4] = {
&xpu_fc_batch_wrapper<XPUType, int16_t>,
&xpu_fc_batch_wrapper<XPUType, int32_t>,
&xpu_fc_batch_wrapper<XPUType, float>,
&xpu_fc_batch_wrapper<XPUType, int_with_ll_t>,
};
auto fc_api = fc_api_list[fccal_type];
......
......@@ -41,7 +41,7 @@ class XPUTestUnbindOP(XPUOpTestWrapper):
x_1 = paddle.static.data(shape=[2, 3], dtype=self.dtype, name='x_1')
[out_0, out_1] = tensor.unbind(input=x_1, axis=0)
input_1 = np.random.random([2, 3]).astype(self.dtype)
axis = paddle.static.data(shape=[1], dtype='int32', name='axis')
axis = paddle.static.data(shape=[], dtype='int32', name='axis')
exe = fluid.Executor(place=self.place)
[res_1, res_2] = exe.run(
......@@ -80,7 +80,7 @@ class XPUTestUnbindOP(XPUOpTestWrapper):
x_1 = paddle.static.data(shape=[2, 3], dtype=self.dtype, name='x_1')
[out_0, out_1] = paddle.unbind(input=x_1, axis=0)
input_1 = np.random.random([2, 3]).astype(self.dtype)
axis = paddle.static.data(shape=[1], dtype='int32', name='axis')
axis = paddle.static.data(shape=[], dtype='int32', name='axis')
exe = fluid.Executor(place=self.place)
[res_1, res_2] = exe.run(
......@@ -196,6 +196,11 @@ class XPUTestUnbindOP(XPUOpTestWrapper):
self.assertRaises(TypeError, test_table_Variable)
def test_invalid_axis():
tensor.unbind(input=x, axis=2)
self.assertRaises(ValueError, test_invalid_axis)
support_types = get_xpu_op_support_types('unbind')
for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册