未验证 提交 f1c1d9e0 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN ] disabling more ops caching (#34830)

* - disabled caching of layer norm

- fix in compilation

- compilation fix

- transpose caching disabled

- compilation fix

- more compilation fixes

- sum caching disabled

- compilation fix

* - LRN with disabled cache

* lint fixes
上级 7b3295a4
...@@ -19,45 +19,36 @@ namespace paddle { ...@@ -19,45 +19,36 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class LayerNormMKLDNNHandler class LayerNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
: public platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward> { T, dnnl::layer_normalization_forward> {
public: public:
LayerNormMKLDNNHandler(const std::vector<int64_t>& dims, const float& epsilon, LayerNormMKLDNNHandler(const std::vector<int64_t>& dims, const float& epsilon,
const dnnl::normalization_flags& flags, const dnnl::normalization_flags& flags,
const bool& is_test, const MKLDNNMemoryFormat fmt, const bool& is_test, const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place)
platform::Place cpu_place, : platform::MKLDNNHandlerNoCachingT<T, dnnl::layer_normalization_forward>(
const std::string& uniq_name) engine, cpu_place) {
: platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward>( auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
dev_ctx, dev_ctx.GetEngine(), cpu_place, if (!is_test) {
platform::CreateKey(dev_ctx, dims, uniq_name)) { // TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced
if (!this->isCached()) { auto stats_md = dnnl::memory::desc(
auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); {begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType<float>(),
if (!is_test) { platform::MKLDNNFormatForSize(dims.size() - 1,
// TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced MKLDNNMemoryFormat::nchw));
auto stats_md = dnnl::memory::desc( this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
{begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType<float>(), md, stats_md, epsilon, flags);
platform::MKLDNNFormatForSize(dims.size() - 1, } else {
MKLDNNMemoryFormat::nchw)); this->AcquireForwardPrimitiveDescriptor(
this->AcquireForwardPrimitiveDescriptor( dnnl::prop_kind::forward_inference, md, epsilon, flags);
dnnl::prop_kind::forward_training, md, stats_md, epsilon, flags);
} else {
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, md, epsilon, flags);
}
} }
} }
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory() {
return this->AcquireMemoryFromPrimitive("@scaleshift_mem_p");
}
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory( std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(
std::vector<float>& scaleshift_data) { std::vector<float>& scaleshift_data) {
// scaleshift_data comes from temporary buffer so we need to copy it into // scaleshift_data comes from temporary buffer so we need to copy it into
// created memory primitivie // created memory primitivie
auto scaleshift_mem = this->AcquireMemoryFromPrimitive( auto scaleshift_mem =
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p"); this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
auto data_ptr = scaleshift_mem->get_data_handle(); auto data_ptr = scaleshift_mem->get_data_handle();
std::size_t num_bytes = scaleshift_data.size() * sizeof(float); std::size_t num_bytes = scaleshift_data.size() * sizeof(float);
std::memcpy(data_ptr, scaleshift_data.data(), num_bytes); std::memcpy(data_ptr, scaleshift_data.data(), num_bytes);
...@@ -68,7 +59,7 @@ class LayerNormMKLDNNHandler ...@@ -68,7 +59,7 @@ class LayerNormMKLDNNHandler
T* mean_data = mean->mutable_data<T>(this->place_, T* mean_data = mean->mutable_data<T>(this->place_,
this->fwd_pd_->mean_desc().get_size()); this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data, "@mean_mem_p"); mean_data);
} }
std::shared_ptr<dnnl::memory> AcquireVarianceMemory( std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
...@@ -76,7 +67,7 @@ class LayerNormMKLDNNHandler ...@@ -76,7 +67,7 @@ class LayerNormMKLDNNHandler
T* variance_data = variance->mutable_data<T>( T* variance_data = variance->mutable_data<T>(
this->place_, this->fwd_pd_->variance_desc().get_size()); this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data, "@variance_mem_p"); variance_data);
} }
}; };
...@@ -95,6 +86,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -95,6 +86,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto src_tz = paddle::framework::vectorize(x->dims()); auto src_tz = paddle::framework::vectorize(x->dims());
PADDLE_ENFORCE_EQ(begin_norm_axis, (src_tz.size() - 1), PADDLE_ENFORCE_EQ(begin_norm_axis, (src_tz.size() - 1),
...@@ -112,8 +104,8 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -112,8 +104,8 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
LayerNormMKLDNNHandler<T> handler(src_tz, epsilon, flags, is_test, LayerNormMKLDNNHandler<T> handler(src_tz, epsilon, flags, is_test,
x->format(), dev_ctx, ctx.GetPlace(), x->format(), mkldnn_engine,
ctx.OutputName("Y")); ctx.GetPlace());
auto src_memory = handler.AcquireSrcMemory(x); auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(y); auto dst_memory = handler.AcquireDstMemory(y);
...@@ -139,24 +131,22 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -139,24 +131,22 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
args.insert({DNNL_ARG_VARIANCE, *variance_memory}); args.insert({DNNL_ARG_VARIANCE, *variance_memory});
} }
auto scaleshift_memory = handler.AcquireScaleShiftMemory(); std::shared_ptr<mkldnn::memory> scaleshift_memory;
if (with_scaleshift) { if (with_scaleshift) {
if (scaleshift_memory == nullptr || !is_test) { auto scale_tz = paddle::framework::vectorize(scale->dims());
auto scale_tz = paddle::framework::vectorize(scale->dims()); const unsigned int C = scale_tz[0];
const unsigned int C = scale_tz[0];
// MKLDNN requires a single piece of memory for scale and shift/bias
// MKLDNN requires a single piece of memory for scale and shift/bias // data
// data std::vector<float> scaleshift_data;
std::vector<float> scaleshift_data; scaleshift_data.reserve(2 * C);
scaleshift_data.reserve(2 * C); scaleshift_data.insert(scaleshift_data.begin(), scale->data<float>(),
scaleshift_data.insert(scaleshift_data.begin(), scale->data<float>(), scale->data<float>() + C);
scale->data<float>() + C);
scaleshift_data.insert(scaleshift_data.end(), bias->data<float>(),
scaleshift_data.insert(scaleshift_data.end(), bias->data<float>(), bias->data<float>() + C);
bias->data<float>() + C);
scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data);
scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data);
}
args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}); args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory});
} }
......
...@@ -21,86 +21,78 @@ using paddle::framework::Tensor; ...@@ -21,86 +21,78 @@ using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
template <typename T> template <typename T>
class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, class LRNMKLDNNHandler
mkldnn::lrn_backward> { : public platform::MKLDNNHandlerNoCachingT<T, mkldnn::lrn_forward,
mkldnn::lrn_backward> {
public: public:
LRNMKLDNNHandler(const framework::ExecutionContext& ctx, LRNMKLDNNHandler(const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine, const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input, platform::Place cpu_place, const Tensor* input)
const std::string& unique_name)
: platform::MKLDNNHandlerNoCachingT<T, mkldnn::lrn_forward,
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>( mkldnn::lrn_backward>(mkldnn_engine,
dev_ctx, mkldnn_engine, cpu_place, cpu_place) {
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), const int n = ctx.Attr<int>("n");
unique_name)) { // MKL-DNN implements LRN in a caffe way:
if (!this->isCached()) { // http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
const int n = ctx.Attr<int>("n"); // Where sum of squares is divided by size of normalization window
// MKL-DNN implements LRN in a caffe way: // this is not the case for PaddlePaddle LRN.
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html // Hence we need to compensate for this diffrence by
// Where sum of squares is divided by size of normalization window // multipliing alpha by size of window(n)
// this is not the case for PaddlePaddle LRN. const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
// Hence we need to compensate for this diffrence by const float beta = ctx.Attr<float>("beta");
// multipliing alpha by size of window(n) const float k = ctx.Attr<float>("k");
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n); bool is_test = ctx.Attr<bool>("is_test");
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k"); auto dims = framework::vectorize(input->dims());
bool is_test = ctx.Attr<bool>("is_test");
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
auto dims = framework::vectorize(input->dims()); input->format());
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), this->AcquireForwardPrimitiveDescriptor(
input->format()); is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
this->AcquireForwardPrimitiveDescriptor( mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
}
} }
LRNMKLDNNHandler(const framework::ExecutionContext& ctx, LRNMKLDNNHandler(const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* in_x, platform::Place cpu_place, const Tensor* in_x,
const Tensor* out_grad, Tensor* in_x_grad, const Tensor* out_grad, Tensor* in_x_grad)
const std::string& unique_name) : platform::MKLDNNHandlerNoCachingT<T, mkldnn::lrn_forward,
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>( mkldnn::lrn_backward>(mkldnn_engine,
dev_ctx, dev_ctx.GetEngine(), cpu_place, cpu_place) {
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()), PADDLE_ENFORCE_EQ(
unique_name)) { ctx.Attr<bool>("is_test"), false,
if (!this->isBwdCached()) { platform::errors::PreconditionNotMet(
PADDLE_ENFORCE_EQ( "is_test attribute should be set to False in training phase."));
ctx.Attr<bool>("is_test"), false,
platform::errors::PreconditionNotMet( const int n = ctx.Attr<int>("n");
"is_test attribute should be set to False in training phase.")); const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const int n = ctx.Attr<int>("n"); const float k = ctx.Attr<float>("k");
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta"); auto dims = framework::vectorize<int64_t>(in_x->dims());
const float k = ctx.Attr<float>("k");
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
auto dims = framework::vectorize<int64_t>(in_x->dims()); in_x->format());
auto diff_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), out_grad->format());
in_x->format());
auto diff_md = mkldnn::memory::desc( this->AcquireForwardPrimitiveDescriptor(
dims, platform::MKLDNNGetDataType<T>(), out_grad->format()); mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, this->AcquireBackwardPrimitiveDescriptor(
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k); mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha, beta,
k);
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha,
beta, k);
}
} }
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(Tensor* workspace) { std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(Tensor* workspace) {
T* ptr = workspace->mutable_data<T>( T* ptr = workspace->mutable_data<T>(
this->place_, this->fwd_pd_->workspace_desc().get_size()); this->place_, this->fwd_pd_->workspace_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->workspace_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->workspace_desc(),
ptr, "@wrk_mem_p"); ptr);
} }
std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory( std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory(
...@@ -108,7 +100,7 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, ...@@ -108,7 +100,7 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward,
const T* workspace_data = workspace->data<T>(); const T* workspace_data = workspace->data<T>();
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->workspace_desc(), this->fwd_pd_->workspace_desc(),
platform::to_void_cast<T>(workspace_data), "@bwd-wrk_mem_p"); platform::to_void_cast<T>(workspace_data));
} }
}; };
...@@ -131,8 +123,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -131,8 +123,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto out = ctx.Output<Tensor>("Out"); auto out = ctx.Output<Tensor>("Out");
auto mid = ctx.Output<Tensor>("MidOut"); auto mid = ctx.Output<Tensor>("MidOut");
LRNMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, LRNMKLDNNHandler<T> handler(ctx, mkldnn_engine, ctx.GetPlace(), x);
ctx.OutputName("Out"));
auto src_memory = handler.AcquireSrcMemory(x); auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(out); auto dst_memory = handler.AcquireDstMemory(out);
...@@ -178,9 +169,10 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -178,9 +169,10 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
LRNMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), in_x, out_grad, LRNMKLDNNHandler<T> handler(ctx, mkldnn_engine, ctx.GetPlace(), in_x,
in_x_grad, ctx.InputName("Out")); out_grad, in_x_grad);
auto src_memory = handler.AcquireSrcMemory(in_x); auto src_memory = handler.AcquireSrcMemory(in_x);
auto workspace = handler.AcquireBackwardWorkspaceMemory(mid); auto workspace = handler.AcquireBackwardWorkspaceMemory(mid);
......
...@@ -45,44 +45,35 @@ using paddle::platform::MKLDNNDeviceContext; ...@@ -45,44 +45,35 @@ using paddle::platform::MKLDNNDeviceContext;
using platform::to_void_cast; using platform::to_void_cast;
template <typename T> template <typename T>
class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> { class SumMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::sum> {
public: public:
SumMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, SumMKLDNNHandler(mkldnn::engine engine, platform::Place cpu_place,
platform::Place cpu_place,
const std::vector<framework::Variable*>& in_vars, const std::vector<framework::Variable*>& in_vars,
framework::LoDTensor* z, const std::string& uniq_name) framework::LoDTensor* z)
: platform::MKLDNNHandlerT<T, dnnl::sum>( : platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place),
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(z->dims()),
uniq_name)),
num_inputs_(0) { num_inputs_(0) {
for (size_t i = 0; i < in_vars.size(); i++) { auto dst_tz = framework::vectorize<int64_t>(z->dims());
srcs_suffix_.push_back(std::string("-") + std::to_string(i)); auto src_tz = dst_tz;
}
if (!this->isCached()) { std::vector<mkldnn::memory::desc> srcs_md;
auto dst_tz = framework::vectorize<int64_t>(z->dims()); for (size_t i = 0; i < in_vars.size(); i++) {
auto src_tz = dst_tz; auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) {
std::vector<mkldnn::memory::desc> srcs_md; continue;
for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) {
continue;
}
MKLDNNMemoryFormat input_format = input_it.format();
srcs_md.push_back(mkldnn::memory::desc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format));
++num_inputs_;
} }
std::vector<float> scales(num_inputs_, 1.0); MKLDNNMemoryFormat input_format = input_it.format();
srcs_md.push_back(mkldnn::memory::desc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format));
++num_inputs_;
}
std::vector<float> scales(num_inputs_, 1.0);
auto dst_md = mkldnn::memory::desc( auto dst_md = mkldnn::memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
dst_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any); MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md); this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
}
} }
// (jczaja) sum oneDNN prim is not having .desc attribute so // (jczaja) sum oneDNN prim is not having .desc attribute so
...@@ -90,37 +81,27 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> { ...@@ -90,37 +81,27 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
void AcquireForwardPrimitiveDescriptor( void AcquireForwardPrimitiveDescriptor(
const mkldnn::memory::desc& dst_md, const std::vector<float>& scales, const mkldnn::memory::desc& dst_md, const std::vector<float>& scales,
const std::vector<mkldnn::memory::desc>& srcs_md) { const std::vector<mkldnn::memory::desc>& srcs_md) {
// Sum op does not have backward so no passing from FWD to BWD is needed this->fwd_pd_.reset(
const std::string key_pd = this->key_ + "@fwd_pd"; new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_));
this->fwd_pd_ = std::static_pointer_cast<dnnl::sum::primitive_desc>(
this->dev_ctx_.GetBlob(key_pd));
if (this->fwd_pd_ == nullptr) {
this->fwd_pd_.reset(new dnnl::sum::primitive_desc(dst_md, scales, srcs_md,
this->engine_));
this->dev_ctx_.SetBlob(key_pd, this->fwd_pd_);
}
} }
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor& input, int i) { const framework::Tensor& input, int i) {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
to_void_cast<T>(input_data), to_void_cast<T>(input_data));
"@src_mem_p" + srcs_suffix_[i]);
} }
using platform::MKLDNNHandlerT<T, dnnl::sum>::AcquireDstMemory; using platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory;
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void) { std::shared_ptr<mkldnn::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc());
"@dst_mem_p");
} }
inline int GetNumInputs(void) { return num_inputs_; } inline int GetNumInputs(void) { return num_inputs_; }
private: private:
int num_inputs_; int num_inputs_;
std::vector<std::string> srcs_suffix_;
}; };
template <typename T> template <typename T>
...@@ -131,6 +112,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -131,6 +112,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Sum must use CPUPlace")); "Operator DNNL Sum must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto in_vars = ctx.MultiInputVar("X"); auto in_vars = ctx.MultiInputVar("X");
PADDLE_ENFORCE_NE(in_vars.empty(), true, platform::errors::InvalidArgument( PADDLE_ENFORCE_NE(in_vars.empty(), true, platform::errors::InvalidArgument(
...@@ -140,8 +122,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -140,8 +122,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output); bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output);
SumMKLDNNHandler<T> handler(dev_ctx, ctx.GetPlace(), in_vars, output, SumMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), in_vars, output);
ctx.OutputName("Out"));
// Create list of SRC MEMs // Create list of SRC MEMs
std::vector<std::shared_ptr<mkldnn::memory>> srcs_mem; std::vector<std::shared_ptr<mkldnn::memory>> srcs_mem;
......
...@@ -24,6 +24,70 @@ namespace operators { ...@@ -24,6 +24,70 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using framework::DataLayout; using framework::DataLayout;
template <typename T>
class TransposeMKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
mkldnn::engine engine)
: dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0),
engine_(engine) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const MKLDNNMemoryFormat& fmt, void* ptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
this->logical_axis_[i] = i;
}
auto src_md = fmt != MKLDNNMemoryFormat::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
return std::make_shared<mkldnn::memory>(src_md, engine_, ptr);
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output,
platform::Place place) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
return std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<mkldnn::reorder> AcquireTranspose(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
return std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
}
protected:
mkldnn::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();
std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
mkldnn::memory::desc mem_d(nchw_tz, platform::MKLDNNGetDataType<T>(),
strides);
return mem_d;
}
private:
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
mkldnn::engine engine_;
};
template <typename T> template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -48,11 +112,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -48,11 +112,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto nchw_tz = paddle::framework::vectorize<int64_t>(input->dims()); auto nchw_tz = paddle::framework::vectorize<int64_t>(input->dims());
const std::string key = TransposeMKLDNNHandler<T> handler(nchw_tz, axis, mkldnn_engine);
platform::CreateKey(dev_ctx, nchw_tz, ctx.OutputName("Out"));
platform::TransposeMKLDNNHandler<T> handler(nchw_tz, axis, dev_ctx,
mkldnn_engine, key);
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto transpose_src_memory_p = handler.AcquireSrcMemory(
input->format(), platform::to_void_cast<T>(input_data)); input->format(), platform::to_void_cast<T>(input_data));
...@@ -103,11 +163,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -103,11 +163,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto nchw_tz = paddle::framework::vectorize<int64_t>(out_grad->dims()); auto nchw_tz = paddle::framework::vectorize<int64_t>(out_grad->dims());
const std::string key = platform::CreateKey( TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, mkldnn_engine);
dev_ctx, nchw_tz, ctx.OutputName(framework::GradVarName("X")));
platform::TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, dev_ctx,
mkldnn_engine, key);
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto transpose_src_memory_p = handler.AcquireSrcMemory(
out_grad->format(), platform::to_void_cast<T>(out_grad_data)); out_grad->format(), platform::to_void_cast<T>(out_grad_data));
......
...@@ -1072,99 +1072,6 @@ class ActivationMKLDNNHandler ...@@ -1072,99 +1072,6 @@ class ActivationMKLDNNHandler
} }
}; };
template <typename T>
class TransposeMKLDNNHandler : public MKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const MKLDNNMemoryFormat& fmt, void* ptr) {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < logical_axis_.size(); ++i) {
logical_axis_[i] = i;
}
auto src_md = fmt != MKLDNNMemoryFormat::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
mem_p = std::make_shared<mkldnn::memory>(src_md, engine_, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output,
platform::Place place) {
auto local_key = key_ + "@user_dst_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
auto dst_data = output->mutable_data<T>(place);
mem_p->set_data_handle(dst_data);
}
return mem_p;
}
std::shared_ptr<mkldnn::reorder> AcquireTranspose(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
auto prim_key = key_ + "@transpose_p";
auto transpose_p =
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
if (transpose_p == nullptr) {
transpose_p =
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, transpose_p);
}
return transpose_p;
}
protected:
mkldnn::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();
std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
mkldnn::memory::desc mem_d(nchw_tz, platform::MKLDNNGetDataType<T>(),
strides);
return mem_d;
}
private:
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
};
class ReorderMKLDNNHandler : public MKLDNNHandler { class ReorderMKLDNNHandler : public MKLDNNHandler {
public: public:
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册