From 52b05baca349d1bbfcbb6ed78b289d6c66dbec3e Mon Sep 17 00:00:00 2001 From: taixiurong Date: Wed, 31 Mar 2021 10:57:46 +0800 Subject: [PATCH] fix some bug in transformer training in xpu (#31918) --- cmake/external/xpu.cmake | 2 +- paddle/fluid/memory/memcpy.cc | 6 +- paddle/fluid/operators/cast_op_xpu.cc | 40 +++- paddle/fluid/operators/matmul_op_xpu.cc | 77 +++++-- paddle/fluid/operators/matmul_v2_op_xpu.cc | 62 ++++-- .../fluid/operators/optimizers/adam_op_xpu.cc | 22 +- paddle/fluid/operators/reshape_op.cc | 28 +-- .../softmax_with_cross_entropy_op_xpu.cc | 18 +- .../fluid/tests/unittests/test_matmul_op.py | 36 +++ .../tests/unittests/xpu/test_cast_op_xpu.py | 8 +- .../tests/unittests/xpu/test_matmul_op_xpu.py | 58 +++-- .../unittests/xpu/test_matmul_v2_op_xpu.py | 205 +++++++++--------- 12 files changed, 354 insertions(+), 208 deletions(-) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index b5a3f015474..16c69a7b503 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -13,7 +13,7 @@ if(NOT XPU_SDK_ROOT) elseif(WITH_SUNWAY) SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE) else() - SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_02_27.tar.gz" CACHE STRING "" FORCE) + SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_03_30.tar.gz" CACHE STRING "" FORCE) endif() SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index 7f871fab5a1..6f252e1bd0d 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -40,7 +40,7 @@ void Copy(platform::XPUPlace dst_place, platform::CPUPlace src_place, const void* src, size_t num) { if (num <= 0) { - VLOG(0) << "memcpy XPU_HOST_TO_DEVICE size <= 0 (" << num << ")"; + VLOG(1) << "memcpy XPU_HOST_TO_DEVICE size <= 0 (" << num << ")"; return; } int dev_id = -1; @@ -86,7 +86,7 @@ void Copy(platform::CPUPlace dst_place, platform::XPUPlace src_place, const void* src, size_t num) { if (num <= 0) { - VLOG(0) << "memcpy XPU_DEVICE_TO_HOST size <= 0 (" << num << ")"; + VLOG(1) << "memcpy XPU_DEVICE_TO_HOST size <= 0 (" << num << ")"; return; } int dev_id = -1; @@ -132,7 +132,7 @@ void Copy(platform::XPUPlace dst_place, platform::XPUPlace src_place, const void* src, size_t num) { if (num <= 0) { - VLOG(0) << "memcpy XPU_DEVICE_TO_DEVICE size <= 0 (" << num << ")"; + VLOG(1) << "memcpy XPU_DEVICE_TO_DEVICE size <= 0 (" << num << ")"; return; } int dev_id = -1; diff --git a/paddle/fluid/operators/cast_op_xpu.cc b/paddle/fluid/operators/cast_op_xpu.cc index bbd43274a00..ca15858cf67 100644 --- a/paddle/fluid/operators/cast_op_xpu.cc +++ b/paddle/fluid/operators/cast_op_xpu.cc @@ -23,8 +23,22 @@ limitations under the License. */ namespace paddle { namespace operators { +template +class XPUFPTypeTrait { + public: + using Type = T; +}; + +template <> +class XPUFPTypeTrait { + public: + using Type = float16; +}; + template class CastXPUKernel : public framework::OpKernel { + using XPUInTDType = typename XPUFPTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); @@ -34,27 +48,39 @@ class CastXPUKernel : public framework::OpKernel { auto out_type = static_cast( context.Attr("out_dtype")); auto* in_data = in->data(); + + // using XPUOutTDType = typename XPUFPTypeTrait::Type; auto numel = in->numel(); auto& dev_ctx = context.template device_context(); int r = -1; if (out_type == framework::proto::VarType::FP32) { auto* out_data = out->mutable_data(context.GetPlace()); - r = xpu::cast_v2(dev_ctx.x_context(), in_data, out_data, - numel); + r = xpu::cast_v2( + dev_ctx.x_context(), reinterpret_cast(in_data), + out_data, numel); } else if (out_type == framework::proto::VarType::INT32) { auto* out_data = out->mutable_data(context.GetPlace()); - r = xpu::cast_v2(dev_ctx.x_context(), in_data, out_data, - numel); + r = xpu::cast_v2( + dev_ctx.x_context(), reinterpret_cast(in_data), + out_data, numel); } else if (out_type == framework::proto::VarType::INT64) { auto* out_data = out->mutable_data(context.GetPlace()); - r = xpu::cast_v2(dev_ctx.x_context(), in_data, out_data, - numel); + r = xpu::cast_v2( + dev_ctx.x_context(), reinterpret_cast(in_data), + out_data, numel); } else if ((out_type == framework::proto::VarType::BOOL) && (in_type == framework::proto::VarType::FP32)) { auto* out_data = out->mutable_data(context.GetPlace()); r = xpu::cast_v2( dev_ctx.x_context(), (const float*)in_data, reinterpret_cast(out_data), numel); + } else if (out_type == framework::proto::VarType::FP16) { + auto* out_data = + out->mutable_data(context.GetPlace()); + r = xpu::cast_v2( + dev_ctx.x_context(), reinterpret_cast(in_data), + reinterpret_cast(out_data), numel); + } else { PADDLE_THROW(platform::errors::Unavailable("Not supported cast %d -> %d", in_type, out_type)); @@ -75,5 +101,7 @@ namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( cast, ops::CastXPUKernel, ops::CastXPUKernel, + ops::CastXPUKernel, ops::CastXPUKernel); #endif diff --git a/paddle/fluid/operators/matmul_op_xpu.cc b/paddle/fluid/operators/matmul_op_xpu.cc index f92cff2f6cd..6fa96aca4be 100644 --- a/paddle/fluid/operators/matmul_op_xpu.cc +++ b/paddle/fluid/operators/matmul_op_xpu.cc @@ -23,7 +23,6 @@ limitations under the License. */ namespace paddle { namespace operators { - using framework::Tensor; static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) { @@ -123,34 +122,47 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out, mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_; } } - PADDLE_ENFORCE_EQ( - mat_dim_a.width_, mat_dim_b.height_, - platform::errors::InvalidArgument("Shape mistake in matmul_op, the " - "first tensor width must be same as " - "second tensor height, but received " - "width:%d, height:%d", - mat_dim_a.width_, mat_dim_b.height_)); + + if (mat_dim_a.width_ == mat_dim_b.height_) { + if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) { + mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; + } + if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) { + mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0; + } + } + + PADDLE_ENFORCE_EQ(mat_dim_a.width_, mat_dim_b.height_, + platform::errors::InvalidArgument( + "Shape mistake in matmul_op, the " + "first tensor width must be same as " + "second tensor height, but received " + "width:%d, height:%d x_dims = %s , y_dims = %s", + mat_dim_a.width_, mat_dim_b.height_, + x_dims.to_str().c_str(), y_dims.to_str().c_str())); PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, mat_dim_b.batch_size_, platform::errors::InvalidArgument( "Shape mistake in matmul_op, the two input" "tensor batch_size must be same, but received first " "tensor batch_size:%d, second " - "tensor batch_size:%d", - mat_dim_a.batch_size_, mat_dim_b.batch_size_)); + "tensor batch_size:%d, x_dims = %s , y_dims = %s", + mat_dim_a.batch_size_, mat_dim_b.batch_size_, + x_dims.to_str().c_str(), y_dims.to_str().c_str())); - T alpha = static_cast(ctx.Attr("alpha")); + float alpha = static_cast(ctx.Attr("alpha")); - float *data_c = out->data(); + T *data_c = out->data(); int m = mat_dim_a.height_; int n = mat_dim_b.width_; int k = mat_dim_a.width_; + int batch_size = mat_dim_a.batch_size_; + int ldx = mat_dim_a.trans_ ? m : k; int ldy = mat_dim_b.trans_ ? k : n; int ldout = n; - int batch_size = mat_dim_a.batch_size_; - - if (batch_size == 0) { - int r = xpu::fc_fusion( + if (batch_size <= 1) { + int r = 0; + r = xpu::fc_fusion( dev_ctx.x_context(), x->data(), y->data(), data_c, m, n, k, mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR); @@ -159,14 +171,32 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out, "XPU fc_fusion kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); } else { - int r = xpu::fc_batched( - dev_ctx.x_context(), batch_size, mat_dim_a.trans_, mat_dim_b.trans_, m, - n, k, alpha, x->data(), mat_dim_a.stride_, y->data(), - mat_dim_b.stride_, 0.0, data_c, m * n, nullptr, nullptr); + // batch matmul + int r = xpu::fc_batched( + dev_ctx.x_context(), // Context* ctx, + batch_size, // int batch_size, + mat_dim_a.trans_, // bool x_trans, + mat_dim_b.trans_, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + alpha, // float alpha, + reinterpret_cast(x->data()), // const TX* x, + mat_dim_a.stride_, // int stride_a, + reinterpret_cast(y->data()), // const TW* w, + mat_dim_b.stride_, // int stride_b, + 0.0, // float beta, + reinterpret_cast(data_c), // TY* y, + m * n, // int stride_c, + nullptr, // const float* x_maxptr, + nullptr); // const float* w_maxptr + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( - "XPU fc_batched kernel return wrong value[%d %s]", r, - XPUAPIErrorMsg[r])); + "XPU fc_batched kernel return wrong value[%d %s] " + "x_dims = %s , y_dims = %s", + r, XPUAPIErrorMsg[r], x_dims.to_str().c_str(), + y_dims.to_str().c_str())); } } @@ -206,9 +236,8 @@ static framework::Tensor XPUFoldHeadAndLastDims( static_cast(in_dims[1]), static_cast(in_dims[2])}; std::vector axis_host = {1, 0, 2}; - int r = xpu::transpose(context.x_context(), input.data(), output.data(), - in_shape_host.data(), axis_host.data(), /*ndims=*/3); + in_shape_host, axis_host); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( "XPU transpose kernel return wrong value[%d %s]", r, diff --git a/paddle/fluid/operators/matmul_v2_op_xpu.cc b/paddle/fluid/operators/matmul_v2_op_xpu.cc index dbb1d7bfb0a..d992ef847db 100644 --- a/paddle/fluid/operators/matmul_v2_op_xpu.cc +++ b/paddle/fluid/operators/matmul_v2_op_xpu.cc @@ -57,32 +57,55 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, PADDLE_ENFORCE_EQ(mat_dim_a.width_, mat_dim_b.height_, platform::errors::InvalidArgument( - "Shape mistake in matmul_v2_op xdims = %s ydims = %s", - x_dims.to_str(), y_dims.to_str())); + "Shape mistake in matmul_v2_op xdims = %s ydims = %s " + "x_trans = %d y_trans = %d", + x_dims.to_str(), y_dims.to_str(), mat_dim_a.trans_, + mat_dim_b.trans_)); PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, mat_dim_b.batch_size_, platform::errors::InvalidArgument( - "Shape mistake in matmul_v2_op xdims = %s ydims = %s", - x_dims.to_str(), y_dims.to_str())); + "Shape mistake in matmul_v2_op xdims = %s ydims = %s " + "x_trans = %d y_trans = %d", + x_dims.to_str(), y_dims.to_str(), mat_dim_a.trans_, + mat_dim_b.trans_)); - float* data_c = out->data(); + T* data_c = out->data(); int m = mat_dim_a.height_; int n = mat_dim_b.width_; int k = mat_dim_a.width_; int batch_size = mat_dim_a.batch_size_; - - if (batch_size == 0) { - int r = xpu::fc( - dev_ctx.x_context(), x->data(), y->data(), data_c, m, n, k, - mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr); - PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, - platform::errors::External( - "XPU fc_fusion kernel return wrong value[%d %s]", r, - XPUAPIErrorMsg[r])); + if (batch_size <= 1) { + int r = 0; + r = xpu::fc(dev_ctx.x_context(), x->data(), y->data(), + data_c, m, n, k, mat_dim_a.trans_, + mat_dim_b.trans_, nullptr, nullptr, nullptr); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU fc_fusion kernel return wrong value[%d %s] , m = %d, n = " + "%d, " + "k = %d, a_tr = %d, b_tr = %d", + r, XPUAPIErrorMsg[r], m, n, k, mat_dim_a.trans_, mat_dim_b.trans_)); } else { - int r = xpu::fc_batched( - dev_ctx.x_context(), batch_size, mat_dim_a.trans_, mat_dim_b.trans_, m, - n, k, 1.0, x->data(), mat_dim_a.stride_, y->data(), - mat_dim_b.stride_, 0.0, data_c, m * n, nullptr, nullptr); + // batch matmul + int r = xpu::fc_batched( + dev_ctx.x_context(), // Context* ctx, + batch_size, // int batch_size, + mat_dim_a.trans_, // bool x_trans, + mat_dim_b.trans_, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + 1.0, // float alpha, + reinterpret_cast(x->data()), // const TX* x, + mat_dim_a.stride_, // int stride_a, + reinterpret_cast(y->data()), // const TW* w, + mat_dim_b.stride_, // int stride_b, + 0.0, // float beta, + reinterpret_cast(data_c), // TY* y, + m * n, // int stride_c, + nullptr, // const float* x_maxptr, + nullptr); // const float* w_maxptr + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( "XPU fc_batched kernel return wrong value[%d %s]", r, @@ -125,7 +148,7 @@ static framework::Tensor XPUFoldHeadAndLastDims( std::vector axis_host = {1, 0, 2}; int r = xpu::transpose(context.x_context(), input.data(), output.data(), - in_shape_host.data(), axis_host.data(), /*ndims=*/3); + in_shape_host, axis_host); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( "XPU transpose kernel return wrong value[%d %s]", r, @@ -189,6 +212,7 @@ class MatMulV2XPUGradKernel : public framework::OpKernel { auto* dx = context.Output(framework::GradVarName("X")); auto* dy = context.Output(framework::GradVarName("Y")); ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + framework::DDim dx_dims; if (dx) { dx_dims = dx->dims(); diff --git a/paddle/fluid/operators/optimizers/adam_op_xpu.cc b/paddle/fluid/operators/optimizers/adam_op_xpu.cc index 1740f2982b6..3baba424e8f 100644 --- a/paddle/fluid/operators/optimizers/adam_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_xpu.cc @@ -121,19 +121,25 @@ class AdamOpXPUKernel : public framework::OpKernel { } else { T cpu_beta1_pow_out_data; T cpu_beta2_pow_out_data; - xpu_memcpy(&cpu_beta1_pow_out_data, beta1_pow_ptr, sizeof(T), - XPU_DEVICE_TO_HOST); + memory::Copy(platform::CPUPlace(), &cpu_beta1_pow_out_data, + BOOST_GET_CONST(platform::XPUPlace, beta1_pow.place()), + beta1_pow_ptr, sizeof(T)); + cpu_beta1_pow_out_data = cpu_beta1_pow_out_data * beta1; - xpu_memcpy(&cpu_beta2_pow_out_data, beta2_pow_ptr, sizeof(T), - XPU_DEVICE_TO_HOST); + memory::Copy(platform::CPUPlace(), &cpu_beta2_pow_out_data, + BOOST_GET_CONST(platform::XPUPlace, beta2_pow.place()), + beta2_pow_ptr, sizeof(T)); + cpu_beta2_pow_out_data = cpu_beta2_pow_out_data * beta2; T* beta1_pow_out_p = beta1_pow_out->mutable_data(ctx.GetPlace()); T* beta2_pow_out_p = beta2_pow_out->mutable_data(ctx.GetPlace()); - xpu_memcpy(beta1_pow_out_p, &cpu_beta1_pow_out_data, sizeof(T), - XPU_HOST_TO_DEVICE); - xpu_memcpy(beta2_pow_out_p, &cpu_beta2_pow_out_data, sizeof(T), - XPU_HOST_TO_DEVICE); + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), + beta1_pow_out_p, platform::CPUPlace(), + &cpu_beta1_pow_out_data, sizeof(T)); + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), + beta2_pow_out_p, platform::CPUPlace(), + &cpu_beta2_pow_out_data, sizeof(T)); } PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 94efa70e467..e119a21caa2 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -377,31 +377,9 @@ class ReshapeKernel { out->Resize(out_dims); out->mutable_data(ctx.GetPlace(), in->type()); - -#ifdef PADDLE_WITH_XPU - if (platform::is_xpu_place(ctx.GetPlace())) { - void *out_ptr = out->data(); - const void *in_ptr = in->data(); - if ((out_ptr != nullptr) && (in_ptr != nullptr) && - (paddle::framework::SizeOfType(in->type()) > 0)) { - auto &dev_ctx = - ctx.template device_context(); - int r = xpu::memcpy_device( - dev_ctx.x_context(), out_ptr, in_ptr, - in->numel() * paddle::framework::SizeOfType(in->type())); - PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, - platform::errors::External( - "XPU memcpy_device return wrong value[%d %s]", r, - XPUAPIErrorMsg[r])); - } - } else { -#endif - framework::TensorCopy( - *in, ctx.GetPlace(), - ctx.template device_context(), out); -#ifdef PADDLE_WITH_XPU - } -#endif + framework::TensorCopy( + *in, ctx.GetPlace(), + ctx.template device_context(), out); out->Resize(out_dims); } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc index 346ed965d06..8635def2ecf 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc @@ -45,11 +45,25 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { const int n = SizeToAxis(axis, logits->dims()); const int d = SizeFromAxis(axis, logits->dims()); std::vector logits_dims = framework::vectorize(logits->dims()); + // softmax auto& dev_ctx = context.template device_context(); - int r = xpu::softmax(dev_ctx.x_context(), logits->data(), - softmax->data(), logits_dims, axis); + int r = XPU_SUCCESS; + Tensor clip_logits; + int len = logits->numel(); + T* clip_logits_data = + clip_logits.mutable_data(context.GetPlace(), len * sizeof(T)); + r = xpu::clip(dev_ctx.x_context(), logits->data(), clip_logits_data, + len, -1e30, 1e30); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU kernel error. clip " + "execution not succeed, error code=%d", + r)); + + r = xpu::softmax(dev_ctx.x_context(), clip_logits_data, + softmax->data(), logits_dims, axis); PADDLE_ENFORCE_EQ( r, xpu::Error_t::SUCCESS, diff --git a/python/paddle/fluid/tests/unittests/test_matmul_op.py b/python/paddle/fluid/tests/unittests/test_matmul_op.py index 2d5f098a7fe..b936567d5b5 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_op.py @@ -206,6 +206,42 @@ for dim_X in (1, 2, 3): api_test(dim_X, dim_Y, transose_x, transose_y) +# Test case more batch_size and N, M, K +def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y, + batch_size): + BATCH_SIZE = 2 + M = 3 + N = 4 + K = 5 + if (dim_X == 1 and transpose_X) or (dim_Y == 1 and transpose_Y): + K = 1 + if dim_X == 1: + if transpose_X: + shape_X = [M] + else: + shape_X = [K] + if dim_Y == 1: + if transpose_Y: + shape_Y = [N] + else: + shape_Y = [K] + if dim_X >= 2: + if transpose_X: + shape_X = [K, M] + else: + shape_X = [M, K] + if dim_X == 3: + shape_X = [BATCH_SIZE] + shape_X + if dim_Y >= 2: + if transpose_Y: + shape_Y = [N, K] + else: + shape_Y = [K, N] + if dim_Y == 3: + shape_Y = [BATCH_SIZE] + shape_Y + return shape_X, shape_Y + + # Test case n-dim def generate_compatible_shapes(dim, transpose_X, transpose_Y): M = 2 diff --git a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py index cb64cb90e8c..f1ba8828f2b 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py @@ -51,10 +51,10 @@ class TestCastOp2(op_test.OpTest): def setUp(self): ipt = np.random.random(size=[10, 10]) self.inputs = {'X': ipt.astype('float32')} - self.outputs = {'Out': ipt.astype('float32')} + self.outputs = {'Out': ipt.astype('float16')} self.attrs = { 'in_dtype': int(core.VarDesc.VarType.FP32), - 'out_dtype': int(core.VarDesc.VarType.FP32) + 'out_dtype': int(core.VarDesc.VarType.FP16) } self.op_type = 'cast' @@ -68,10 +68,10 @@ class TestCastOp2(op_test.OpTest): class TestCastOp3(op_test.OpTest): def setUp(self): ipt = np.random.random(size=[10, 10]) - self.inputs = {'X': ipt.astype('float32')} + self.inputs = {'X': ipt.astype('float16')} self.outputs = {'Out': ipt.astype('float32')} self.attrs = { - 'in_dtype': int(core.VarDesc.VarType.FP32), + 'in_dtype': int(core.VarDesc.VarType.FP16), 'out_dtype': int(core.VarDesc.VarType.FP32) } self.op_type = 'cast' diff --git a/python/paddle/fluid/tests/unittests/xpu/test_matmul_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_matmul_op_xpu.py index fa0feb02f43..54dc46cd0ec 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_matmul_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_matmul_op_xpu.py @@ -27,8 +27,12 @@ from paddle.fluid import Program, program_guard paddle.enable_static() -def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y): +def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y, + batch_size): BATCH_SIZE = 2 + if batch_size != None: + BATCH_SIZE = batch_size + M = 3 N = 4 K = 5 @@ -58,6 +62,13 @@ def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y): shape_Y = [K, N] if dim_Y == 3: shape_Y = [BATCH_SIZE] + shape_Y + + if dim_Y == 3 and dim_X == 2: + if transpose_X == False: + shape_X[1] = shape_X[1] * BATCH_SIZE + else: + shape_X[0] = shape_X[0] * BATCH_SIZE + return shape_X, shape_Y @@ -77,11 +88,19 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): if transpose_Y: if Y.ndim == 1: Y = Y.reshape((1, Y.size)) + elif Y.ndim == 2: + Y = Y.T else: dim = [i for i in range(len(Y.shape))] dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] Y = np.transpose(Y, tuple(dim)) + if X.ndim == 3 and Y.ndim == 2: + x_dims = X.shape + X = X.reshape((x_dims[0] * x_dims[1], x_dims[2])) + if Y.ndim == 3 and X.ndim == 2: + y_dims = Y.shape + Y = Y.reshape((y_dims[0] * y_dims[1], y_dims[2])) Out = np.matmul(X, Y) if not Out.shape: # We do not support 0-dimensional Tensors (scalars). So where @@ -203,11 +222,11 @@ def test_negative_dims_program(obj): # Generate program api cases for all negative possibilities -def api_test(dim_x, dim_y, trans_x, trans_y): +def api_test(dim_x, dim_y, trans_x, trans_y, batch_size): test_name = ('TestMatMulAPI_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( dim_x, dim_y, trans_x, trans_y)) shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x, - trans_y) + trans_y, batch_size) globals()[test_name] = type(test_name, (unittest.TestCase, ), { 'shape_X': shape_x, 'shape_Y': shape_y, @@ -218,29 +237,35 @@ def api_test(dim_x, dim_y, trans_x, trans_y): # Generate operators cases for all possibilities -def inject_test(dim_x, dim_y, trans_x, trans_y): - test_name = ('TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( - dim_x, dim_y, trans_x, trans_y)) +def inject_test(dim_x, dim_y, trans_x, trans_y, batch_size): + test_name = ( + 'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}_batch_{}'.format( + dim_x, dim_y, trans_x, trans_y, batch)) shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x, - trans_y) + trans_y, batch_size) globals()[test_name] = type(test_name, (Generator, XPUOpTest), { 'shape_X': shape_x, 'shape_Y': shape_y, 'transpose_X': trans_x, 'transpose_Y': trans_y, + 'op_type': "matmul" }) -for dim_X in (1, 2, 3): - for dim_Y in (1, 2, 3): - transose_x = False - transose_y = False - if dim_X == 3 and dim_Y == 3: - inject_test(dim_X, dim_Y, transose_x, transose_y) - api_test(dim_X, dim_Y, transose_x, transose_y) +xpu_support_dims_list = [[1, 1], [2, 2], [3, 3]] +batch_size = [2, 4, 5, 10, 50, 100, 300] +for dims in xpu_support_dims_list: + dim_X = dims[0] + dim_Y = dims[1] + for transose_x in (False, True): + for transose_y in (False, True): + for batch in batch_size: + inject_test(dim_X, dim_Y, transose_x, transose_y, batch) + # xpu not support all negative possibilities + # api_test(dim_X, dim_Y, False, False, 10) -# Test case n-dim + # Test case n-dim def generate_compatible_shapes(dim, transpose_X, transpose_Y): M = 2 N = 4 @@ -261,7 +286,7 @@ def generate_compatible_shapes(dim, transpose_X, transpose_Y): return shape_X, shape_Y -# # Test case n-dim +# Test case n-dim for dim in [4]: for transpose_X in [False, True]: for transpose_Y in [False, True]: @@ -275,6 +300,7 @@ for dim in [4]: 'shape_Y': shape_Y, 'transpose_X': transpose_X, 'transpose_Y': transpose_Y, + 'op_type': "matmul" }) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py index 531e9488d60..435026220c2 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py @@ -45,7 +45,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): dim = [i for i in range(len(Y.shape))] dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] Y = np.transpose(Y, tuple(dim)) - Out = np.matmul(X, Y) if not Out.shape: # We do not support 0-dimensional Tensors (scalars). So where @@ -98,16 +97,16 @@ class TestMatMulV2Op(XPUOpTest): self.check_grad_with_place(place, ['X', 'Y'], 'Out') -# class TestMatMuklOp2(TestMatMulV2Op): -# """ -# case 2 -# """ +class TestMatMuklOp2(TestMatMulV2Op): + """ + case 2 + """ -# def config(self): -# self.x_shape = (100, ) -# self.y_shape = (1, 3, 2, 100) -# self.trans_x = False -# self.trans_y = True + def config(self): + self.x_shape = (100) + self.y_shape = (100, 3) + self.trans_x = False + self.trans_y = False class TestMatMuklOp3(TestMatMulV2Op): @@ -122,16 +121,16 @@ class TestMatMuklOp3(TestMatMulV2Op): self.trans_y = False -# class TestMatMuklOp4(TestMatMulV2Op): -# """ -# case 4 -# """ +class TestMatMuklOp4(TestMatMulV2Op): + """ + case 4 + """ -# def config(self): -# self.x_shape = (100, ) -# self.y_shape = (1, 2, 100, 2) -# self.trans_x = False -# self.trans_y = False + def config(self): + self.x_shape = (1, 1, 100, 1) + self.y_shape = (1, 100) + self.trans_x = False + self.trans_y = False class TestMatMuklOp5(TestMatMulV2Op): @@ -146,27 +145,28 @@ class TestMatMuklOp5(TestMatMulV2Op): self.trans_y = False -# class TestMatMuklOp6(TestMatMulV2Op): -# """ -# case 6 -# """ +class TestMatMuklOp6(TestMatMulV2Op): + """ + case 6 + """ -# def config(self): -# self.x_shape = (1, 2, 102, 1) -# self.y_shape = (102, ) -# self.trans_x = True -# self.trans_y = False + def config(self): + self.x_shape = (1, 2, 102, 10) + self.y_shape = (2, 10, 111) + self.trans_x = False + self.trans_y = False -# class TestMatMuklOp7(TestMatMulV2Op): -# """ -# case 7 -# """ -# def config(self): -# self.x_shape = (1, 2, 1, 100) -# self.y_shape = (100, ) -# self.trans_x = False -# self.trans_y = False +class TestMatMuklOp7(TestMatMulV2Op): + """ + case 7 + """ + + def config(self): + self.x_shape = (1, 2, 100, 1) + self.y_shape = (2, 100, 12) + self.trans_x = True + self.trans_y = False class TestMatMuklOp8(TestMatMulV2Op): @@ -181,49 +181,52 @@ class TestMatMuklOp8(TestMatMulV2Op): self.trans_y = False -# class TestMatMuklOp9(TestMatMulV2Op): -# """ -# case 9 -# """ +class TestMatMuklOp9(TestMatMulV2Op): + """ + case 9 + """ -# def config(self): -# self.x_shape = (1, 1, 1, 100) -# self.y_shape = (2, 1, 2, 100) -# self.trans_x = False -# self.trans_y = True + def config(self): + self.x_shape = (100, 20, 100) + self.y_shape = (100, 100, 100) + self.trans_x = False + self.trans_y = True -# class TestMatMuklOp10(TestMatMulV2Op): -# """ -# case 10 -# """ -# def config(self): -# self.x_shape = (1, 1, 25, 4) -# self.y_shape = (1, 2, 4, 25) -# self.trans_x = False -# self.trans_y = False +class TestMatMuklOp10(TestMatMulV2Op): + """ + case 10 + """ -# class TestMatMuklOp11(TestMatMulV2Op): -# """ -# case 11 -# """ + def config(self): + self.x_shape = (100, 20, 100) + self.y_shape = (100, 20, 100) + self.trans_x = True + self.trans_y = False -# def config(self): -# self.x_shape = (2, 1, 2, 100) -# self.y_shape = (1, 1, 100, 2) -# self.trans_x = False -# self.trans_y = False -# class TestMatMuklOp12(TestMatMulV2Op): -# """ -# case 12 -# """ +class TestMatMuklOp11(TestMatMulV2Op): + """ + case 11 + """ -# def config(self): -# self.x_shape = (2, 1, 4, 25) -# self.y_shape = (1, 1, 4, 25) -# self.trans_x = True -# self.trans_y = False + def config(self): + self.x_shape = (2, 20, 100) + self.y_shape = (100, 30) + self.trans_x = False + self.trans_y = False + + +class TestMatMuklOp12(TestMatMulV2Op): + """ + case 12 + """ + + def config(self): + self.x_shape = (1, 20, 100) + self.y_shape = (100, ) + self.trans_x = False + self.trans_y = False class TestMatMuklOp13(TestMatMulV2Op): @@ -238,38 +241,40 @@ class TestMatMuklOp13(TestMatMulV2Op): self.trans_y = False -# class TestMatMuklOp14(TestMatMulV2Op): -# """ -# case 14_1 -# """ +class TestMatMuklOp14(TestMatMulV2Op): + """ + case 14_1 + """ -# def config(self): -# self.x_shape = (3, 1, 6, 6) -# self.y_shape = (1, 2, 6, 9) -# self.trans_x = True -# self.trans_y = False + def config(self): + self.x_shape = (100, 2, 100, 10) + self.y_shape = (100, 2, 10, 90) + self.trans_x = False + self.trans_y = False -# class TestMatMuklOp15(TestMatMulV2Op): -# """ -# case 14_2 -# """ -# def config(self): -# self.x_shape = (3, 1, 6, 6) -# self.y_shape = (1, 2, 6, 9) -# self.trans_x = False -# self.trans_y = False +class TestMatMuklOp15(TestMatMulV2Op): + """ + case 14_2 + """ -# class TestMatMuklOp16(TestMatMulV2Op): -# """ -# case 16 : to check the gradient for special case -# """ + def config(self): + self.x_shape = (100, 2, 100, 10) + self.y_shape = (100, 2, 100, 10) + self.trans_x = False + self.trans_y = True -# def config(self): -# self.x_shape = (100) -# self.y_shape = (1, 2, 2, 100, 2) -# self.trans_x = False -# self.trans_y = False + +class TestMatMuklOp16(TestMatMulV2Op): + """ + case 16 : to check the big data + """ + + def config(self): + self.x_shape = (1000, 2, 100, 100) + self.y_shape = (1000, 2, 100, 900) + self.trans_x = False + self.trans_y = False class TestMatMuklOp17(TestMatMulV2Op): -- GitLab