提交 4c9a444e 编写于 作者: 刘托

Merge branch 'master' into 'master'

Add Dockerfile arg and fix some bugs

See merge request !962
......@@ -66,7 +66,7 @@ class OpenCLBufferTransformer {
VLOG(2) << "Transform CPU Buffer " << input->name()
<< " to GPU Buffer " << internal_tensor->name()
<< " with data type " << dt;
if (data_format == DataFormat::NHWC && input->shape().size() == 4) {
if (data_format == DataFormat::NCHW && input->shape().size() == 4) {
// 1. (NCHW -> NHWC)
std::vector<int> dst_dims = {0, 2, 3, 1};
std::vector<index_t> output_shape =
......
......@@ -72,10 +72,9 @@ class PriorBoxOp : public Operation {
Tensor::MappingGuard output_guard(output);
T *output_data = output->mutable_data<T>();
float box_w, box_h;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < input_h; ++i) {
index_t idx = i * input_w * num_prior * 4;
for (index_t j = 0; j < input_w; ++j) {
index_t idx = i * input_w * num_prior * 4;
float center_y = (offset_ + i) * step_h;
float center_x = (offset_ + j) * step_w;
for (index_t k = 0; k < num_min_size; ++k) {
......
......@@ -77,6 +77,14 @@ class ReshapeOp : public Operation {
}
Tensor *output = this->Output(OUTPUT);
// NCHW -> NHWC
if (D == DeviceType::GPU && out_shape.size() == 4) {
std::vector<int> dst_dims = {0, 2, 3, 1};
std::vector<index_t> out_shape_gpu = TransposeShape<index_t, index_t>(
out_shape, dst_dims);
out_shape = out_shape_gpu;
}
output->ReuseTensorBuffer(*input);
output->Reshape(out_shape);
......
......@@ -209,6 +209,30 @@ def sha256_checksum(fname):
return hash_func.hexdigest()
def get_dockerfile_file(dockerfile_path="",
dockerfile_sha256_checksum=""):
dockerfile = dockerfile_path
if dockerfile_path.startswith("http://") or \
dockerfile_path.startswith("https://"):
dockerfile = \
"third_party/caffe/" + md5sum(dockerfile_path) + "/Dockerfile"
if not os.path.exists(dockerfile) or \
sha256_checksum(dockerfile) != dockerfile_sha256_checksum:
os.makedirs(dockerfile.strip("/Dockerfile"))
MaceLogger.info("Downloading Dockerfile, please wait ...")
six.moves.urllib.request.urlretrieve(dockerfile_path, dockerfile)
MaceLogger.info("Dockerfile downloaded successfully.")
if dockerfile:
if sha256_checksum(dockerfile) != dockerfile_sha256_checksum:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"Dockerfile sha256checksum not match")
else:
dockerfile = "third_party/caffe"
return dockerfile
def get_model_files(model_file_path,
model_sha256_checksum,
model_output_dir,
......@@ -373,6 +397,8 @@ class YAMLKeyword(object):
graph_optimize_options = 'graph_optimize_options' # internal use for now
cl_mem_type = 'cl_mem_type'
backend = 'backend'
dockerfile_path = 'dockerfile_path'
dockerfile_sha256_checksum = 'dockerfile_sha256_checksum'
################################
......
......@@ -624,11 +624,18 @@ class DeviceWrapper:
validate_type = device_type
if model_config[YAMLKeyword.quantize] == 1:
validate_type = device_type + '_QUANTIZE'
dockerfile_path = get_dockerfile_file(
model_config.get(YAMLKeyword.dockerfile_path),
model_config.get(YAMLKeyword.dockerfile_sha256_checksum) # noqa
) if YAMLKeyword.dockerfile_path in model_config else "third_party/caffe" # noqa
sh_commands.validate_model(
abi=target_abi,
device=self,
model_file_path=model_file_path,
weight_file_path=weight_file_path,
dockerfile_path=dockerfile_path,
platform=model_config[YAMLKeyword.platform],
device_type=device_type,
input_nodes=subgraphs[0][
......
......@@ -627,6 +627,7 @@ def validate_model(abi,
device,
model_file_path,
weight_file_path,
dockerfile_path,
platform,
device_type,
input_nodes,
......@@ -690,7 +691,7 @@ def validate_model(abi,
if not docker_image_id:
six.print_("Build caffe docker")
sh.docker("build", "-t", image_name,
"third_party/caffe")
dockerfile_path)
container_id = sh.docker("ps", "-qa", "-f",
"name=%s" % container_name)
......
......@@ -357,6 +357,11 @@ def parse_args():
type=str,
default="tensorflow",
help="onnx backend framwork")
parser.add_argument(
"--log_file",
type=str,
default="",
help="log file")
return parser.parse_known_args()
......@@ -375,4 +380,5 @@ if __name__ == '__main__':
FLAGS.output_node,
FLAGS.validation_threshold,
FLAGS.input_data_type,
FLAGS.backend)
FLAGS.backend,
FLAGS.log_file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册