diff --git a/paddle/fluid/operators/mkldnn/caching_tests.cmake b/paddle/fluid/operators/mkldnn/caching_tests.cmake index ff910a18767dc86d179fe13d53d53f0596192b95..4130c295b203eb0fddaf7e9fd8f398baa4144c99 100644 --- a/paddle/fluid/operators/mkldnn/caching_tests.cmake +++ b/paddle/fluid/operators/mkldnn/caching_tests.cmake @@ -1 +1 @@ -cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_add_op activation_op softmax_op softmax scope device_context enforce) +cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_mul_op elementwise_add_op activation_op softmax_op softmax scope device_context enforce) diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc index f88b0d56218b5f7231fbebbd9c58d5e7d5b1ca3c..1df7c7ac9b1128bcc9bd73dd4d8d08ed8c6fb235 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc @@ -27,6 +27,8 @@ USE_OP(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); +USE_OP(elementwise_mul); +USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN); USE_OP(relu); USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP(softmax); @@ -66,8 +68,10 @@ void RunOperator(const platform::Place &place, const std::string &op_type, bool inplace = false) { framework::Scope scope; - std::map num_inputs = { - {"softmax", 1}, {"relu", 1}, {"elementwise_add", 2}}; + std::map num_inputs = {{"softmax", 1}, + {"relu", 1}, + {"elementwise_add", 2}, + {"elementwise_mul", 2}}; std::string first_input = inplace == true ? output_name : "x"; @@ -165,5 +169,17 @@ TEST(test_elementwise_add_reuse_cache, cpu_place) { "Wrong number of cached oneDNN objects")); } +TEST(test_elementwises_sequence_reuse_cache, cpu_place) { + framework::DDim dims({32, 64}); + platform::CPUPlace p; + CacheTester ct; + RunOperator(p, "elementwise_add", dims, "elementwise_add_out", true); + RunOperator(p, "elementwise_mul", dims, "elementwise_add_out", true); + RunOperator(p, "relu", dims, "elementwise_add_out", true); + PADDLE_ENFORCE_EQ(ct.Analyze(11), true, + platform::errors::InvalidArgument( + "Wrong number of cached oneDNN objects")); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 58a8f6263ff68445134db51249253ec41c354ed0..f3dade5a169b1257dd6732a4635438dffd8899d5 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -516,8 +516,8 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey( - dev_ctx, framework::vectorize(x->dims()), - uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) { + dev_ctx, framework::vectorize(x->dims()), uniq_name, + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) { // bradcasting combined with in-place may require auto rankdiff = x->dims().size() - y->dims().size(); if (rankdiff > 0) {