未验证 提交 770ce7cf 编写于 作者: T taixiurong 提交者: GitHub

xpu mul unittest *test=kunlun (#41140)

上级 1ed1a97b
......@@ -19,6 +19,8 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/xpu_api_wrapper.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
......@@ -28,6 +30,8 @@ using framework::Tensor;
template <typename DeviceContext, typename T>
class MulXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
......@@ -62,14 +66,15 @@ class MulXPUKernel : public framework::OpKernel<T> {
const T* data_b = y_matrix.data<T>();
T* data_c = z->data<T>();
auto& dev_ctx = context.template device_context<DeviceContext>();
int ret = xpu::fc_int16(dev_ctx.x_context(), trans_a, trans_b, m, n, k,
alpha, data_a, data_b, beta, data_c);
PADDLE_ENFORCE_EQ(
ret, XPU_SUCCESS,
platform::errors::External(
"XPU API return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
ret));
int ret = xpu_fc_wrapper<XPUType, int16_t>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(data_a),
reinterpret_cast<const XPUType*>(data_b),
reinterpret_cast<XPUType*>(data_c), m, n, k, trans_a, trans_b, nullptr,
nullptr, nullptr, k, n, n, alpha, beta, nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
if (z_dim.size() != 2) {
z->Resize(z_dim);
}
......@@ -78,6 +83,8 @@ class MulXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T>
class MulGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
......@@ -126,14 +133,14 @@ class MulGradXPUKernel : public framework::OpKernel<T> {
const T* data_a = dout->data<T>();
const T* data_b = y_matrix.data<T>();
T* data_c = dx_matrix.data<T>();
int ret =
xpu::gemm_int16(dev_ctx.x_context(), trans_a, trans_b, m, n, k, alpha,
data_a, lda, data_b, ldb, beta, data_c, ldc);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU API return wrong value[%d], please check "
"where Baidu Kunlun Card is properly installed.",
ret));
int ret = xpu_fc_wrapper<XPUType, int16_t>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(data_a),
reinterpret_cast<const XPUType*>(data_b),
reinterpret_cast<XPUType*>(data_c), m, n, k, trans_a, trans_b,
nullptr, nullptr, nullptr, lda, ldb, ldc, alpha, beta, nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
}
if (dy) {
......@@ -159,14 +166,14 @@ class MulGradXPUKernel : public framework::OpKernel<T> {
const T* data_a = x_matrix.data<T>();
const T* data_b = dout->data<T>();
T* data_c = dy_matrix.data<T>();
int ret =
xpu::gemm_int16(dev_ctx.x_context(), trans_a, trans_b, m, n, k, alpha,
data_a, lda, data_b, ldb, beta, data_c, ldc);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU API return wrong value[%d], please check "
"where Baidu Kunlun Card is properly installed.",
ret));
int ret = xpu_fc_wrapper<XPUType, int16_t>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(data_a),
reinterpret_cast<const XPUType*>(data_b),
reinterpret_cast<XPUType*>(data_c), m, n, k, trans_a, trans_b,
nullptr, nullptr, nullptr, lda, ldb, ldc, alpha, beta, nullptr,
xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
}
}
};
......@@ -175,9 +182,12 @@ class MulGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
mul, ops::MulXPUKernel<paddle::platform::XPUDeviceContext, float>);
mul, ops::MulXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::MulXPUKernel<paddle::platform::XPUDeviceContext, plat::float16>);
REGISTER_OP_XPU_KERNEL(
mul_grad, ops::MulGradXPUKernel<paddle::platform::XPUDeviceContext, float>)
mul_grad, ops::MulGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::MulGradXPUKernel<paddle::platform::XPUDeviceContext, plat::float16>)
#endif
......@@ -70,8 +70,10 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"dropout_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_add_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
......@@ -249,6 +251,8 @@ XPUOpMap& get_kl2_ops() {
{"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"nearest_interp_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"nearest_interp_v2_grad",
......
......@@ -27,104 +27,120 @@ import time
paddle.enable_static()
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestMulOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of mul_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
np.array([[-1]]), [[1]], fluid.XPUPlace(0))
x2 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
np.array([[-1]]), [[1]], fluid.XPUPlace(0))
self.assertRaises(TypeError, fluid.layers.mul, x1, x2)
# The input dtype of mul_op must be float32 or float64.
# The input dtype of mul_op must be float32.
x3 = fluid.layers.data(name='x3', shape=[4], dtype="int32")
x4 = fluid.layers.data(name='x4', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.mul, x3, x4)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUMulOp1(XPUOpTest):
def setUp(self):
self.op_type = "mul"
self.dtype = np.float32
self.use_xpu = True
self.init_dtype_type()
self.inputs = {
'X': np.random.random((3, 4, 2, 9)).astype(self.dtype),
'Y': np.random.random((3, 6, 1, 2, 3)).astype(self.dtype)
}
self.attrs = {
'x_num_col_dims': 2,
'y_num_col_dims': 2,
}
result = np.dot(self.inputs['X'].reshape(3 * 4, 2 * 9),
self.inputs['Y'].reshape(3 * 6, 1 * 2 * 3))
result = result.reshape(3, 4, 1, 2, 3)
self.outputs = {'Out': result}
def init_dtype_type(self):
pass
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=0.01)
def test_check_grad_normal(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', max_relative_error=0.1)
def test_check_grad_ingore_x(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X"))
def test_check_grad_ignore_y(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUMulOp2(XPUOpTest):
def setUp(self):
self.op_type = "mul"
self.use_xpu = True
self.dtype = np.float32
self.init_dtype_type()
self.inputs = {
'X': np.random.random((20, 5)).astype(self.dtype),
'Y': np.random.random((5, 21)).astype(self.dtype)
}
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
def init_dtype_type(self):
self.dtype = np.float32
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=0.01)
def test_check_grad_normal(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', max_relative_error=0.1)
def test_check_grad_ingore_x(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
class XPUTestMulOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'mul'
self.use_dynamic_create_class = False
class TestXPUMulOp1(XPUOpTest):
def setUp(self):
self.op_type = "mul"
self.dtype = self.in_type
self.inputs = {
'X': np.random.random((3, 4, 2, 9)).astype(self.in_type_str),
'Y': np.random.random((3, 6, 1, 2, 3)).astype(self.in_type_str)
}
self.attrs = {
'x_num_col_dims': 2,
'y_num_col_dims': 2,
}
result = np.dot(self.inputs['X'].reshape(3 * 4, 2 * 9),
self.inputs['Y'].reshape(3 * 6, 1 * 2 * 3))
result = result.reshape(3, 4, 1, 2, 3)
self.outputs = {'Out': result}
def test_check_output(self):
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=0.01)
def test_check_grad_normal(self):
place = paddle.XPUPlace(0)
paddle.enable_static()
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', max_relative_error=0.1)
def test_check_grad_ingore_x(self):
place = paddle.XPUPlace(0)
paddle.enable_static()
self.check_grad_with_place(
place, ['Y'],
'Out',
max_relative_error=0.1,
no_grad_set=set("X"))
def test_check_grad_ignore_y(self):
place = paddle.XPUPlace(0)
paddle.enable_static()
self.check_grad_with_place(
place, ['X'],
'Out',
max_relative_error=0.1,
no_grad_set=set('Y'))
class TestXPUMulOp2(XPUOpTest):
def setUp(self):
self.op_type = "mul"
self.use_xpu = True
self.dtype = self.in_type
self.inputs = {
'X': np.random.random((20, 5)).astype(self.in_type_str),
'Y': np.random.random((5, 21)).astype(self.in_type_str)
}
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
place = paddle.XPUPlace(0)
paddle.enable_static()
self.check_output_with_place(place, atol=0.01)
def test_check_grad_normal(self):
place = paddle.XPUPlace(0)
paddle.enable_static()
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', max_relative_error=0.1)
def test_check_grad_ingore_x(self):
place = paddle.XPUPlace(0)
paddle.enable_static()
self.check_grad_with_place(
place, ['Y'],
'Out',
max_relative_error=0.1,
no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
place = paddle.XPUPlace(0)
paddle.enable_static()
self.check_grad_with_place(
place, ['X'],
'Out',
max_relative_error=0.1,
no_grad_set=set('Y'))
support_types = get_xpu_op_support_types('mul')
for stype in support_types:
create_test_class(globals(), XPUTestMulOp, stype)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册