未验证 提交 f1c1d9e0 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN ] disabling more ops caching (#34830)

* - disabled caching of layer norm

- fix in compilation

- compilation fix

- transpose caching disabled

- compilation fix

- more compilation fixes

- sum caching disabled

- compilation fix

* - LRN with disabled cache

* lint fixes
上级 7b3295a4
......@@ -19,45 +19,36 @@ namespace paddle {
namespace operators {
template <typename T>
class LayerNormMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward> {
class LayerNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
T, dnnl::layer_normalization_forward> {
public:
LayerNormMKLDNNHandler(const std::vector<int64_t>& dims, const float& epsilon,
const dnnl::normalization_flags& flags,
const bool& is_test, const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place,
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, uniq_name)) {
if (!this->isCached()) {
auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
if (!is_test) {
// TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced
auto stats_md = dnnl::memory::desc(
{begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType<float>(),
platform::MKLDNNFormatForSize(dims.size() - 1,
MKLDNNMemoryFormat::nchw));
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_training, md, stats_md, epsilon, flags);
} else {
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, md, epsilon, flags);
}
const mkldnn::engine engine, platform::Place cpu_place)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::layer_normalization_forward>(
engine, cpu_place) {
auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
if (!is_test) {
// TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced
auto stats_md = dnnl::memory::desc(
{begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType<float>(),
platform::MKLDNNFormatForSize(dims.size() - 1,
MKLDNNMemoryFormat::nchw));
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
md, stats_md, epsilon, flags);
} else {
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, md, epsilon, flags);
}
}
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory() {
return this->AcquireMemoryFromPrimitive("@scaleshift_mem_p");
}
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(
std::vector<float>& scaleshift_data) {
// scaleshift_data comes from temporary buffer so we need to copy it into
// created memory primitivie
auto scaleshift_mem = this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p");
auto scaleshift_mem =
this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
auto data_ptr = scaleshift_mem->get_data_handle();
std::size_t num_bytes = scaleshift_data.size() * sizeof(float);
std::memcpy(data_ptr, scaleshift_data.data(), num_bytes);
......@@ -68,7 +59,7 @@ class LayerNormMKLDNNHandler
T* mean_data = mean->mutable_data<T>(this->place_,
this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data, "@mean_mem_p");
mean_data);
}
std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
......@@ -76,7 +67,7 @@ class LayerNormMKLDNNHandler
T* variance_data = variance->mutable_data<T>(
this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data, "@variance_mem_p");
variance_data);
}
};
......@@ -95,6 +86,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto src_tz = paddle::framework::vectorize(x->dims());
PADDLE_ENFORCE_EQ(begin_norm_axis, (src_tz.size() - 1),
......@@ -112,8 +104,8 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
LayerNormMKLDNNHandler<T> handler(src_tz, epsilon, flags, is_test,
x->format(), dev_ctx, ctx.GetPlace(),
ctx.OutputName("Y"));
x->format(), mkldnn_engine,
ctx.GetPlace());
auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(y);
......@@ -139,24 +131,22 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
args.insert({DNNL_ARG_VARIANCE, *variance_memory});
}
auto scaleshift_memory = handler.AcquireScaleShiftMemory();
std::shared_ptr<mkldnn::memory> scaleshift_memory;
if (with_scaleshift) {
if (scaleshift_memory == nullptr || !is_test) {
auto scale_tz = paddle::framework::vectorize(scale->dims());
const unsigned int C = scale_tz[0];
// MKLDNN requires a single piece of memory for scale and shift/bias
// data
std::vector<float> scaleshift_data;
scaleshift_data.reserve(2 * C);
scaleshift_data.insert(scaleshift_data.begin(), scale->data<float>(),
scale->data<float>() + C);
scaleshift_data.insert(scaleshift_data.end(), bias->data<float>(),
bias->data<float>() + C);
scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data);
}
auto scale_tz = paddle::framework::vectorize(scale->dims());
const unsigned int C = scale_tz[0];
// MKLDNN requires a single piece of memory for scale and shift/bias
// data
std::vector<float> scaleshift_data;
scaleshift_data.reserve(2 * C);
scaleshift_data.insert(scaleshift_data.begin(), scale->data<float>(),
scale->data<float>() + C);
scaleshift_data.insert(scaleshift_data.end(), bias->data<float>(),
bias->data<float>() + C);
scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data);
args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory});
}
......
......@@ -21,86 +21,78 @@ using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
template <typename T>
class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward,
mkldnn::lrn_backward> {
class LRNMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, mkldnn::lrn_forward,
mkldnn::lrn_backward> {
public:
LRNMKLDNNHandler(const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) {
if (!this->isCached()) {
const int n = ctx.Attr<int>("n");
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
// Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test");
auto dims = framework::vectorize(input->dims());
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
}
platform::Place cpu_place, const Tensor* input)
: platform::MKLDNNHandlerNoCachingT<T, mkldnn::lrn_forward,
mkldnn::lrn_backward>(mkldnn_engine,
cpu_place) {
const int n = ctx.Attr<int>("n");
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
// Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test");
auto dims = framework::vectorize(input->dims());
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
}
LRNMKLDNNHandler(const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* in_x,
const Tensor* out_grad, Tensor* in_x_grad,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()),
unique_name)) {
if (!this->isBwdCached()) {
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
platform::errors::PreconditionNotMet(
"is_test attribute should be set to False in training phase."));
const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
auto dims = framework::vectorize<int64_t>(in_x->dims());
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
in_x->format());
auto diff_md = mkldnn::memory::desc(
dims, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha,
beta, k);
}
const Tensor* out_grad, Tensor* in_x_grad)
: platform::MKLDNNHandlerNoCachingT<T, mkldnn::lrn_forward,
mkldnn::lrn_backward>(mkldnn_engine,
cpu_place) {
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
platform::errors::PreconditionNotMet(
"is_test attribute should be set to False in training phase."));
const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
auto dims = framework::vectorize<int64_t>(in_x->dims());
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
in_x->format());
auto diff_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
out_grad->format());
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha, beta,
k);
}
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(Tensor* workspace) {
T* ptr = workspace->mutable_data<T>(
this->place_, this->fwd_pd_->workspace_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->workspace_desc(),
ptr, "@wrk_mem_p");
ptr);
}
std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory(
......@@ -108,7 +100,7 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward,
const T* workspace_data = workspace->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->workspace_desc(),
platform::to_void_cast<T>(workspace_data), "@bwd-wrk_mem_p");
platform::to_void_cast<T>(workspace_data));
}
};
......@@ -131,8 +123,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto out = ctx.Output<Tensor>("Out");
auto mid = ctx.Output<Tensor>("MidOut");
LRNMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), x,
ctx.OutputName("Out"));
LRNMKLDNNHandler<T> handler(ctx, mkldnn_engine, ctx.GetPlace(), x);
auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(out);
......@@ -178,9 +169,10 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
LRNMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), in_x, out_grad,
in_x_grad, ctx.InputName("Out"));
LRNMKLDNNHandler<T> handler(ctx, mkldnn_engine, ctx.GetPlace(), in_x,
out_grad, in_x_grad);
auto src_memory = handler.AcquireSrcMemory(in_x);
auto workspace = handler.AcquireBackwardWorkspaceMemory(mid);
......
......@@ -45,44 +45,35 @@ using paddle::platform::MKLDNNDeviceContext;
using platform::to_void_cast;
template <typename T>
class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
class SumMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::sum> {
public:
SumMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place,
SumMKLDNNHandler(mkldnn::engine engine, platform::Place cpu_place,
const std::vector<framework::Variable*>& in_vars,
framework::LoDTensor* z, const std::string& uniq_name)
framework::LoDTensor* z)
: platform::MKLDNNHandlerT<T, dnnl::sum>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(z->dims()),
uniq_name)),
: platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place),
num_inputs_(0) {
for (size_t i = 0; i < in_vars.size(); i++) {
srcs_suffix_.push_back(std::string("-") + std::to_string(i));
}
auto dst_tz = framework::vectorize<int64_t>(z->dims());
auto src_tz = dst_tz;
if (!this->isCached()) {
auto dst_tz = framework::vectorize<int64_t>(z->dims());
auto src_tz = dst_tz;
std::vector<mkldnn::memory::desc> srcs_md;
for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) {
continue;
}
MKLDNNMemoryFormat input_format = input_it.format();
srcs_md.push_back(mkldnn::memory::desc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format));
++num_inputs_;
std::vector<mkldnn::memory::desc> srcs_md;
for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) {
continue;
}
std::vector<float> scales(num_inputs_, 1.0);
MKLDNNMemoryFormat input_format = input_it.format();
srcs_md.push_back(mkldnn::memory::desc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format));
++num_inputs_;
}
std::vector<float> scales(num_inputs_, 1.0);
auto dst_md = mkldnn::memory::desc(
dst_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
auto dst_md = mkldnn::memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
}
this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
}
// (jczaja) sum oneDNN prim is not having .desc attribute so
......@@ -90,37 +81,27 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
void AcquireForwardPrimitiveDescriptor(
const mkldnn::memory::desc& dst_md, const std::vector<float>& scales,
const std::vector<mkldnn::memory::desc>& srcs_md) {
// Sum op does not have backward so no passing from FWD to BWD is needed
const std::string key_pd = this->key_ + "@fwd_pd";
this->fwd_pd_ = std::static_pointer_cast<dnnl::sum::primitive_desc>(
this->dev_ctx_.GetBlob(key_pd));
if (this->fwd_pd_ == nullptr) {
this->fwd_pd_.reset(new dnnl::sum::primitive_desc(dst_md, scales, srcs_md,
this->engine_));
this->dev_ctx_.SetBlob(key_pd, this->fwd_pd_);
}
this->fwd_pd_.reset(
new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_));
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor& input, int i) {
const T* input_data = input.data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
to_void_cast<T>(input_data),
"@src_mem_p" + srcs_suffix_[i]);
to_void_cast<T>(input_data));
}
using platform::MKLDNNHandlerT<T, dnnl::sum>::AcquireDstMemory;
using platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory;
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(),
"@dst_mem_p");
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc());
}
inline int GetNumInputs(void) { return num_inputs_; }
private:
int num_inputs_;
std::vector<std::string> srcs_suffix_;
};
template <typename T>
......@@ -131,6 +112,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Sum must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto in_vars = ctx.MultiInputVar("X");
PADDLE_ENFORCE_NE(in_vars.empty(), true, platform::errors::InvalidArgument(
......@@ -140,8 +122,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output);
SumMKLDNNHandler<T> handler(dev_ctx, ctx.GetPlace(), in_vars, output,
ctx.OutputName("Out"));
SumMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), in_vars, output);
// Create list of SRC MEMs
std::vector<std::shared_ptr<mkldnn::memory>> srcs_mem;
......
......@@ -24,6 +24,70 @@ namespace operators {
using Tensor = framework::Tensor;
using framework::DataLayout;
template <typename T>
class TransposeMKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
mkldnn::engine engine)
: dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0),
engine_(engine) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const MKLDNNMemoryFormat& fmt, void* ptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
this->logical_axis_[i] = i;
}
auto src_md = fmt != MKLDNNMemoryFormat::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
return std::make_shared<mkldnn::memory>(src_md, engine_, ptr);
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output,
platform::Place place) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
return std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<mkldnn::reorder> AcquireTranspose(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
return std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
}
protected:
mkldnn::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();
std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
mkldnn::memory::desc mem_d(nchw_tz, platform::MKLDNNGetDataType<T>(),
strides);
return mem_d;
}
private:
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
mkldnn::engine engine_;
};
template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
......@@ -48,11 +112,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto nchw_tz = paddle::framework::vectorize<int64_t>(input->dims());
const std::string key =
platform::CreateKey(dev_ctx, nchw_tz, ctx.OutputName("Out"));
platform::TransposeMKLDNNHandler<T> handler(nchw_tz, axis, dev_ctx,
mkldnn_engine, key);
TransposeMKLDNNHandler<T> handler(nchw_tz, axis, mkldnn_engine);
auto transpose_src_memory_p = handler.AcquireSrcMemory(
input->format(), platform::to_void_cast<T>(input_data));
......@@ -103,11 +163,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto nchw_tz = paddle::framework::vectorize<int64_t>(out_grad->dims());
const std::string key = platform::CreateKey(
dev_ctx, nchw_tz, ctx.OutputName(framework::GradVarName("X")));
platform::TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, dev_ctx,
mkldnn_engine, key);
TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, mkldnn_engine);
auto transpose_src_memory_p = handler.AcquireSrcMemory(
out_grad->format(), platform::to_void_cast<T>(out_grad_data));
......
......@@ -1072,99 +1072,6 @@ class ActivationMKLDNNHandler
}
};
template <typename T>
class TransposeMKLDNNHandler : public MKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const MKLDNNMemoryFormat& fmt, void* ptr) {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < logical_axis_.size(); ++i) {
logical_axis_[i] = i;
}
auto src_md = fmt != MKLDNNMemoryFormat::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
mem_p = std::make_shared<mkldnn::memory>(src_md, engine_, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output,
platform::Place place) {
auto local_key = key_ + "@user_dst_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
auto dst_data = output->mutable_data<T>(place);
mem_p->set_data_handle(dst_data);
}
return mem_p;
}
std::shared_ptr<mkldnn::reorder> AcquireTranspose(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
auto prim_key = key_ + "@transpose_p";
auto transpose_p =
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
if (transpose_p == nullptr) {
transpose_p =
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, transpose_p);
}
return transpose_p;
}
protected:
mkldnn::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();
std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
mkldnn::memory::desc mem_d(nchw_tz, platform::MKLDNNGetDataType<T>(),
strides);
return mem_d;
}
private:
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
};
class ReorderMKLDNNHandler : public MKLDNNHandler {
public:
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册