未验证 提交 31f0221f 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] disable caching oneDNN primitives in matmul v2, Reduce grad and...

[oneDNN] disable caching oneDNN primitives in  matmul v2, Reduce grad and elementwise_add grad, expand_v2 (#35132)

* - grad caching disabled of matmul_v1

- compilation fix

- compilation fix

* - reduction removed

* - Matmul v2 disabled caching

* Draft of further changes

* - workaround for reducegrad

* - fixes to UT

* - fix to compilation

* - another fix

* - fix
上级 8dc050d8
...@@ -84,10 +84,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -84,10 +84,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
} else { } else {
// Broadcasting // Broadcasting
platform::ReductionMKLDNNHandler<T> handler_sum( platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
ctx.InputName(framework::GradVarName("Out")),
CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy); auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive(); auto reduction_p = handler_sum.AcquireForwardPrimitive();
reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p}, reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p},
......
...@@ -101,10 +101,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -101,10 +101,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// Reduction is needed for broadcasting scenario // Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) { if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum( platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, mkldnn_engine, dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine,
ctx.GetPlace(), dout, dy, ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
ctx.InputName(framework::GradVarName("Out")),
CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy); auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive(); auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation // As source we use mem object with results from binary operation
......
...@@ -53,8 +53,8 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -53,8 +53,8 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> {
out->Resize(paddle::framework::make_ddim(out_new_dims)); out->Resize(paddle::framework::make_ddim(out_new_dims));
out->set_format(x_format_tag); out->set_format(x_format_tag);
paddle::platform::BroadcastDataMKLDNNHandler<T> handler( paddle::platform::BroadcastDataMKLDNNHandler<T> handler(
dnnl::algorithm::binary_add, dev_ctx, onednn_engine, ctx.GetPlace(), dnnl::algorithm::binary_add, onednn_engine, ctx.GetPlace(), out, x,
out, x, 0.0f, 1.0f, ctx.InputName("X"), x_vec_dims); 0.0f, 1.0f, x_vec_dims);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out); auto dst_memory_p = handler.AcquireDstMemory(out);
...@@ -136,8 +136,8 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -136,8 +136,8 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
paddle::platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc())); paddle::platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc()));
} else { } else {
paddle::platform::ReductionMKLDNNHandler<T> handler( paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dx, ctx.InputName("X"), dx_vec_dims); ctx.GetPlace(), dout, dx, dx_vec_dims);
auto src_memory_p = handler.AcquireSrcMemory(dout); auto src_memory_p = handler.AcquireSrcMemory(dout);
auto dst_memory_p = handler.AcquireDstMemory(dx); auto dst_memory_p = handler.AcquireDstMemory(dx);
......
...@@ -83,58 +83,52 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx, ...@@ -83,58 +83,52 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx,
template <typename T> template <typename T>
class MatMulMKLDNNHandler class MatMulMKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> { : public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public: public:
MatMulMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, MatMulMKLDNNHandler(const mkldnn::engine engine,
const mkldnn::engine engine,
paddle::platform::Place cpu_place, Tensor* x, paddle::platform::Place cpu_place, Tensor* x,
bool trans_x, Tensor* y, bool trans_y, Tensor* out, bool trans_x, Tensor* y, bool trans_y, Tensor* out,
float scale, const std::string& uniq_name) float scale)
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>( : paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
dev_ctx, engine, cpu_place, cpu_place) {
paddle::platform::CreateKey(dev_ctx, vectorize(x->dims()), auto mat_dim_x =
uniq_name)) { paddle::operators::math::CreateMatrixDescriptor(x->dims(), 0, trans_x);
if (!this->isCached()) { auto mat_dim_y =
auto mat_dim_x = paddle::operators::math::CreateMatrixDescriptor( paddle::operators::math::CreateMatrixDescriptor(y->dims(), 0, trans_y);
x->dims(), 0, trans_x);
auto mat_dim_y = paddle::operators::math::CreateMatrixDescriptor( memory::dim x_bs = mat_dim_x.batch_size_;
y->dims(), 0, trans_y); memory::dim y_bs = mat_dim_y.batch_size_;
memory::dim x_bs = mat_dim_x.batch_size_; memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
memory::dim y_bs = mat_dim_y.batch_size_; const memory::dim M = mat_dim_x.height_;
const memory::dim N = mat_dim_y.width_;
memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; const memory::dim K = mat_dim_x.width_;
const memory::dim M = mat_dim_x.height_;
const memory::dim N = mat_dim_y.width_; memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
const memory::dim K = mat_dim_x.width_; memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
memory::dims out_dims = {out_bs, M, N};
memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N}; memory::dims x_strides =
memory::dims out_dims = {out_bs, M, N}; !trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M};
memory::dims x_strides = memory::dims y_strides =
!trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M}; !trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};
memory::dims y_strides =
!trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K}; auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
memory::dims out_strides = memory::dims{M * N, N, 1}; auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_dims, MKLDNNGetDataType<T>(), out_strides);
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides); dnnl::primitive_attr attrs;
auto out_md = memory::desc(out_dims, MKLDNNGetDataType<T>(), out_strides); if (scale != 1.0f) attrs.set_output_scales(0, {scale});
dnnl::primitive_attr attrs; this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
if (scale != 1.0f) attrs.set_output_scales(0, {scale});
this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
}
} }
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) { std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data), to_void_cast<T>(input_data));
"@weights_mem_p");
} }
}; };
...@@ -565,7 +559,7 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad( ...@@ -565,7 +559,7 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x, const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_y, bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_y,
Tensor* out, int execution_number) const { Tensor* out) const {
// gradient is calculated in a different way when broadcasting is used // gradient is calculated in a different way when broadcasting is used
bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) && bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
out->dims().size() == 2; out->dims().size() == 2;
...@@ -583,10 +577,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad( ...@@ -583,10 +577,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f; float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
MatMulMKLDNNHandler<T> handler(dev_ctx, engine, ctx.GetPlace(), &x_combined, MatMulMKLDNNHandler<T> handler(engine, ctx.GetPlace(), &x_combined, trans_x,
trans_x, &y_combined, trans_y, out, alpha, &y_combined, trans_y, out, alpha);
ctx.InputName(framework::GradVarName("Out")) +
std::to_string(execution_number));
const auto src_memory_p = handler.AcquireSrcMemory(&x_combined); const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined); const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
...@@ -645,24 +637,24 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext& ctx) const { ...@@ -645,24 +637,24 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext& ctx) const {
if (transpose_x && transpose_y) { if (transpose_x && transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, &dout, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, &dout,
true, false, dx, 0); true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
true, false, dy, 1); true, false, dy);
} else if (transpose_x) { } else if (transpose_x) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false,
&dout, true, false, dx, 0); &dout, true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false,
&dout, false, true, dy, 1); &dout, false, true, dy);
} else if (transpose_y) { } else if (transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, false, true, dx, 0); &y, false, true, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
false, true, dy, 1); false, true, dy);
} else { } else {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, true, false, dx, 0); &y, true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, &dout, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, &dout,
false, true, dy, 1); false, true, dy);
} }
if (dx) { if (dx) {
......
...@@ -34,8 +34,7 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -34,8 +34,7 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
const MKLDNNDeviceContext& dev_ctx, const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x, const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_x, Tensor* y, bool trans_y,
bool is_fold_init_dims_y, Tensor* out, bool is_fold_init_dims_y, Tensor* out) const;
int execution_number) const;
void RunKernel(const ExecutionContext& ctx) const; void RunKernel(const ExecutionContext& ctx) const;
}; };
} // namespace operators } // namespace operators
......
...@@ -31,79 +31,72 @@ using paddle::framework::GradVarName; ...@@ -31,79 +31,72 @@ using paddle::framework::GradVarName;
template <typename T> template <typename T>
class MatMulV2MKLDNNHandler class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> { : public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public: public:
MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, MatMulV2MKLDNNHandler(const mkldnn::engine engine,
const mkldnn::engine engine,
paddle::platform::Place cpu_place, paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x, const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y, const std::vector<int64_t>& y_org_dims, bool trans_y)
const std::string& uniq_name) : paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>( cpu_place) {
dev_ctx, engine, cpu_place, // M X K * K X N
paddle::platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) { std::vector<int64_t> x_dims(x_org_dims);
if (!this->isCached()) { std::vector<int64_t> y_dims(y_org_dims);
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims); const int MB_idx = x_dims.size() - 3;
std::vector<int64_t> y_dims(y_org_dims); const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;
const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2; if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
const int W_idx = x_dims.size() - 1; if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);
if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]); const memory::dim M = x_dims[H_idx];
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]); const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];
const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx]; std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
const memory::dim N = y_dims[W_idx]; std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> x_strides(x_dims.size() - 3, 1); std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1); x_strides.reserve(x_dims.size());
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1); y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());
x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size()); if (!trans_x) {
out_strides.reserve(x_dims.size()); x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
if (!trans_x) { x_strides.insert(x_strides.end(), {M * K, 1, M});
x_strides.insert(x_strides.end(), {M * K, K, 1}); }
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}
if (!trans_y) { if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1}); y_strides.insert(y_strides.end(), {N * K, N, 1});
} else { } else {
y_strides.insert(y_strides.end(), {N * K, 1, K}); y_strides.insert(y_strides.end(), {N * K, 1, K});
} }
out_strides.insert(out_strides.end(), {M * N, N, 1}); out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(), out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N}); {std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});
for (int i = x_dims.size() - 4; i >= 0; --i) { for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]); out_ddims[i] = std::max(x_dims[i], y_dims[i]);
x_strides[i] = x_dims[i + 1] * x_strides[i + 1]; x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
y_strides[i] = y_dims[i + 1] * y_strides[i + 1]; y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
} }
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides); auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides); auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md); this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
}
} }
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) { std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data), to_void_cast<T>(input_data));
"@weights_mem_p");
} }
}; };
...@@ -122,9 +115,8 @@ class MatMulV2MKLDNNKernel ...@@ -122,9 +115,8 @@ class MatMulV2MKLDNNKernel
const Tensor* y, std::vector<int64_t>& y_dims, const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims, bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const { int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler( MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims, trans_x, y_dims, trans_y);
trans_y, ctx.InputName("X") + std::to_string(execution_number));
const auto src_memory_p = handler.AcquireSrcMemory(x); const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y); const auto weights_memory_p = handler.AcquireWeightsMemory(y);
...@@ -251,8 +243,8 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> { ...@@ -251,8 +243,8 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
const Tensor* dx_tmp, Tensor* dx, const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t> dx_dims) const { std::vector<int64_t> dx_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler( paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims); ctx.GetPlace(), dx_tmp, dx, dx_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp); auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx); auto dst_memory_p = handler.AcquireDstMemory(dx);
......
...@@ -96,9 +96,9 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> { ...@@ -96,9 +96,9 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape( platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape(
paddle::framework::vectorize<int64_t>(output->dims())))); paddle::framework::vectorize<int64_t>(output->dims()))));
} else { } else {
platform::ReductionMKLDNNHandler<T> handler( platform::ReductionMKLDNNHandler<T> handler(reduction_type, 0.0f, 0.0f,
reduction_type, 0.0f, 0.0f, dev_ctx, onednn_engine, ctx.GetPlace(), onednn_engine, ctx.GetPlace(),
input, output, ctx.InputName("X"), output_dims); input, output, output_dims);
auto src_memory_p = handler.AcquireSrcMemory(input); auto src_memory_p = handler.AcquireSrcMemory(input);
auto dst_memory_p = handler.AcquireDstMemory(output); auto dst_memory_p = handler.AcquireDstMemory(output);
...@@ -137,40 +137,28 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -137,40 +137,28 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
mkldnn::memory::format_tag x_format_tag; mkldnn::memory::format_tag x_format_tag;
auto input_dims = auto input_dims =
CalculateReducedDims(output_dx, input_dy, dims, reduce_all, keep_dim); CalculateReducedDims(output_dx, input_dy, dims, reduce_all, keep_dim);
auto output_dims = framework::vectorize(output_dx->dims());
if (input_dims != framework::vectorize(output_dx->dims())) { if (input_dims != output_dims) {
const std::string key_pd = auto input_dy_md = dnnl::memory::desc(
platform::CreateKey( framework::vectorize(input_dy->dims()),
dev_ctx, framework::vectorize(output_dx->dims()), platform::MKLDNNGetDataType<T>(), input_dy->format());
ctx.InputName("X"), auto input_dy_ex_md = input_dy_md.reshape(input_dims);
(std::to_string(static_cast<int>(reduction_type)))) + // TODO(jczaja): once MD is stored in Tensor we no longer need to guess
"@fwd_pd"; // formats
std::shared_ptr<dnnl::reduction::primitive_desc> fwd_pd = x_format_tag = platform::GetMKLDNNFormat(input_dy_ex_md);
std::static_pointer_cast<dnnl::reduction::primitive_desc>(
dev_ctx.GetBlob(key_pd)); } else {
// There was no broadcasting then just simple copy is done
PADDLE_ENFORCE_NOT_NULL( // same format used for input and output
fwd_pd, platform::errors::Unavailable(
"Forward primitive descriptor is not available in %s op, "
"cannot deduce memory format tag",
ctx.Type()));
x_format_tag = platform::GetMKLDNNFormat(fwd_pd->src_desc());
PADDLE_ENFORCE_NE(x_format_tag, mkldnn::memory::format_tag::undef,
platform::errors::InvalidArgument(
"Cannot deduce format tag for %s op", ctx.Type()));
} else { // fwd descriptor not available because reorder was used instead
// of reduction
x_format_tag = getPlainFormatTag(output_dx); x_format_tag = getPlainFormatTag(output_dx);
} }
output_dx->set_format(x_format_tag); output_dx->set_format(x_format_tag);
platform::BroadcastDataMKLDNNHandler<T> handler( platform::BroadcastDataMKLDNNHandler<T> handler(
binary_type, dev_ctx, onednn_engine, ctx.GetPlace(), output_dx, binary_type, onednn_engine, ctx.GetPlace(), output_dx, input_dy,
input_dy, scale_x, scale_y, scale_x, scale_y, input_dims);
ctx.InputName(framework::GradVarName("Out")), input_dims);
const auto src_memory_p = handler.AcquireSrcMemory(input_dy); const auto src_memory_p = handler.AcquireSrcMemory(input_dy);
const auto dst_memory_p = handler.AcquireDstMemory(output_dx); const auto dst_memory_p = handler.AcquireDstMemory(output_dx);
...@@ -184,6 +172,8 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -184,6 +172,8 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
binary_prim->execute(astream, args); binary_prim->execute(astream, args);
astream.wait(); astream.wait();
output_dx->set_layout(framework::DataLayout::kMKLDNN);
} }
protected: protected:
......
...@@ -895,41 +895,34 @@ class BinaryMKLDNNHandler ...@@ -895,41 +895,34 @@ class BinaryMKLDNNHandler
template <typename T> template <typename T>
class BroadcastDataMKLDNNHandler class BroadcastDataMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::binary> { : public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public: public:
BroadcastDataMKLDNNHandler(const dnnl::algorithm algo, BroadcastDataMKLDNNHandler(const dnnl::algorithm algo,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, const mkldnn::engine engine,
platform::Place cpu_place, const Tensor* out, platform::Place cpu_place, const Tensor* out,
const Tensor* x, float scale_x, float scale_y, const Tensor* x, float scale_x, float scale_y,
const std::string& uniq_name,
const std::vector<int64_t>& input_dims) const std::vector<int64_t>& input_dims)
: platform::MKLDNNHandlerT<T, dnnl::binary>( : platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
dev_ctx, engine, cpu_place, PADDLE_ENFORCE_EQ(
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), x->layout(), DataLayout::kMKLDNN,
uniq_name)) { platform::errors::InvalidArgument("Wrong layout set for X tensor."));
if (!this->isCached()) { PADDLE_ENFORCE_NE(
PADDLE_ENFORCE_EQ( x->format(), MKLDNNMemoryFormat::undef,
x->layout(), DataLayout::kMKLDNN, platform::errors::InvalidArgument("Wrong format set for X tensor."));
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
PADDLE_ENFORCE_NE( const auto src0_tz = framework::vectorize(out->dims());
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor.")); const auto src0_md = dnnl::memory::desc(
src0_tz, platform::MKLDNNGetDataType<T>(), out->format());
const auto src0_tz = framework::vectorize(out->dims()); const auto src1_md = dnnl::memory::desc(
input_dims, platform::MKLDNNGetDataType<T>(), out->format());
const auto src0_md = dnnl::memory::desc(
src0_tz, platform::MKLDNNGetDataType<T>(), out->format()); dnnl::primitive_attr attributes;
const auto src1_md = dnnl::memory::desc( attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x});
input_dims, platform::MKLDNNGetDataType<T>(), out->format()); attributes.set_scales(DNNL_ARG_SRC_1, 0, {scale_y});
dnnl::primitive_attr attributes; this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md,
attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x}); src0_md);
attributes.set_scales(DNNL_ARG_SRC_1, 0, {scale_y});
this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md,
src1_md, src0_md);
}
} }
template <typename T_out = T> template <typename T_out = T>
...@@ -938,43 +931,35 @@ class BroadcastDataMKLDNNHandler ...@@ -938,43 +931,35 @@ class BroadcastDataMKLDNNHandler
this->place_, this->fwd_pd_->dst_desc().get_size()); this->place_, this->fwd_pd_->dst_desc().get_size());
; ;
memset(ptr, 0, this->fwd_pd_->dst_desc().get_size()); memset(ptr, 0, this->fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr, return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
"@dst_mem_p");
} }
}; };
template <typename T> template <typename T>
class ReductionMKLDNNHandler class ReductionMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::reduction> { : public platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction> {
public: public:
ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p, ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p,
const float eps, const MKLDNNDeviceContext& dev_ctx, const float eps, const mkldnn::engine engine,
const mkldnn::engine engine, platform::Place cpu_place, platform::Place cpu_place, const Tensor* x,
const Tensor* x, const Tensor* y, const Tensor* y, std::vector<int64_t> y_tz)
const std::string& uniq_name, : platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction>(engine,
std::vector<int64_t> y_tz) cpu_place) {
: platform::MKLDNNHandlerT<T, dnnl::reduction>( PADDLE_ENFORCE_EQ(
dev_ctx, engine, cpu_place, x->layout(), DataLayout::kMKLDNN,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), platform::errors::InvalidArgument("Wrong layout set for X tensor."));
uniq_name, PADDLE_ENFORCE_NE(
(std::to_string(static_cast<int>(algo))))) { x->format(), MKLDNNMemoryFormat::undef,
if (!this->isCached()) { platform::errors::InvalidArgument("Wrong format set for X tensor."));
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN, const auto x_tz = framework::vectorize(x->dims());
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
PADDLE_ENFORCE_NE( const auto x_md =
x->format(), MKLDNNMemoryFormat::undef, dnnl::memory::desc(x_tz, platform::MKLDNNGetDataType<T>(), x->format());
platform::errors::InvalidArgument("Wrong format set for X tensor.")); const auto y_md =
memory::desc(y_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto x_tz = framework::vectorize(x->dims());
this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps);
const auto x_md = dnnl::memory::desc(
x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto y_md =
memory::desc(y_tz, platform::MKLDNNGetDataType<T>(), x->format());
this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps);
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册