未验证 提交 91635de3 编写于 作者: C cucuzg 提交者: GitHub

opt matmul and matmul_v2 on kunlun, *test=kunlun (#31326)

* add clip_by_norm on kunlun, *test=kunlun

* opt matmul and matmul_v2 on kunlun, *test=kunlun
上级 e2023409
......@@ -159,23 +159,14 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out,
"XPU fc_fusion kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
} else {
// batch matmul
int x_stride = mat_dim_a.stride_;
int y_stride = mat_dim_b.stride_;
int out_stride = m * n;
for (int i = 0; i < batch_size; ++i) {
const float *x_data = x->data<T>() + x_stride * i;
const float *y_data = y->data<T>() + y_stride * i;
float *out_data = data_c + out_stride * i;
int r = xpu::fc_fusion<float, float, float, FCT>(
dev_ctx.x_context(), x_data, y_data, out_data, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx,
ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU fc_fusion kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
int r = xpu::fc_batched<float, float, float, FCT>(
dev_ctx.x_context(), batch_size, mat_dim_a.trans_, mat_dim_b.trans_, m,
n, k, alpha, x->data<T>(), mat_dim_a.stride_, y->data<T>(),
mat_dim_b.stride_, 0.0, data_c, m * n, nullptr, nullptr);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU fc_batched kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
}
......
......@@ -79,22 +79,14 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out,
"XPU fc_fusion kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
} else {
// batch matmul
int x_stride = mat_dim_a.stride_;
int y_stride = mat_dim_b.stride_;
int out_stride = m * n;
for (int i = 0; i < batch_size; ++i) {
const float* x_data = x->data<T>() + x_stride * i;
const float* y_data = y->data<T>() + y_stride * i;
float* out_data = data_c + out_stride * i;
int r = xpu::fc<float, float, float, FCT>(
dev_ctx.x_context(), x_data, y_data, out_data, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU fc_fusion kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
int r = xpu::fc_batched<float, float, float, FCT>(
dev_ctx.x_context(), batch_size, mat_dim_a.trans_, mat_dim_b.trans_, m,
n, k, 1.0, x->data<T>(), mat_dim_a.stride_, y->data<T>(),
mat_dim_b.stride_, 0.0, data_c, m * n, nullptr, nullptr);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU fc_batched kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册