未验证 提交 4aba17b5 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Added UT for testing elementwise_mul caching (#30203)

* - Added UT for testing elementwise_mul caching

* lint fixes
上级 be5c2e60
cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_add_op activation_op softmax_op softmax scope device_context enforce) cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_mul_op elementwise_add_op activation_op softmax_op softmax scope device_context enforce)
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
USE_OP(elementwise_add); USE_OP(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(elementwise_mul);
USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN);
USE_OP(relu); USE_OP(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP(softmax); USE_OP(softmax);
...@@ -66,8 +68,10 @@ void RunOperator(const platform::Place &place, const std::string &op_type, ...@@ -66,8 +68,10 @@ void RunOperator(const platform::Place &place, const std::string &op_type,
bool inplace = false) { bool inplace = false) {
framework::Scope scope; framework::Scope scope;
std::map<const std::string, int> num_inputs = { std::map<const std::string, int> num_inputs = {{"softmax", 1},
{"softmax", 1}, {"relu", 1}, {"elementwise_add", 2}}; {"relu", 1},
{"elementwise_add", 2},
{"elementwise_mul", 2}};
std::string first_input = inplace == true ? output_name : "x"; std::string first_input = inplace == true ? output_name : "x";
...@@ -165,5 +169,17 @@ TEST(test_elementwise_add_reuse_cache, cpu_place) { ...@@ -165,5 +169,17 @@ TEST(test_elementwise_add_reuse_cache, cpu_place) {
"Wrong number of cached oneDNN objects")); "Wrong number of cached oneDNN objects"));
} }
TEST(test_elementwises_sequence_reuse_cache, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "elementwise_add", dims, "elementwise_add_out", true);
RunOperator<float>(p, "elementwise_mul", dims, "elementwise_add_out", true);
RunOperator<float>(p, "relu", dims, "elementwise_add_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(11), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -516,8 +516,8 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> { ...@@ -516,8 +516,8 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
: platform::MKLDNNHandlerT<T, dnnl::binary>( : platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place, dev_ctx, engine, cpu_place,
platform::CreateKey( platform::CreateKey(
dev_ctx, framework::vectorize(x->dims()), dev_ctx, framework::vectorize(x->dims()), uniq_name,
uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) { (algo == dnnl::algorithm::binary_mul ? "M" : ""))) {
// bradcasting combined with in-place may require // bradcasting combined with in-place may require
auto rankdiff = x->dims().size() - y->dims().size(); auto rankdiff = x->dims().size() - y->dims().size();
if (rankdiff > 0) { if (rankdiff > 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册