未验证 提交 31f0221f 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] disable caching oneDNN primitives in matmul v2, Reduce grad and...

[oneDNN] disable caching oneDNN primitives in  matmul v2, Reduce grad and elementwise_add grad, expand_v2 (#35132)

* - grad caching disabled of matmul_v1

- compilation fix

- compilation fix

* - reduction removed

* - Matmul v2 disabled caching

* Draft of further changes

* - workaround for reducegrad

* - fixes to UT

* - fix to compilation

* - another fix

* - fix
上级 8dc050d8
......@@ -84,10 +84,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
} else {
// Broadcasting
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dout, dy,
ctx.InputName(framework::GradVarName("Out")),
CalculateBroadcastedDims(dout, dy));
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p},
......
......@@ -101,10 +101,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, mkldnn_engine,
ctx.GetPlace(), dout, dy,
ctx.InputName(framework::GradVarName("Out")),
CalculateBroadcastedDims(dout, dy));
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation
......
......@@ -53,8 +53,8 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> {
out->Resize(paddle::framework::make_ddim(out_new_dims));
out->set_format(x_format_tag);
paddle::platform::BroadcastDataMKLDNNHandler<T> handler(
dnnl::algorithm::binary_add, dev_ctx, onednn_engine, ctx.GetPlace(),
out, x, 0.0f, 1.0f, ctx.InputName("X"), x_vec_dims);
dnnl::algorithm::binary_add, onednn_engine, ctx.GetPlace(), out, x,
0.0f, 1.0f, x_vec_dims);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
......@@ -136,8 +136,8 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
paddle::platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc()));
} else {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dout, dx, ctx.InputName("X"), dx_vec_dims);
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dx, dx_vec_dims);
auto src_memory_p = handler.AcquireSrcMemory(dout);
auto dst_memory_p = handler.AcquireDstMemory(dx);
......
......@@ -83,22 +83,18 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx,
template <typename T>
class MatMulMKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public:
MatMulMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine,
MatMulMKLDNNHandler(const mkldnn::engine engine,
paddle::platform::Place cpu_place, Tensor* x,
bool trans_x, Tensor* y, bool trans_y, Tensor* out,
float scale, const std::string& uniq_name)
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place,
paddle::platform::CreateKey(dev_ctx, vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
auto mat_dim_x = paddle::operators::math::CreateMatrixDescriptor(
x->dims(), 0, trans_x);
auto mat_dim_y = paddle::operators::math::CreateMatrixDescriptor(
y->dims(), 0, trans_y);
float scale)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
auto mat_dim_x =
paddle::operators::math::CreateMatrixDescriptor(x->dims(), 0, trans_x);
auto mat_dim_y =
paddle::operators::math::CreateMatrixDescriptor(y->dims(), 0, trans_y);
memory::dim x_bs = mat_dim_x.batch_size_;
memory::dim y_bs = mat_dim_y.batch_size_;
......@@ -128,13 +124,11 @@ class MatMulMKLDNNHandler
this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
}
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data),
"@weights_mem_p");
to_void_cast<T>(input_data));
}
};
......@@ -565,7 +559,7 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_y,
Tensor* out, int execution_number) const {
Tensor* out) const {
// gradient is calculated in a different way when broadcasting is used
bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
out->dims().size() == 2;
......@@ -583,10 +577,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
MatMulMKLDNNHandler<T> handler(dev_ctx, engine, ctx.GetPlace(), &x_combined,
trans_x, &y_combined, trans_y, out, alpha,
ctx.InputName(framework::GradVarName("Out")) +
std::to_string(execution_number));
MatMulMKLDNNHandler<T> handler(engine, ctx.GetPlace(), &x_combined, trans_x,
&y_combined, trans_y, out, alpha);
const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
......@@ -645,24 +637,24 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext& ctx) const {
if (transpose_x && transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, &dout,
true, false, dx, 0);
true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
true, false, dy, 1);
true, false, dy);
} else if (transpose_x) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false,
&dout, true, false, dx, 0);
&dout, true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false,
&dout, false, true, dy, 1);
&dout, false, true, dy);
} else if (transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, false, true, dx, 0);
&y, false, true, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
false, true, dy, 1);
false, true, dy);
} else {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, true, false, dx, 0);
&y, true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, &dout,
false, true, dy, 1);
false, true, dy);
}
if (dx) {
......
......@@ -34,8 +34,7 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y,
bool is_fold_init_dims_y, Tensor* out,
int execution_number) const;
bool is_fold_init_dims_y, Tensor* out) const;
void RunKernel(const ExecutionContext& ctx) const;
};
} // namespace operators
......
......@@ -31,18 +31,14 @@ using paddle::framework::GradVarName;
template <typename T>
class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine,
MatMulV2MKLDNNHandler(const mkldnn::engine engine,
paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
const std::string& uniq_name)
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place,
paddle::platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) {
if (!this->isCached()) {
const std::vector<int64_t>& y_org_dims, bool trans_y)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
......@@ -92,18 +88,15 @@ class MatMulV2MKLDNNHandler
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md =
memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
}
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data),
"@weights_mem_p");
to_void_cast<T>(input_data));
}
};
......@@ -122,9 +115,8 @@ class MatMulV2MKLDNNKernel
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler(
dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims,
trans_y, ctx.InputName("X") + std::to_string(execution_number));
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y);
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
......@@ -251,8 +243,8 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t> dx_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims);
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, dx_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
......
......@@ -96,9 +96,9 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape(
paddle::framework::vectorize<int64_t>(output->dims()))));
} else {
platform::ReductionMKLDNNHandler<T> handler(
reduction_type, 0.0f, 0.0f, dev_ctx, onednn_engine, ctx.GetPlace(),
input, output, ctx.InputName("X"), output_dims);
platform::ReductionMKLDNNHandler<T> handler(reduction_type, 0.0f, 0.0f,
onednn_engine, ctx.GetPlace(),
input, output, output_dims);
auto src_memory_p = handler.AcquireSrcMemory(input);
auto dst_memory_p = handler.AcquireDstMemory(output);
......@@ -137,40 +137,28 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
mkldnn::memory::format_tag x_format_tag;
auto input_dims =
CalculateReducedDims(output_dx, input_dy, dims, reduce_all, keep_dim);
auto output_dims = framework::vectorize(output_dx->dims());
if (input_dims != framework::vectorize(output_dx->dims())) {
const std::string key_pd =
platform::CreateKey(
dev_ctx, framework::vectorize(output_dx->dims()),
ctx.InputName("X"),
(std::to_string(static_cast<int>(reduction_type)))) +
"@fwd_pd";
std::shared_ptr<dnnl::reduction::primitive_desc> fwd_pd =
std::static_pointer_cast<dnnl::reduction::primitive_desc>(
dev_ctx.GetBlob(key_pd));
PADDLE_ENFORCE_NOT_NULL(
fwd_pd, platform::errors::Unavailable(
"Forward primitive descriptor is not available in %s op, "
"cannot deduce memory format tag",
ctx.Type()));
x_format_tag = platform::GetMKLDNNFormat(fwd_pd->src_desc());
PADDLE_ENFORCE_NE(x_format_tag, mkldnn::memory::format_tag::undef,
platform::errors::InvalidArgument(
"Cannot deduce format tag for %s op", ctx.Type()));
} else { // fwd descriptor not available because reorder was used instead
// of reduction
if (input_dims != output_dims) {
auto input_dy_md = dnnl::memory::desc(
framework::vectorize(input_dy->dims()),
platform::MKLDNNGetDataType<T>(), input_dy->format());
auto input_dy_ex_md = input_dy_md.reshape(input_dims);
// TODO(jczaja): once MD is stored in Tensor we no longer need to guess
// formats
x_format_tag = platform::GetMKLDNNFormat(input_dy_ex_md);
} else {
// There was no broadcasting then just simple copy is done
// same format used for input and output
x_format_tag = getPlainFormatTag(output_dx);
}
output_dx->set_format(x_format_tag);
platform::BroadcastDataMKLDNNHandler<T> handler(
binary_type, dev_ctx, onednn_engine, ctx.GetPlace(), output_dx,
input_dy, scale_x, scale_y,
ctx.InputName(framework::GradVarName("Out")), input_dims);
binary_type, onednn_engine, ctx.GetPlace(), output_dx, input_dy,
scale_x, scale_y, input_dims);
const auto src_memory_p = handler.AcquireSrcMemory(input_dy);
const auto dst_memory_p = handler.AcquireDstMemory(output_dx);
......@@ -184,6 +172,8 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
binary_prim->execute(astream, args);
astream.wait();
output_dx->set_layout(framework::DataLayout::kMKLDNN);
}
protected:
......
......@@ -895,20 +895,14 @@ class BinaryMKLDNNHandler
template <typename T>
class BroadcastDataMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::binary> {
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public:
BroadcastDataMKLDNNHandler(const dnnl::algorithm algo,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine,
platform::Place cpu_place, const Tensor* out,
const Tensor* x, float scale_x, float scale_y,
const std::string& uniq_name,
const std::vector<int64_t>& input_dims)
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
......@@ -927,9 +921,8 @@ class BroadcastDataMKLDNNHandler
attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x});
attributes.set_scales(DNNL_ARG_SRC_1, 0, {scale_y});
this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md,
src1_md, src0_md);
}
this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md,
src0_md);
}
template <typename T_out = T>
......@@ -938,27 +931,20 @@ class BroadcastDataMKLDNNHandler
this->place_, this->fwd_pd_->dst_desc().get_size());
;
memset(ptr, 0, this->fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr,
"@dst_mem_p");
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
}
};
template <typename T>
class ReductionMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::reduction> {
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction> {
public:
ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p,
const float eps, const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y,
const std::string& uniq_name,
std::vector<int64_t> y_tz)
: platform::MKLDNNHandlerT<T, dnnl::reduction>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name,
(std::to_string(static_cast<int>(algo))))) {
if (!this->isCached()) {
const float eps, const mkldnn::engine engine,
platform::Place cpu_place, const Tensor* x,
const Tensor* y, std::vector<int64_t> y_tz)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction>(engine,
cpu_place) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
......@@ -968,14 +954,13 @@ class ReductionMKLDNNHandler
const auto x_tz = framework::vectorize(x->dims());
const auto x_md = dnnl::memory::desc(
x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto x_md =
dnnl::memory::desc(x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto y_md =
memory::desc(y_tz, platform::MKLDNNGetDataType<T>(), x->format());
this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps);
}
}
};
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册