未验证 提交 6fc8574a 编写于 作者: Y Yanzhan Yang 提交者: GitHub

fix reshape2 bug && enhance auto test (#1827)

上级 411b24e3
......@@ -22,6 +22,9 @@ namespace operators {
template <typename Dtype, typename T>
void Reshape2Op<Dtype, T>::InferShape() const {
if (this->param_.InputShape() != nullptr) {
return;
}
auto &shape = this->param_.Shape();
auto input_x_dims = this->param_.InputX()->dims();
#ifdef PADDLE_MOBILE_CL
......
......@@ -214,10 +214,6 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test_yolo_combined net/test_yolo_combined.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test_yolo_combined paddle-mobile)
# gen test
ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-net paddle-mobile)
# gen test
ADD_EXECUTABLE(test-op-in-net net/test_op_in_net.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-op-in-net paddle-mobile)
......@@ -527,4 +523,8 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-net-benchmark net/test_net_benchmark.cpp test_helper.h test_include.h)
target_link_libraries(test-net-benchmark paddle-mobile)
# gen test
ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-net paddle-mobile)
endif ()
......@@ -93,6 +93,8 @@ void test(int argc, char *argv[]) {
var_names.push_back(var_name);
}
arg_index += var_count;
bool check_shape = std::stoi(argv[arg_index]) == 1;
arg_index++;
auto time1 = time();
if (paddle_mobile.Load("./checked_model/model", "./checked_model/params",
......@@ -194,6 +196,11 @@ void test(int argc, char *argv[]) {
auto data = tensor_data;
std::string sample = "";
if (check_shape) {
for (int i = 0; i < cl_image->dims().size(); i++) {
sample += " " + std::to_string(cl_image->dims()[i]);
}
}
if (!is_sample_step) {
sample_step = len / sample_num;
}
......@@ -219,6 +226,11 @@ void test(int argc, char *argv[]) {
if (out->type() == type_id<int>()) {
auto data = out->data<int>();
std::string sample = "";
if (check_shape) {
for (int i = 0; i < out->dims().size(); i++) {
sample += " " + std::to_string(out->dims()[i]);
}
}
if (!is_sample_step) {
sample_step = len / sample_num;
}
......@@ -233,6 +245,11 @@ void test(int argc, char *argv[]) {
} else if (out->type() == type_id<float>()) {
auto data = out->data<float>();
std::string sample = "";
if (check_shape) {
for (int i = 0; i < out->dims().size(); i++) {
sample += " " + std::to_string(out->dims()[i]);
}
}
if (!is_sample_step) {
sample_step = len / sample_num;
}
......
......@@ -19,6 +19,9 @@ sample_step = 1
sample_num = 20
need_encrypt = False
checked_encrypt_model_path = "checked_encrypt_model"
output_var_filter = []
output_key_filter = {}
check_shape = False
np.set_printoptions(linewidth=150)
......@@ -282,6 +285,8 @@ def save_all_op_output(feed_kv=None):
for fetch in fetches:
fetch_names.append(fetch.name)
feed_names = feeds
for fetch_name in fetch_names:
output_var_filter.append(fetch_name)
for i in range(len(ops)):
op = ops[i]
var_name = None
......@@ -297,6 +302,53 @@ def save_all_op_output(feed_kv=None):
var_name = name
if "tmp" in name:
break
if len(output_var_filter) > 0:
if var_name not in output_var_filter:
continue
# real_var_name = None
# if op.type == "fetch":
# for name in op.input_arg_names:
# real_var_name = name
# if "tmp" in name:
# break
# else:
# real_var_name = var_name
if fast_check:
if var_name not in fetch_names and var_name not in feed_names:
continue
try:
data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist()
sample = tensor_sample(data)
output_var_cache[var_name] = (sample)
op_cache[i] = (var_name, op)
file_name = var_name.replace("/", "_")
out_file = open(output_path + "/" + file_name, "w")
if var_name in feed_names:
for item in data:
out_file.write("{}\n".format(item))
else:
for item in sample:
out_file.write("{}\n".format(item))
out_file.close()
except:
pass
for i in range(len(ops)):
op = ops[i]
if op.type not in output_key_filter:
continue
var_name = None
var_name_index = -1
for index in range(len(op.output_names)):
if op.output_names[index] in output_key_filter[op.type]:
var_name_index = index
break
if var_name_index != -1:
var_name = op.output_arg_names[var_name_index]
else:
continue
if len(output_var_filter) > 0:
if var_name not in output_var_filter:
continue
# real_var_name = None
# if op.type == "fetch":
# for name in op.input_arg_names:
......@@ -386,12 +438,19 @@ def check_mobile_results(args, fuse, mem_opt):
continue
values1 = output_var_cache[op_output_var_name]
values2 = mobile_var_cache[op_output_var_name]
if len(values1) != len(values2):
shape = get_var_shape(op_output_var_name) if check_shape else []
if len(values1) + len(shape) != len(values2):
error_index = index
for i in range(len(shape)):
v1 = shape[i]
v2 = values2[i]
if v1 != v2:
error_index = index
break
if error_index == None:
for i in range(len(values1)):
v1 = values1[i]
v2 = values2[i]
v2 = values2[len(shape) + i]
if abs(v1 - v2) > diff_threshold:
error_index = index
break
......@@ -496,6 +555,7 @@ def main():
args += " " + str(sample_num)
for var_name in output_var_cache.keys():
args += " " + var_name
args += " " + str(1 if check_shape else 0)
if not fast_check:
check_mobile_results(args, False, False)
check_mobile_results(args, False, True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册