提交 285a94e3 编写于 作者: 叶剑武

Merge branch 'master' into 'master'

Fix caffe validate docker bug

See merge request !974
......@@ -94,9 +94,7 @@ MaceStatus ResizeNearestNeighborKernel<T>::Compute(
const index_t in_height = input->dim(1);
const index_t in_width = input->dim(2);
const index_t channels = input->dim(3);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard size_mapper(size);
Tensor::MappingGuard output_mapper(output);
const index_t out_height = size->data<int32_t>()[0];
const index_t out_width = size->data<int32_t>()[1];
const index_t channel_blocks = RoundUpDiv4(channels);
......
......@@ -78,6 +78,7 @@ class ResizeNearestNeighborOp<DeviceType::CPU, T> : public Operation {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
const Tensor *size = this->Input(1);
Tensor::MappingGuard size_mapper(size);
Tensor *output = this->Output(0);
MACE_CHECK(input->dim_size() == 4 && size->dim_size() == 1,
......@@ -95,7 +96,6 @@ class ResizeNearestNeighborOp<DeviceType::CPU, T> : public Operation {
std::vector<index_t> out_shape{batch, channels, out_height, out_width};
MACE_RETURN_IF_ERROR(output->Resize(out_shape));
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard size_mapper(size);
Tensor::MappingGuard output_mapper(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
......
......@@ -107,9 +107,7 @@ void TestRandomResizeNearestNeighbor() {
{batch, in_height, in_width, channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
net.AddInputFromArray<D, int32_t>("Size",
{2}, size);
net.AddInputFromArray<D, int32_t>("Size", {2}, size);
OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest")
.Input("InputNCHW")
.Input("Size")
......
......@@ -209,28 +209,32 @@ def sha256_checksum(fname):
return hash_func.hexdigest()
def get_dockerfile_file(dockerfile_path="",
dockerfile_sha256_checksum=""):
dockerfile = dockerfile_path
def get_dockerfile_info(dockerfile_path="",
dockerfile_sha256_checksum="",
docker_image_tag=""):
dockerfile_local_path = ""
if dockerfile_path.startswith("http://") or \
dockerfile_path.startswith("https://"):
dockerfile = \
"third_party/caffe/" + md5sum(dockerfile_path) + "/Dockerfile"
dockerfile_local_path = \
"third_party/caffe/" + docker_image_tag
dockerfile = dockerfile_local_path + "/Dockerfile"
if not os.path.exists(dockerfile_local_path):
os.makedirs(dockerfile_local_path)
if not os.path.exists(dockerfile) or \
sha256_checksum(dockerfile) != dockerfile_sha256_checksum:
os.makedirs(dockerfile.strip("/Dockerfile"))
MaceLogger.info("Downloading Dockerfile, please wait ...")
six.moves.urllib.request.urlretrieve(dockerfile_path, dockerfile)
MaceLogger.info("Dockerfile downloaded successfully.")
if dockerfile:
if dockerfile_local_path:
if sha256_checksum(dockerfile) != dockerfile_sha256_checksum:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"Dockerfile sha256checksum not match")
else:
dockerfile = "third_party/caffe"
dockerfile_local_path = "third_party/caffe"
docker_image_tag = "lastest"
return dockerfile
return dockerfile_local_path, docker_image_tag
def get_model_files(model_file_path,
......@@ -397,6 +401,7 @@ class YAMLKeyword(object):
graph_optimize_options = 'graph_optimize_options' # internal use for now
cl_mem_type = 'cl_mem_type'
backend = 'backend'
docker_image_tag = 'docker_image_tag'
dockerfile_path = 'dockerfile_path'
dockerfile_sha256_checksum = 'dockerfile_sha256_checksum'
......
......@@ -626,16 +626,21 @@ class DeviceWrapper:
if model_config[YAMLKeyword.quantize] == 1:
validate_type = device_type + '_QUANTIZE'
dockerfile_path = get_dockerfile_file(
model_config.get(YAMLKeyword.dockerfile_path),
model_config.get(YAMLKeyword.dockerfile_sha256_checksum) # noqa
) if YAMLKeyword.dockerfile_path in model_config else "third_party/caffe" # noqa
dockerfile_path, docker_image_tag = \
get_dockerfile_info(
model_config.get(YAMLKeyword.dockerfile_path),
model_config.get(
YAMLKeyword.dockerfile_sha256_checksum),
model_config.get(YAMLKeyword.docker_image_tag)
) if YAMLKeyword.dockerfile_path in model_config \
else ("third_party/caffe", "lastest")
sh_commands.validate_model(
abi=target_abi,
device=self,
model_file_path=model_file_path,
weight_file_path=weight_file_path,
docker_image_tag=docker_image_tag,
dockerfile_path=dockerfile_path,
platform=model_config[YAMLKeyword.platform],
device_type=device_type,
......
......@@ -641,6 +641,7 @@ def validate_model(abi,
device,
model_file_path,
weight_file_path,
docker_image_tag,
dockerfile_path,
platform,
device_type,
......@@ -684,8 +685,8 @@ def validate_model(abi,
validation_threshold, ",".join(input_data_types), backend,
log_file)
elif platform == "caffe":
image_name = "mace-caffe:latest"
container_name = "mace_caffe_validator"
image_name = "mace-caffe:" + docker_image_tag
container_name = "mace_caffe_" + docker_image_tag + "_validator"
if caffe_env == common.CaffeEnvType.LOCAL:
try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册