未验证 提交 fab92824 编写于 作者: S Sylwester Fraczek 提交者: GitHub

refactoring matmul_v2 mkldnn hierarchy (#37622)

* refactoring matmul hierarchy

* review fix

* review fix

* review_FIX-part2
上级 5747fd1e
......@@ -25,9 +25,9 @@ using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast;
using Tensor = paddle::framework::Tensor;
using paddle::framework::vectorize;
using paddle::framework::make_ddim;
using paddle::framework::GradVarName;
using paddle::framework::make_ddim;
using paddle::framework::vectorize;
template <typename T>
class MatMulV2MKLDNNHandler
......@@ -123,21 +123,29 @@ class MatMulV2MKLDNNHandler
}
};
template <typename T>
class MatMulV2MKLDNNKernel
: public paddle::operators::MatMulGradMKLDNNKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
bool IsOutputFused(const ExecutionContext& ctx) {
auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto& fused_transpose_Out = ctx.Attr<std::vector<int>>("fused_transpose_Out");
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}
protected:
void ExecuteMatMul(const ExecutionContext& ctx,
float ComputeOutputScale(const ExecutionContext& ctx) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.Attr<bool>("force_fp32_output");
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
return scale_out / (scale_x * scale_y);
}
template <typename T>
void ExecuteMatMulV2(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const dnnl::engine onednn_engine,
paddle::platform::Place cpu_place, const Tensor* x,
std::vector<int64_t>& x_dims, bool trans_x,
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const {
int execution_number = 0) {
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y,
IsOutputFused(ctx));
......@@ -161,7 +169,12 @@ class MatMulV2MKLDNNKernel
out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(format);
}
}
template <typename T>
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
private:
void CalculateMatrixDims(const ExecutionContext& ctx,
......@@ -207,13 +220,6 @@ class MatMulV2MKLDNNKernel
}
}
bool IsOutputFused(const ExecutionContext& ctx) const {
auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto& fused_transpose_Out =
ctx.Attr<std::vector<int>>("fused_transpose_Out");
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}
void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
......@@ -237,13 +243,14 @@ class MatMulV2MKLDNNKernel
CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims,
out);
ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_bd_dims,
trans_x, y, y_bd_dims, trans_y, out, out_dims);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_bd_dims, trans_x, y, y_bd_dims, trans_y, out,
out_dims);
}
};
template <typename T>
class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
......@@ -316,7 +323,7 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
// if no broadcasting is needed, we can simply use matmul's grad and avoid
// using reduce_sum
if (!is_broadcast) {
paddle::operators::MatMulGradMKLDNNKernel<T>::Compute(ctx);
matmul_v1_grad_mkldnn_kernel.Compute(ctx);
return;
}
......@@ -342,33 +349,29 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
dy_bd_dims);
if (trans_x && trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, true, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims,
true, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims,
2);
} else if (trans_x) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, false, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, false, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims,
false, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims,
false, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2);
} else if (trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, false, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, false, &dy_tmp,
dy_bd_dims, 2);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, false, &dy_tmp, dy_bd_dims,
2);
} else {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, true, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, true, &dx_tmp, dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims,
true, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2);
}
if (x_dims != dx_bd_dims) {
......@@ -389,8 +392,12 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
dy->set_layout(paddle::framework::DataLayout::kMKLDNN);
dy->set_format(y->format());
}
private:
paddle::operators::MatMulGradMKLDNNKernel<T> matmul_v1_grad_mkldnn_kernel;
};
} // anonymous namespace
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册