diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc index 56971b9443fe29fac4093e74d69be9ceea398f8a..22cf5c457b098975a75e9d85ae0706efd0cc68c1 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc @@ -66,8 +66,7 @@ class CacheTester { template void RunOperator(const platform::Place &place, const std::string &op_type, - const framework::DDim &dims, const std::string &output_name, - bool inplace = false) { + const framework::DDim &dims, const std::string &first_input) { framework::Scope scope; std::map num_inputs = {{"softmax", 1}, @@ -76,11 +75,10 @@ void RunOperator(const platform::Place &place, const std::string &op_type, {"elementwise_add", 2}, {"elementwise_mul", 2}}; - std::string first_input = inplace == true ? output_name : "x"; - std::string first_input_var_name = (op_type == "conv2d") ? "Input" : "X"; std::string second_input_var_name = (op_type == "conv2d") ? "Filter" : "Y"; std::string output_var_name = (op_type == "conv2d") ? "Output" : "Out"; + std::string output_name = "output"; std::vector input_names = { {first_input, scope.Var(first_input)->GetMutable()}, @@ -134,24 +132,24 @@ void RunOperator(const platform::Place &place, const std::string &op_type, pool.Get(place)->Wait(); } -TEST(test_softmax_reuse_cache, cpu_place) { +TEST(test_conv2d_reuse_cache, cpu_place) { framework::DDim dims({1, 16, 32, 64}); platform::CPUPlace p; CacheTester ct; - RunOperator(p, "conv2d", dims, "conv_out"); - RunOperator(p, "conv2d", dims, "conv_out"); - PADDLE_ENFORCE_EQ(ct.Analyze(4), true, + RunOperator(p, "conv2d", dims, "input_signal"); + RunOperator(p, "conv2d", dims, "input_signal"); + PADDLE_ENFORCE_EQ(ct.Analyze(9), true, platform::errors::InvalidArgument( "Wrong number of cached oneDNN objects")); } -TEST(test_softmax_noreuse_cache, cpu_place) { +TEST(test_conv2d_noreuse_cache, cpu_place) { framework::DDim dims({1, 16, 32, 64}); platform::CPUPlace p; CacheTester ct; - RunOperator(p, "conv2d", dims, "conv_out"); - RunOperator(p, "conv2d", dims, "conv_out2"); - PADDLE_ENFORCE_EQ(ct.Analyze(8), true, + RunOperator(p, "conv2d", dims, "input_signal"); + RunOperator(p, "conv2d", dims, "input_signal2"); + PADDLE_ENFORCE_EQ(ct.Analyze(9), true, platform::errors::InvalidArgument( "Wrong number of cached oneDNN objects")); }