未验证 提交 db6c00c4 编写于 作者: J Jacek Czaja 提交者: GitHub

Disable pool&conv_transpose&quantize caching (#36695)

* - WIP

- compilation fix

- fix

- fixes

- fix

- fix

- fix again

- fix

- another fix

- another compilation fix

- fix

- fix

- fix

- lint

* - pool2d partially stripped from cache

- pool2d partially stripped of caching

* - compilation fix

* - compilation fix

* - Fix to UT of caching

* - Enabling test_conv3d_mkldnn

* - conv_transpose stripped of cache

* - compilation fix

* - fix

* - fix

* - compilation fix

* - fix

* Reverted disabling caching of conv2d

* - compilation fix

* - ut reverted
上级 9a53477c
......@@ -21,7 +21,6 @@ namespace operators {
using paddle::framework::LoDTensor;
using paddle::framework::Tensor;
using paddle::platform::CPUDeviceContext;
using paddle::platform::CreateKey;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc;
using platform::to_void_cast;
......
......@@ -21,7 +21,6 @@ namespace operators {
using paddle::framework::LoDTensor;
using paddle::framework::Tensor;
using paddle::platform::CPUDeviceContext;
using paddle::platform::CreateKey;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc;
using platform::to_void_cast;
......
......@@ -565,7 +565,7 @@ class ConvMKLDNNHandlerT
const auto target_mem_p = this->AcquireMemory(target_key_suffix);
user_mem_p->set_data_handle(platform::to_void_cast<T>(in_mem_data));
if (user_mem_p != target_mem_p) {
this->AcquireReorder(user_mem_p, target_mem_p, key_mem);
this->AcquireReorder(user_mem_p, target_mem_p);
}
return target_mem_p;
}
......@@ -643,7 +643,7 @@ class ConvMKLDNNHandlerT
platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) {
auto residual_memory_p = this->AcquireResidualMemory(residual_param);
dst_memory_p = this->template AcquireDstMemory<T_out>(output);
this->AcquireReorder(residual_memory_p, dst_memory_p, "@residual_dst");
this->AcquireReorder(residual_memory_p, dst_memory_p);
} else {
// Changing ShareDataWith to TensorCopy results in performance drop
// on ResNet architectures
......
......@@ -64,81 +64,46 @@ class QuantOpKernel : public framework::OpKernel<T> {
bool is_negative_input = ctx.Attr<bool>("is_negative_input");
bool bfloat16 = ctx.Attr<bool>("bfloat16");
std::string key =
platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift,
is_negative_input, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
// TODO(jczaja): Refactor with Acquire API
std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory;
std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
if (reorder_p == nullptr) {
std::string out_layout = ctx.Attr<std::string>("output_format");
MKLDNNMemoryFormat out_format =
platform::data_format_to_memory_format(out_layout);
mkldnn::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, {scale_data});
if (with_shift) {
mkldnn::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
// memset casts scale_shift to unsigned char (uint8_t) internally
std::memset(output_data, scale_shift, output->numel());
}
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format());
src_memory = std::make_shared<mkldnn::memory>(
src_md, engine, to_void_cast<T>(input_data));
std::shared_ptr<mkldnn::memory::desc> dst_md;
if (bfloat16) {
platform::SetDstMemoryQuantized<paddle::platform::bfloat16>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
} else if (is_negative_input && !with_shift) {
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format);
} else {
platform::SetDstMemoryQuantized<uint8_t>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
}
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
std::string out_layout = ctx.Attr<std::string>("output_format");
MKLDNNMemoryFormat out_format =
platform::data_format_to_memory_format(out_layout);
mkldnn::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, {scale_data});
if (with_shift) {
mkldnn::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
// memset casts scale_shift to unsigned char (uint8_t) internally
std::memset(output_data, scale_shift, output->numel());
}
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format());
src_memory = std::make_shared<mkldnn::memory>(src_md, engine,
to_void_cast<T>(input_data));
std::shared_ptr<mkldnn::memory::desc> dst_md;
if (bfloat16) {
platform::SetDstMemoryQuantized<paddle::platform::bfloat16>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
} else if (is_negative_input && !with_shift) {
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format);
} else {
src_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_dst_mem));
auto place = ctx.GetPlace();
if (bfloat16) {
dst_memory->set_data_handle(
output->mutable_data<paddle::platform::bfloat16>(place));
} else if (with_shift || !is_negative_input) {
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
if (with_shift) std::memset(output_data, scale_shift, output->numel());
dst_memory->set_data_handle(output_data);
} else {
dst_memory->set_data_handle(
output->mutable_data<int8_t>(ctx.GetPlace()));
}
platform::SetDstMemoryQuantized<uint8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format);
}
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{
......
......@@ -207,7 +207,7 @@ class MKLDNNHandlerNoCachingT {
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr,
const std::string& suffix, bool is_persistent = false,
bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) {
std::shared_ptr<mkldnn::memory> target_memory_p;
if (custom_reorder_func) {
......@@ -500,18 +500,9 @@ class MKLDNNHandlerT {
}
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p,
const std::string& suffix) {
const auto key_reorder_p = key_ + suffix + "reorder_p";
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p == nullptr) {
reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
}
const std::shared_ptr<mkldnn::memory>& target_memory_p) {
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
......@@ -578,6 +569,8 @@ class MKLDNNHandlerT {
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
user_memory_p->set_data_handle(ptr);
// TODO(jczaja): Here we detect if reorder is cached it means it is needed
// need to change this to get rid of keys
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) {
......
......@@ -95,4 +95,6 @@ class TestConv3DOp_Valid_MKLDNN(TestConv3DOp_AsyPadding_MKLDNN):
if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册