From 4aba17b5dbc803dd9b72682f03269eeddf2aa132 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Sat, 9 Jan 2021 05:20:25 +0100 Subject: [PATCH] [oneDNN] Added UT for testing elementwise_mul caching (#30203) * - Added UT for testing elementwise_mul caching * lint fixes --- .../operators/mkldnn/caching_tests.cmake | 2 +- .../operators/mkldnn/test_mkldnn_caching.cc | 20 +++++++++++++++++-- paddle/fluid/platform/mkldnn_reuse.h | 4 ++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/caching_tests.cmake b/paddle/fluid/operators/mkldnn/caching_tests.cmake index ff910a1876..4130c295b2 100644 --- a/paddle/fluid/operators/mkldnn/caching_tests.cmake +++ b/paddle/fluid/operators/mkldnn/caching_tests.cmake @@ -1 +1 @@ -cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_add_op activation_op softmax_op softmax scope device_context enforce) +cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_mul_op elementwise_add_op activation_op softmax_op softmax scope device_context enforce) diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc index f88b0d5621..1df7c7ac9b 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc @@ -27,6 +27,8 @@ USE_OP(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); +USE_OP(elementwise_mul); +USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN); USE_OP(relu); USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP(softmax); @@ -66,8 +68,10 @@ void RunOperator(const platform::Place &place, const std::string &op_type, bool inplace = false) { framework::Scope scope; - std::map num_inputs = { - {"softmax", 1}, {"relu", 1}, {"elementwise_add", 2}}; + std::map num_inputs = {{"softmax", 1}, + {"relu", 1}, + {"elementwise_add", 2}, + {"elementwise_mul", 2}}; std::string first_input = inplace == true ? output_name : "x"; @@ -165,5 +169,17 @@ TEST(test_elementwise_add_reuse_cache, cpu_place) { "Wrong number of cached oneDNN objects")); } +TEST(test_elementwises_sequence_reuse_cache, cpu_place) { + framework::DDim dims({32, 64}); + platform::CPUPlace p; + CacheTester ct; + RunOperator(p, "elementwise_add", dims, "elementwise_add_out", true); + RunOperator(p, "elementwise_mul", dims, "elementwise_add_out", true); + RunOperator(p, "relu", dims, "elementwise_add_out", true); + PADDLE_ENFORCE_EQ(ct.Analyze(11), true, + platform::errors::InvalidArgument( + "Wrong number of cached oneDNN objects")); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 58a8f6263f..f3dade5a16 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -516,8 +516,8 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey( - dev_ctx, framework::vectorize(x->dims()), - uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) { + dev_ctx, framework::vectorize(x->dims()), uniq_name, + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) { // bradcasting combined with in-place may require auto rankdiff = x->dims().size() - y->dims().size(); if (rankdiff > 0) { -- GitLab