From d84b8e83500c1bda8ec82abbccdda87a7f97371f Mon Sep 17 00:00:00 2001 From: taixiurong Date: Mon, 28 Dec 2020 19:24:34 +0800 Subject: [PATCH] support some shape for matmul and cast in xpu place (#29900) (#29907) * support some shape in matmul and cast * modify matmul --- paddle/fluid/operators/cast_op_xpu.cc | 6 ++++++ paddle/fluid/operators/matmul_op_xpu.cc | 28 +++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/paddle/fluid/operators/cast_op_xpu.cc b/paddle/fluid/operators/cast_op_xpu.cc index a2791cb262..bbd43274a0 100644 --- a/paddle/fluid/operators/cast_op_xpu.cc +++ b/paddle/fluid/operators/cast_op_xpu.cc @@ -49,6 +49,12 @@ class CastXPUKernel : public framework::OpKernel { auto* out_data = out->mutable_data(context.GetPlace()); r = xpu::cast_v2(dev_ctx.x_context(), in_data, out_data, numel); + } else if ((out_type == framework::proto::VarType::BOOL) && + (in_type == framework::proto::VarType::FP32)) { + auto* out_data = out->mutable_data(context.GetPlace()); + r = xpu::cast_v2( + dev_ctx.x_context(), (const float*)in_data, + reinterpret_cast(out_data), numel); } else { PADDLE_THROW(platform::errors::Unavailable("Not supported cast %d -> %d", in_type, out_type)); diff --git a/paddle/fluid/operators/matmul_op_xpu.cc b/paddle/fluid/operators/matmul_op_xpu.cc index 4dc458460e..103ac9add1 100644 --- a/paddle/fluid/operators/matmul_op_xpu.cc +++ b/paddle/fluid/operators/matmul_op_xpu.cc @@ -111,6 +111,20 @@ class MatMulXPUKernel : public framework::OpKernel { auto mat_dim_b = math::CreateMatrixDescriptor(ColumnMatrixFromVector(y->dims()), 0, context.Attr("transpose_Y")); + + const auto &x_dims = x->dims(); + const auto &y_dims = y->dims(); + if (x_dims.size() == 3 && y_dims.size() <= 2) { + // if transpose_X is true, the transpose cost much time + if (!context.Attr("transpose_X")) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } else { + mat_dim_b.batch_size_ = mat_dim_a.batch_size_; + mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_; + } + } + PADDLE_ENFORCE_EQ( mat_dim_a.width_, mat_dim_b.height_, platform::errors::InvalidArgument("Shape mistake in matmul_op")); @@ -224,12 +238,26 @@ class MatMulGradXPUKernel : public framework::OpKernel { out->mutable_data(context.GetPlace()); auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + const auto &a_dims = a.dims(); + const auto &b_dims = b.dims(); + if (a_dims.size() == 3 && b_dims.size() <= 2) { + // if transpose_X is true, the transpose cost much time + if (!context.Attr("transpose_X")) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } else { + mat_dim_b.batch_size_ = mat_dim_a.batch_size_; + mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_; + } + } + PADDLE_ENFORCE_EQ( mat_dim_a.width_, mat_dim_b.height_, platform::errors::InvalidArgument("Shape mistake in matmul_grad_op")); PADDLE_ENFORCE_EQ( mat_dim_a.batch_size_, mat_dim_b.batch_size_, platform::errors::InvalidArgument("Shape mistake in matmul_grad_op")); + T alpha = static_cast(context.Attr("alpha")); auto &dev_ctx = context.template device_context(); -- GitLab