未验证 提交 db6c00c4 编写于 作者: J Jacek Czaja 提交者: GitHub

Disable pool&conv_transpose&quantize caching (#36695)

* - WIP

- compilation fix

- fix

- fixes

- fix

- fix

- fix again

- fix

- another fix

- another compilation fix

- fix

- fix

- fix

- lint

* - pool2d partially stripped from cache

- pool2d partially stripped of caching

* - compilation fix

* - compilation fix

* - Fix to UT of caching

* - Enabling test_conv3d_mkldnn

* - conv_transpose stripped of cache

* - compilation fix

* - fix

* - fix

* - compilation fix

* - fix

* Reverted disabling caching of conv2d

* - compilation fix

* - ut reverted
上级 9a53477c
...@@ -21,7 +21,6 @@ namespace operators { ...@@ -21,7 +21,6 @@ namespace operators {
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::framework::Tensor; using paddle::framework::Tensor;
using paddle::platform::CPUDeviceContext; using paddle::platform::CPUDeviceContext;
using paddle::platform::CreateKey;
using paddle::platform::MKLDNNGetDataType; using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc; using paddle::platform::MKLDNNMemDesc;
using platform::to_void_cast; using platform::to_void_cast;
......
...@@ -21,7 +21,6 @@ namespace operators { ...@@ -21,7 +21,6 @@ namespace operators {
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::framework::Tensor; using paddle::framework::Tensor;
using paddle::platform::CPUDeviceContext; using paddle::platform::CPUDeviceContext;
using paddle::platform::CreateKey;
using paddle::platform::MKLDNNGetDataType; using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc; using paddle::platform::MKLDNNMemDesc;
using platform::to_void_cast; using platform::to_void_cast;
......
...@@ -565,7 +565,7 @@ class ConvMKLDNNHandlerT ...@@ -565,7 +565,7 @@ class ConvMKLDNNHandlerT
const auto target_mem_p = this->AcquireMemory(target_key_suffix); const auto target_mem_p = this->AcquireMemory(target_key_suffix);
user_mem_p->set_data_handle(platform::to_void_cast<T>(in_mem_data)); user_mem_p->set_data_handle(platform::to_void_cast<T>(in_mem_data));
if (user_mem_p != target_mem_p) { if (user_mem_p != target_mem_p) {
this->AcquireReorder(user_mem_p, target_mem_p, key_mem); this->AcquireReorder(user_mem_p, target_mem_p);
} }
return target_mem_p; return target_mem_p;
} }
...@@ -643,7 +643,7 @@ class ConvMKLDNNHandlerT ...@@ -643,7 +643,7 @@ class ConvMKLDNNHandlerT
platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) { platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) {
auto residual_memory_p = this->AcquireResidualMemory(residual_param); auto residual_memory_p = this->AcquireResidualMemory(residual_param);
dst_memory_p = this->template AcquireDstMemory<T_out>(output); dst_memory_p = this->template AcquireDstMemory<T_out>(output);
this->AcquireReorder(residual_memory_p, dst_memory_p, "@residual_dst"); this->AcquireReorder(residual_memory_p, dst_memory_p);
} else { } else {
// Changing ShareDataWith to TensorCopy results in performance drop // Changing ShareDataWith to TensorCopy results in performance drop
// on ResNet architectures // on ResNet architectures
......
...@@ -40,151 +40,144 @@ inline mkldnn::memory::dims GetWeightsTz(const Tensor* filter, ...@@ -40,151 +40,144 @@ inline mkldnn::memory::dims GetWeightsTz(const Tensor* filter,
template <typename T, typename K, typename T_out> template <typename T, typename K, typename T_out>
class ConvTransposeMKLDNNHandlerT class ConvTransposeMKLDNNHandlerT
: public platform::MKLDNNHandlerT<T, mkldnn::deconvolution_forward> { : public platform::MKLDNNHandlerNoCachingT<T,
mkldnn::deconvolution_forward> {
public: public:
ConvTransposeMKLDNNHandlerT(const framework::ExecutionContext& ctx, ConvTransposeMKLDNNHandlerT(const framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine, const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input, const Tensor* input, const Tensor* filter,
const Tensor* filter, const Tensor* bias, const Tensor* bias, Tensor* output)
Tensor* output, const std::string& unique_name) : platform::MKLDNNHandlerNoCachingT<T, mkldnn::deconvolution_forward>(
: platform::MKLDNNHandlerT<T, mkldnn::deconvolution_forward>( mkldnn_engine, ctx.GetPlace()),
dev_ctx, mkldnn_engine, cpu_place, is_test_(ctx.Attr<bool>("is_test")) {
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), PADDLE_ENFORCE_EQ(is_test_, true,
unique_name)) { platform::errors::InvalidArgument(
if (!this->isCached()) { "ConvTransposeMKLDNN works only for inference. "
const bool is_test = ctx.Attr<bool>("is_test"); "The attribute \'is_test\' value should be set to "
PADDLE_ENFORCE_EQ(is_test, true, "True, but got is_test=False."));
platform::errors::InvalidArgument(
"ConvTransposeMKLDNN works only for inference. " PADDLE_ENFORCE_EQ(
"The attribute \'is_test\' value should be set to " input->layout(), DataLayout::kMKLDNN,
"True, but got is_test=False.")); platform::errors::InvalidArgument(
"Got wrong layout = %d for Input tensor.", input->layout()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
input->layout(), DataLayout::kMKLDNN, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Got wrong format for Input tensor. The input "
"Got wrong layout = %d for Input tensor.", input->layout())); "format is undefined."));
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(
"Got wrong format for Input tensor. The input " filter->layout(), DataLayout::kMKLDNN,
"format is undefined.")); platform::errors::InvalidArgument(
"The filter tensor's laytout should be %d, but got %d.",
DataLayout::kMKLDNN, filter->layout()));
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong formats for Filter tensor."));
PADDLE_ENFORCE_EQ(
input->dims().size(), 4,
platform::errors::InvalidArgument("Input must be with 4 dimensions, "
"i.e. NCHW. but got dimension =%d",
input->dims().size()));
PADDLE_ENFORCE_EQ(
filter->dims().size(), 4,
platform::errors::InvalidArgument("Filter must be with 4 dimensions, "
"i.e. OIHW, but got dimension =%d",
filter->dims().size()));
if (bias) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN, bias->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The filter tensor's laytout should be %d, but got %d.", "The bias tensor's laytout should be %d, but got %d.",
DataLayout::kMKLDNN, filter->layout())); DataLayout::kMKLDNN, bias->layout()));
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef, PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Got wrong formats for Filter tensor.")); "Got wrong format for Bias tensor."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->dims().size(), 4, bias->dims().size(), 1,
platform::errors::InvalidArgument("Input must be with 4 dimensions, " platform::errors::InvalidArgument("Bias must only have 1 dimension, "
"i.e. NCHW. but got dimension =%d", "i.e. X, but got dimension = %d .",
input->dims().size())); bias->dims().size()));
PADDLE_ENFORCE_EQ( }
filter->dims().size(), 4,
platform::errors::InvalidArgument("Filter must be with 4 dimensions, "
"i.e. OIHW, but got dimension =%d",
filter->dims().size()));
if (bias) {
PADDLE_ENFORCE_EQ(
bias->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The bias tensor's laytout should be %d, but got %d.",
DataLayout::kMKLDNN, bias->layout()));
PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Bias tensor."));
PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
platform::errors::InvalidArgument(
"Bias must only have 1 dimension, "
"i.e. X, but got dimension = %d .",
bias->dims().size()));
}
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
mkldnn::memory::dims strides(begin(strides_temp), end(strides_temp));
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
mkldnn::memory::dims paddings(begin(paddings_temp), end(paddings_temp));
std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
mkldnn::memory::dims dilations(begin(dilations_temp),
end(dilations_temp));
int groups = ctx.Attr<int>("groups");
std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
PADDLE_ENFORCE_EQ( std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
strides.size(), 2, mkldnn::memory::dims strides(begin(strides_temp), end(strides_temp));
platform::errors::Unimplemented(
"Now we only support 2d oneDNN convolution transpose op")); std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
mkldnn::memory::dims paddings(begin(paddings_temp), end(paddings_temp));
const auto& input_dims = input->dims();
const auto data_dims = std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
framework::slice_ddim(input_dims, 2, input_dims.size()); mkldnn::memory::dims dilations(begin(dilations_temp), end(dilations_temp));
const auto& filter_dims = filter->dims();
const auto filter_data_dims = int groups = ctx.Attr<int>("groups");
framework::slice_ddim(filter_dims, 2, filter_dims.size()); std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
const auto ksize = framework::vectorize(filter_data_dims); PADDLE_ENFORCE_EQ(
strides.size(), 2,
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, platform::errors::Unimplemented(
data_dims, strides, ksize); "Now we only support 2d oneDNN convolution transpose op"));
std::transform(dilations.begin(), dilations.end(), dilations.begin(), const auto& input_dims = input->dims();
[](int64_t i) { return i - 1; }); const auto data_dims =
framework::slice_ddim(input_dims, 2, input_dims.size());
const auto src_tz = framework::vectorize(input->dims()); const auto& filter_dims = filter->dims();
const auto weights_tz = GetWeightsTz(filter, groups); const auto filter_data_dims =
const auto dst_tz = framework::vectorize(output->dims()); framework::slice_ddim(filter_dims, 2, filter_dims.size());
const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
const auto ksize = framework::vectorize(filter_data_dims);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
* the memory format preferred for best performance data_dims, strides, ksize);
*/
const auto chosen_memory_format = MKLDNNMemoryFormat::any; std::transform(dilations.begin(), dilations.end(), dilations.begin(),
const std::string fuse_activation = [](int64_t i) { return i - 1; });
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha"); const auto src_tz = framework::vectorize(input->dims());
const float fuse_beta = ctx.Attr<float>("fuse_beta"); const auto weights_tz = GetWeightsTz(filter, groups);
const auto dst_tz = framework::vectorize(output->dims());
auto data_type = mkldnn::memory::data_type::f32; const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
std::is_same<T_out, platform::bfloat16>::value) /* create memory descriptor for convolution without specified format
data_type = mkldnn::memory::data_type::bf16; * ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
const auto src_md = */
platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format); const auto chosen_memory_format = MKLDNNMemoryFormat::any;
const auto weights_md = const std::string fuse_activation =
platform::MKLDNNMemDesc(weights_tz, data_type, chosen_memory_format); ctx.Attr<std::string>("fuse_activation");
const auto dst_md = platform::MKLDNNMemDesc( const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format); const float fuse_beta = ctx.Attr<float>("fuse_beta");
const mkldnn::primitive_attr conv_trans_attr = auto data_type = mkldnn::memory::data_type::f32;
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta); if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference std::is_same<T_out, platform::bfloat16>::value)
: mkldnn::prop_kind::forward_training; data_type = mkldnn::memory::data_type::bf16;
if (bias) {
std::vector<int64_t> bias_tz = framework::vectorize(bias->dims()); const auto src_md =
const auto bias_md = platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x); const auto weights_md =
this->AcquireForwardPrimitiveDescriptor( platform::MKLDNNMemDesc(weights_tz, data_type, chosen_memory_format);
conv_trans_attr, fwd_prop_kind, const auto dst_md = platform::MKLDNNMemDesc(
dnnl::algorithm::deconvolution_direct, src_md, weights_md, bias_md, dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
dst_md, strides, dilations, mkldnn_paddings[0], mkldnn_paddings[1]);
} else { const mkldnn::primitive_attr conv_trans_attr =
this->AcquireForwardPrimitiveDescriptor( CreatePostOps(fuse_activation, fuse_alpha, fuse_beta);
conv_trans_attr, fwd_prop_kind, auto fwd_prop_kind = is_test_ ? mkldnn::prop_kind::forward_inference
dnnl::algorithm::deconvolution_direct, src_md, weights_md, dst_md, : mkldnn::prop_kind::forward_training;
strides, dilations, mkldnn_paddings[0], mkldnn_paddings[1]); if (bias) {
} std::vector<int64_t> bias_tz = framework::vectorize(bias->dims());
const auto bias_md =
platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptor(
conv_trans_attr, fwd_prop_kind, dnnl::algorithm::deconvolution_direct,
src_md, weights_md, bias_md, dst_md, strides, dilations,
mkldnn_paddings[0], mkldnn_paddings[1]);
} else {
this->AcquireForwardPrimitiveDescriptor(
conv_trans_attr, fwd_prop_kind, dnnl::algorithm::deconvolution_direct,
src_md, weights_md, dst_md, strides, dilations, mkldnn_paddings[0],
mkldnn_paddings[1]);
} }
} }
...@@ -217,86 +210,140 @@ class ConvTransposeMKLDNNHandlerT ...@@ -217,86 +210,140 @@ class ConvTransposeMKLDNNHandlerT
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const std::string user_key_suffix{"@src_mem_p_user"}; auto user_src_md = platform::MKLDNNMemDesc(
auto user_src_mem_p = this->AcquireMemory(user_key_suffix); framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(),
if (!user_src_mem_p) { input->format());
auto user_src_md = platform::MKLDNNMemDesc( return platform::MKLDNNHandlerNoCachingT<T, mkldnn::deconvolution_forward>::
framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(), AcquireMemoryWithReorder(user_src_md, this->fwd_pd_->src_desc(),
input->format()); platform::to_void_cast<T>(input_data));
return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->src_desc(),
platform::to_void_cast<T>(input_data), "@src_mem_p");
} else {
const std::string target_key_suffix{"@src_mem_p_target"};
const auto target_src_mem_p = this->AcquireMemory(target_key_suffix);
user_src_mem_p->set_data_handle(platform::to_void_cast<T>(input_data));
if (user_src_mem_p != target_src_mem_p) {
this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p");
}
return target_src_mem_p;
}
} }
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
const framework::Tensor* filter, const int& groups, const bool& is_test) { const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key,
// This is workaround to make execution faster, delete const framework::Tensor* filter, const int& groups) {
// if statement after including md inside Tensor const K* filter_data = filter->data<K>();
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); auto weights_tz = GetWeightsTz(filter, groups);
if (is_test && weights_mem_p) { int g = std::max(groups, 1);
return weights_mem_p;
} else { auto user_src_md = platform::MKLDNNMemDesc(
const K* filter_data = filter->data<K>(); weights_tz, platform::MKLDNNGetDataType<K>(),
auto weights_tz = GetWeightsTz(filter, groups); (g == 1) ? filter->format() : MKLDNNMemoryFormat::goihw);
int g = std::max(groups, 1);
auto iohw_weights_tz = framework::vectorize(filter->dims());
auto user_src_md = platform::MKLDNNMemDesc( // Custom Reorder from IOHW to OIHW
weights_tz, platform::MKLDNNGetDataType<K>(), auto iohw2oihw_reorder =
(g == 1) ? filter->format() : MKLDNNMemoryFormat::goihw); [&iohw_weights_tz](const K* filter_data) -> std::shared_ptr<K> {
int o = iohw_weights_tz[1];
auto iohw_weights_tz = framework::vectorize(filter->dims()); int c = iohw_weights_tz[0];
// Custom Reorder from IOHW to OIHW int h = iohw_weights_tz[2];
auto iohw2oihw_reorder = int w = iohw_weights_tz[3];
[&iohw_weights_tz](const K* filter_data) -> std::shared_ptr<K> { std::shared_ptr<K> reordered_filter_data(new K[o * c * h * w](),
int o = iohw_weights_tz[1]; std::default_delete<K[]>());
int c = iohw_weights_tz[0]; for (int i = 0; i < c; ++i) {
int h = iohw_weights_tz[2]; for (int j = 0; j < o; ++j) {
int w = iohw_weights_tz[3]; int in_offset = j * h * w + i * o * h * w;
std::shared_ptr<K> reordered_filter_data(new K[o * c * h * w](), int out_offset = j * c * h * w + i * h * w;
std::default_delete<K[]>()); std::memcpy(&(reordered_filter_data.get())[out_offset],
for (int i = 0; i < c; ++i) { &filter_data[in_offset], h * w * sizeof(K));
for (int j = 0; j < o; ++j) {
int in_offset = j * h * w + i * o * h * w;
int out_offset = j * c * h * w + i * h * w;
std::memcpy(&(reordered_filter_data.get())[out_offset],
&filter_data[in_offset], h * w * sizeof(K));
}
} }
}
return reordered_filter_data;
};
return reordered_filter_data; return this->template AcquireMemoryWithReorder<K>(
}; dev_ctx, user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data), key, "@weights_mem_p", is_test_,
iohw2oihw_reorder);
}
return this->template AcquireMemoryWithReorder<K>( template <typename F = T>
user_src_md, this->fwd_pd_->weights_desc(), std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test, const platform::MKLDNNDeviceContext& dev_ctx,
iohw2oihw_reorder); const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr, const std::string& key,
const std::string& suffix, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {},
const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
const auto target_key = key + suffix + "_target";
const auto key_reorder_p = key + suffix + "reorder_p";
const auto user_key = key + suffix + "_user";
auto target_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(target_key));
if (target_memory_p == nullptr) {
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
dev_ctx.SetBlob(key_reorder_p + "-custom_reorder", reordered_data);
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
if (user_md != target_md) {
target_memory_p =
std::make_shared<mkldnn::memory>(target_md, this->engine_);
dnnl::reorder::primitive_desc reorder_pdesc;
if (platform::is_int8<T>()) {
dnnl::primitive_attr attr;
attr.set_output_scales(mask, scale_data);
reorder_pdesc = dnnl::reorder::primitive_desc(*user_memory_p,
*target_memory_p, attr);
} else {
reorder_pdesc =
dnnl::reorder::primitive_desc(*user_memory_p, *target_memory_p);
}
auto reorder_p = std::make_shared<dnnl::reorder>(reorder_pdesc);
dev_ctx.SetBlob(key_reorder_p, reorder_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
} else {
target_memory_p = user_memory_p;
}
dev_ctx.SetBlob(user_key, user_memory_p);
dev_ctx.SetBlob(target_key, target_memory_p);
} else if (!is_persistent) {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto user_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(user_key));
user_memory_p->set_data_handle(ptr);
// TODO(jczaja): Here we detect if reorder is cached it means it is needed
// need to change this to get rid of keys
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx.GetBlob(key_reorder_p));
if (reorder_p != nullptr) {
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
} }
return target_memory_p;
} }
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool& is_test) { const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key,
auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target"); const framework::Tensor* bias) {
if (is_test && bias_mem_p) { const K* bias_data = bias->data<K>();
return bias_mem_p; auto user_bias_md = platform::MKLDNNMemDesc(
} else { framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(),
const K* bias_data = bias->data<K>(); MKLDNNMemoryFormat::x);
auto user_bias_md = platform::MKLDNNMemDesc( return this->AcquireMemoryWithReorder(
framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(), dev_ctx, user_bias_md, this->fwd_pd_->bias_desc(),
MKLDNNMemoryFormat::x); platform::to_void_cast<K>(bias_data), key, "@bias_mem_p", is_test_);
return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(),
platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test);
}
} }
private:
const bool is_test_;
}; };
template <typename T, typename K> template <typename T, typename K>
...@@ -325,22 +372,21 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -325,22 +372,21 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
const bool is_test = ctx.Attr<bool>("is_test");
const auto* input = ctx.Input<Tensor>("Input"); const auto* input = ctx.Input<Tensor>("Input");
const auto* filter = ctx.Input<Tensor>("Filter"); const auto* filter = ctx.Input<Tensor>("Filter");
const auto* bias = const auto* bias =
ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr; ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
const std::string unique_name = ctx.InputName("Input") + ConvTransposeMKLDNNHandlerT<T, K, T_out> handler(ctx, mkldnn_engine, input,
ctx.InputName("Filter") + filter, bias, output);
(bias ? ctx.InputName("Bias") : "");
ConvTransposeMKLDNNHandlerT<T, K, T_out> handler(
ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, filter, bias,
output, unique_name);
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
// Caching Key for weights is needed
std::string key = platform::CreateKey(dev_ctx, ctx.InputName("Input"),
ctx.InputName("Filter"),
(bias ? ctx.InputName("Bias") : ""));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, ctx.Attr<int>("groups"), is_test); dev_ctx, key, filter, ctx.Attr<int>("groups"));
std::shared_ptr<dnnl::memory> dst_memory_p = std::shared_ptr<dnnl::memory> dst_memory_p =
handler.template AcquireDstMemory<T_out>(output); handler.template AcquireDstMemory<T_out>(output);
...@@ -352,7 +398,8 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -352,7 +398,8 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
{MKLDNN_ARG_DST, *dst_memory_p}}; {MKLDNN_ARG_DST, *dst_memory_p}};
if (bias) { if (bias) {
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test); auto bias_memory_p =
handler.AcquireBiasMemoryWithReorder(dev_ctx, key, bias);
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
} }
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
......
...@@ -30,234 +30,220 @@ using platform::to_void_cast; ...@@ -30,234 +30,220 @@ using platform::to_void_cast;
template <typename T> template <typename T>
class PoolingMKLDNNHandler class PoolingMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::pooling_forward, : public platform::MKLDNNHandlerNoCachingT<T, mkldnn::pooling_forward,
mkldnn::pooling_backward> { mkldnn::pooling_backward> {
public: public:
PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx, const mkldnn::engine mkldnn_engine, const Tensor* input,
platform::Place cpu_place, const Tensor* input, Tensor* output)
Tensor* output, const std::string& unique_name) : platform::MKLDNNHandlerNoCachingT<T, mkldnn::pooling_forward,
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward, mkldnn::pooling_backward>(
mkldnn::pooling_backward>( mkldnn_engine, ctx.GetPlace()) {
dev_ctx, dev_ctx.GetEngine(), cpu_place, PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), platform::errors::InvalidArgument(
framework::ToMKLDNNDataType(input->type()), "Wrong layout set for Input tensor."));
unique_name)) { PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
if (!this->isCached()) { platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, "Wrong format set for Input tensor."));
platform::errors::InvalidArgument(
"Wrong layout set for Input tensor.")); const std::string pooling_type = ctx.Attr<std::string>("pooling_type");
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
"Wrong format set for Input tensor.")); std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));
const std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp)); std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp)); const bool global_pooling = ctx.Attr<bool>("global_pooling");
const std::string padding_algorithm =
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings"); ctx.Attr<std::string>("padding_algorithm");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
// Only 2D pooling is supported now
const bool global_pooling = ctx.Attr<bool>("global_pooling"); PADDLE_ENFORCE_EQ(
const std::string padding_algorithm = ksize.size(), 2,
ctx.Attr<std::string>("padding_algorithm"); platform::errors::InvalidArgument(
"The ksize must be 2D, i.e. 2D pooling, but received %dD.",
// Only 2D pooling is supported now ksize.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ksize.size(), 2, pooling_type == "max" || pooling_type == "avg", true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The ksize must be 2D, i.e. 2D pooling, but received %dD.", "The pooling_type must be 'max' or 'avg', but received %s.",
ksize.size())); pooling_type));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
pooling_type == "max" || pooling_type == "avg", true, input->dims().size(), 4,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The pooling_type must be 'max' or 'avg', but received %s.", "Input dim must be with 4, i.e. NCHW, but received %d.",
pooling_type)); input->dims().size()));
PADDLE_ENFORCE_EQ(
input->dims().size(), 4, const auto input_dims = input->dims();
platform::errors::InvalidArgument( framework::DDim data_dims =
"Input dim must be with 4, i.e. NCHW, but received %d.", framework::slice_ddim(input_dims, 2, input_dims.size());
input->dims().size()));
if (global_pooling) {
const auto input_dims = input->dims(); operators::UpdateKsize(&ksize, data_dims);
framework::DDim data_dims = }
framework::slice_ddim(input_dims, 2, input_dims.size());
if (global_pooling) {
operators::UpdateKsize(&ksize, data_dims);
}
operators::UpdatePadding(&paddings, global_pooling, 0, padding_algorithm,
data_dims, strides, ksize);
const auto src_tz = paddle::framework::vectorize(input->dims()); operators::UpdatePadding(&paddings, global_pooling, 0, padding_algorithm,
const auto dst_tz = paddle::framework::vectorize(output->dims()); data_dims, strides, ksize);
const auto is_test = ctx.Attr<bool>("is_test"); const auto src_tz = paddle::framework::vectorize(input->dims());
const auto dst_tz = paddle::framework::vectorize(output->dims());
const auto dt = framework::ToMKLDNNDataType(input->type()); const auto is_test = ctx.Attr<bool>("is_test");
const auto exclude_padding = ctx.Attr<bool>("exclusive"); const auto dt = framework::ToMKLDNNDataType(input->type());
const auto src_md = mkldnn::memory::desc(src_tz, dt, input->format()); const auto exclude_padding = ctx.Attr<bool>("exclusive");
/* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance
*/
const auto dst_md = const auto src_md = mkldnn::memory::desc(src_tz, dt, input->format());
platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any); /* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance
*/
auto mkldnn_paddings = platform::ToMkldnnPadding(paddings); const auto dst_md =
platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any);
const bool ceil_mode = ctx.Attr<bool>("ceil_mode"); auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
if (ceil_mode) { const bool ceil_mode = ctx.Attr<bool>("ceil_mode");
CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides,
mkldnn_paddings[1]);
}
ComputeAdaptivePoolParameters(ctx, src_tz, &ksize, &strides); if (ceil_mode) {
CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides,
this->AcquireForwardPrimitiveDescriptor( mkldnn_paddings[1]);
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
pooling_type == "max"
? mkldnn::algorithm::pooling_max
: (exclude_padding
? mkldnn::algorithm::pooling_avg_exclude_padding
: mkldnn::algorithm::pooling_avg_include_padding),
src_md, dst_md, strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1]);
} }
ComputeAdaptivePoolParameters(ctx, src_tz, &ksize, &strides);
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
pooling_type == "max"
? mkldnn::algorithm::pooling_max
: (exclude_padding
? mkldnn::algorithm::pooling_avg_exclude_padding
: mkldnn::algorithm::pooling_avg_include_padding),
src_md, dst_md, strides, ksize, mkldnn_paddings[0], mkldnn_paddings[1]);
} }
PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx, const mkldnn::engine mkldnn_engine, const Tensor* in_x,
platform::Place cpu_place, const Tensor* in_x, const Tensor* out_grad, Tensor* in_x_grad)
const Tensor* out_grad, Tensor* in_x_grad,
const std::string& unique_name) : platform::MKLDNNHandlerNoCachingT<T, mkldnn::pooling_forward,
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward, mkldnn::pooling_backward>(
mkldnn::pooling_backward>( mkldnn_engine, ctx.GetPlace()) {
dev_ctx, dev_ctx.GetEngine(), cpu_place, PADDLE_ENFORCE_EQ(
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()), in_x->layout(), DataLayout::kMKLDNN,
framework::ToMKLDNNDataType(in_x->type()), platform::errors::InvalidArgument("Wrong layout set for Input tensor"));
unique_name)) { PADDLE_ENFORCE_NE(
if (!this->isBwdCached()) { in_x->format(), MKLDNNMemoryFormat::undef,
PADDLE_ENFORCE_EQ(in_x->layout(), DataLayout::kMKLDNN, platform::errors::InvalidArgument("Wrong format set for Input tensor"));
platform::errors::InvalidArgument(
"Wrong layout set for Input tensor")); PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
PADDLE_ENFORCE_NE(in_x->format(), MKLDNNMemoryFormat::undef, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Wrong layout set for Input output_grad tensor"));
"Wrong format set for Input tensor")); PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN, "Wrong format set for Input output_grad tensor"));
platform::errors::InvalidArgument(
"Wrong layout set for Input output_grad tensor")); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef, ctx.Attr<bool>("is_test"), false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong format set for Input output_grad tensor")); "is_test attribute should be set to False in training phase."));
PADDLE_ENFORCE_EQ( std::string pooling_type = ctx.Attr<std::string>("pooling_type");
ctx.Attr<bool>("is_test"), false,
platform::errors::InvalidArgument( std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
"is_test attribute should be set to False in training phase.")); std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp)); std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp)); bool global_pooling = ctx.Attr<bool>("global_pooling");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp)); auto in_x_dims = in_x->dims();
framework::DDim data_dims =
bool global_pooling = ctx.Attr<bool>("global_pooling"); framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm"); if (global_pooling) {
operators::UpdateKsize(&ksize, data_dims);
auto in_x_dims = in_x->dims(); }
framework::DDim data_dims =
framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
if (global_pooling) {
operators::UpdateKsize(&ksize, data_dims);
}
operators::UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, operators::UpdatePadding(&paddings, global_pooling, 0, padding_algorithm,
data_dims, strides, ksize); data_dims, strides, ksize);
auto src_tz = paddle::framework::vectorize<int64_t>(in_x->dims()); auto src_tz = paddle::framework::vectorize<int64_t>(in_x->dims());
auto diff_src_tz = auto diff_src_tz = paddle::framework::vectorize<int64_t>(in_x_grad->dims());
paddle::framework::vectorize<int64_t>(in_x_grad->dims()); auto diff_dst_tz = paddle::framework::vectorize<int64_t>(out_grad->dims());
auto diff_dst_tz =
paddle::framework::vectorize<int64_t>(out_grad->dims()); const auto dt = framework::ToMKLDNNDataType(in_x->type());
auto src_md = mkldnn::memory::desc(src_tz, dt, in_x->format());
const auto dt = framework::ToMKLDNNDataType(in_x->type()); auto dst_md =
auto src_md = mkldnn::memory::desc(src_tz, dt, in_x->format()); mkldnn::memory::desc(diff_dst_tz, dt, MKLDNNMemoryFormat::any);
auto dst_md = auto diff_dst_md = mkldnn::memory::desc(
mkldnn::memory::desc(diff_dst_tz, dt, MKLDNNMemoryFormat::any); diff_dst_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
auto diff_dst_md = mkldnn::memory::desc( auto diff_src_md = mkldnn::memory::desc(
diff_dst_tz, platform::MKLDNNGetDataType<T>(), out_grad->format()); diff_src_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
auto diff_src_md =
mkldnn::memory::desc(diff_src_tz, platform::MKLDNNGetDataType<T>(), auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
MKLDNNMemoryFormat::any); const bool ceil_mode = ctx.Attr<bool>("ceil_mode");
auto mkldnn_paddings = platform::ToMkldnnPadding(paddings); if (ceil_mode) {
const bool ceil_mode = ctx.Attr<bool>("ceil_mode"); CorrectOutputSize(src_tz, diff_dst_tz, ksize, paddings, strides,
mkldnn_paddings[1]);
if (ceil_mode) {
CorrectOutputSize(src_tz, diff_dst_tz, ksize, paddings, strides,
mkldnn_paddings[1]);
}
ComputeAdaptivePoolParameters(ctx, diff_src_tz, &ksize, &strides);
const auto exclude_padding = ctx.Attr<bool>("exclusive");
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training,
pooling_type == "max"
? mkldnn::algorithm::pooling_max
: (exclude_padding
? mkldnn::algorithm::pooling_avg_exclude_padding
: mkldnn::algorithm::pooling_avg_include_padding),
src_md, dst_md, strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1]);
this->AcquireBackwardPrimitiveDescriptor(
pooling_type == "max"
? mkldnn::algorithm::pooling_max
: (exclude_padding
? mkldnn::algorithm::pooling_avg_exclude_padding
: mkldnn::algorithm::pooling_avg_include_padding),
diff_src_md, diff_dst_md, strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1]);
} }
ComputeAdaptivePoolParameters(ctx, diff_src_tz, &ksize, &strides);
const auto exclude_padding = ctx.Attr<bool>("exclusive");
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training,
pooling_type == "max"
? mkldnn::algorithm::pooling_max
: (exclude_padding
? mkldnn::algorithm::pooling_avg_exclude_padding
: mkldnn::algorithm::pooling_avg_include_padding),
src_md, dst_md, strides, ksize, mkldnn_paddings[0], mkldnn_paddings[1]);
this->AcquireBackwardPrimitiveDescriptor(
pooling_type == "max"
? mkldnn::algorithm::pooling_max
: (exclude_padding
? mkldnn::algorithm::pooling_avg_exclude_padding
: mkldnn::algorithm::pooling_avg_include_padding),
diff_src_md, diff_dst_md, strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1]);
} }
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(void) { std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(
const platform::MKLDNNDeviceContext& dev_ctx,
const std::string& unique_name) {
mkldnn::memory::desc workspace_md = this->fwd_pd_->workspace_desc(); mkldnn::memory::desc workspace_md = this->fwd_pd_->workspace_desc();
// Pooling PD has to be passed to Grad op that // Pooling Workspace has to be passed to Grad op that
// may be executed by diffrent thread, hence // may be executed by diffrent thread, hence
// for that one we use key that does not contain TID // for that one we use key that does not contain TID
auto local_key = this->key_common_ + "@workspace"; std::string workspace_key =
platform::CreateKey(dev_ctx, workspace_md.dims(),
workspace_md.data_type(), unique_name, "@wrk");
auto mem_p = std::static_pointer_cast<mkldnn::memory>( auto mem_p = std::static_pointer_cast<mkldnn::memory>(
this->dev_ctx_.GetBlob(local_key)); dev_ctx.GetBlob(workspace_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
static std::mutex acquire_barrier; static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job( std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier); acquire_barrier);
mem_p = std::static_pointer_cast<mkldnn::memory>( mem_p = std::static_pointer_cast<mkldnn::memory>(
this->dev_ctx_.GetBlob(local_key)); dev_ctx.GetBlob(workspace_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(workspace_md, this->engine_); mem_p = std::make_shared<mkldnn::memory>(workspace_md, this->engine_);
this->dev_ctx_.SetBlob(local_key, mem_p); dev_ctx.SetBlob(workspace_key, mem_p);
} }
} }
return mem_p; return mem_p;
...@@ -319,8 +305,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -319,8 +305,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* input = ctx.Input<Tensor>("X"); const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
PoolingMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), input, output, PoolingMKLDNNHandler<T> handler(ctx, dev_ctx.GetEngine(), input, output);
ctx.OutputName("Out"));
auto src_memory = handler.AcquireSrcMemory(input); auto src_memory = handler.AcquireSrcMemory(input);
auto dst_memory = handler.AcquireDstMemory(output); auto dst_memory = handler.AcquireDstMemory(output);
...@@ -331,7 +316,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -331,7 +316,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if ((ctx.Attr<bool>("is_test") == false) && if ((ctx.Attr<bool>("is_test") == false) &&
(ctx.Attr<std::string>("pooling_type") == "max")) { (ctx.Attr<std::string>("pooling_type") == "max")) {
// Training // Training
auto workspace_memory = handler.AcquireWorkspaceMemory(); auto workspace_memory =
handler.AcquireWorkspaceMemory(dev_ctx, ctx.OutputName("Out"));
pool_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory}, pool_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
{MKLDNN_ARG_DST, *dst_memory}, {MKLDNN_ARG_DST, *dst_memory},
{MKLDNN_ARG_WORKSPACE, *workspace_memory}}); {MKLDNN_ARG_WORKSPACE, *workspace_memory}});
...@@ -361,8 +347,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -361,8 +347,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
PoolingMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), in_x, PoolingMKLDNNHandler<T> handler(ctx, dev_ctx.GetEngine(), in_x, out_grad,
out_grad, in_x_grad, ctx.InputName("Out")); in_x_grad);
auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad); auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad); auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad);
...@@ -372,7 +358,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -372,7 +358,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (ctx.Attr<std::string>("pooling_type") == "max") { if (ctx.Attr<std::string>("pooling_type") == "max") {
// Max - pooling needs Workspace // Max - pooling needs Workspace
auto workspace_memory = handler.AcquireWorkspaceMemory(); auto workspace_memory =
handler.AcquireWorkspaceMemory(dev_ctx, ctx.InputName("Out"));
pool_bwd_p->execute(astream, {{MKLDNN_ARG_DIFF_SRC, *diff_src_memory}, pool_bwd_p->execute(astream, {{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory}, {MKLDNN_ARG_DIFF_DST, *diff_dst_memory},
{MKLDNN_ARG_WORKSPACE, *workspace_memory}}); {MKLDNN_ARG_WORKSPACE, *workspace_memory}});
......
...@@ -64,81 +64,46 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -64,81 +64,46 @@ class QuantOpKernel : public framework::OpKernel<T> {
bool is_negative_input = ctx.Attr<bool>("is_negative_input"); bool is_negative_input = ctx.Attr<bool>("is_negative_input");
bool bfloat16 = ctx.Attr<bool>("bfloat16"); bool bfloat16 = ctx.Attr<bool>("bfloat16");
std::string key = // TODO(jczaja): Refactor with Acquire API
platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift,
is_negative_input, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory; std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory; std::shared_ptr<mkldnn::memory> dst_memory;
std::shared_ptr<reorder> reorder_p; std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
std::string out_layout = ctx.Attr<std::string>("output_format");
if (reorder_p == nullptr) { MKLDNNMemoryFormat out_format =
std::string out_layout = ctx.Attr<std::string>("output_format"); platform::data_format_to_memory_format(out_layout);
MKLDNNMemoryFormat out_format = mkldnn::primitive_attr attri;
platform::data_format_to_memory_format(out_layout); int mask = 0;
mkldnn::primitive_attr attri; attri.set_output_scales(mask, {scale_data});
int mask = 0;
attri.set_output_scales(mask, {scale_data}); if (with_shift) {
mkldnn::post_ops post_operations;
if (with_shift) { post_operations.append_sum();
mkldnn::post_ops post_operations; attri.set_post_ops(post_operations);
post_operations.append_sum(); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
attri.set_post_ops(post_operations); // memset casts scale_shift to unsigned char (uint8_t) internally
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); std::memset(output_data, scale_shift, output->numel());
// memset casts scale_shift to unsigned char (uint8_t) internally }
std::memset(output_data, scale_shift, output->numel());
} auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format());
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, src_memory = std::make_shared<mkldnn::memory>(src_md, engine,
input->format()); to_void_cast<T>(input_data));
src_memory = std::make_shared<mkldnn::memory>(
src_md, engine, to_void_cast<T>(input_data)); std::shared_ptr<mkldnn::memory::desc> dst_md;
if (bfloat16) {
std::shared_ptr<mkldnn::memory::desc> dst_md; platform::SetDstMemoryQuantized<paddle::platform::bfloat16>(
if (bfloat16) { ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
platform::SetDstMemoryQuantized<paddle::platform::bfloat16>( } else if (is_negative_input && !with_shift) {
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
} else if (is_negative_input && !with_shift) { dst_md, dst_memory, out_format);
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format);
} else {
platform::SetDstMemoryQuantized<uint8_t>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
}
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else { } else {
src_memory = std::static_pointer_cast<mkldnn::memory>( platform::SetDstMemoryQuantized<uint8_t>(ctx, output, dst_tz, engine,
dev_ctx.GetBlob(key_src_mem)); dst_md, dst_memory, out_format);
src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_dst_mem));
auto place = ctx.GetPlace();
if (bfloat16) {
dst_memory->set_data_handle(
output->mutable_data<paddle::platform::bfloat16>(place));
} else if (with_shift || !is_negative_input) {
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
if (with_shift) std::memset(output_data, scale_shift, output->numel());
dst_memory->set_data_handle(output_data);
} else {
dst_memory->set_data_handle(
output->mutable_data<int8_t>(ctx.GetPlace()));
}
} }
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{ {
......
...@@ -207,7 +207,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -207,7 +207,7 @@ class MKLDNNHandlerNoCachingT {
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md, const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr, const mkldnn::memory::desc& target_md, void* ptr,
const std::string& suffix, bool is_persistent = false, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) { std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) {
std::shared_ptr<mkldnn::memory> target_memory_p; std::shared_ptr<mkldnn::memory> target_memory_p;
if (custom_reorder_func) { if (custom_reorder_func) {
...@@ -500,18 +500,9 @@ class MKLDNNHandlerT { ...@@ -500,18 +500,9 @@ class MKLDNNHandlerT {
} }
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p, void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p, const std::shared_ptr<mkldnn::memory>& target_memory_p) {
const std::string& suffix) { auto reorder_p =
const auto key_reorder_p = key_ + suffix + "reorder_p"; std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p == nullptr) {
reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
...@@ -578,6 +569,8 @@ class MKLDNNHandlerT { ...@@ -578,6 +569,8 @@ class MKLDNNHandlerT {
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key)); std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
user_memory_p->set_data_handle(ptr); user_memory_p->set_data_handle(ptr);
// TODO(jczaja): Here we detect if reorder is cached it means it is needed
// need to change this to get rid of keys
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>( auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p)); dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) { if (reorder_p != nullptr) {
......
...@@ -95,4 +95,6 @@ class TestConv3DOp_Valid_MKLDNN(TestConv3DOp_AsyPadding_MKLDNN): ...@@ -95,4 +95,6 @@ class TestConv3DOp_Valid_MKLDNN(TestConv3DOp_AsyPadding_MKLDNN):
if __name__ == '__main__': if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册