未验证 提交 20120d9c 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #14608 from jczaja/prv-conv2d-transpose-mkldnn

[MKL-DNN]conv2d transpose
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "mkldnn.hpp" #include "mkldnn.hpp"
#include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,259 +28,6 @@ using mkldnn::stream; ...@@ -28,259 +28,6 @@ using mkldnn::stream;
using platform::to_void_cast; using platform::to_void_cast;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
class ConvMKLDNNHandler : public platform::MKLDNNHandler {
public:
ConvMKLDNNHandler(
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {
conv_pd_ = conv_pd;
}
ConvMKLDNNHandler(
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd,
std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
conv_bwd_data_pd,
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
conv_bwd_weights_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
conv_pd_(conv_pd),
conv_bwd_weights_pd_(conv_bwd_weights_pd),
conv_bwd_data_pd_(conv_bwd_data_pd) {
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
key_ += "-BWD";
}
size_t GetDstMemorySize() const {
return conv_pd_->dst_primitive_desc().get_size();
}
mkldnn::memory::format GetDstFormat() const {
return static_cast<mkldnn::memory::format>(
conv_pd_->dst_primitive_desc().desc().data.format);
}
size_t GetDiffWeightsMemorySize() const {
return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size();
}
size_t GetDiffSourceMemorySize() const {
return conv_bwd_data_pd_->diff_src_primitive_desc().get_size();
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto src_pd = conv_bwd_weights_pd_->src_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(src_pd, user_pd, user_memory_p,
"@weights-src_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto diff_dst_pd = conv_bwd_weights_pd_->diff_dst_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
"@weights-diff_dst_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemoryFromWeightsPrimitive(
void* ptr) {
return this->AcquireMemoryFromPrimitive(
conv_bwd_weights_pd_->diff_weights_primitive_desc(), ptr,
"@diff_weights_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto diff_dst_pd = conv_bwd_data_pd_->diff_dst_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
"@data-diff_dst_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto weights_pd = conv_bwd_data_pd_->weights_primitive_desc();
auto user_pd = user_weights_memory_p->get_primitive_desc();
return this->AcquireMemory(weights_pd, user_pd, user_weights_memory_p,
"@data-weights_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireResidualDataMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_residual_data_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromResidualDataMemory(
const std::shared_ptr<mkldnn::memory>& user_residual_memory_p,
void* dst_ptr,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
return this->AcquireMemory(user_residual_memory_p,
this->AcquireDstMemoryFromPrimitive(dst_ptr),
"@residual_data_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromDataPrimitive(
void* ptr) {
return this->AcquireMemoryFromPrimitive(
conv_bwd_data_pd_->diff_src_primitive_desc(), ptr, "@diff_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(conv_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto src_pd = conv_pd_->src_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(src_pd, user_pd, user_memory_p, "@src_mem_p",
pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
auto weights_pd = conv_pd_->weights_primitive_desc();
return this->AcquireMemory(weights_pd, user_weights_pd,
user_weights_memory_p, "@weights_mem_p",
pipeline, is_persistent);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_bias_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto user_bias_pd = user_bias_memory_p->get_primitive_desc();
auto bias_pd = conv_pd_->bias_primitive_desc();
return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p,
"@bias_mem_p", pipeline);
}
std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution(
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
auto prim_key = key_ + "@conv_p";
auto conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (conv_p == nullptr) {
conv_p = std::make_shared<mkldnn::convolution_forward>(
*conv_pd_, *(src_memory_p), *(weights_memory_p.get()),
*(dst_memory_p.get()));
dev_ctx_.SetBlob(prim_key, conv_p);
} else {
is_reusing_ = true;
}
return conv_p;
}
std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution(
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> bias_memory_p,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
auto prim_key = key_ + "@conv_p";
auto conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (conv_p == nullptr) {
conv_p = std::make_shared<mkldnn::convolution_forward>(
*conv_pd_, *(src_memory_p), *(weights_memory_p.get()),
*(bias_memory_p.get()), *(dst_memory_p.get()));
dev_ctx_.SetBlob(prim_key, conv_p);
} else {
is_reusing_ = true;
}
return conv_p;
}
std::shared_ptr<mkldnn::convolution_backward_weights>
AcquireConvolutionBackwardWeights(
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_weights_memory_p) {
auto prim_key = key_ + "@conv_bwd_weights_p";
auto conv_bwd_weights_p =
std::static_pointer_cast<mkldnn::convolution_backward_weights>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(conv_bwd_weights_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution bwd weights primitive in device context");
if (conv_bwd_weights_p == nullptr) {
// create backward conv primitive for weights
conv_bwd_weights_p =
std::make_shared<mkldnn::convolution_backward_weights>(
*conv_bwd_weights_pd_, *src_memory_p, *diff_dst_memory_p,
*diff_weights_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p);
} else {
is_reusing_ = true;
}
return conv_bwd_weights_p;
}
std::shared_ptr<mkldnn::convolution_backward_data>
AcquireConvolutionBackwardData(
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
auto prim_key = key_ + "@conv_bwd_data_p";
auto conv_bwd_data_p =
std::static_pointer_cast<mkldnn::convolution_backward_data>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(conv_bwd_data_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution bwd data primitive in device context");
if (conv_bwd_data_p == nullptr) {
conv_bwd_data_p = std::make_shared<mkldnn::convolution_backward_data>(
*conv_bwd_data_pd_, *diff_dst_memory_p, *weights_memory_p,
*diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_data_p);
} else {
is_reusing_ = true;
}
return conv_bwd_data_p;
}
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std::string GetHash(memory::dims& input_dims, // NOLINT
memory::dims& weights_dims, // NOLINT
std::vector<int>& strides, // NOLINT
std::vector<int>& paddings, // NOLINT
std::vector<int>& dilations, // NOLINT
int groups, const std::string& suffix) {
return dims2str(input_dims) + dims2str(weights_dims) + dims2str(strides) +
dims2str(paddings) + dims2str(dilations) + std::to_string(groups) +
suffix;
}
private:
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd_;
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
conv_bwd_weights_pd_;
std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
conv_bwd_data_pd_;
};
template <typename T> template <typename T>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -351,7 +98,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -351,7 +98,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
const std::string key = ConvMKLDNNHandler::GetHash( const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups, src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output")); ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
...@@ -400,7 +147,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -400,7 +147,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
if (!is_test) dev_ctx.SetBlob(key_conv_pd, conv_pd); if (!is_test) dev_ctx.SetBlob(key_conv_pd, conv_pd);
ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key); platform::ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key);
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = auto user_src_memory_p =
...@@ -616,9 +363,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -616,9 +363,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Output" variable // Get an unique name from "argument" name of "Output" variable
// as well as attributes of primitive to be created // as well as attributes of primitive to be created
// This name will be used as key when saving info into device context // This name will be used as key when saving info into device context
const std::string key = const std::string key = platform::ConvMKLDNNHandler::GetHash(
ConvMKLDNNHandler::GetHash(src_tz, weights_tz, strides, paddings, src_tz, weights_tz, strides, paddings, dilations, groups,
dilations, groups, ctx.op().Input("Output")); ctx.op().Input("Output"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
...@@ -673,8 +420,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -673,8 +420,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::make_shared<mkldnn::convolution_backward_data::primitive_desc>( std::make_shared<mkldnn::convolution_backward_data::primitive_desc>(
conv_bwd_data_desc, mkldnn_engine, *conv_pd); conv_bwd_data_desc, mkldnn_engine, *conv_pd);
ConvMKLDNNHandler handler(conv_pd, conv_bwd_data_pd, conv_bwd_weights_pd, platform::ConvMKLDNNHandler handler(conv_pd, conv_bwd_data_pd,
dev_ctx, mkldnn_engine, key); conv_bwd_weights_pd, dev_ctx,
mkldnn_engine, key);
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = auto user_src_memory_p =
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using framework::DataLayout;
template <typename T>
class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE(
is_test == true,
"ConvTransposeMKLDNN works only for inference!. Set is_test = True");
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
input->format() != mkldnn::memory::format::format_undef,
"Wrong layout/format set for Input tensor");
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != mkldnn::memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4,
"Input must be with 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(filter->dims().size() == 4,
"Filter must be with 4 dimensions, i.e. OIHW");
if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
bias->format() != mkldnn::memory::format::format_undef,
"Wrong layout/format set for Bias tensor");
PADDLE_ENFORCE(bias->dims().size() == 1,
"Bias must only have 1 dimension, i.e. X");
}
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> iohw_weights_tz =
paddle::framework::vectorize2int(filter->dims());
std::vector<int> weights_tz = iohw_weights_tz;
// IOHW -> OIHW
weights_tz[0] = iohw_weights_tz[1];
weights_tz[1] = iohw_weights_tz[0];
// Custom Reorder from IOHW to OIHW
auto iohw2oihw_reorder =
[&iohw_weights_tz](const T* filter_data) -> std::shared_ptr<T> {
int o = iohw_weights_tz[1];
int c = iohw_weights_tz[0];
int h = iohw_weights_tz[2];
int w = iohw_weights_tz[3];
std::shared_ptr<T> reordered_filter_data(new T[o * c * h * w](),
std::default_delete<T[]>());
for (int i = 0; i < c; ++i) {
for (int j = 0; j < o; ++j) {
int in_offset = j * h * w + i * o * h * w;
int out_offset = j * c * h * w + i * h * w;
std::memcpy(&(reordered_filter_data.get())[out_offset],
&filter_data[in_offset], h * w * sizeof(T));
}
}
return reordered_filter_data;
};
int g = std::max(groups, 1);
if (g > 1) {
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
}
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// Get unique name for storing MKLDNN primitives
const std::string key = platform::ConvTransposeMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output"));
const std::string key_conv_transpose_pd = key + "@conv_transpose_pd";
std::vector<mkldnn::primitive> pipeline;
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
auto user_weights_md =
platform::MKLDNNMemDesc({weights_tz}, platform::MKLDNNGetDataType<T>(),
(g == 1) ? mkldnn::memory::format::oihw
: mkldnn::memory::format::goihw);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// create a deconv(conv transpose) primitive descriptor and save it for
// usage in backward
std::shared_ptr<mkldnn::deconvolution_forward::primitive_desc>
conv_transpose_pd;
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training;
if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), mkldnn::memory::format::x);
conv_transpose_pd = ConvTransposeFwdPrimitiveDesc(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu, fwd_prop_kind);
} else {
conv_transpose_pd = ConvTransposeFwdPrimitiveDesc(
src_md, weights_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu, fwd_prop_kind);
}
// Save conv_pd/src_memory/weights_memory for backward pass
if (!is_test) dev_ctx.SetBlob(key_conv_transpose_pd, conv_transpose_pd);
platform::ConvTransposeMKLDNNHandler handler(conv_transpose_pd, dev_ctx,
mkldnn_engine, key);
// create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = handler.AcquireSrcMemory(
user_src_md, platform::to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, platform::to_void_cast<T>(filter_data),
is_test ? iohw2oihw_reorder : platform::user_function());
// create reorder primitive if the input format is not the preferred one
auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test);
std::shared_ptr<mkldnn::memory> dst_memory_p;
auto output_data = output->mutable_data<T>(
ctx.GetPlace(), paddle::memory::Allocator::kDefault,
handler.GetDstMemorySize());
dst_memory_p = handler.AcquireDstMemoryFromPrimitive(
platform::to_void_cast<T>(output_data));
// create convolution op primitive
std::shared_ptr<mkldnn::deconvolution_forward> conv_p;
if (bias) {
const T* bias_data = bias->data<T>();
auto user_bias_md =
platform::MKLDNNMemDesc({bias_tz}, platform::MKLDNNGetDataType<T>(),
mkldnn::memory::format::x);
auto user_bias_memory_p = handler.AcquireBiasMemory(
user_bias_md, platform::to_void_cast<T>(bias_data));
auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p);
} else {
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
dst_memory_p);
}
// push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
}
private:
mkldnn::primitive_attr CreatePostOps(bool fuse_relu) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::unique_ptr<mkldnn::deconvolution_forward::primitive_desc>
ConvTransposeFwdPrimitiveDesc(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& weights,
const mkldnn::memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const mkldnn::engine& engine,
const bool fuse_relu, mkldnn::prop_kind fwd_prop_kind) const {
mkldnn::memory::dims stride_dims = {strides[0], strides[1]};
mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]};
auto deconv_desc = mkldnn::deconvolution_forward::desc(
fwd_prop_kind, mkldnn::deconvolution_direct, src, weights, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr deconv_attr = CreatePostOps(fuse_relu);
auto p_conv_transpose_pd =
new mkldnn::deconvolution_forward::primitive_desc(deconv_desc,
deconv_attr, engine);
return std::unique_ptr<mkldnn::deconvolution_forward::primitive_desc>(
p_conv_transpose_pd);
}
std::unique_ptr<mkldnn::deconvolution_forward::primitive_desc>
ConvTransposeFwdPrimitiveDesc(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& weights,
const mkldnn::memory::desc& bias, const mkldnn::memory::desc& dst,
const std::vector<int>& strides, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
mkldnn::prop_kind fwd_prop_kind) const {
mkldnn::memory::dims stride_dims = {strides[0], strides[1]};
mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]};
auto deconv_desc = mkldnn::deconvolution_forward::desc(
fwd_prop_kind, mkldnn::deconvolution_direct, src, weights, bias, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr deconv_attr = CreatePostOps(fuse_relu);
auto p_conv_transpose_pd =
new mkldnn::deconvolution_forward::primitive_desc(deconv_desc,
deconv_attr, engine);
return std::unique_ptr<mkldnn::deconvolution_forward::primitive_desc>(
p_conv_transpose_pd);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(conv2d_transpose, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvTransposeMKLDNNOpKernel<float>);
...@@ -16,6 +16,10 @@ limitations under the License. */ ...@@ -16,6 +16,10 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -78,29 +82,38 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -78,29 +82,38 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr; use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
framework::LibraryType library_;
if (use_cudnn) { if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} else {
library_ = framework::LibraryType::kPlain;
} }
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_); layout_, library_);
} }
void Conv2DTransposeOpMaker::Make() { void Conv2DTransposeOpMaker::Make() {
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddInput( AddInput(
"Input", "Input",
"(Tensor) The input tensor of convolution transpose operator. " "(Tensor) The input tensor of convolution transpose operator. "
...@@ -145,6 +158,11 @@ void Conv2DTransposeOpMaker::Make() { ...@@ -145,6 +158,11 @@ void Conv2DTransposeOpMaker::Make() {
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
...@@ -238,6 +256,9 @@ void Conv3DTransposeOpMaker::Make() { ...@@ -238,6 +256,9 @@ void Conv3DTransposeOpMaker::Make() {
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <iostream> #include <iostream>
#include "mkldnn.hpp" #include "mkldnn.hpp"
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -107,170 +107,6 @@ inline mkldnn::memory::format GetMKLDNNFormat( ...@@ -107,170 +107,6 @@ inline mkldnn::memory::format GetMKLDNNFormat(
memory.dst_primitive_desc().desc().data.format); memory.dst_primitive_desc().desc().data.format);
} }
class MKLDNNHandler {
public:
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: dev_ctx_(dev_ctx),
engine_(engine),
key_(base_key),
is_reusing_(false) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_weights_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_bias_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_diff_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_diff_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::primitive_desc mdp, void* ptr,
const std::string& suffix) {
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mdp, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(const mkldnn::memory::desc& md,
void* ptr,
const std::string& suffix) {
/*Generate key*/
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(
const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p,
const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p";
auto stored_reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (stored_reorder_p) {
pipeline.push_back(*stored_reorder_p);
} else {
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
pipeline.push_back(*reorder_p);
}
return target_memory_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(
mkldnn::memory::primitive_desc& mpd, // NOLINT
mkldnn::memory::primitive_desc& user_mpd, // NOLINT
const std::shared_ptr<mkldnn::memory> user_memory_p,
const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
// create reorder primitive if the input format is not the preferred one
auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p";
auto target_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((target_memory_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (target_memory_p == nullptr) {
target_memory_p = user_memory_p;
std::shared_ptr<mkldnn::primitive> reorder_p;
if (mpd != user_mpd) {
target_memory_p = std::make_shared<mkldnn::memory>(mpd);
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
pipeline.push_back(*reorder_p);
}
dev_ctx_.SetBlob(local_key, target_memory_p);
} else if (!is_persistent) {
// Make reorder if needed
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) {
pipeline.push_back(*reorder_p);
}
is_reusing_ = true;
}
return target_memory_p;
}
static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT
const std::string& suffix) {
return dims2str(operand_dims) + suffix;
}
protected:
static std::string dims2str(const mkldnn::memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
}
protected:
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
std::string key_;
bool is_reusing_;
};
inline mkldnn::memory::format MKLDNNFormatForSize( inline mkldnn::memory::format MKLDNNFormatForSize(
size_t dims_size, mkldnn::memory::format data_format) { size_t dims_size, mkldnn::memory::format data_format) {
if (dims_size == 1) { if (dims_size == 1) {
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
using user_function = std::function<std::shared_ptr<float>(const float*)>;
class MKLDNNHandler {
public:
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: dev_ctx_(dev_ctx),
engine_(engine),
key_(base_key),
is_reusing_(false) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::desc& md, void* ptr,
user_function custom_func = {}) {
return this->AcquireMemory(md, ptr, "@user_weights_mem_p", custom_func);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_bias_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_diff_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_diff_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::primitive_desc mdp, void* ptr,
const std::string& suffix) {
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mdp, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
// This incarnation of AcquireMemory can call user function eg. custom reorder
// or preprocessing routine if needed
std::shared_ptr<mkldnn::memory> AcquireMemory(
const mkldnn::memory::desc& md, void* ptr, const std::string& suffix,
user_function custom_func = {}) {
/*Generate key*/
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (mem_p == nullptr) {
// Call custom reorder/preprocessing func if available
if (custom_func) {
auto reordered_data = custom_func(reinterpret_cast<const float*>(ptr));
dev_ctx_.SetBlob(local_key + "-custom_reorder", reordered_data);
ptr = reinterpret_cast<void*>(reordered_data.get());
}
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(
const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p,
const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p";
auto stored_reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (stored_reorder_p) {
pipeline.push_back(*stored_reorder_p);
} else {
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
pipeline.push_back(*reorder_p);
}
return target_memory_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(
mkldnn::memory::primitive_desc& mpd, // NOLINT
mkldnn::memory::primitive_desc& user_mpd, // NOLINT
const std::shared_ptr<mkldnn::memory> user_memory_p,
const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
// create reorder primitive if the input format is not the preferred one
auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p";
auto target_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((target_memory_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (target_memory_p == nullptr) {
target_memory_p = user_memory_p;
std::shared_ptr<mkldnn::primitive> reorder_p;
if (mpd != user_mpd) {
target_memory_p = std::make_shared<mkldnn::memory>(mpd);
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
pipeline.push_back(*reorder_p);
}
dev_ctx_.SetBlob(local_key, target_memory_p);
} else if (!is_persistent) {
// Make reorder if needed
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) {
pipeline.push_back(*reorder_p);
}
is_reusing_ = true;
}
return target_memory_p;
}
static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT
const std::string& suffix) {
return dims2str(operand_dims) + suffix;
}
protected:
static std::string dims2str(const mkldnn::memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
}
protected:
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
std::string key_;
bool is_reusing_;
};
template <class forward_t, class backward_data_t, class backward_weights_t>
class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
public:
ConvMKLDNNTemplateHandler(
std::shared_ptr<typename forward_t::primitive_desc> conv_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {
conv_pd_ = conv_pd;
}
ConvMKLDNNTemplateHandler(
std::shared_ptr<typename forward_t::primitive_desc> conv_pd,
std::shared_ptr<typename backward_data_t::primitive_desc>
conv_bwd_data_pd,
std::shared_ptr<typename backward_weights_t::primitive_desc>
conv_bwd_weights_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
conv_pd_(conv_pd),
conv_bwd_weights_pd_(conv_bwd_weights_pd),
conv_bwd_data_pd_(conv_bwd_data_pd) {
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
key_ += "-BWD";
}
size_t GetDstMemorySize() const {
return conv_pd_->dst_primitive_desc().get_size();
}
mkldnn::memory::format GetDstFormat() const {
return static_cast<mkldnn::memory::format>(
conv_pd_->dst_primitive_desc().desc().data.format);
}
size_t GetDiffWeightsMemorySize() const {
return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size();
}
size_t GetDiffSourceMemorySize() const {
return conv_bwd_data_pd_->diff_src_primitive_desc().get_size();
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto src_pd = conv_bwd_weights_pd_->src_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(src_pd, user_pd, user_memory_p,
"@weights-src_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto diff_dst_pd = conv_bwd_weights_pd_->diff_dst_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
"@weights-diff_dst_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemoryFromWeightsPrimitive(
void* ptr) {
return this->AcquireMemoryFromPrimitive(
conv_bwd_weights_pd_->diff_weights_primitive_desc(), ptr,
"@diff_weights_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto diff_dst_pd = conv_bwd_data_pd_->diff_dst_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
"@data-diff_dst_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto weights_pd = conv_bwd_data_pd_->weights_primitive_desc();
auto user_pd = user_weights_memory_p->get_primitive_desc();
return this->AcquireMemory(weights_pd, user_pd, user_weights_memory_p,
"@data-weights_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireResidualDataMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_residual_data_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromResidualDataMemory(
const std::shared_ptr<mkldnn::memory>& user_residual_memory_p,
void* dst_ptr,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
return this->AcquireMemory(user_residual_memory_p,
this->AcquireDstMemoryFromPrimitive(dst_ptr),
"@residual_data_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromDataPrimitive(
void* ptr) {
return this->AcquireMemoryFromPrimitive(
conv_bwd_data_pd_->diff_src_primitive_desc(), ptr, "@diff_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(conv_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto src_pd = conv_pd_->src_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(src_pd, user_pd, user_memory_p, "@src_mem_p",
pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
auto weights_pd = conv_pd_->weights_primitive_desc();
return this->AcquireMemory(weights_pd, user_weights_pd,
user_weights_memory_p, "@weights_mem_p",
pipeline, is_persistent);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_bias_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto user_bias_pd = user_bias_memory_p->get_primitive_desc();
auto bias_pd = conv_pd_->bias_primitive_desc();
return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p,
"@bias_mem_p", pipeline);
}
std::shared_ptr<forward_t> AcquireConvolution(
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
auto prim_key = key_ + "@conv_p";
auto conv_p =
std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (conv_p == nullptr) {
conv_p = std::make_shared<forward_t>(*conv_pd_, *(src_memory_p),
*(weights_memory_p.get()),
*(dst_memory_p.get()));
dev_ctx_.SetBlob(prim_key, conv_p);
} else {
is_reusing_ = true;
}
return conv_p;
}
std::shared_ptr<forward_t> AcquireConvolution(
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> bias_memory_p,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
auto prim_key = key_ + "@conv_p";
auto conv_p =
std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (conv_p == nullptr) {
conv_p = std::make_shared<forward_t>(
*conv_pd_, *(src_memory_p), *(weights_memory_p.get()),
*(bias_memory_p.get()), *(dst_memory_p.get()));
dev_ctx_.SetBlob(prim_key, conv_p);
} else {
is_reusing_ = true;
}
return conv_p;
}
std::shared_ptr<backward_weights_t> AcquireConvolutionBackwardWeights(
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_weights_memory_p) {
auto prim_key = key_ + "@conv_bwd_weights_p";
auto conv_bwd_weights_p = std::static_pointer_cast<backward_weights_t>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(conv_bwd_weights_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution bwd weights primitive in device context");
if (conv_bwd_weights_p == nullptr) {
// create backward conv primitive for weights
conv_bwd_weights_p = std::make_shared<backward_weights_t>(
*conv_bwd_weights_pd_, *src_memory_p, *diff_dst_memory_p,
*diff_weights_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p);
} else {
is_reusing_ = true;
}
return conv_bwd_weights_p;
}
std::shared_ptr<backward_data_t> AcquireConvolutionBackwardData(
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
auto prim_key = key_ + "@conv_bwd_data_p";
auto conv_bwd_data_p =
std::static_pointer_cast<backward_data_t>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(conv_bwd_data_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution bwd data primitive in device context");
if (conv_bwd_data_p == nullptr) {
conv_bwd_data_p = std::make_shared<backward_data_t>(
*conv_bwd_data_pd_, *diff_dst_memory_p, *weights_memory_p,
*diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_data_p);
} else {
is_reusing_ = true;
}
return conv_bwd_data_p;
}
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT
mkldnn::memory::dims& weights_dims, // NOLINT
std::vector<int>& strides, // NOLINT
std::vector<int>& paddings, // NOLINT
std::vector<int>& dilations, // NOLINT
int groups, const std::string& suffix) {
return dims2str(input_dims) + dims2str(weights_dims) + dims2str(strides) +
dims2str(paddings) + dims2str(dilations) + std::to_string(groups) +
suffix;
}
private:
std::shared_ptr<typename forward_t::primitive_desc> conv_pd_;
std::shared_ptr<typename backward_weights_t::primitive_desc>
conv_bwd_weights_pd_;
std::shared_ptr<typename backward_data_t::primitive_desc> conv_bwd_data_pd_;
};
using ConvMKLDNNHandler =
ConvMKLDNNTemplateHandler<mkldnn::convolution_forward,
mkldnn::convolution_backward_data,
mkldnn::convolution_backward_weights>;
using ConvTransposeMKLDNNHandler =
ConvMKLDNNTemplateHandler<mkldnn::deconvolution_forward,
mkldnn::deconvolution_backward_data,
mkldnn::deconvolution_backward_weights>;
} // namespace platform
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_conv2d_transpose_op import TestConv2dTransposeOp, TestWithPad, TestWithStride
class TestMKLDNN(TestConv2dTransposeOp):
def init_op_type(self):
self.is_test = True
self.use_mkldnn = True
self.data_format = "NCHW"
self.op_type = "conv2d_transpose"
self._cpu_only = True
def test_check_grad(self):
return
def test_check_grad_no_input(self):
return
def test_check_grad_no_filter(self):
return
class TestMKLDNNWithPad(TestWithPad):
def init_op_type(self):
self.is_test = True
self.use_mkldnn = True
self.data_format = "NCHW"
self.op_type = "conv2d_transpose"
self._cpu_only = True
def test_check_grad(self):
return
def test_check_grad_no_input(self):
return
def test_check_grad_no_filter(self):
return
class TestMKLDNNWithStride(TestWithStride):
def init_op_type(self):
self.is_test = True
self.use_mkldnn = True
self.data_format = "NCHW"
self.op_type = "conv2d_transpose"
self._cpu_only = True
def test_check_grad(self):
return
def test_check_grad_no_input(self):
return
def test_check_grad_no_filter(self):
return
if __name__ == '__main__':
unittest.main()
...@@ -68,8 +68,11 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs): ...@@ -68,8 +68,11 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
class TestConv2dTransposeOp(OpTest): class TestConv2dTransposeOp(OpTest):
def setUp(self): def setUp(self):
# init as conv transpose # init as conv transpose
self.is_test = False
self.use_cudnn = False self.use_cudnn = False
self.use_mkldnn = False
self.output_size = None self.output_size = None
self.data_format = "AnyLayout"
self.init_op_type() self.init_op_type()
self.init_test_case() self.init_test_case()
...@@ -83,7 +86,9 @@ class TestConv2dTransposeOp(OpTest): ...@@ -83,7 +86,9 @@ class TestConv2dTransposeOp(OpTest):
'groups': self.groups, 'groups': self.groups,
'dilations': self.dilations, 'dilations': self.dilations,
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter 'is_test': self.is_test,
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
} }
if self.output_size is not None: if self.output_size is not None:
self.attrs['output_size'] = self.output_size self.attrs['output_size'] = self.output_size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册