未验证 提交 c52e6647 编写于 作者: Y Yanzhan Yang 提交者: GitHub

fix yolov3 test=develop (#1875)

上级 ebcefbba
...@@ -16,60 +16,37 @@ import re ...@@ -16,60 +16,37 @@ import re
import os import os
import sys import sys
source = """ def gen_opencl_kernels():
#pragma source = """
#ifdef PADDLE_MOBILE_CL #pragma
#include <map> #ifdef PADDLE_MOBILE_CL
#include <string> #include <map>
#include <vector> #include <string>
namespace paddle_mobile { #include <vector>
extern const std::map<std::string, std::vector<unsigned char>> opencl_kernels = { namespace paddle_mobile {
%s extern const std::map<std::string, std::vector<unsigned char>> opencl_kernels = {
}; %s
extern const std::vector<std::string> need_conv_header_kernels = { };
%s extern const std::vector<std::string> need_conv_header_kernels = {
}; %s
} };
#endif }
""" #endif
"""
def string_to_hex(str): def string_to_hex(str):
hex_list = [] hex_list = []
for i in range(len(code_str)): for i in range(len(code_str)):
hex_ = hex(ord(code_str[i])) hex_ = hex(ord(code_str[i]))
hex_list.append(hex_) hex_list.append(hex_)
return hex_list return hex_list
infile = open("cl_kernel/cl_common.h", "r") infile = open("cl_kernel/cl_common.h", "r")
common_content = infile.read() common_content = infile.read()
infile.close()
common_content = re.sub(r"/\*[^*]*\*/", "", common_content, flags=re.DOTALL)
lines = common_content.split("\n")
new_lines = []
for i in range(len(lines)):
line = lines[i]
line = line.strip()
if line == "":
continue
if line.startswith("//"):
continue
line = re.sub(r"//.*$", "", line)
new_lines.append(line)
common_content = "\n".join(new_lines)
need_conv_header_kernels = []
cores = ""
filenames = os.listdir("cl_kernel")
file_count = len(filenames)
for i in range(file_count):
filename = filenames[i]
infile = open("cl_kernel/" + filename, "r")
new_lines = []
content = infile.read()
content = re.sub(r"/\*[^*]*\*/", "", content, flags=re.DOTALL)
infile.close() infile.close()
lines = content.split("\n") common_content = re.sub(r"/\*[^*]*\*/", "", common_content, flags=re.DOTALL)
lines = common_content.split("\n")
new_lines = []
for i in range(len(lines)): for i in range(len(lines)):
line = lines[i] line = lines[i]
line = line.strip() line = line.strip()
...@@ -78,26 +55,73 @@ for i in range(file_count): ...@@ -78,26 +55,73 @@ for i in range(file_count):
if line.startswith("//"): if line.startswith("//"):
continue continue
line = re.sub(r"//.*$", "", line) line = re.sub(r"//.*$", "", line)
if "cl_common.h" in line:
line = common_content
elif "conv_kernel.inc.cl" in line:
need_conv_header_kernels.append("\"%s\"" % filename)
continue
new_lines.append(line) new_lines.append(line)
content = "\n".join(new_lines) common_content = "\n".join(new_lines)
if content == "":
content = " " need_conv_header_kernels = []
hexes = []
for char in content: cores = ""
hexes.append(hex(ord(char))) filenames = os.listdir("cl_kernel")
core = " {\"%s\", {" % filename file_count = len(filenames)
for item in hexes: for i in range(file_count):
core += str(item) + ", " filename = filenames[i]
core = core[: -2] infile = open("cl_kernel/" + filename, "r")
core += "}}" new_lines = []
if i != file_count - 1: content = infile.read()
core += ",\n" content = re.sub(r"/\*[^*]*\*/", "", content, flags=re.DOTALL)
cores += core infile.close()
lines = content.split("\n")
for i in range(len(lines)):
line = lines[i]
line = line.strip()
if line == "":
continue
if line.startswith("//"):
continue
line = re.sub(r"//.*$", "", line)
if "cl_common.h" in line:
line = common_content
elif "conv_kernel.inc.cl" in line:
need_conv_header_kernels.append("\"%s\"" % filename)
continue
new_lines.append(line)
content = "\n".join(new_lines)
if content == "":
content = " "
hexes = []
for char in content:
hexes.append(hex(ord(char)))
core = " {\"%s\", {" % filename
for item in hexes:
core += str(item) + ", "
core = core[: -2]
core += "}}"
if i != file_count - 1:
core += ",\n"
cores += core
source = source % (cores, ",".join(need_conv_header_kernels))
print(source)
def gen_empty_opencl_kernels():
source = """
#pragma
#ifdef PADDLE_MOBILE_CL
#include <map>
#include <string>
#include <vector>
namespace paddle_mobile {
extern const std::map<std::string, std::vector<unsigned char>> opencl_kernels = {
};
extern const std::vector<std::string> need_conv_header_kernels = {
};
}
#endif
"""
print(source)
source = source % (cores, ",".join(need_conv_header_kernels)) if __name__ == "__main__":
print(source) if sys.argv[1] == "0":
gen_empty_opencl_kernels()
elif sys.argv[1] == "1":
gen_opencl_kernels()
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#ifdef SLICE_OP #ifdef SLICE_OP
#include "operators/slice_op.h" #include "operators/slice_op.h"
#include <algorithm>
#include <vector> #include <vector>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -49,18 +51,12 @@ void SliceOp<Dtype, T>::InferShape() const { ...@@ -49,18 +51,12 @@ void SliceOp<Dtype, T>::InferShape() const {
PADDLE_MOBILE_ENFORCE(axes.size() == 1, "axes size should equals 1"); PADDLE_MOBILE_ENFORCE(axes.size() == 1, "axes size should equals 1");
PADDLE_MOBILE_ENFORCE(input->dims().size() == output->dims().size(), PADDLE_MOBILE_ENFORCE(input->dims().size() == output->dims().size(),
"input dim size should equals output dim size"); "input dim size should equals output dim size");
#ifdef PADDLE_MOBILE_CL
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
input->dims().size() - output->dims().size() -
(axes[0] - (this->param_.original_output_dims_size_ - (axes[0] - (this->param_.original_output_dims_size_ -
this->param_.output_->dims().size())) == this->param_.output_->dims().size())) ==
3, 3,
"op only support slice channel now"); "op only support slice channel now");
#endif
if (input->dims().size() >= 4) {
PADDLE_MOBILE_ENFORCE(input->dims().size() - axes[0] == 3,
"op only support slice channel now");
}
auto starts = this->param_.starts_; auto starts = this->param_.starts_;
auto ends = this->param_.ends_; auto ends = this->param_.ends_;
framework::DDim out_dims(input->dims()); framework::DDim out_dims(input->dims());
...@@ -70,7 +66,9 @@ void SliceOp<Dtype, T>::InferShape() const { ...@@ -70,7 +66,9 @@ void SliceOp<Dtype, T>::InferShape() const {
"axes.size should equal starts.size"); "axes.size should equal starts.size");
int dim_value, start, end; int dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
dim_value = out_dims[axes[i]]; int axis = axes[i] - (this->param_.original_output_dims_size_ -
this->param_.output_->dims().size());
dim_value = out_dims[axis];
if (dim_value > 0) { if (dim_value > 0) {
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i]; start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i]; end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
...@@ -80,7 +78,7 @@ void SliceOp<Dtype, T>::InferShape() const { ...@@ -80,7 +78,7 @@ void SliceOp<Dtype, T>::InferShape() const {
end = std::min(end, dim_value); end = std::min(end, dim_value);
// start = std::min(start, end); // start = std::min(start, end);
PADDLE_MOBILE_ENFORCE(end > start, "end should greater than start"); PADDLE_MOBILE_ENFORCE(end > start, "end should greater than start");
out_dims[axes[i]] = end - start; out_dims[axis] = end - start;
} }
} }
output->Resize(out_dims); output->Resize(out_dims);
......
...@@ -3,13 +3,11 @@ NETS="" ...@@ -3,13 +3,11 @@ NETS=""
declare -a supportedNets=("googlenet" "mobilenet" "yolo" "squeezenet" "resnet" "mobilenetssd" "nlp" "mobilenetfssd" "genet" "super" "op") declare -a supportedNets=("googlenet" "mobilenet" "yolo" "squeezenet" "resnet" "mobilenetssd" "nlp" "mobilenetfssd" "genet" "super" "op")
# merge cl to so # merge cl to so
merge_cl_to_so=1 merge_cl_to_so=0
rm ../src/operators/kernel/cl/opencl_kernels.cpp rm ../src/operators/kernel/cl/opencl_kernels.cpp
if [ $merge_cl_to_so == 1 ]; then cd ../src/operators/kernel/cl
cd ../src/operators/kernel/cl python gen_code.py $merge_cl_to_so > opencl_kernels.cpp
python gen_code.py > opencl_kernels.cpp cd -
cd -
fi
build_for_mac() { build_for_mac() {
if [ ! `which brew` ]; then if [ ! `which brew` ]; then
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册