未验证 提交 c7acad9f 编写于 作者: T taixiurong 提交者: GitHub

support some shape for matmul and cast in xpu place (#29900)

* support some shape in matmul and cast

* modify matmul
上级 80eb7778
......@@ -49,6 +49,12 @@ class CastXPUKernel : public framework::OpKernel<InT> {
auto* out_data = out->mutable_data<int64_t>(context.GetPlace());
r = xpu::cast_v2<InT, int64_t>(dev_ctx.x_context(), in_data, out_data,
numel);
} else if ((out_type == framework::proto::VarType::BOOL) &&
(in_type == framework::proto::VarType::FP32)) {
auto* out_data = out->mutable_data<bool>(context.GetPlace());
r = xpu::cast_v2<float, int8_t>(
dev_ctx.x_context(), (const float*)in_data,
reinterpret_cast<int8_t*>(out_data), numel);
} else {
PADDLE_THROW(platform::errors::Unavailable("Not supported cast %d -> %d",
in_type, out_type));
......
......@@ -111,6 +111,20 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
auto mat_dim_b =
math::CreateMatrixDescriptor(ColumnMatrixFromVector(y->dims()), 0,
context.Attr<bool>("transpose_Y"));
const auto &x_dims = x->dims();
const auto &y_dims = y->dims();
if (x_dims.size() == 3 && y_dims.size() <= 2) {
// if transpose_X is true, the transpose cost much time
if (!context.Attr<bool>("transpose_X")) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
} else {
mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
}
}
PADDLE_ENFORCE_EQ(
mat_dim_a.width_, mat_dim_b.height_,
platform::errors::InvalidArgument("Shape mistake in matmul_op"));
......@@ -224,12 +238,26 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
out->mutable_data<T>(context.GetPlace());
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
const auto &a_dims = a.dims();
const auto &b_dims = b.dims();
if (a_dims.size() == 3 && b_dims.size() <= 2) {
// if transpose_X is true, the transpose cost much time
if (!context.Attr<bool>("transpose_X")) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
} else {
mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
}
}
PADDLE_ENFORCE_EQ(
mat_dim_a.width_, mat_dim_b.height_,
platform::errors::InvalidArgument("Shape mistake in matmul_grad_op"));
PADDLE_ENFORCE_EQ(
mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument("Shape mistake in matmul_grad_op"));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto &dev_ctx = context.template device_context<DeviceContext>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册