From c52e66476fb71da2a86f3456863f4256494f20a7 Mon Sep 17 00:00:00 2001 From: Yanzhan Yang Date: Tue, 27 Aug 2019 15:20:52 +0800 Subject: [PATCH] fix yolov3 test=develop (#1875) --- mobile/src/operators/kernel/cl/gen_code.py | 168 ++++++++++++--------- mobile/src/operators/slice_op.cpp | 16 +- mobile/tools/build.sh | 10 +- 3 files changed, 107 insertions(+), 87 deletions(-) diff --git a/mobile/src/operators/kernel/cl/gen_code.py b/mobile/src/operators/kernel/cl/gen_code.py index 14608c95fc..6cbaf1a152 100644 --- a/mobile/src/operators/kernel/cl/gen_code.py +++ b/mobile/src/operators/kernel/cl/gen_code.py @@ -16,60 +16,37 @@ import re import os import sys -source = """ -#pragma -#ifdef PADDLE_MOBILE_CL -#include -#include -#include -namespace paddle_mobile { - extern const std::map> opencl_kernels = { -%s - }; - extern const std::vector need_conv_header_kernels = { - %s - }; -} -#endif -""" +def gen_opencl_kernels(): + source = """ + #pragma + #ifdef PADDLE_MOBILE_CL + #include + #include + #include + namespace paddle_mobile { + extern const std::map> opencl_kernels = { + %s + }; + extern const std::vector need_conv_header_kernels = { + %s + }; + } + #endif + """ -def string_to_hex(str): - hex_list = [] - for i in range(len(code_str)): - hex_ = hex(ord(code_str[i])) - hex_list.append(hex_) - return hex_list + def string_to_hex(str): + hex_list = [] + for i in range(len(code_str)): + hex_ = hex(ord(code_str[i])) + hex_list.append(hex_) + return hex_list -infile = open("cl_kernel/cl_common.h", "r") -common_content = infile.read() -infile.close() -common_content = re.sub(r"/\*[^*]*\*/", "", common_content, flags=re.DOTALL) -lines = common_content.split("\n") -new_lines = [] -for i in range(len(lines)): - line = lines[i] - line = line.strip() - if line == "": - continue - if line.startswith("//"): - continue - line = re.sub(r"//.*$", "", line) - new_lines.append(line) -common_content = "\n".join(new_lines) - -need_conv_header_kernels = [] - -cores = "" -filenames = os.listdir("cl_kernel") -file_count = len(filenames) -for i in range(file_count): - filename = filenames[i] - infile = open("cl_kernel/" + filename, "r") - new_lines = [] - content = infile.read() - content = re.sub(r"/\*[^*]*\*/", "", content, flags=re.DOTALL) + infile = open("cl_kernel/cl_common.h", "r") + common_content = infile.read() infile.close() - lines = content.split("\n") + common_content = re.sub(r"/\*[^*]*\*/", "", common_content, flags=re.DOTALL) + lines = common_content.split("\n") + new_lines = [] for i in range(len(lines)): line = lines[i] line = line.strip() @@ -78,26 +55,73 @@ for i in range(file_count): if line.startswith("//"): continue line = re.sub(r"//.*$", "", line) - if "cl_common.h" in line: - line = common_content - elif "conv_kernel.inc.cl" in line: - need_conv_header_kernels.append("\"%s\"" % filename) - continue new_lines.append(line) - content = "\n".join(new_lines) - if content == "": - content = " " - hexes = [] - for char in content: - hexes.append(hex(ord(char))) - core = " {\"%s\", {" % filename - for item in hexes: - core += str(item) + ", " - core = core[: -2] - core += "}}" - if i != file_count - 1: - core += ",\n" - cores += core + common_content = "\n".join(new_lines) + + need_conv_header_kernels = [] + + cores = "" + filenames = os.listdir("cl_kernel") + file_count = len(filenames) + for i in range(file_count): + filename = filenames[i] + infile = open("cl_kernel/" + filename, "r") + new_lines = [] + content = infile.read() + content = re.sub(r"/\*[^*]*\*/", "", content, flags=re.DOTALL) + infile.close() + lines = content.split("\n") + for i in range(len(lines)): + line = lines[i] + line = line.strip() + if line == "": + continue + if line.startswith("//"): + continue + line = re.sub(r"//.*$", "", line) + if "cl_common.h" in line: + line = common_content + elif "conv_kernel.inc.cl" in line: + need_conv_header_kernels.append("\"%s\"" % filename) + continue + new_lines.append(line) + content = "\n".join(new_lines) + if content == "": + content = " " + hexes = [] + for char in content: + hexes.append(hex(ord(char))) + core = " {\"%s\", {" % filename + for item in hexes: + core += str(item) + ", " + core = core[: -2] + core += "}}" + if i != file_count - 1: + core += ",\n" + cores += core + + source = source % (cores, ",".join(need_conv_header_kernels)) + print(source) + +def gen_empty_opencl_kernels(): + source = """ + #pragma + #ifdef PADDLE_MOBILE_CL + #include + #include + #include + namespace paddle_mobile { + extern const std::map> opencl_kernels = { + }; + extern const std::vector need_conv_header_kernels = { + }; + } + #endif + """ + print(source) -source = source % (cores, ",".join(need_conv_header_kernels)) -print(source) +if __name__ == "__main__": + if sys.argv[1] == "0": + gen_empty_opencl_kernels() + elif sys.argv[1] == "1": + gen_opencl_kernels() diff --git a/mobile/src/operators/slice_op.cpp b/mobile/src/operators/slice_op.cpp index 6351d2d028..6107b92679 100644 --- a/mobile/src/operators/slice_op.cpp +++ b/mobile/src/operators/slice_op.cpp @@ -15,7 +15,9 @@ limitations under the License. */ #ifdef SLICE_OP #include "operators/slice_op.h" +#include #include + namespace paddle_mobile { namespace operators { @@ -49,18 +51,12 @@ void SliceOp::InferShape() const { PADDLE_MOBILE_ENFORCE(axes.size() == 1, "axes size should equals 1"); PADDLE_MOBILE_ENFORCE(input->dims().size() == output->dims().size(), "input dim size should equals output dim size"); -#ifdef PADDLE_MOBILE_CL PADDLE_MOBILE_ENFORCE( - input->dims().size() - + output->dims().size() - (axes[0] - (this->param_.original_output_dims_size_ - this->param_.output_->dims().size())) == 3, "op only support slice channel now"); -#endif - if (input->dims().size() >= 4) { - PADDLE_MOBILE_ENFORCE(input->dims().size() - axes[0] == 3, - "op only support slice channel now"); - } auto starts = this->param_.starts_; auto ends = this->param_.ends_; framework::DDim out_dims(input->dims()); @@ -70,7 +66,9 @@ void SliceOp::InferShape() const { "axes.size should equal starts.size"); int dim_value, start, end; for (size_t i = 0; i < axes.size(); ++i) { - dim_value = out_dims[axes[i]]; + int axis = axes[i] - (this->param_.original_output_dims_size_ - + this->param_.output_->dims().size()); + dim_value = out_dims[axis]; if (dim_value > 0) { start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i]; end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i]; @@ -80,7 +78,7 @@ void SliceOp::InferShape() const { end = std::min(end, dim_value); // start = std::min(start, end); PADDLE_MOBILE_ENFORCE(end > start, "end should greater than start"); - out_dims[axes[i]] = end - start; + out_dims[axis] = end - start; } } output->Resize(out_dims); diff --git a/mobile/tools/build.sh b/mobile/tools/build.sh index f0e192805b..8f3a17ef7b 100755 --- a/mobile/tools/build.sh +++ b/mobile/tools/build.sh @@ -3,13 +3,11 @@ NETS="" declare -a supportedNets=("googlenet" "mobilenet" "yolo" "squeezenet" "resnet" "mobilenetssd" "nlp" "mobilenetfssd" "genet" "super" "op") # merge cl to so -merge_cl_to_so=1 +merge_cl_to_so=0 rm ../src/operators/kernel/cl/opencl_kernels.cpp -if [ $merge_cl_to_so == 1 ]; then - cd ../src/operators/kernel/cl - python gen_code.py > opencl_kernels.cpp - cd - -fi +cd ../src/operators/kernel/cl +python gen_code.py $merge_cl_to_so > opencl_kernels.cpp +cd - build_for_mac() { if [ ! `which brew` ]; then -- GitLab