未验证 提交 fab92824 编写于 作者: S Sylwester Fraczek 提交者: GitHub

refactoring matmul_v2 mkldnn hierarchy (#37622)

* refactoring matmul hierarchy

* review fix

* review fix

* review_FIX-part2
上级 5747fd1e
...@@ -25,9 +25,9 @@ using paddle::platform::MKLDNNDeviceContext; ...@@ -25,9 +25,9 @@ using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType; using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast; using paddle::platform::to_void_cast;
using Tensor = paddle::framework::Tensor; using Tensor = paddle::framework::Tensor;
using paddle::framework::vectorize;
using paddle::framework::make_ddim;
using paddle::framework::GradVarName; using paddle::framework::GradVarName;
using paddle::framework::make_ddim;
using paddle::framework::vectorize;
template <typename T> template <typename T>
class MatMulV2MKLDNNHandler class MatMulV2MKLDNNHandler
...@@ -123,45 +123,58 @@ class MatMulV2MKLDNNHandler ...@@ -123,45 +123,58 @@ class MatMulV2MKLDNNHandler
} }
}; };
template <typename T> bool IsOutputFused(const ExecutionContext& ctx) {
class MatMulV2MKLDNNKernel auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
: public paddle::operators::MatMulGradMKLDNNKernel<T> { auto& fused_transpose_Out = ctx.Attr<std::vector<int>>("fused_transpose_Out");
public: return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } }
float ComputeOutputScale(const ExecutionContext& ctx) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.Attr<bool>("force_fp32_output");
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
return scale_out / (scale_x * scale_y);
}
protected: template <typename T>
void ExecuteMatMul(const ExecutionContext& ctx, void ExecuteMatMulV2(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx, const MKLDNNDeviceContext& dev_ctx,
const dnnl::engine onednn_engine, const dnnl::engine onednn_engine,
paddle::platform::Place cpu_place, const Tensor* x, paddle::platform::Place cpu_place, const Tensor* x,
std::vector<int64_t>& x_dims, bool trans_x, std::vector<int64_t>& x_dims, bool trans_x,
const Tensor* y, std::vector<int64_t>& y_dims, const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims, bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const { int execution_number = 0) {
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims, MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y, trans_x, y_dims, trans_y,
IsOutputFused(ctx)); IsOutputFused(ctx));
const auto src_memory_p = handler.AcquireSrcMemory(x); const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y); const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out); const auto dst_memory_p = handler.AcquireDstMemory(out);
auto matmul_p = handler.AcquireForwardPrimitive(); auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, memory> matmul_args = { std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
auto& astream = MKLDNNDeviceContext::tls().get_stream(); auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
auto format = paddle::platform::MKLDNNFormatForSize( auto format = paddle::platform::MKLDNNFormatForSize(
out->dims().size(), dnnl::memory::format_tag::nchw); out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_layout(paddle::framework::DataLayout::kMKLDNN); out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(format); out->set_format(format);
} }
template <typename T>
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
private: private:
void CalculateMatrixDims(const ExecutionContext& ctx, void CalculateMatrixDims(const ExecutionContext& ctx,
...@@ -207,13 +220,6 @@ class MatMulV2MKLDNNKernel ...@@ -207,13 +220,6 @@ class MatMulV2MKLDNNKernel
} }
} }
bool IsOutputFused(const ExecutionContext& ctx) const {
auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto& fused_transpose_Out =
ctx.Attr<std::vector<int>>("fused_transpose_Out");
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}
void RunKernel(const ExecutionContext& ctx) const { void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
...@@ -237,13 +243,14 @@ class MatMulV2MKLDNNKernel ...@@ -237,13 +243,14 @@ class MatMulV2MKLDNNKernel
CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims, CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims,
out); out);
ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_bd_dims, ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
trans_x, y, y_bd_dims, trans_y, out, out_dims); x_bd_dims, trans_x, y, y_bd_dims, trans_y, out,
out_dims);
} }
}; };
template <typename T> template <typename T>
class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> { class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
...@@ -316,7 +323,7 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> { ...@@ -316,7 +323,7 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
// if no broadcasting is needed, we can simply use matmul's grad and avoid // if no broadcasting is needed, we can simply use matmul's grad and avoid
// using reduce_sum // using reduce_sum
if (!is_broadcast) { if (!is_broadcast) {
paddle::operators::MatMulGradMKLDNNKernel<T>::Compute(ctx); matmul_v1_grad_mkldnn_kernel.Compute(ctx);
return; return;
} }
...@@ -342,33 +349,29 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> { ...@@ -342,33 +349,29 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
dy_bd_dims); dy_bd_dims);
if (trans_x && trans_y) { if (trans_x && trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims,
y_dims, true, dout, dout_dims, true, &dx_tmp, true, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1);
dx_bd_dims, 1); ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims,
dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims, 2);
2);
} else if (trans_x) { } else if (trans_x) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims,
y_dims, false, dout, dout_dims, true, &dx_tmp, false, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1);
dx_bd_dims, 1); ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims,
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, false, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2);
x_dims, false, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
} else if (trans_y) { } else if (trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, false, &dx_tmp, dout_dims, false, y, y_dims, false, &dx_tmp,
dx_bd_dims, 1); dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, false, &dy_tmp, dout_dims, true, x, x_dims, false, &dy_tmp, dy_bd_dims,
dy_bd_dims, 2); 2);
} else { } else {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, true, &dx_tmp, dout_dims, false, y, y_dims, true, &dx_tmp, dx_bd_dims,
dx_bd_dims, 1); 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims,
x_dims, true, dout, dout_dims, false, &dy_tmp, true, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2);
dy_bd_dims, 2);
} }
if (x_dims != dx_bd_dims) { if (x_dims != dx_bd_dims) {
...@@ -389,8 +392,12 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> { ...@@ -389,8 +392,12 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
dy->set_layout(paddle::framework::DataLayout::kMKLDNN); dy->set_layout(paddle::framework::DataLayout::kMKLDNN);
dy->set_format(y->format()); dy->set_format(y->format());
} }
private:
paddle::operators::MatMulGradMKLDNNKernel<T> matmul_v1_grad_mkldnn_kernel;
}; };
} // anonymous namespace } // anonymous namespace
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册