未验证 提交 db6c00c4 编写于 作者: J Jacek Czaja 提交者: GitHub

Disable pool&conv_transpose&quantize caching (#36695)

* - WIP

- compilation fix

- fix

- fixes

- fix

- fix

- fix again

- fix

- another fix

- another compilation fix

- fix

- fix

- fix

- lint

* - pool2d partially stripped from cache

- pool2d partially stripped of caching

* - compilation fix

* - compilation fix

* - Fix to UT of caching

* - Enabling test_conv3d_mkldnn

* - conv_transpose stripped of cache

* - compilation fix

* - fix

* - fix

* - compilation fix

* - fix

* Reverted disabling caching of conv2d

* - compilation fix

* - ut reverted
上级 9a53477c
...@@ -21,7 +21,6 @@ namespace operators { ...@@ -21,7 +21,6 @@ namespace operators {
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::framework::Tensor; using paddle::framework::Tensor;
using paddle::platform::CPUDeviceContext; using paddle::platform::CPUDeviceContext;
using paddle::platform::CreateKey;
using paddle::platform::MKLDNNGetDataType; using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc; using paddle::platform::MKLDNNMemDesc;
using platform::to_void_cast; using platform::to_void_cast;
......
...@@ -21,7 +21,6 @@ namespace operators { ...@@ -21,7 +21,6 @@ namespace operators {
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::framework::Tensor; using paddle::framework::Tensor;
using paddle::platform::CPUDeviceContext; using paddle::platform::CPUDeviceContext;
using paddle::platform::CreateKey;
using paddle::platform::MKLDNNGetDataType; using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc; using paddle::platform::MKLDNNMemDesc;
using platform::to_void_cast; using platform::to_void_cast;
......
...@@ -565,7 +565,7 @@ class ConvMKLDNNHandlerT ...@@ -565,7 +565,7 @@ class ConvMKLDNNHandlerT
const auto target_mem_p = this->AcquireMemory(target_key_suffix); const auto target_mem_p = this->AcquireMemory(target_key_suffix);
user_mem_p->set_data_handle(platform::to_void_cast<T>(in_mem_data)); user_mem_p->set_data_handle(platform::to_void_cast<T>(in_mem_data));
if (user_mem_p != target_mem_p) { if (user_mem_p != target_mem_p) {
this->AcquireReorder(user_mem_p, target_mem_p, key_mem); this->AcquireReorder(user_mem_p, target_mem_p);
} }
return target_mem_p; return target_mem_p;
} }
...@@ -643,7 +643,7 @@ class ConvMKLDNNHandlerT ...@@ -643,7 +643,7 @@ class ConvMKLDNNHandlerT
platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) { platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) {
auto residual_memory_p = this->AcquireResidualMemory(residual_param); auto residual_memory_p = this->AcquireResidualMemory(residual_param);
dst_memory_p = this->template AcquireDstMemory<T_out>(output); dst_memory_p = this->template AcquireDstMemory<T_out>(output);
this->AcquireReorder(residual_memory_p, dst_memory_p, "@residual_dst"); this->AcquireReorder(residual_memory_p, dst_memory_p);
} else { } else {
// Changing ShareDataWith to TensorCopy results in performance drop // Changing ShareDataWith to TensorCopy results in performance drop
// on ResNet architectures // on ResNet architectures
......
...@@ -64,81 +64,46 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -64,81 +64,46 @@ class QuantOpKernel : public framework::OpKernel<T> {
bool is_negative_input = ctx.Attr<bool>("is_negative_input"); bool is_negative_input = ctx.Attr<bool>("is_negative_input");
bool bfloat16 = ctx.Attr<bool>("bfloat16"); bool bfloat16 = ctx.Attr<bool>("bfloat16");
std::string key = // TODO(jczaja): Refactor with Acquire API
platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift,
is_negative_input, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory; std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory; std::shared_ptr<mkldnn::memory> dst_memory;
std::shared_ptr<reorder> reorder_p; std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
std::string out_layout = ctx.Attr<std::string>("output_format");
if (reorder_p == nullptr) { MKLDNNMemoryFormat out_format =
std::string out_layout = ctx.Attr<std::string>("output_format"); platform::data_format_to_memory_format(out_layout);
MKLDNNMemoryFormat out_format = mkldnn::primitive_attr attri;
platform::data_format_to_memory_format(out_layout); int mask = 0;
mkldnn::primitive_attr attri; attri.set_output_scales(mask, {scale_data});
int mask = 0;
attri.set_output_scales(mask, {scale_data}); if (with_shift) {
mkldnn::post_ops post_operations;
if (with_shift) { post_operations.append_sum();
mkldnn::post_ops post_operations; attri.set_post_ops(post_operations);
post_operations.append_sum(); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
attri.set_post_ops(post_operations); // memset casts scale_shift to unsigned char (uint8_t) internally
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); std::memset(output_data, scale_shift, output->numel());
// memset casts scale_shift to unsigned char (uint8_t) internally }
std::memset(output_data, scale_shift, output->numel());
} auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format());
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, src_memory = std::make_shared<mkldnn::memory>(src_md, engine,
input->format()); to_void_cast<T>(input_data));
src_memory = std::make_shared<mkldnn::memory>(
src_md, engine, to_void_cast<T>(input_data)); std::shared_ptr<mkldnn::memory::desc> dst_md;
if (bfloat16) {
std::shared_ptr<mkldnn::memory::desc> dst_md; platform::SetDstMemoryQuantized<paddle::platform::bfloat16>(
if (bfloat16) { ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
platform::SetDstMemoryQuantized<paddle::platform::bfloat16>( } else if (is_negative_input && !with_shift) {
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
} else if (is_negative_input && !with_shift) { dst_md, dst_memory, out_format);
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format);
} else {
platform::SetDstMemoryQuantized<uint8_t>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
}
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else { } else {
src_memory = std::static_pointer_cast<mkldnn::memory>( platform::SetDstMemoryQuantized<uint8_t>(ctx, output, dst_tz, engine,
dev_ctx.GetBlob(key_src_mem)); dst_md, dst_memory, out_format);
src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_dst_mem));
auto place = ctx.GetPlace();
if (bfloat16) {
dst_memory->set_data_handle(
output->mutable_data<paddle::platform::bfloat16>(place));
} else if (with_shift || !is_negative_input) {
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
if (with_shift) std::memset(output_data, scale_shift, output->numel());
dst_memory->set_data_handle(output_data);
} else {
dst_memory->set_data_handle(
output->mutable_data<int8_t>(ctx.GetPlace()));
}
} }
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{ {
......
...@@ -207,7 +207,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -207,7 +207,7 @@ class MKLDNNHandlerNoCachingT {
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md, const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr, const mkldnn::memory::desc& target_md, void* ptr,
const std::string& suffix, bool is_persistent = false, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) { std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) {
std::shared_ptr<mkldnn::memory> target_memory_p; std::shared_ptr<mkldnn::memory> target_memory_p;
if (custom_reorder_func) { if (custom_reorder_func) {
...@@ -500,18 +500,9 @@ class MKLDNNHandlerT { ...@@ -500,18 +500,9 @@ class MKLDNNHandlerT {
} }
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p, void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p, const std::shared_ptr<mkldnn::memory>& target_memory_p) {
const std::string& suffix) { auto reorder_p =
const auto key_reorder_p = key_ + suffix + "reorder_p"; std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p == nullptr) {
reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
...@@ -578,6 +569,8 @@ class MKLDNNHandlerT { ...@@ -578,6 +569,8 @@ class MKLDNNHandlerT {
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key)); std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
user_memory_p->set_data_handle(ptr); user_memory_p->set_data_handle(ptr);
// TODO(jczaja): Here we detect if reorder is cached it means it is needed
// need to change this to get rid of keys
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>( auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p)); dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) { if (reorder_p != nullptr) {
......
...@@ -95,4 +95,6 @@ class TestConv3DOp_Valid_MKLDNN(TestConv3DOp_AsyPadding_MKLDNN): ...@@ -95,4 +95,6 @@ class TestConv3DOp_Valid_MKLDNN(TestConv3DOp_AsyPadding_MKLDNN):
if __name__ == '__main__': if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册