diff --git a/mace/ops/ops_registry.cc b/mace/ops/ops_registry.cc index cd958705a094794ce92d194c30d8a83da906a716..a281b8f415232e152436975c5e0d670f24a98447 100644 --- a/mace/ops/ops_registry.cc +++ b/mace/ops/ops_registry.cc @@ -45,6 +45,7 @@ extern void RegisterMatMul(OpRegistryBase *op_registry); extern void RegisterPad(OpRegistryBase *op_registry); extern void RegisterPooling(OpRegistryBase *op_registry); extern void RegisterReduce(OpRegistryBase *op_registry); +extern void RegisterPriorBox(OpRegistryBase *op_registry); extern void RegisterReshape(OpRegistryBase *op_registry); extern void RegisterResizeBicubic(OpRegistryBase *op_registry); extern void RegisterResizeBilinear(OpRegistryBase *op_registry); @@ -103,6 +104,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ops::RegisterPad(this); ops::RegisterPooling(this); ops::RegisterReduce(this); + ops::RegisterPriorBox(this); ops::RegisterReshape(this); ops::RegisterResizeBicubic(this); ops::RegisterResizeBilinear(this); diff --git a/mace/ops/prior_box.cc b/mace/ops/prior_box.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f909c58ac35d0a6b8de3f1b7a7a6e96dae90f7b --- /dev/null +++ b/mace/ops/prior_box.cc @@ -0,0 +1,161 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class PriorBoxOp : public Operation { + public: + explicit PriorBoxOp(OpConstructContext *context) + : Operation(context), + min_size_(Operation::GetRepeatedArgs("min_size")), + max_size_(Operation::GetRepeatedArgs("max_size")), + aspect_ratio_(Operation::GetRepeatedArgs("aspect_ratio")), + flip_(Operation::GetOptionalArg("flip", true)), + clip_(Operation::GetOptionalArg("clip", false)), + variance_(Operation::GetRepeatedArgs("variance")), + offset_(Operation::GetOptionalArg("offset", 0.5)) {} + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(INPUT); + const Tensor *data = this->Input(DATA); + Tensor *output = this->Output(OUTPUT); + const std::vector &input_shape = input->shape(); + const std::vector &data_shape = data->shape(); + const index_t input_w = input_shape[3]; + const index_t input_h = input_shape[2]; + const index_t image_w = data_shape[3]; + const index_t image_h = data_shape[2]; + float step_h = static_cast(image_h) / static_cast(input_h); + float step_w = static_cast(image_w) / static_cast(input_w); + if (Operation::GetOptionalArg("step_h", 0) != 0 && + Operation::GetOptionalArg("step_w", 0) != 0) { + step_h = Operation::GetOptionalArg("step_h", 0); + step_w = Operation::GetOptionalArg("step_w", 0); + } + + const index_t num_min_size = min_size_.size(); + MACE_CHECK(num_min_size > 0, "min_size is required!"); + const index_t num_max_size = max_size_.size(); + const index_t num_aspect_ratio = aspect_ratio_.size(); + MACE_CHECK(num_aspect_ratio > 0, "aspect_ratio is required!"); + + index_t num_prior = num_min_size * num_aspect_ratio + + num_min_size + num_max_size; + if (flip_) + num_prior += num_min_size * num_aspect_ratio; + + index_t dim = 4 * input_w * input_h * num_prior; + std::vector output_shape = {1, 2, dim}; + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + Tensor::MappingGuard output_guard(output); + T *output_data = output->mutable_data(); + float box_w, box_h; +#pragma omp parallel for collapse(2) schedule(runtime) + for (index_t i = 0; i < input_h; ++i) { + for (index_t j = 0; j < input_w; ++j) { + index_t idx = i * input_w * num_prior * 4; + float center_y = (offset_ + i) * step_h; + float center_x = (offset_ + j) * step_w; + for (index_t k = 0; k < num_min_size; ++k) { + float min_s = min_size_[k]; + box_w = box_h = min_s * 0.5; + output_data[idx + 0] = (center_x - box_w) / image_w; + output_data[idx + 1] = (center_y - box_h) / image_h; + output_data[idx + 2] = (center_x + box_w) / image_w; + output_data[idx + 3] = (center_y + box_h) / image_h; + idx += 4; + if (num_max_size > 0) { + float max_s_ = max_size_[k]; + box_w = box_h = sqrt(max_s_ * min_s) * 0.5f; + output_data[idx + 0] = (center_x - box_w) / image_w; + output_data[idx + 1] = (center_y - box_h) / image_h; + output_data[idx + 2] = (center_x + box_w) / image_w; + output_data[idx + 3] = (center_y + box_h) / image_h; + idx += 4; + } + for (int l = 0; l < num_aspect_ratio; ++l) { + float ar = aspect_ratio_[l]; + box_w = min_s * sqrt(ar) * 0.5f; + box_h = min_s / sqrt(ar) * 0.5f; + output_data[idx + 0] = (center_x - box_w) / image_w; + output_data[idx + 1] = (center_y - box_h) / image_h; + output_data[idx + 2] = (center_x + box_w) / image_w; + output_data[idx + 3] = (center_y + box_h) / image_h; + idx += 4; + if (flip_) { + output_data[idx + 0] = (center_x - box_h) / image_w; + output_data[idx + 1] = (center_y - box_w) / image_h; + output_data[idx + 2] = (center_x + box_h) / image_w; + output_data[idx + 3] = (center_y + box_w) / image_h; + idx += 4; + } + } + } + } + } + + if (clip_) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < dim; ++i) { + T min = 0; + T max = 1; + output_data[i] = std::min(std::max(output_data[i], min), max); + } + } + + output_data += dim; +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < dim / 4; ++i) { + int index = i * 4; + output_data[0 + index] = variance_[0]; + output_data[1 + index] = variance_[1]; + output_data[2 + index] = variance_[2]; + output_data[3 + index] = variance_[3]; + } + return MaceStatus::MACE_SUCCESS; + } + + private: + std::vector min_size_; + std::vector max_size_; + std::vector aspect_ratio_; + bool flip_; + bool clip_; + std::vector variance_; + const float offset_; + + private: + MACE_OP_INPUT_TAGS(INPUT, DATA); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +void RegisterPriorBox(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "PriorBox", PriorBoxOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace + diff --git a/mace/ops/prior_box_benchmark.cc b/mace/ops/prior_box_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..0d85d1c9595154cf190c45853c61a44d01f91d1c --- /dev/null +++ b/mace/ops/prior_box_benchmark.cc @@ -0,0 +1,77 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void PriorBox( + int iters, float min_size, float max_size, float aspect_ratio, int flip, + int clip, float variance0, float variance1, float offset, int h) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("INPUT", {1, h, 1, 1}); + net.AddRandomInput("DATA", {1, 3, 300, 300}); + + OpDefBuilder("PriorBox", "PriorBoxBM") + .Input("INPUT") + .Input("DATA") + .Output("OUTPUT") + .AddFloatsArg("min_size", {min_size}) + .AddFloatsArg("max_size", {max_size}) + .AddFloatsArg("aspect_ratio", {aspect_ratio}) + .AddIntArg("flip", flip) + .AddIntArg("clip", clip) + .AddFloatsArg("variance", {variance0, variance0, variance1, variance1}) + .AddFloatArg("offset", offset) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + const int64_t tot = static_cast(iters) * (300 * 300 * 3 + h); + mace::testing::MaccProcessed(tot); + testing::BytesProcessed(tot * sizeof(T)); + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } +} +} // namespace + +#define MACE_BM_PRIOR_BOX(MIN, MAX, AR, FLIP, CLIP, V0, V1, OFFSET, H) \ + static void MACE_BM_PRIOR_BOX_##MIN##_##MAX##_##AR##_##FLIP##_##CLIP##_##V0##\ + _##V1##_##OFFSET##_##H(int iters) { \ + PriorBox(iters, MIN, MAX, AR, FLIP, CLIP, V0, V1, \ + OFFSET, H); \ + } \ + MACE_BENCHMARK(MACE_BM_PRIOR_BOX_##MIN##_##MAX##_##AR##_##FLIP##_##CLIP##_## \ + V0##_##V1##_##OFFSET##_##H) + +MACE_BM_PRIOR_BOX(285, 300, 2, 1, 0, 1, 2, 1, 128); + +} // namespace test +} // namespace ops +} // namespace mace + diff --git a/mace/ops/prior_box_test.cc b/mace/ops/prior_box_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9a65f24b7a97c7a7a439a4d92233a886c4ea613 --- /dev/null +++ b/mace/ops/prior_box_test.cc @@ -0,0 +1,64 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class PriorBoxOpTest : public OpsTestBase {}; + +TEST_F(PriorBoxOpTest, Simple) { + OpsTestNet net; + // Add input data + net.AddRandomInput("INPUT", {1, 128, 1, 1}); + net.AddRandomInput("DATA", {1, 3, 300, 300}); + OpDefBuilder("PriorBox", "PriorBoxTest") + .Input("INPUT") + .Input("DATA") + .Output("OUTPUT") + .AddFloatsArg("min_size", {285}) + .AddFloatsArg("max_size", {300}) + .AddFloatsArg("aspect_ratio", {2, 3}) + .AddIntArg("flip", 1) + .AddIntArg("clip", 0) + .AddFloatsArg("variance", {0.1, 0.1, 0.2, 0.2}) + .AddFloatArg("offset", 0.5) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(DeviceType::CPU); + + // Check + auto expected_tensor = net.CreateTensor({1, 2, 24}, + {0.025, 0.025, 0.975, 0.975, + 0.012660282759551838, 0.012660282759551838, + 0.9873397172404482, 0.9873397172404482, + -0.17175144212722018, 0.16412427893638995, + 1.1717514421272204, 0.8358757210636101, + 0.16412427893638995, -0.17175144212722018, + 0.8358757210636101, 1.1717514421272204, + -0.3227241335952166, 0.22575862213492773, + 1.3227241335952165, 0.7742413778650723, + 0.22575862213492773, -0.3227241335952166, + 0.7742413778650723, 1.3227241335952165, + 0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2, + 0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2}); + ExpectTensorNear(*expected_tensor, *net.GetTensor("OUTPUT")); +} +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index 30c7ce890290139e22d319807ce19eb30afc928f..330d3fe1366d0a7cec6b91851551d030641cbee9 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -28,6 +28,11 @@ class ReshapeOp : public Operation { MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); const Tensor *input = this->Input(INPUT); + const std::vector &input_shape = input->shape(); + int axis = Operation::GetOptionalArg("reshape_axis", 0); + int num_axes = Operation::GetOptionalArg("num_axes", -1); + MACE_CHECK(axis == 0 && num_axes == -1, + "Only support axis = 0 and num_axes = -1"); const Tensor *shape = this->Input(SHAPE); const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0); Tensor::MappingGuard shape_guard(shape); @@ -43,8 +48,12 @@ class ReshapeOp : public Operation { MACE_CHECK(unknown_idx == -1, "Only one input size may be -1"); unknown_idx = i; out_shape.push_back(1); + } else if (shape_data[i] == 0) { + MACE_CHECK(shape_data[i] == 0, "Shape should be 0"); + out_shape.push_back(input_shape[i]); + product *= input_shape[i]; } else { - MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ", + MACE_CHECK(shape_data[i] > 0, "Shape must be non-negative: ", shape_data[i]); if (shape_data[i] == 0) { MACE_CHECK(i < input->dim_size(), diff --git a/mace/ops/softmax.cc b/mace/ops/softmax.cc index c4bef3d972790d9e105eb2eb9d3b20d5a50b42ef..5abb524d6e868eae520f72c299212a5d01cd3afa 100644 --- a/mace/ops/softmax.cc +++ b/mace/ops/softmax.cc @@ -92,9 +92,17 @@ class SoftmaxOp : public Operation { } } // k } // b - } else if (input->dim_size() == 2) { // normal 2d softmax - const index_t class_size = input->dim(0); - const index_t class_count = input->dim(1); + } else if (input->dim_size() == 2 || input->dim_size() == 3) { + // normal 2d softmax and 3d softmax (dim(0) is batch) + index_t class_size = 0; + index_t class_count = 0; + if (input->dim_size() == 2) { + class_size = input->dim(0); + class_count = input->dim(1); + } else { + class_size = input->dim(0) * input->dim(1); + class_count = input->dim(2); + } #pragma omp parallel for schedule(runtime) for (index_t k = 0; k < class_size; ++k) { const float *input_ptr = input_data + k * class_count; diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index fa748ed474ac11ae2ed2476040ebad513d58e167..aeb626a681a579a120008c81082bed0e505f582e 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -123,6 +123,7 @@ MaceSupportedOps = [ 'MatMul', 'Pad', 'Pooling', + 'PriorBox', 'Proposal', 'Quantize', 'Reduce', @@ -174,8 +175,11 @@ class MaceKeyword(object): mace_space_batch_block_shape_str = 'block_shape' mace_space_depth_block_size_str = 'block_size' mace_constant_value_str = 'constant_value' + mace_dim_str = 'dim' mace_dims_str = 'dims' mace_axis_str = 'axis' + mace_end_axis_str = 'end_axis' + mace_num_axes_str = 'num_axes' mace_num_split_str = 'num_split' mace_keepdims_str = 'keepdims' mace_shape_str = 'shape' @@ -205,6 +209,14 @@ class MaceKeyword(object): mace_reduce_type_str = 'reduce_type' mace_argmin_str = 'argmin' mace_round_mode_str = 'round_mode' + mace_min_size_str = 'min_size' + mace_max_size_str = 'max_size' + mace_aspect_ratio_str = 'aspect_ratio' + mace_flip_str = 'flip' + mace_clip_str = 'clip' + mace_variance_str = 'variance' + mace_step_h_str = 'step_h' + mace_step_w_str = 'step_w' class TransformerRule(Enum): @@ -243,6 +255,7 @@ class TransformerRule(Enum): FOLD_SQRDIFF_MEAN = 33 TRANSPOSE_MATMUL_WEIGHT = 34 FOLD_EMBEDDING_LOOKUP = 35 + TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN = 36 class ConverterInterface(object): @@ -433,6 +446,7 @@ class ConverterOption(object): TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE, TransformerRule.TRANSFORM_BASIC_LSTMCELL, + TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN, TransformerRule.FOLD_RESHAPE, TransformerRule.TRANSFORM_MATMUL_TO_FC, TransformerRule.FOLD_BATCHNORM, diff --git a/mace/python/tools/converter_tool/caffe_converter.py b/mace/python/tools/converter_tool/caffe_converter.py index 56e8b645089758a23d7d70017e9032b3a6273432..c1fea3141603b6a3568492cffba87ca058035b61 100644 --- a/mace/python/tools/converter_tool/caffe_converter.py +++ b/mace/python/tools/converter_tool/caffe_converter.py @@ -188,6 +188,10 @@ class CaffeConverter(base_converter.ConverterInterface): 'Crop': self.convert_crop, 'Scale': self.convert_scale, 'ShuffleChannel': self.convert_channel_shuffle, + 'Permute': self.convert_permute, + 'Flatten': self.convert_flatten, + 'PriorBox': self.convert_prior_box, + 'Reshape': self.convert_reshape, } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -565,8 +569,6 @@ class CaffeConverter(base_converter.ConverterInterface): axis_arg.i = param.axis elif param.HasField('concat_dim'): axis_arg.i = param.concat_dim - axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i - mace_check(axis_arg.i == 1, "only support concat at channel dimension") def convert_slice(self, caffe_op): op = self.convert_general_op(caffe_op) @@ -668,3 +670,107 @@ class CaffeConverter(base_converter.ConverterInterface): group_arg.i = 1 if param.HasField('group'): group_arg.i = param.group + + def convert_permute(self, caffe_op): + op = self.convert_general_op(caffe_op) + param = caffe_op.layer.permute_param + op.type = MaceOp.Transpose.name + + dims_arg = op.arg.add() + dims_arg.name = MaceKeyword.mace_dims_str + dims_arg.ints.extend(list(param.order)) + + def convert_flatten(self, caffe_op): + op = self.convert_general_op(caffe_op) + param = caffe_op.layer.flatten_param + op.type = MaceOp.Reshape.name + + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.i = 1 + if param.HasField('axis'): + axis_arg.i = param.axis + axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i + + end_axis_arg = op.arg.add() + end_axis_arg.name = MaceKeyword.mace_end_axis_str + end_axis_arg.i = -1 + if param.HasField('end_axis'): + end_axis_arg.i = param.end_axis + + def convert_prior_box(self, caffe_op): + op = self.convert_general_op(caffe_op) + param = caffe_op.layer.prior_box_param + op.type = MaceOp.PriorBox.name + + min_size_arg = op.arg.add() + min_size_arg.name = MaceKeyword.mace_min_size_str + min_size_arg.floats.extend(list(param.min_size)) + max_size_arg = op.arg.add() + max_size_arg.name = MaceKeyword.mace_max_size_str + max_size_arg.floats.extend(list(param.max_size)) + aspect_ratio_arg = op.arg.add() + aspect_ratio_arg.name = MaceKeyword.mace_aspect_ratio_str + aspect_ratio_arg.floats.extend(list(param.aspect_ratio)) + flip_arg = op.arg.add() + flip_arg.name = MaceKeyword.mace_flip_str + flip_arg.i = 1 + if param.HasField('flip'): + flip_arg.i = int(param.flip) + clip_arg = op.arg.add() + clip_arg.name = MaceKeyword.mace_clip_str + clip_arg.i = 0 + if param.HasField('clip'): + clip_arg.i = int(param.clip) + variance_arg = op.arg.add() + variance_arg.name = MaceKeyword.mace_variance_str + variance_arg.floats.extend(list(param.variance)) + offset_arg = op.arg.add() + offset_arg.name = MaceKeyword.mace_offset_str + offset_arg.f = 0.5 + if param.HasField('offset'): + offset_arg.f = param.offset + step_h_arg = op.arg.add() + step_h_arg.name = MaceKeyword.mace_step_h_str + step_h_arg.f = 0 + if param.HasField('step_h'): + mace_check(not param.HasField('step'), + "Either step or step_h/step_w should be specified; not both.") # noqa + step_h_arg.f = param.step_h + mace_check(step_h_arg.f > 0, "step_h should be larger than 0.") + step_w_arg = op.arg.add() + step_w_arg.name = MaceKeyword.mace_step_w_str + step_w_arg.f = 0 + if param.HasField('step_w'): + mace_check(not param.HasField('step'), + "Either step or step_h/step_w should be specified; not both.") # noqa + step_w_arg.f = param.step_w + mace_check(step_w_arg.f > 0, "step_w should be larger than 0.") + + if param.HasField('step'): + mace_check(not param.HasField('step_h') and not param.HasField('step_w'), # noqa + "Either step or step_h/step_w should be specified; not both.") # noqa + mace_check(param.step > 0, "step should be larger than 0.") + step_h_arg.f = param.step + step_w_arg.f = param.step + + def convert_reshape(self, caffe_op): + op = self.convert_general_op(caffe_op) + param = caffe_op.layer.reshape_param + op.type = MaceOp.Reshape.name + + dim_arg = op.arg.add() + dim_arg.name = MaceKeyword.mace_dim_str + dim_arg.ints.extend(list(param.shape.dim)) + + axis_arg = op.arg.add() + axis_arg.name = 'reshape_' + MaceKeyword.mace_axis_str + axis_arg.i = 0 + if param.HasField('axis'): + axis_arg.i = param.axis + + num_axes_arg = op.arg.add() + num_axes_arg.name = MaceKeyword.mace_num_axes_str + num_axes_arg.i = -1 + if param.HasField('num_axes'): + num_axes_arg.i = param.num_axes diff --git a/mace/python/tools/converter_tool/shape_inference.py b/mace/python/tools/converter_tool/shape_inference.py index aeb19022badc324855d82820a61d5b9a2a7f1cb1..fe269078fa76fd07f7834adf49e5dfd296545cf6 100644 --- a/mace/python/tools/converter_tool/shape_inference.py +++ b/mace/python/tools/converter_tool/shape_inference.py @@ -49,6 +49,9 @@ class ShapeInference(object): MaceOp.Crop.name: self.infer_shape_crop, MaceOp.BiasAdd.name: self.infer_shape_general, MaceOp.ChannelShuffle.name: self.infer_shape_channel_shuffle, + MaceOp.Transpose.name: self.infer_shape_permute, + MaceOp.PriorBox.name: self.infer_shape_prior_box, + MaceOp.Reshape.name: self.infer_shape_reshape, } self._net = net @@ -190,12 +193,14 @@ class ShapeInference(object): self.add_output_shape(op, [output_shape]) def infer_shape_concat(self, op): - output_shape = self._output_shape_cache[op.input[0]] + output_shape = list(self._output_shape_cache[op.input[0]]) axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i + if axis < 0: + axis = len(output_shape) + axis + output_shape[axis] = 0 for input_node in op.input: - input_shape = self._output_shape_cache[input_node] - output_shape[axis] += input_shape[axis] - + input_shape = list(self._output_shape_cache[input_node]) + output_shape[axis] = output_shape[axis] + input_shape[axis] self.add_output_shape(op, [output_shape]) def infer_shape_slice(self, op): @@ -225,3 +230,64 @@ class ShapeInference(object): def infer_shape_channel_shuffle(self, op): output_shape = self._output_shape_cache[op.input[0]] self.add_output_shape(op, [output_shape]) + + def infer_shape_permute(self, op): + output_shape = list(self._output_shape_cache[op.input[0]]) + dims = ConverterUtil.get_arg(op, MaceKeyword.mace_dims_str).ints + for i in xrange(len(dims)): + output_shape[i] = self._output_shape_cache[op.input[0]][dims[i]] + self.add_output_shape(op, [output_shape]) + + def infer_shape_prior_box(self, op): + output_shape = [1, 2, 1] + input_shape = list(self._output_shape_cache[op.input[0]]) + input_w = input_shape[3] + input_h = input_shape[2] + min_size = ConverterUtil.get_arg(op, MaceKeyword.mace_min_size_str).floats # noqa + max_size = ConverterUtil.get_arg(op, MaceKeyword.mace_max_size_str).floats # noqa + aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa + flip = ConverterUtil.get_arg(op, MaceKeyword.mace_flip_str).i # noqa + num_prior = (len(min_size) * len(aspect_ratio) + + len(min_size) + len(max_size)) + if flip: + num_prior = num_prior + len(min_size) * len(aspect_ratio) + output_shape[2] = num_prior * input_h * input_w * 4 + self.add_output_shape(op, [output_shape]) + + def infer_shape_reshape(self, op): + if ConverterUtil.get_arg(op, MaceKeyword.mace_dim_str) is not None: + dim = ConverterUtil.get_arg(op, MaceKeyword.mace_dim_str).ints + output_shape = list(dim) + product = input_size = 1 + idx = -1 + for i in range(len(self._output_shape_cache[op.input[0]])): + input_size *= self._output_shape_cache[op.input[0]][i] + for i in range(len(dim)): + if dim[i] == 0: + output_shape[i] = self._output_shape_cache[op.input[0]][i] + product *= self._output_shape_cache[op.input[0]][i] + elif dim[i] == -1: + idx = i + output_shape[i] = 1 + else: + output_shape[i] = dim[i] + product *= dim[i] + if idx != -1: + output_shape[idx] = input_size / product + self.add_output_shape(op, [output_shape]) + else: + output_shape = list(self._output_shape_cache[op.input[0]]) + axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i + end_axis = ConverterUtil.get_arg(op, MaceKeyword.mace_end_axis_str).i # noqa + if end_axis < 0: + end_axis = len(output_shape) + end_axis + dim = 1 + for i in range(0, axis): + output_shape[i] = self._output_shape_cache[op.input[0]][i] + for i in range(axis, end_axis + 1): + dim *= self._output_shape_cache[op.input[0]][i] + output_shape[i] = 1 + for i in range(end_axis + 1, len(output_shape)): + output_shape[i] = self._output_shape_cache[op.input[0]][i] + output_shape[axis] = dim + self.add_output_shape(op, [output_shape]) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 1e288b05f4a9875fb4f319a284051a4dd08dabba..5e564fa41a36529294c39837d41acf0a0ef3e653 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -95,6 +95,8 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution, TransformerRule.CHECK_QUANTIZE_INFO: self.check_quantize_info, + TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN: + self.transform_caffe_reshape_and_flatten, } self._option = option @@ -979,8 +981,9 @@ class Transformer(base_converter.ConverterInterface): elif op.type == MaceOp.Concat.name or op.type == MaceOp.Split.name: for arg in op.arg: if arg.name == MaceKeyword.mace_axis_str: - if ConverterUtil.data_format(op) == DataFormat.NCHW \ - and self._target_data_format == DataFormat.NHWC: # noqa + if (ConverterUtil.data_format(op) == DataFormat.NCHW + and self._target_data_format == DataFormat.NHWC + and len(op.output_shape[0].dims) == 4): print("Transpose concat/split args: %s(%s)" % (op.name, op.type)) if arg.i == 1: @@ -1231,6 +1234,7 @@ class Transformer(base_converter.ConverterInterface): # transform input(4D) -> reshape(2D) -> matmul to fc # work for TensorFlow if op.type == MaceOp.Reshape.name and \ + len(op.input) == 2 and \ op.input[1] in self._consts and \ len(op.output_shape[0].dims) == 2 and \ filter_format == FilterFormat.HWIO and \ @@ -1714,3 +1718,35 @@ class Transformer(base_converter.ConverterInterface): output_info.name = check_node.name output_info.dims.extend(check_node.output_shape[0].dims) output_info.data_type = mace_pb2.DT_FLOAT + + def transform_caffe_reshape_and_flatten(self): + net = self._model + for op in net.op: + if op.type == MaceOp.Reshape.name and \ + len(op.input) == 1: + print("Transform Caffe Reshape") + if op.arg[3].name == 'dim': + shape_tensor = net.tensors.add() + shape_tensor.name = op.name + '_shape' + shape_tensor.dims.append(len(op.output_shape[0].dims)) + shape_tensor.data_type = mace_pb2.DT_INT32 + shape_tensor.int32_data.extend(op.arg[3].ints) + op.input.append(shape_tensor.name) + else: + axis = op.arg[3].i + dims = [1] * len(op.output_shape[0].dims) + end_axis = op.arg[4].i + end_axis = end_axis if end_axis >= 0 else end_axis + len(dims) # noqa + for i in range(0, axis): + dims[i] = 0 + for i in range(axis + 1, end_axis + 1): + dims[i] = 1 + for i in range(end_axis + 1, len(dims)): + dims[i] = 0 + dims[axis] = -1 + shape_tensor = net.tensors.add() + shape_tensor.name = op.name + '_shape' + shape_tensor.dims.append(len(dims)) + shape_tensor.data_type = mace_pb2.DT_INT32 + shape_tensor.int32_data.extend(dims) + op.input.append(shape_tensor.name) diff --git a/mace/utils/detection_output.cc b/mace/utils/detection_output.cc new file mode 100644 index 0000000000000000000000000000000000000000..a268f3504796b71e9cfe884e97fb83a665d18698 --- /dev/null +++ b/mace/utils/detection_output.cc @@ -0,0 +1,159 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "mace/utils/logging.h" + +namespace mace { + +struct BBox { + float xmin; + float ymin; + float xmax; + float ymax; + int label; + float confidence; +}; + +namespace { +inline float overlap(const BBox &a, const BBox &b) { + if (a.xmin > b.xmax || a.xmax < b.xmin || + a.ymin > b.ymax || a.ymax < b.ymin) { + return 0.f; + } + float overlap_w = std::min(a.xmax, b.xmax) - std::max(a.xmin, b.xmin); + float overlap_h = std::min(a.ymax, b.ymax) - std::max(a.ymin, b.ymin); + return overlap_w * overlap_h; +} + +void NmsSortedBboxes(const std::vector &bboxes, + const float nms_threshold, + const int top_k, + std::vector *sorted_boxes) { + const int n = std::min(top_k, static_cast(bboxes.size())); + std::vector picked; + + std::vector areas(n); +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < n; ++i) { + const BBox &r = bboxes[i]; + float width = std::max(0.f, r.xmax - r.xmin); + float height = std::max(0.f, r.ymax - r.ymin); + areas[i] = width * height; + } + + for (int i = 0; i < n; ++i) { + const BBox &a = bboxes[i]; + int keep = 1; + for (size_t j = 0; j < picked.size(); ++j) { + const BBox &b = bboxes[picked[j]]; + + float inter_area = overlap(a, b); + float union_area = areas[i] + areas[picked[j]] - inter_area; + MACE_CHECK(union_area > 0, "union_area should be greater than 0"); + if (inter_area / union_area > nms_threshold) { + keep = 0; + break; + } + } + + if (keep) { + picked.push_back(i); + sorted_boxes->push_back(bboxes[i]); + } + } +} + +inline bool cmp(const BBox &a, const BBox &b) { + return a.confidence > b.confidence; +} +} // namespace + +int DetectionOutput(const float *loc_ptr, + const float *conf_ptr, + const float *pbox_ptr, + const int num_prior, + const int num_classes, + const float nms_threshold, + const int top_k, + const int keep_top_k, + const float confidence_threshold, + std::vector *bbox_rects) { + MACE_CHECK(keep_top_k > 0, "keep_top_k should be greater than 0"); + std::vector bboxes(4 * num_prior); +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < num_prior; ++i) { + int index = i * 4; + const float *lc = loc_ptr + index; + const float *pb = pbox_ptr + index; + const float *var = pb + num_prior * 4; + + float pb_w = pb[2] - pb[0]; + float pb_h = pb[3] - pb[1]; + float pb_cx = (pb[0] + pb[2]) * 0.5f; + float pb_cy = (pb[1] + pb[3]) * 0.5f; + + float bbox_cx = var[0] * lc[0] * pb_w + pb_cx; + float bbox_cy = var[1] * lc[1] * pb_h + pb_cy; + float bbox_w = std::exp(var[2] * lc[2]) * pb_w; + float bbox_h = std::exp(var[3] * lc[3]) * pb_h; + + bboxes[0 + index] = bbox_cx - bbox_w * 0.5f; + bboxes[1 + index] = bbox_cy - bbox_h * 0.5f; + bboxes[2 + index] = bbox_cx + bbox_w * 0.5f; + bboxes[3 + index] = bbox_cy + bbox_h * 0.5f; + } + // start from 1 to ignore background class + + for (int i = 1; i < num_classes; ++i) { + // filter by confidence threshold + std::vector class_bbox_rects; + for (int j = 0; j < num_prior; ++j) { + float confidence = conf_ptr[j * num_classes + i]; + if (confidence > confidence_threshold) { + BBox c = {bboxes[0 + j * 4], bboxes[1 + j * 4], bboxes[2 + j * 4], + bboxes[3 + j * 4], i, confidence}; + class_bbox_rects.push_back(c); + } + } + std::sort(class_bbox_rects.begin(), class_bbox_rects.end(), cmp); + + // apply nms + std::vector sorted_boxes; + NmsSortedBboxes(class_bbox_rects, + nms_threshold, + std::min(top_k, + static_cast(class_bbox_rects.size())), + &sorted_boxes); + // gather + bbox_rects->insert(bbox_rects->end(), sorted_boxes.begin(), + sorted_boxes.end()); + } + + std::sort(bbox_rects->begin(), bbox_rects->end(), cmp); + + // output + int num_detected = keep_top_k < static_cast(bbox_rects->size()) ? + keep_top_k : static_cast(bbox_rects->size()); + bbox_rects->resize(num_detected); + + return num_detected; +} +} // namespace mace diff --git a/third_party/caffe/caffe.proto b/third_party/caffe/caffe.proto index 54ccf20ca2378f7e15881930333f0014c1923b63..b2d56b9898fbcfd0bbd31d7d1356aea12ce87445 100644 --- a/third_party/caffe/caffe.proto +++ b/third_party/caffe/caffe.proto @@ -40,6 +40,120 @@ message Datum { optional bool encoded = 7 [default = false]; } +// The label (display) name and label id. +message LabelMapItem { + // Both name and label are required. + optional string name = 1; + optional int32 label = 2; + // display_name is optional. + optional string display_name = 3; +} + +message LabelMap { + repeated LabelMapItem item = 1; +} + +// Sample a bbox in the normalized space [0, 1] with provided constraints. +message Sampler { + // Minimum scale of the sampled bbox. + optional float min_scale = 1 [default = 1.]; + // Maximum scale of the sampled bbox. + optional float max_scale = 2 [default = 1.]; + + // Minimum aspect ratio of the sampled bbox. + optional float min_aspect_ratio = 3 [default = 1.]; + // Maximum aspect ratio of the sampled bbox. + optional float max_aspect_ratio = 4 [default = 1.]; +} + +// Constraints for selecting sampled bbox. +message SampleConstraint { + // Minimum Jaccard overlap between sampled bbox and all bboxes in + // AnnotationGroup. + optional float min_jaccard_overlap = 1; + // Maximum Jaccard overlap between sampled bbox and all bboxes in + // AnnotationGroup. + optional float max_jaccard_overlap = 2; + + // Minimum coverage of sampled bbox by all bboxes in AnnotationGroup. + optional float min_sample_coverage = 3; + // Maximum coverage of sampled bbox by all bboxes in AnnotationGroup. + optional float max_sample_coverage = 4; + + // Minimum coverage of all bboxes in AnnotationGroup by sampled bbox. + optional float min_object_coverage = 5; + // Maximum coverage of all bboxes in AnnotationGroup by sampled bbox. + optional float max_object_coverage = 6; +} + +// Sample a batch of bboxes with provided constraints. +message BatchSampler { + // Use original image as the source for sampling. + optional bool use_original_image = 1 [default = true]; + + // Constraints for sampling bbox. + optional Sampler sampler = 2; + + // Constraints for determining if a sampled bbox is positive or negative. + optional SampleConstraint sample_constraint = 3; + + // If provided, break when found certain number of samples satisfing the + // sample_constraint. + optional uint32 max_sample = 4; + + // Maximum number of trials for sampling to avoid infinite loop. + optional uint32 max_trials = 5 [default = 100]; +} + +// Condition for emitting annotations. +message EmitConstraint { + enum EmitType { + CENTER = 0; + MIN_OVERLAP = 1; + } + optional EmitType emit_type = 1 [default = CENTER]; + // If emit_type is MIN_OVERLAP, provide the emit_overlap. + optional float emit_overlap = 2; +} + +// The normalized bounding box [0, 1] w.r.t. the input image size. +message NormalizedBBox { + optional float xmin = 1; + optional float ymin = 2; + optional float xmax = 3; + optional float ymax = 4; + optional int32 label = 5; + optional bool difficult = 6; + optional float score = 7; + optional float size = 8; +} + +// Annotation for each object instance. +message Annotation { + optional int32 instance_id = 1 [default = 0]; + optional NormalizedBBox bbox = 2; +} + +// Group of annotations for a particular label. +message AnnotationGroup { + optional int32 group_label = 1; + repeated Annotation annotation = 2; +} + +// An extension of Datum which contains "rich" annotations. +message AnnotatedDatum { + enum AnnotationType { + BBOX = 0; + } + optional Datum datum = 1; + // If there are "rich" annotations, specify the type of annotation. + // Currently it only supports bounding box. + // If there are no "rich" annotations, use label in datum instead. + optional AnnotationType type = 2; + // Each group contains annotation for a particular class. + repeated AnnotationGroup annotation_group = 3; +} + message FillerParameter { // The filler type. optional string type = 1 [default = 'constant']; @@ -98,7 +212,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 43 (last added: weights) +// SolverParameter next available ID: 44 (last added: plateau_winsize) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -128,12 +242,24 @@ message SolverParameter { // The states for the train/test nets. Must be unspecified or // specified once per net. // - // By default, train_state will have phase = TRAIN, + // By default, all states will have solver = true; + // train_state will have phase = TRAIN, // and all test_state's will have phase = TEST. // Other defaults are set according to the NetState defaults. optional NetState train_state = 26; repeated NetState test_state = 27; + // Evaluation type. + optional string eval_type = 41 [default = "classification"]; + // ap_version: different ways of computing Average Precision. + // Check https://sanchom.wordpress.com/tag/average-precision/ for details. + // 11point: the 11-point interpolated average precision. Used in VOC2007. + // MaxIntegral: maximally interpolated AP. Used in VOC2012/ILSVRC. + // Integral: the natural integral of the precision-recall curve. + optional string ap_version = 42 [default = "Integral"]; + // If true, display per class result. + optional bool show_per_class_result = 44 [default = false]; + // The number of iterations for each test net. repeated int32 test_iter = 3; @@ -165,6 +291,8 @@ message SolverParameter { // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) // - sigmoid: the effective learning rate follows a sigmod decay // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) + // - plateau: decreases lr + // if the minimum loss isn't updated for 'plateau_winsize' iters // // where base_lr, max_iter, gamma, step, stepvalue and power are defined // in the solver parameter protocol buffer, and iter is the current iteration. @@ -180,17 +308,15 @@ message SolverParameter { optional int32 stepsize = 13; // the stepsize for learning rate policy "multistep" repeated int32 stepvalue = 34; + // the stepsize for learning rate policy "plateau" + repeated int32 plateau_winsize = 43; // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, // whenever their actual L2 norm is larger. optional float clip_gradients = 35 [default = -1]; optional int32 snapshot = 14 [default = 0]; // The snapshot interval - // The prefix for the snapshot. - // If not set then is replaced by prototxt file path without extention. - // If is set to directory then is augmented by prototxt file name - // without extention. - optional string snapshot_prefix = 15; + optional string snapshot_prefix = 15; // The prefix for the snapshot. // whether to snapshot diff in the results or not. Snapshotting diff will help // debugging but the final protocol buffer size will be much larger. optional bool snapshot_diff = 16 [default = false]; @@ -242,19 +368,6 @@ message SolverParameter { } // DEPRECATED: use type instead of solver_type optional SolverType solver_type = 30 [default = SGD]; - - // Overlap compute and communication for data parallel training - optional bool layer_wise_reduce = 41 [default = true]; - - // Path to caffemodel file(s) with pretrained weights to initialize finetuning. - // Tha same as command line --weights parameter for caffe train command. - // If command line --weights parameter if specified, it has higher priority - // and owerwrites this one(s). - // If --snapshot command line parameter is specified, this one(s) are ignored. - // If several model files are expected, they can be listed in a one - // weights parameter separated by ',' (like in a command string) or - // in repeated weights parameters separately. - repeated string weights = 42; } // A message that stores the solver snapshots @@ -263,6 +376,8 @@ message SolverState { optional string learned_net = 2; // The file that stores the learned net. repeated BlobProto history = 3; // The history for sgd solvers optional int32 current_step = 4 [default = 0]; // The current step for learning rate + optional float minimum_loss = 5 [default = 1E38]; // Historical minimum loss + optional int32 iter_last_event = 6 [default = 0]; // The iteration when last lr-update or min_loss-update happend } enum Phase { @@ -375,6 +490,7 @@ message LayerParameter { // engine parameter for selecting the implementation. // The default for the engine is set by the ENGINE switch at compile-time. optional AccuracyParameter accuracy_param = 102; + optional AnnotatedDataParameter annotated_data_param = 200; optional ArgMaxParameter argmax_param = 103; optional BatchNormParameter batch_norm_param = 139; optional BiasParameter bias_param = 141; @@ -383,6 +499,8 @@ message LayerParameter { optional ConvolutionParameter convolution_param = 106; optional CropParameter crop_param = 144; optional DataParameter data_param = 107; + optional DetectionEvaluateParameter detection_evaluate_param = 205; + optional DetectionOutputParameter detection_output_param = 204; optional DropoutParameter dropout_param = 108; optional DummyDataParameter dummy_data_param = 109; optional EltwiseParameter eltwise_param = 110; @@ -400,13 +518,15 @@ message LayerParameter { optional LogParameter log_param = 134; optional LRNParameter lrn_param = 118; optional MemoryDataParameter memory_data_param = 119; + optional MultiBoxLossParameter multibox_loss_param = 201; optional MVNParameter mvn_param = 120; + optional NormalizeParameter norm_param = 206; optional ParameterParameter parameter_param = 145; + optional PermuteParameter permute_param = 202; optional PoolingParameter pooling_param = 121; optional PowerParameter power_param = 122; - optional ProposalParameter proposal_param = 8266713; optional PReLUParameter prelu_param = 131; - optional PSROIAlignParameter psroi_align_param = 1490; + optional PriorBoxParameter prior_box_param = 203; optional PythonParameter python_param = 130; optional RecurrentParameter recurrent_param = 146; optional ReductionParameter reduction_param = 136; @@ -420,6 +540,7 @@ message LayerParameter { optional TanHParameter tanh_param = 127; optional ThresholdParameter threshold_param = 128; optional TileParameter tile_param = 138; + optional VideoDataParameter video_data_param = 207; optional WindowDataParameter window_data_param = 129; optional ShuffleChannelParameter shuffle_channel_param = 164; } @@ -435,9 +556,12 @@ message TransformationParameter { optional bool mirror = 2 [default = false]; // Specify if we would like to randomly crop an image. optional uint32 crop_size = 3 [default = 0]; + optional uint32 crop_h = 11 [default = 0]; + optional uint32 crop_w = 12 [default = 0]; + // mean_file and mean_value cannot be specified at the same time optional string mean_file = 4; - // if specified can be repeated once (would subtract it from all the channels) + // if specified can be repeated once (would substract it from all the channels) // or can be repeated the same number of times as channels // (would subtract them from the corresponding channel) repeated float mean_value = 5; @@ -445,6 +569,141 @@ message TransformationParameter { optional bool force_color = 6 [default = false]; // Force the decoded image to have 1 color channels. optional bool force_gray = 7 [default = false]; + // Resize policy + optional ResizeParameter resize_param = 8; + // Noise policy + optional NoiseParameter noise_param = 9; + // Distortion policy + optional DistortionParameter distort_param = 13; + // Expand policy + optional ExpansionParameter expand_param = 14; + // Constraint for emitting the annotation after transformation. + optional EmitConstraint emit_constraint = 10; +} + +// Message that stores parameters used by data transformer for resize policy +message ResizeParameter { + //Probability of using this resize policy + optional float prob = 1 [default = 1]; + + enum Resize_mode { + WARP = 1; + FIT_SMALL_SIZE = 2; + FIT_LARGE_SIZE_AND_PAD = 3; + } + optional Resize_mode resize_mode = 2 [default = WARP]; + optional uint32 height = 3 [default = 0]; + optional uint32 width = 4 [default = 0]; + // A parameter used to update bbox in FIT_SMALL_SIZE mode. + optional uint32 height_scale = 8 [default = 0]; + optional uint32 width_scale = 9 [default = 0]; + + enum Pad_mode { + CONSTANT = 1; + MIRRORED = 2; + REPEAT_NEAREST = 3; + } + // Padding mode for BE_SMALL_SIZE_AND_PAD mode and object centering + optional Pad_mode pad_mode = 5 [default = CONSTANT]; + // if specified can be repeated once (would fill all the channels) + // or can be repeated the same number of times as channels + // (would use it them to the corresponding channel) + repeated float pad_value = 6; + + enum Interp_mode { //Same as in OpenCV + LINEAR = 1; + AREA = 2; + NEAREST = 3; + CUBIC = 4; + LANCZOS4 = 5; + } + //interpolation for for resizing + repeated Interp_mode interp_mode = 7; +} + +message SaltPepperParameter { + //Percentage of pixels + optional float fraction = 1 [default = 0]; + repeated float value = 2; +} + +// Message that stores parameters used by data transformer for transformation +// policy +message NoiseParameter { + //Probability of using this resize policy + optional float prob = 1 [default = 0]; + // Histogram equalized + optional bool hist_eq = 2 [default = false]; + // Color inversion + optional bool inverse = 3 [default = false]; + // Grayscale + optional bool decolorize = 4 [default = false]; + // Gaussian blur + optional bool gauss_blur = 5 [default = false]; + + // JPEG compression quality (-1 = no compression) + optional float jpeg = 6 [default = -1]; + + // Posterization + optional bool posterize = 7 [default = false]; + + // Erosion + optional bool erode = 8 [default = false]; + + // Salt-and-pepper noise + optional bool saltpepper = 9 [default = false]; + + optional SaltPepperParameter saltpepper_param = 10; + + // Local histogram equalization + optional bool clahe = 11 [default = false]; + + // Color space conversion + optional bool convert_to_hsv = 12 [default = false]; + + // Color space conversion + optional bool convert_to_lab = 13 [default = false]; +} + +// Message that stores parameters used by data transformer for distortion policy +message DistortionParameter { + // The probability of adjusting brightness. + optional float brightness_prob = 1 [default = 0.0]; + // Amount to add to the pixel values within [-delta, delta]. + // The possible value is within [0, 255]. Recommend 32. + optional float brightness_delta = 2 [default = 0.0]; + + // The probability of adjusting contrast. + optional float contrast_prob = 3 [default = 0.0]; + // Lower bound for random contrast factor. Recommend 0.5. + optional float contrast_lower = 4 [default = 0.0]; + // Upper bound for random contrast factor. Recommend 1.5. + optional float contrast_upper = 5 [default = 0.0]; + + // The probability of adjusting hue. + optional float hue_prob = 6 [default = 0.0]; + // Amount to add to the hue channel within [-delta, delta]. + // The possible value is within [0, 180]. Recommend 36. + optional float hue_delta = 7 [default = 0.0]; + + // The probability of adjusting saturation. + optional float saturation_prob = 8 [default = 0.0]; + // Lower bound for the random saturation factor. Recommend 0.5. + optional float saturation_lower = 9 [default = 0.0]; + // Upper bound for the random saturation factor. Recommend 1.5. + optional float saturation_upper = 10 [default = 0.0]; + + // The probability of randomly order the image channels. + optional float random_order_prob = 11 [default = 0.0]; +} + +// Message that stores parameters used by data transformer for expansion policy +message ExpansionParameter { + //Probability of using this expansion policy + optional float prob = 1 [default = 1]; + + // The ratio to expand the image. + optional float max_expand_ratio = 2 [default = 1.]; } // Message that stores parameters shared by loss layers @@ -496,6 +755,16 @@ message AccuracyParameter { optional int32 ignore_label = 3; } +message AnnotatedDataParameter { + // Define the sampler. + repeated BatchSampler batch_sampler = 1; + // Store label name and label id in LabelMap format. + optional string label_map_file = 2; + // If provided, it will replace the AnnotationType stored in each + // AnnotatedDatum. + optional AnnotatedDatum.AnnotationType anno_type = 3; +} + message ArgMaxParameter { // If true produce pairs (argmax, maxval) optional bool out_max_val = 1 [default = false]; @@ -519,21 +788,11 @@ message ConcatParameter { } message BatchNormParameter { - // If false, normalization is performed over the current mini-batch - // and global statistics are accumulated (but not yet used) by a moving - // average. - // If true, those accumulated mean and variance values are used for the - // normalization. - // By default, it is set to false when the network is in the training - // phase and true when the network is in the testing phase. + // If false, accumulate global mean/variance values via a moving average. If + // true, use those accumulated values instead of computing mean/variance + // across the batch. optional bool use_global_stats = 1; - // What fraction of the moving average remains each iteration? - // Smaller values make the moving average decay faster, giving more - // weight to the recent values. - // Each iteration updates the moving average @f$S_{t-1}@f$ with the - // current mean @f$ Y_t @f$ by - // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$ - // is the moving_average_fraction parameter. + // How much does the moving average decay each iteration? optional float moving_average_fraction = 2 [default = .999]; // Small value to add to the variance estimate so that we don't divide by // zero. @@ -684,11 +943,100 @@ message DataParameter { optional bool mirror = 6 [default = false]; // Force the encoded image to have 3 color channels optional bool force_encoded_color = 9 [default = false]; - // Prefetch queue (Increase if data feeding bandwidth varies, within the - // limit of device memory for GPU training) + // Prefetch queue (Number of batches to prefetch to host memory, increase if + // data access bandwidth varies). optional uint32 prefetch = 10 [default = 4]; } +// Message that store parameters used by DetectionEvaluateLayer +message DetectionEvaluateParameter { + // Number of classes that are actually predicted. Required! + optional uint32 num_classes = 1; + // Label id for background class. Needed for sanity check so that + // background class is neither in the ground truth nor the detections. + optional uint32 background_label_id = 2 [default = 0]; + // Threshold for deciding true/false positive. + optional float overlap_threshold = 3 [default = 0.5]; + // If true, also consider difficult ground truth for evaluation. + optional bool evaluate_difficult_gt = 4 [default = true]; + // A file which contains a list of names and sizes with same order + // of the input DB. The file is in the following format: + // name height width + // ... + // If provided, we will scale the prediction and ground truth NormalizedBBox + // for evaluation. + optional string name_size_file = 5; + // The resize parameter used in converting NormalizedBBox to original image. + optional ResizeParameter resize_param = 6; +} + +message NonMaximumSuppressionParameter { + // Threshold to be used in nms. + optional float nms_threshold = 1 [default = 0.3]; + // Maximum number of results to be kept. + optional int32 top_k = 2; + // Parameter for adaptive nms. + optional float eta = 3 [default = 1.0]; +} + +message SaveOutputParameter { + // Output directory. If not empty, we will save the results. + optional string output_directory = 1; + // Output name prefix. + optional string output_name_prefix = 2; + // Output format. + // VOC - PASCAL VOC output format. + // COCO - MS COCO output format. + optional string output_format = 3; + // If you want to output results, must also provide the following two files. + // Otherwise, we will ignore saving results. + // label map file. + optional string label_map_file = 4; + // A file which contains a list of names and sizes with same order + // of the input DB. The file is in the following format: + // name height width + // ... + optional string name_size_file = 5; + // Number of test images. It can be less than the lines specified in + // name_size_file. For example, when we only want to evaluate on part + // of the test images. + optional uint32 num_test_image = 6; + // The resize parameter used in saving the data. + optional ResizeParameter resize_param = 7; +} + +// Message that store parameters used by DetectionOutputLayer +message DetectionOutputParameter { + // Number of classes to be predicted. Required! + optional uint32 num_classes = 1; + // If true, bounding box are shared among different classes. + optional bool share_location = 2 [default = true]; + // Background label id. If there is no background class, + // set it as -1. + optional int32 background_label_id = 3 [default = 0]; + // Parameters used for non maximum suppression. + optional NonMaximumSuppressionParameter nms_param = 4; + // Parameters used for saving detection results. + optional SaveOutputParameter save_output_param = 5; + // Type of coding method for bbox. + optional PriorBoxParameter.CodeType code_type = 6 [default = CORNER]; + // If true, variance is encoded in target; otherwise we need to adjust the + // predicted offset accordingly. + optional bool variance_encoded_in_target = 8 [default = false]; + // Number of total bboxes to be kept per image after nms step. + // -1 means keeping all bboxes after nms step. + optional int32 keep_top_k = 7 [default = -1]; + // Only consider detections whose confidences are larger than a threshold. + // If not provided, consider all boxes. + optional float confidence_threshold = 9; + // If true, visualize the detection results. + optional bool visualize = 10 [default = false]; + // The threshold used to visualize the detection results. + optional float visualize_threshold = 11; + // If provided, save outputs to video file. + optional string save_file = 12; +} + message DropoutParameter { optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio } @@ -832,7 +1180,6 @@ message ImageDataParameter { message InfogainLossParameter { // Specify the infogain matrix source. optional string source = 1; - optional int32 axis = 2 [default = 1]; // axis of prob } message InnerProductParameter { @@ -896,6 +1243,78 @@ message MemoryDataParameter { optional uint32 width = 4; } +// Message that store parameters used by MultiBoxLossLayer +message MultiBoxLossParameter { + // Localization loss type. + enum LocLossType { + L2 = 0; + SMOOTH_L1 = 1; + } + optional LocLossType loc_loss_type = 1 [default = SMOOTH_L1]; + // Confidence loss type. + enum ConfLossType { + SOFTMAX = 0; + LOGISTIC = 1; + } + optional ConfLossType conf_loss_type = 2 [default = SOFTMAX]; + // Weight for localization loss. + optional float loc_weight = 3 [default = 1.0]; + // Number of classes to be predicted. Required! + optional uint32 num_classes = 4; + // If true, bounding box are shared among different classes. + optional bool share_location = 5 [default = true]; + // Matching method during training. + enum MatchType { + BIPARTITE = 0; + PER_PREDICTION = 1; + } + optional MatchType match_type = 6 [default = PER_PREDICTION]; + // If match_type is PER_PREDICTION, use overlap_threshold to + // determine the extra matching bboxes. + optional float overlap_threshold = 7 [default = 0.5]; + // Use prior for matching. + optional bool use_prior_for_matching = 8 [default = true]; + // Background label id. + optional uint32 background_label_id = 9 [default = 0]; + // If true, also consider difficult ground truth. + optional bool use_difficult_gt = 10 [default = true]; + // If true, perform negative mining. + // DEPRECATED: use mining_type instead. + optional bool do_neg_mining = 11; + // The negative/positive ratio. + optional float neg_pos_ratio = 12 [default = 3.0]; + // The negative overlap upperbound for the unmatched predictions. + optional float neg_overlap = 13 [default = 0.5]; + // Type of coding method for bbox. + optional PriorBoxParameter.CodeType code_type = 14 [default = CORNER]; + // If true, encode the variance of prior box in the loc loss target instead of + // in bbox. + optional bool encode_variance_in_target = 16 [default = false]; + // If true, map all object classes to agnostic class. It is useful for learning + // objectness detector. + optional bool map_object_to_agnostic = 17 [default = false]; + // If true, ignore cross boundary bbox during matching. + // Cross boundary bbox is a bbox who is outside of the image region. + optional bool ignore_cross_boundary_bbox = 18 [default = false]; + // If true, only backpropagate on corners which are inside of the image + // region when encode_type is CORNER or CORNER_SIZE. + optional bool bp_inside = 19 [default = false]; + // Mining type during training. + // NONE : use all negatives. + // MAX_NEGATIVE : select negatives based on the score. + // HARD_EXAMPLE : select hard examples based on "Training Region-based Object Detectors with Online Hard Example Mining", Shrivastava et.al. + enum MiningType { + NONE = 0; + MAX_NEGATIVE = 1; + HARD_EXAMPLE = 2; + } + optional MiningType mining_type = 20 [default = MAX_NEGATIVE]; + // Parameters used for non maximum suppression durig hard example mining. + optional NonMaximumSuppressionParameter nms_param = 21; + optional int32 sample_size = 22 [default = 64]; + optional bool use_prior_for_nms = 23 [default = false]; +} + message MVNParameter { // This parameter can be set to false to normalize mean only optional bool normalize_variance = 1 [default = true]; @@ -907,10 +1326,28 @@ message MVNParameter { optional float eps = 3 [default = 1e-9]; } +// Message that stores parameters used by NormalizeLayer +message NormalizeParameter { + optional bool across_spatial = 1 [default = true]; + // Initial value of scale. Default is 1.0 for all + optional FillerParameter scale_filler = 2; + // Whether or not scale parameters are shared across channels. + optional bool channel_shared = 3 [default = true]; + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 4 [default = 1e-10]; +} + message ParameterParameter { optional BlobShape shape = 1; } +message PermuteParameter { + // The new orders of the axes of data. Notice it should be with + // in the same range as the input data, and it starts from 0. + // Do not provide repeated order. + repeated uint32 order = 1; +} + message PoolingParameter { enum PoolMethod { MAX = 0; @@ -947,17 +1384,46 @@ message PowerParameter { optional float shift = 3 [default = 0.0]; } -// Message that stores parameters used by ProposalLayer -message ProposalParameter { - optional uint32 feat_stride = 1 [default = 16]; - repeated uint32 scales = 2; - repeated float ratios = 3; -} - -message PSROIAlignParameter { - required float spatial_scale = 1; - required int32 output_dim = 2; // output channel number - required int32 group_size = 3; // number of groups to encode position-sensitive score maps +// Message that store parameters used by PriorBoxLayer +message PriorBoxParameter { + // Encode/decode type. + enum CodeType { + CORNER = 1; + CENTER_SIZE = 2; + CORNER_SIZE = 3; + } + // Minimum box size (in pixels). Required! + repeated float min_size = 1; + // Maximum box size (in pixels). Required! + repeated float max_size = 2; + // Various of aspect ratios. Duplicate ratios will be ignored. + // If none is provided, we use default ratio 1. + repeated float aspect_ratio = 3; + // If true, will flip each aspect ratio. + // For example, if there is aspect ratio "r", + // we will generate aspect ratio "1.0/r" as well. + optional bool flip = 4 [default = true]; + // If true, will clip the prior so that it is within [0, 1] + optional bool clip = 5 [default = false]; + // Variance for adjusting the prior bboxes. + repeated float variance = 6; + // By default, we calculate img_height, img_width, step_x, step_y based on + // bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely + // provided. + // Explicitly provide the img_size. + optional uint32 img_size = 7; + // Either img_size or img_h/img_w should be specified; not both. + optional uint32 img_h = 8; + optional uint32 img_w = 9; + + // Explicitly provide the step size. + optional float step = 10; + // Either step or step_h/step_w should be specified; not both. + optional float step_h = 11; + optional float step_w = 12; + + // Offset to the top left corner of each cell. + optional float offset = 13 [default = 0.5]; } message PythonParameter { @@ -968,7 +1434,9 @@ message PythonParameter { // string, dictionary in Python dict format, JSON, etc. You may parse this // string in `setup` method and use it in `forward` and `backward`. optional string param_str = 3 [default = '']; - // DEPRECATED + // Whether this PythonLayer is shared among worker solvers during data parallelism. + // If true, each worker solver sequentially run forward from this layer. + // This value should be set true if you are using it as a data layer. optional bool share_in_parallel = 4 [default = false]; } @@ -1195,6 +1663,18 @@ message ThresholdParameter { optional float threshold = 1 [default = 0]; // Strictly positive values } +message VideoDataParameter{ + enum VideoType { + WEBCAM = 0; + VIDEO = 1; + } + optional VideoType video_type = 1 [default = WEBCAM]; + optional int32 device_id = 2 [default = 0]; + optional string video_file = 3; + // Number of frames to be skipped before processing a frame. + optional uint32 skip_frames = 4 [default = 0]; +} + message WindowDataParameter { // Specify the data source. optional string source = 1; @@ -1437,7 +1917,7 @@ message PReLUParameter { // Initial value of a_i. Default is a_i=0.25 for all i. optional FillerParameter filler = 1; - // Whether or not slope parameters are shared across channels. + // Whether or not slope paramters are shared across channels. optional bool channel_shared = 2 [default = false]; }