diff --git a/mobile/src/operators/reshape2_op.cpp b/mobile/src/operators/reshape2_op.cpp index fd95cad44a3475b59b7ed9c280ac02d3c061cd94..376416b14b6c1e2db328cc803f6454d446bfd183 100644 --- a/mobile/src/operators/reshape2_op.cpp +++ b/mobile/src/operators/reshape2_op.cpp @@ -22,6 +22,9 @@ namespace operators { template void Reshape2Op::InferShape() const { + if (this->param_.InputShape() != nullptr) { + return; + } auto &shape = this->param_.Shape(); auto input_x_dims = this->param_.InputX()->dims(); #ifdef PADDLE_MOBILE_CL diff --git a/mobile/test/CMakeLists.txt b/mobile/test/CMakeLists.txt index 26c183963d9174b572461781d174196bab425dda..36293ab8846741fd7e5c4de66fe6537eca277270 100644 --- a/mobile/test/CMakeLists.txt +++ b/mobile/test/CMakeLists.txt @@ -214,10 +214,6 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test_yolo_combined net/test_yolo_combined.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test_yolo_combined paddle-mobile) - # gen test - ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h) - target_link_libraries(test-net paddle-mobile) - # gen test ADD_EXECUTABLE(test-op-in-net net/test_op_in_net.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-op-in-net paddle-mobile) @@ -527,4 +523,8 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-net-benchmark net/test_net_benchmark.cpp test_helper.h test_include.h) target_link_libraries(test-net-benchmark paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-net paddle-mobile) endif () diff --git a/mobile/test/net/test_net.cpp b/mobile/test/net/test_net.cpp index 36301b544fc34d0631a97f962ad76f6abe4fea93..7e904b48af86e4bbd3a7f59426a35446a2c7114e 100644 --- a/mobile/test/net/test_net.cpp +++ b/mobile/test/net/test_net.cpp @@ -93,6 +93,8 @@ void test(int argc, char *argv[]) { var_names.push_back(var_name); } arg_index += var_count; + bool check_shape = std::stoi(argv[arg_index]) == 1; + arg_index++; auto time1 = time(); if (paddle_mobile.Load("./checked_model/model", "./checked_model/params", @@ -194,6 +196,11 @@ void test(int argc, char *argv[]) { auto data = tensor_data; std::string sample = ""; + if (check_shape) { + for (int i = 0; i < cl_image->dims().size(); i++) { + sample += " " + std::to_string(cl_image->dims()[i]); + } + } if (!is_sample_step) { sample_step = len / sample_num; } @@ -219,6 +226,11 @@ void test(int argc, char *argv[]) { if (out->type() == type_id()) { auto data = out->data(); std::string sample = ""; + if (check_shape) { + for (int i = 0; i < out->dims().size(); i++) { + sample += " " + std::to_string(out->dims()[i]); + } + } if (!is_sample_step) { sample_step = len / sample_num; } @@ -233,6 +245,11 @@ void test(int argc, char *argv[]) { } else if (out->type() == type_id()) { auto data = out->data(); std::string sample = ""; + if (check_shape) { + for (int i = 0; i < out->dims().size(); i++) { + sample += " " + std::to_string(out->dims()[i]); + } + } if (!is_sample_step) { sample_step = len / sample_num; } diff --git a/mobile/tools/python/fluidtools/run.py b/mobile/tools/python/fluidtools/run.py index fcbdc8d1e7f4ba7cb71eda83e5da4558db2508b3..fc65f19a1dfc0e3fce2c55f487ba901cd9132242 100644 --- a/mobile/tools/python/fluidtools/run.py +++ b/mobile/tools/python/fluidtools/run.py @@ -19,6 +19,9 @@ sample_step = 1 sample_num = 20 need_encrypt = False checked_encrypt_model_path = "checked_encrypt_model" +output_var_filter = [] +output_key_filter = {} +check_shape = False np.set_printoptions(linewidth=150) @@ -282,6 +285,8 @@ def save_all_op_output(feed_kv=None): for fetch in fetches: fetch_names.append(fetch.name) feed_names = feeds + for fetch_name in fetch_names: + output_var_filter.append(fetch_name) for i in range(len(ops)): op = ops[i] var_name = None @@ -297,6 +302,53 @@ def save_all_op_output(feed_kv=None): var_name = name if "tmp" in name: break + if len(output_var_filter) > 0: + if var_name not in output_var_filter: + continue + # real_var_name = None + # if op.type == "fetch": + # for name in op.input_arg_names: + # real_var_name = name + # if "tmp" in name: + # break + # else: + # real_var_name = var_name + if fast_check: + if var_name not in fetch_names and var_name not in feed_names: + continue + try: + data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist() + sample = tensor_sample(data) + output_var_cache[var_name] = (sample) + op_cache[i] = (var_name, op) + file_name = var_name.replace("/", "_") + out_file = open(output_path + "/" + file_name, "w") + if var_name in feed_names: + for item in data: + out_file.write("{}\n".format(item)) + else: + for item in sample: + out_file.write("{}\n".format(item)) + out_file.close() + except: + pass + for i in range(len(ops)): + op = ops[i] + if op.type not in output_key_filter: + continue + var_name = None + var_name_index = -1 + for index in range(len(op.output_names)): + if op.output_names[index] in output_key_filter[op.type]: + var_name_index = index + break + if var_name_index != -1: + var_name = op.output_arg_names[var_name_index] + else: + continue + if len(output_var_filter) > 0: + if var_name not in output_var_filter: + continue # real_var_name = None # if op.type == "fetch": # for name in op.input_arg_names: @@ -386,12 +438,19 @@ def check_mobile_results(args, fuse, mem_opt): continue values1 = output_var_cache[op_output_var_name] values2 = mobile_var_cache[op_output_var_name] - if len(values1) != len(values2): + shape = get_var_shape(op_output_var_name) if check_shape else [] + if len(values1) + len(shape) != len(values2): error_index = index + for i in range(len(shape)): + v1 = shape[i] + v2 = values2[i] + if v1 != v2: + error_index = index + break if error_index == None: for i in range(len(values1)): v1 = values1[i] - v2 = values2[i] + v2 = values2[len(shape) + i] if abs(v1 - v2) > diff_threshold: error_index = index break @@ -496,6 +555,7 @@ def main(): args += " " + str(sample_num) for var_name in output_var_cache.keys(): args += " " + var_name + args += " " + str(1 if check_shape else 0) if not fast_check: check_mobile_results(args, False, False) check_mobile_results(args, False, True)