提交 b2642609 编写于 作者: J Jacek Czaja

- fix

上级 40b9025d
...@@ -66,8 +66,7 @@ class CacheTester { ...@@ -66,8 +66,7 @@ class CacheTester {
template <typename T> template <typename T>
void RunOperator(const platform::Place &place, const std::string &op_type, void RunOperator(const platform::Place &place, const std::string &op_type,
const framework::DDim &dims, const std::string &output_name, const framework::DDim &dims, const std::string &first_input) {
bool inplace = false) {
framework::Scope scope; framework::Scope scope;
std::map<const std::string, int> num_inputs = {{"softmax", 1}, std::map<const std::string, int> num_inputs = {{"softmax", 1},
...@@ -76,11 +75,10 @@ void RunOperator(const platform::Place &place, const std::string &op_type, ...@@ -76,11 +75,10 @@ void RunOperator(const platform::Place &place, const std::string &op_type,
{"elementwise_add", 2}, {"elementwise_add", 2},
{"elementwise_mul", 2}}; {"elementwise_mul", 2}};
std::string first_input = inplace == true ? output_name : "x";
std::string first_input_var_name = (op_type == "conv2d") ? "Input" : "X"; std::string first_input_var_name = (op_type == "conv2d") ? "Input" : "X";
std::string second_input_var_name = (op_type == "conv2d") ? "Filter" : "Y"; std::string second_input_var_name = (op_type == "conv2d") ? "Filter" : "Y";
std::string output_var_name = (op_type == "conv2d") ? "Output" : "Out"; std::string output_var_name = (op_type == "conv2d") ? "Output" : "Out";
std::string output_name = "output";
std::vector<InputVars> input_names = { std::vector<InputVars> input_names = {
{first_input, scope.Var(first_input)->GetMutable<framework::LoDTensor>()}, {first_input, scope.Var(first_input)->GetMutable<framework::LoDTensor>()},
...@@ -134,24 +132,24 @@ void RunOperator(const platform::Place &place, const std::string &op_type, ...@@ -134,24 +132,24 @@ void RunOperator(const platform::Place &place, const std::string &op_type,
pool.Get(place)->Wait(); pool.Get(place)->Wait();
} }
TEST(test_softmax_reuse_cache, cpu_place) { TEST(test_conv2d_reuse_cache, cpu_place) {
framework::DDim dims({1, 16, 32, 64}); framework::DDim dims({1, 16, 32, 64});
platform::CPUPlace p; platform::CPUPlace p;
CacheTester ct; CacheTester ct;
RunOperator<float>(p, "conv2d", dims, "conv_out"); RunOperator<float>(p, "conv2d", dims, "input_signal");
RunOperator<float>(p, "conv2d", dims, "conv_out"); RunOperator<float>(p, "conv2d", dims, "input_signal");
PADDLE_ENFORCE_EQ(ct.Analyze(4), true, PADDLE_ENFORCE_EQ(ct.Analyze(9), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects")); "Wrong number of cached oneDNN objects"));
} }
TEST(test_softmax_noreuse_cache, cpu_place) { TEST(test_conv2d_noreuse_cache, cpu_place) {
framework::DDim dims({1, 16, 32, 64}); framework::DDim dims({1, 16, 32, 64});
platform::CPUPlace p; platform::CPUPlace p;
CacheTester ct; CacheTester ct;
RunOperator<float>(p, "conv2d", dims, "conv_out"); RunOperator<float>(p, "conv2d", dims, "input_signal");
RunOperator<float>(p, "conv2d", dims, "conv_out2"); RunOperator<float>(p, "conv2d", dims, "input_signal2");
PADDLE_ENFORCE_EQ(ct.Analyze(8), true, PADDLE_ENFORCE_EQ(ct.Analyze(9), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects")); "Wrong number of cached oneDNN objects"));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册