提交 c180e0b4 编写于 作者: L lichao18

Add SSD box predictor module

上级 05f3dc93
......@@ -45,6 +45,7 @@ extern void RegisterMatMul(OpRegistryBase *op_registry);
extern void RegisterPad(OpRegistryBase *op_registry);
extern void RegisterPooling(OpRegistryBase *op_registry);
extern void RegisterReduce(OpRegistryBase *op_registry);
extern void RegisterPriorBox(OpRegistryBase *op_registry);
extern void RegisterReshape(OpRegistryBase *op_registry);
extern void RegisterResizeBicubic(OpRegistryBase *op_registry);
extern void RegisterResizeBilinear(OpRegistryBase *op_registry);
......@@ -103,6 +104,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterPad(this);
ops::RegisterPooling(this);
ops::RegisterReduce(this);
ops::RegisterPriorBox(this);
ops::RegisterReshape(this);
ops::RegisterResizeBicubic(this);
ops::RegisterResizeBilinear(this);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template<DeviceType D, typename T>
class PriorBoxOp : public Operation {
public:
explicit PriorBoxOp(OpConstructContext *context)
: Operation(context),
min_size_(Operation::GetRepeatedArgs<float>("min_size")),
max_size_(Operation::GetRepeatedArgs<float>("max_size")),
aspect_ratio_(Operation::GetRepeatedArgs<float>("aspect_ratio")),
flip_(Operation::GetOptionalArg<bool>("flip", true)),
clip_(Operation::GetOptionalArg<bool>("clip", false)),
variance_(Operation::GetRepeatedArgs<float>("variance")),
offset_(Operation::GetOptionalArg<float>("offset", 0.5)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(INPUT);
const Tensor *data = this->Input(DATA);
Tensor *output = this->Output(OUTPUT);
const std::vector<index_t> &input_shape = input->shape();
const std::vector<index_t> &data_shape = data->shape();
const index_t input_w = input_shape[3];
const index_t input_h = input_shape[2];
const index_t image_w = data_shape[3];
const index_t image_h = data_shape[2];
float step_h = static_cast<float>(image_h) / static_cast<float>(input_h);
float step_w = static_cast<float>(image_w) / static_cast<float>(input_w);
if (Operation::GetOptionalArg<float>("step_h", 0) != 0 &&
Operation::GetOptionalArg<float>("step_w", 0) != 0) {
step_h = Operation::GetOptionalArg<float>("step_h", 0);
step_w = Operation::GetOptionalArg<float>("step_w", 0);
}
const index_t num_min_size = min_size_.size();
MACE_CHECK(num_min_size > 0, "min_size is required!");
const index_t num_max_size = max_size_.size();
const index_t num_aspect_ratio = aspect_ratio_.size();
MACE_CHECK(num_aspect_ratio > 0, "aspect_ratio is required!");
index_t num_prior = num_min_size * num_aspect_ratio +
num_min_size + num_max_size;
if (flip_)
num_prior += num_min_size * num_aspect_ratio;
index_t dim = 4 * input_w * input_h * num_prior;
std::vector<index_t> output_shape = {1, 2, dim};
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard output_guard(output);
T *output_data = output->mutable_data<T>();
float box_w, box_h;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < input_h; ++i) {
for (index_t j = 0; j < input_w; ++j) {
index_t idx = i * input_w * num_prior * 4;
float center_y = (offset_ + i) * step_h;
float center_x = (offset_ + j) * step_w;
for (index_t k = 0; k < num_min_size; ++k) {
float min_s = min_size_[k];
box_w = box_h = min_s * 0.5;
output_data[idx + 0] = (center_x - box_w) / image_w;
output_data[idx + 1] = (center_y - box_h) / image_h;
output_data[idx + 2] = (center_x + box_w) / image_w;
output_data[idx + 3] = (center_y + box_h) / image_h;
idx += 4;
if (num_max_size > 0) {
float max_s_ = max_size_[k];
box_w = box_h = sqrt(max_s_ * min_s) * 0.5f;
output_data[idx + 0] = (center_x - box_w) / image_w;
output_data[idx + 1] = (center_y - box_h) / image_h;
output_data[idx + 2] = (center_x + box_w) / image_w;
output_data[idx + 3] = (center_y + box_h) / image_h;
idx += 4;
}
for (int l = 0; l < num_aspect_ratio; ++l) {
float ar = aspect_ratio_[l];
box_w = min_s * sqrt(ar) * 0.5f;
box_h = min_s / sqrt(ar) * 0.5f;
output_data[idx + 0] = (center_x - box_w) / image_w;
output_data[idx + 1] = (center_y - box_h) / image_h;
output_data[idx + 2] = (center_x + box_w) / image_w;
output_data[idx + 3] = (center_y + box_h) / image_h;
idx += 4;
if (flip_) {
output_data[idx + 0] = (center_x - box_h) / image_w;
output_data[idx + 1] = (center_y - box_w) / image_h;
output_data[idx + 2] = (center_x + box_h) / image_w;
output_data[idx + 3] = (center_y + box_w) / image_h;
idx += 4;
}
}
}
}
}
if (clip_) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < dim; ++i) {
T min = 0;
T max = 1;
output_data[i] = std::min(std::max(output_data[i], min), max);
}
}
output_data += dim;
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < dim / 4; ++i) {
int index = i * 4;
output_data[0 + index] = variance_[0];
output_data[1 + index] = variance_[1];
output_data[2 + index] = variance_[2];
output_data[3 + index] = variance_[3];
}
return MaceStatus::MACE_SUCCESS;
}
private:
std::vector<float> min_size_;
std::vector<float> max_size_;
std::vector<float> aspect_ratio_;
bool flip_;
bool clip_;
std::vector<float> variance_;
const float offset_;
private:
MACE_OP_INPUT_TAGS(INPUT, DATA);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
void RegisterPriorBox(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "PriorBox", PriorBoxOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void PriorBox(
int iters, float min_size, float max_size, float aspect_ratio, int flip,
int clip, float variance0, float variance1, float offset, int h) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("INPUT", {1, h, 1, 1});
net.AddRandomInput<D, float>("DATA", {1, 3, 300, 300});
OpDefBuilder("PriorBox", "PriorBoxBM")
.Input("INPUT")
.Input("DATA")
.Output("OUTPUT")
.AddFloatsArg("min_size", {min_size})
.AddFloatsArg("max_size", {max_size})
.AddFloatsArg("aspect_ratio", {aspect_ratio})
.AddIntArg("flip", flip)
.AddIntArg("clip", clip)
.AddFloatsArg("variance", {variance0, variance0, variance1, variance1})
.AddFloatArg("offset", offset)
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
const int64_t tot = static_cast<int64_t>(iters) * (300 * 300 * 3 + h);
mace::testing::MaccProcessed(tot);
testing::BytesProcessed(tot * sizeof(T));
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
}
} // namespace
#define MACE_BM_PRIOR_BOX(MIN, MAX, AR, FLIP, CLIP, V0, V1, OFFSET, H) \
static void MACE_BM_PRIOR_BOX_##MIN##_##MAX##_##AR##_##FLIP##_##CLIP##_##V0##\
_##V1##_##OFFSET##_##H(int iters) { \
PriorBox<DeviceType::CPU, float>(iters, MIN, MAX, AR, FLIP, CLIP, V0, V1, \
OFFSET, H); \
} \
MACE_BENCHMARK(MACE_BM_PRIOR_BOX_##MIN##_##MAX##_##AR##_##FLIP##_##CLIP##_## \
V0##_##V1##_##OFFSET##_##H)
MACE_BM_PRIOR_BOX(285, 300, 2, 1, 0, 1, 2, 1, 128);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class PriorBoxOpTest : public OpsTestBase {};
TEST_F(PriorBoxOpTest, Simple) {
OpsTestNet net;
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("INPUT", {1, 128, 1, 1});
net.AddRandomInput<DeviceType::CPU, float>("DATA", {1, 3, 300, 300});
OpDefBuilder("PriorBox", "PriorBoxTest")
.Input("INPUT")
.Input("DATA")
.Output("OUTPUT")
.AddFloatsArg("min_size", {285})
.AddFloatsArg("max_size", {300})
.AddFloatsArg("aspect_ratio", {2, 3})
.AddIntArg("flip", 1)
.AddIntArg("clip", 0)
.AddFloatsArg("variance", {0.1, 0.1, 0.2, 0.2})
.AddFloatArg("offset", 0.5)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(DeviceType::CPU);
// Check
auto expected_tensor = net.CreateTensor<float>({1, 2, 24},
{0.025, 0.025, 0.975, 0.975,
0.012660282759551838, 0.012660282759551838,
0.9873397172404482, 0.9873397172404482,
-0.17175144212722018, 0.16412427893638995,
1.1717514421272204, 0.8358757210636101,
0.16412427893638995, -0.17175144212722018,
0.8358757210636101, 1.1717514421272204,
-0.3227241335952166, 0.22575862213492773,
1.3227241335952165, 0.7742413778650723,
0.22575862213492773, -0.3227241335952166,
0.7742413778650723, 1.3227241335952165,
0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2,
0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2});
ExpectTensorNear<float>(*expected_tensor, *net.GetTensor("OUTPUT"));
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -28,6 +28,11 @@ class ReshapeOp : public Operation {
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(INPUT);
const std::vector<index_t> &input_shape = input->shape();
int axis = Operation::GetOptionalArg<int>("reshape_axis", 0);
int num_axes = Operation::GetOptionalArg<int>("num_axes", -1);
MACE_CHECK(axis == 0 && num_axes == -1,
"Only support axis = 0 and num_axes = -1");
const Tensor *shape = this->Input(SHAPE);
const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0);
Tensor::MappingGuard shape_guard(shape);
......@@ -43,8 +48,12 @@ class ReshapeOp : public Operation {
MACE_CHECK(unknown_idx == -1, "Only one input size may be -1");
unknown_idx = i;
out_shape.push_back(1);
} else if (shape_data[i] == 0) {
MACE_CHECK(shape_data[i] == 0, "Shape should be 0");
out_shape.push_back(input_shape[i]);
product *= input_shape[i];
} else {
MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ",
MACE_CHECK(shape_data[i] > 0, "Shape must be non-negative: ",
shape_data[i]);
if (shape_data[i] == 0) {
MACE_CHECK(i < input->dim_size(),
......
......@@ -92,9 +92,17 @@ class SoftmaxOp<DeviceType::CPU, float> : public Operation {
}
} // k
} // b
} else if (input->dim_size() == 2) { // normal 2d softmax
const index_t class_size = input->dim(0);
const index_t class_count = input->dim(1);
} else if (input->dim_size() == 2 || input->dim_size() == 3) {
// normal 2d softmax and 3d softmax (dim(0) is batch)
index_t class_size = 0;
index_t class_count = 0;
if (input->dim_size() == 2) {
class_size = input->dim(0);
class_count = input->dim(1);
} else {
class_size = input->dim(0) * input->dim(1);
class_count = input->dim(2);
}
#pragma omp parallel for schedule(runtime)
for (index_t k = 0; k < class_size; ++k) {
const float *input_ptr = input_data + k * class_count;
......
......@@ -123,6 +123,7 @@ MaceSupportedOps = [
'MatMul',
'Pad',
'Pooling',
'PriorBox',
'Proposal',
'Quantize',
'Reduce',
......@@ -174,8 +175,11 @@ class MaceKeyword(object):
mace_space_batch_block_shape_str = 'block_shape'
mace_space_depth_block_size_str = 'block_size'
mace_constant_value_str = 'constant_value'
mace_dim_str = 'dim'
mace_dims_str = 'dims'
mace_axis_str = 'axis'
mace_end_axis_str = 'end_axis'
mace_num_axes_str = 'num_axes'
mace_num_split_str = 'num_split'
mace_keepdims_str = 'keepdims'
mace_shape_str = 'shape'
......@@ -205,6 +209,14 @@ class MaceKeyword(object):
mace_reduce_type_str = 'reduce_type'
mace_argmin_str = 'argmin'
mace_round_mode_str = 'round_mode'
mace_min_size_str = 'min_size'
mace_max_size_str = 'max_size'
mace_aspect_ratio_str = 'aspect_ratio'
mace_flip_str = 'flip'
mace_clip_str = 'clip'
mace_variance_str = 'variance'
mace_step_h_str = 'step_h'
mace_step_w_str = 'step_w'
class TransformerRule(Enum):
......@@ -243,6 +255,7 @@ class TransformerRule(Enum):
FOLD_SQRDIFF_MEAN = 33
TRANSPOSE_MATMUL_WEIGHT = 34
FOLD_EMBEDDING_LOOKUP = 35
TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN = 36
class ConverterInterface(object):
......@@ -433,6 +446,7 @@ class ConverterOption(object):
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
TransformerRule.TRANSFORM_BASIC_LSTMCELL,
TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN,
TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM,
......
......@@ -188,6 +188,10 @@ class CaffeConverter(base_converter.ConverterInterface):
'Crop': self.convert_crop,
'Scale': self.convert_scale,
'ShuffleChannel': self.convert_channel_shuffle,
'Permute': self.convert_permute,
'Flatten': self.convert_flatten,
'PriorBox': self.convert_prior_box,
'Reshape': self.convert_reshape,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -565,8 +569,6 @@ class CaffeConverter(base_converter.ConverterInterface):
axis_arg.i = param.axis
elif param.HasField('concat_dim'):
axis_arg.i = param.concat_dim
axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i
mace_check(axis_arg.i == 1, "only support concat at channel dimension")
def convert_slice(self, caffe_op):
op = self.convert_general_op(caffe_op)
......@@ -668,3 +670,107 @@ class CaffeConverter(base_converter.ConverterInterface):
group_arg.i = 1
if param.HasField('group'):
group_arg.i = param.group
def convert_permute(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.permute_param
op.type = MaceOp.Transpose.name
dims_arg = op.arg.add()
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend(list(param.order))
def convert_flatten(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.flatten_param
op.type = MaceOp.Reshape.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 1
if param.HasField('axis'):
axis_arg.i = param.axis
axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i
end_axis_arg = op.arg.add()
end_axis_arg.name = MaceKeyword.mace_end_axis_str
end_axis_arg.i = -1
if param.HasField('end_axis'):
end_axis_arg.i = param.end_axis
def convert_prior_box(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.prior_box_param
op.type = MaceOp.PriorBox.name
min_size_arg = op.arg.add()
min_size_arg.name = MaceKeyword.mace_min_size_str
min_size_arg.floats.extend(list(param.min_size))
max_size_arg = op.arg.add()
max_size_arg.name = MaceKeyword.mace_max_size_str
max_size_arg.floats.extend(list(param.max_size))
aspect_ratio_arg = op.arg.add()
aspect_ratio_arg.name = MaceKeyword.mace_aspect_ratio_str
aspect_ratio_arg.floats.extend(list(param.aspect_ratio))
flip_arg = op.arg.add()
flip_arg.name = MaceKeyword.mace_flip_str
flip_arg.i = 1
if param.HasField('flip'):
flip_arg.i = int(param.flip)
clip_arg = op.arg.add()
clip_arg.name = MaceKeyword.mace_clip_str
clip_arg.i = 0
if param.HasField('clip'):
clip_arg.i = int(param.clip)
variance_arg = op.arg.add()
variance_arg.name = MaceKeyword.mace_variance_str
variance_arg.floats.extend(list(param.variance))
offset_arg = op.arg.add()
offset_arg.name = MaceKeyword.mace_offset_str
offset_arg.f = 0.5
if param.HasField('offset'):
offset_arg.f = param.offset
step_h_arg = op.arg.add()
step_h_arg.name = MaceKeyword.mace_step_h_str
step_h_arg.f = 0
if param.HasField('step_h'):
mace_check(not param.HasField('step'),
"Either step or step_h/step_w should be specified; not both.") # noqa
step_h_arg.f = param.step_h
mace_check(step_h_arg.f > 0, "step_h should be larger than 0.")
step_w_arg = op.arg.add()
step_w_arg.name = MaceKeyword.mace_step_w_str
step_w_arg.f = 0
if param.HasField('step_w'):
mace_check(not param.HasField('step'),
"Either step or step_h/step_w should be specified; not both.") # noqa
step_w_arg.f = param.step_w
mace_check(step_w_arg.f > 0, "step_w should be larger than 0.")
if param.HasField('step'):
mace_check(not param.HasField('step_h') and not param.HasField('step_w'), # noqa
"Either step or step_h/step_w should be specified; not both.") # noqa
mace_check(param.step > 0, "step should be larger than 0.")
step_h_arg.f = param.step
step_w_arg.f = param.step
def convert_reshape(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.reshape_param
op.type = MaceOp.Reshape.name
dim_arg = op.arg.add()
dim_arg.name = MaceKeyword.mace_dim_str
dim_arg.ints.extend(list(param.shape.dim))
axis_arg = op.arg.add()
axis_arg.name = 'reshape_' + MaceKeyword.mace_axis_str
axis_arg.i = 0
if param.HasField('axis'):
axis_arg.i = param.axis
num_axes_arg = op.arg.add()
num_axes_arg.name = MaceKeyword.mace_num_axes_str
num_axes_arg.i = -1
if param.HasField('num_axes'):
num_axes_arg.i = param.num_axes
......@@ -49,6 +49,9 @@ class ShapeInference(object):
MaceOp.Crop.name: self.infer_shape_crop,
MaceOp.BiasAdd.name: self.infer_shape_general,
MaceOp.ChannelShuffle.name: self.infer_shape_channel_shuffle,
MaceOp.Transpose.name: self.infer_shape_permute,
MaceOp.PriorBox.name: self.infer_shape_prior_box,
MaceOp.Reshape.name: self.infer_shape_reshape,
}
self._net = net
......@@ -190,12 +193,14 @@ class ShapeInference(object):
self.add_output_shape(op, [output_shape])
def infer_shape_concat(self, op):
output_shape = self._output_shape_cache[op.input[0]]
output_shape = list(self._output_shape_cache[op.input[0]])
axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i
if axis < 0:
axis = len(output_shape) + axis
output_shape[axis] = 0
for input_node in op.input:
input_shape = self._output_shape_cache[input_node]
output_shape[axis] += input_shape[axis]
input_shape = list(self._output_shape_cache[input_node])
output_shape[axis] = output_shape[axis] + input_shape[axis]
self.add_output_shape(op, [output_shape])
def infer_shape_slice(self, op):
......@@ -225,3 +230,64 @@ class ShapeInference(object):
def infer_shape_channel_shuffle(self, op):
output_shape = self._output_shape_cache[op.input[0]]
self.add_output_shape(op, [output_shape])
def infer_shape_permute(self, op):
output_shape = list(self._output_shape_cache[op.input[0]])
dims = ConverterUtil.get_arg(op, MaceKeyword.mace_dims_str).ints
for i in xrange(len(dims)):
output_shape[i] = self._output_shape_cache[op.input[0]][dims[i]]
self.add_output_shape(op, [output_shape])
def infer_shape_prior_box(self, op):
output_shape = [1, 2, 1]
input_shape = list(self._output_shape_cache[op.input[0]])
input_w = input_shape[3]
input_h = input_shape[2]
min_size = ConverterUtil.get_arg(op, MaceKeyword.mace_min_size_str).floats # noqa
max_size = ConverterUtil.get_arg(op, MaceKeyword.mace_max_size_str).floats # noqa
aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa
flip = ConverterUtil.get_arg(op, MaceKeyword.mace_flip_str).i # noqa
num_prior = (len(min_size) * len(aspect_ratio) +
len(min_size) + len(max_size))
if flip:
num_prior = num_prior + len(min_size) * len(aspect_ratio)
output_shape[2] = num_prior * input_h * input_w * 4
self.add_output_shape(op, [output_shape])
def infer_shape_reshape(self, op):
if ConverterUtil.get_arg(op, MaceKeyword.mace_dim_str) is not None:
dim = ConverterUtil.get_arg(op, MaceKeyword.mace_dim_str).ints
output_shape = list(dim)
product = input_size = 1
idx = -1
for i in range(len(self._output_shape_cache[op.input[0]])):
input_size *= self._output_shape_cache[op.input[0]][i]
for i in range(len(dim)):
if dim[i] == 0:
output_shape[i] = self._output_shape_cache[op.input[0]][i]
product *= self._output_shape_cache[op.input[0]][i]
elif dim[i] == -1:
idx = i
output_shape[i] = 1
else:
output_shape[i] = dim[i]
product *= dim[i]
if idx != -1:
output_shape[idx] = input_size / product
self.add_output_shape(op, [output_shape])
else:
output_shape = list(self._output_shape_cache[op.input[0]])
axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i
end_axis = ConverterUtil.get_arg(op, MaceKeyword.mace_end_axis_str).i # noqa
if end_axis < 0:
end_axis = len(output_shape) + end_axis
dim = 1
for i in range(0, axis):
output_shape[i] = self._output_shape_cache[op.input[0]][i]
for i in range(axis, end_axis + 1):
dim *= self._output_shape_cache[op.input[0]][i]
output_shape[i] = 1
for i in range(end_axis + 1, len(output_shape)):
output_shape[i] = self._output_shape_cache[op.input[0]][i]
output_shape[axis] = dim
self.add_output_shape(op, [output_shape])
......@@ -95,6 +95,8 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution,
TransformerRule.CHECK_QUANTIZE_INFO:
self.check_quantize_info,
TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN:
self.transform_caffe_reshape_and_flatten,
}
self._option = option
......@@ -979,8 +981,9 @@ class Transformer(base_converter.ConverterInterface):
elif op.type == MaceOp.Concat.name or op.type == MaceOp.Split.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if ConverterUtil.data_format(op) == DataFormat.NCHW \
and self._target_data_format == DataFormat.NHWC: # noqa
if (ConverterUtil.data_format(op) == DataFormat.NCHW
and self._target_data_format == DataFormat.NHWC
and len(op.output_shape[0].dims) == 4):
print("Transpose concat/split args: %s(%s)"
% (op.name, op.type))
if arg.i == 1:
......@@ -1231,6 +1234,7 @@ class Transformer(base_converter.ConverterInterface):
# transform input(4D) -> reshape(2D) -> matmul to fc
# work for TensorFlow
if op.type == MaceOp.Reshape.name and \
len(op.input) == 2 and \
op.input[1] in self._consts and \
len(op.output_shape[0].dims) == 2 and \
filter_format == FilterFormat.HWIO and \
......@@ -1714,3 +1718,35 @@ class Transformer(base_converter.ConverterInterface):
output_info.name = check_node.name
output_info.dims.extend(check_node.output_shape[0].dims)
output_info.data_type = mace_pb2.DT_FLOAT
def transform_caffe_reshape_and_flatten(self):
net = self._model
for op in net.op:
if op.type == MaceOp.Reshape.name and \
len(op.input) == 1:
print("Transform Caffe Reshape")
if op.arg[3].name == 'dim':
shape_tensor = net.tensors.add()
shape_tensor.name = op.name + '_shape'
shape_tensor.dims.append(len(op.output_shape[0].dims))
shape_tensor.data_type = mace_pb2.DT_INT32
shape_tensor.int32_data.extend(op.arg[3].ints)
op.input.append(shape_tensor.name)
else:
axis = op.arg[3].i
dims = [1] * len(op.output_shape[0].dims)
end_axis = op.arg[4].i
end_axis = end_axis if end_axis >= 0 else end_axis + len(dims) # noqa
for i in range(0, axis):
dims[i] = 0
for i in range(axis + 1, end_axis + 1):
dims[i] = 1
for i in range(end_axis + 1, len(dims)):
dims[i] = 0
dims[axis] = -1
shape_tensor = net.tensors.add()
shape_tensor.name = op.name + '_shape'
shape_tensor.dims.append(len(dims))
shape_tensor.data_type = mace_pb2.DT_INT32
shape_tensor.int32_data.extend(dims)
op.input.append(shape_tensor.name)
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cmath>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "mace/utils/logging.h"
namespace mace {
struct BBox {
float xmin;
float ymin;
float xmax;
float ymax;
int label;
float confidence;
};
namespace {
inline float overlap(const BBox &a, const BBox &b) {
if (a.xmin > b.xmax || a.xmax < b.xmin ||
a.ymin > b.ymax || a.ymax < b.ymin) {
return 0.f;
}
float overlap_w = std::min(a.xmax, b.xmax) - std::max(a.xmin, b.xmin);
float overlap_h = std::min(a.ymax, b.ymax) - std::max(a.ymin, b.ymin);
return overlap_w * overlap_h;
}
void NmsSortedBboxes(const std::vector<BBox> &bboxes,
const float nms_threshold,
const int top_k,
std::vector<BBox> *sorted_boxes) {
const int n = std::min(top_k, static_cast<int>(bboxes.size()));
std::vector<int> picked;
std::vector<float> areas(n);
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < n; ++i) {
const BBox &r = bboxes[i];
float width = std::max(0.f, r.xmax - r.xmin);
float height = std::max(0.f, r.ymax - r.ymin);
areas[i] = width * height;
}
for (int i = 0; i < n; ++i) {
const BBox &a = bboxes[i];
int keep = 1;
for (size_t j = 0; j < picked.size(); ++j) {
const BBox &b = bboxes[picked[j]];
float inter_area = overlap(a, b);
float union_area = areas[i] + areas[picked[j]] - inter_area;
MACE_CHECK(union_area > 0, "union_area should be greater than 0");
if (inter_area / union_area > nms_threshold) {
keep = 0;
break;
}
}
if (keep) {
picked.push_back(i);
sorted_boxes->push_back(bboxes[i]);
}
}
}
inline bool cmp(const BBox &a, const BBox &b) {
return a.confidence > b.confidence;
}
} // namespace
int DetectionOutput(const float *loc_ptr,
const float *conf_ptr,
const float *pbox_ptr,
const int num_prior,
const int num_classes,
const float nms_threshold,
const int top_k,
const int keep_top_k,
const float confidence_threshold,
std::vector<BBox> *bbox_rects) {
MACE_CHECK(keep_top_k > 0, "keep_top_k should be greater than 0");
std::vector<float> bboxes(4 * num_prior);
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < num_prior; ++i) {
int index = i * 4;
const float *lc = loc_ptr + index;
const float *pb = pbox_ptr + index;
const float *var = pb + num_prior * 4;
float pb_w = pb[2] - pb[0];
float pb_h = pb[3] - pb[1];
float pb_cx = (pb[0] + pb[2]) * 0.5f;
float pb_cy = (pb[1] + pb[3]) * 0.5f;
float bbox_cx = var[0] * lc[0] * pb_w + pb_cx;
float bbox_cy = var[1] * lc[1] * pb_h + pb_cy;
float bbox_w = std::exp(var[2] * lc[2]) * pb_w;
float bbox_h = std::exp(var[3] * lc[3]) * pb_h;
bboxes[0 + index] = bbox_cx - bbox_w * 0.5f;
bboxes[1 + index] = bbox_cy - bbox_h * 0.5f;
bboxes[2 + index] = bbox_cx + bbox_w * 0.5f;
bboxes[3 + index] = bbox_cy + bbox_h * 0.5f;
}
// start from 1 to ignore background class
for (int i = 1; i < num_classes; ++i) {
// filter by confidence threshold
std::vector<BBox> class_bbox_rects;
for (int j = 0; j < num_prior; ++j) {
float confidence = conf_ptr[j * num_classes + i];
if (confidence > confidence_threshold) {
BBox c = {bboxes[0 + j * 4], bboxes[1 + j * 4], bboxes[2 + j * 4],
bboxes[3 + j * 4], i, confidence};
class_bbox_rects.push_back(c);
}
}
std::sort(class_bbox_rects.begin(), class_bbox_rects.end(), cmp);
// apply nms
std::vector<BBox> sorted_boxes;
NmsSortedBboxes(class_bbox_rects,
nms_threshold,
std::min(top_k,
static_cast<int>(class_bbox_rects.size())),
&sorted_boxes);
// gather
bbox_rects->insert(bbox_rects->end(), sorted_boxes.begin(),
sorted_boxes.end());
}
std::sort(bbox_rects->begin(), bbox_rects->end(), cmp);
// output
int num_detected = keep_top_k < static_cast<int>(bbox_rects->size()) ?
keep_top_k : static_cast<int>(bbox_rects->size());
bbox_rects->resize(num_detected);
return num_detected;
}
} // namespace mace
......@@ -40,6 +40,120 @@ message Datum {
optional bool encoded = 7 [default = false];
}
// The label (display) name and label id.
message LabelMapItem {
// Both name and label are required.
optional string name = 1;
optional int32 label = 2;
// display_name is optional.
optional string display_name = 3;
}
message LabelMap {
repeated LabelMapItem item = 1;
}
// Sample a bbox in the normalized space [0, 1] with provided constraints.
message Sampler {
// Minimum scale of the sampled bbox.
optional float min_scale = 1 [default = 1.];
// Maximum scale of the sampled bbox.
optional float max_scale = 2 [default = 1.];
// Minimum aspect ratio of the sampled bbox.
optional float min_aspect_ratio = 3 [default = 1.];
// Maximum aspect ratio of the sampled bbox.
optional float max_aspect_ratio = 4 [default = 1.];
}
// Constraints for selecting sampled bbox.
message SampleConstraint {
// Minimum Jaccard overlap between sampled bbox and all bboxes in
// AnnotationGroup.
optional float min_jaccard_overlap = 1;
// Maximum Jaccard overlap between sampled bbox and all bboxes in
// AnnotationGroup.
optional float max_jaccard_overlap = 2;
// Minimum coverage of sampled bbox by all bboxes in AnnotationGroup.
optional float min_sample_coverage = 3;
// Maximum coverage of sampled bbox by all bboxes in AnnotationGroup.
optional float max_sample_coverage = 4;
// Minimum coverage of all bboxes in AnnotationGroup by sampled bbox.
optional float min_object_coverage = 5;
// Maximum coverage of all bboxes in AnnotationGroup by sampled bbox.
optional float max_object_coverage = 6;
}
// Sample a batch of bboxes with provided constraints.
message BatchSampler {
// Use original image as the source for sampling.
optional bool use_original_image = 1 [default = true];
// Constraints for sampling bbox.
optional Sampler sampler = 2;
// Constraints for determining if a sampled bbox is positive or negative.
optional SampleConstraint sample_constraint = 3;
// If provided, break when found certain number of samples satisfing the
// sample_constraint.
optional uint32 max_sample = 4;
// Maximum number of trials for sampling to avoid infinite loop.
optional uint32 max_trials = 5 [default = 100];
}
// Condition for emitting annotations.
message EmitConstraint {
enum EmitType {
CENTER = 0;
MIN_OVERLAP = 1;
}
optional EmitType emit_type = 1 [default = CENTER];
// If emit_type is MIN_OVERLAP, provide the emit_overlap.
optional float emit_overlap = 2;
}
// The normalized bounding box [0, 1] w.r.t. the input image size.
message NormalizedBBox {
optional float xmin = 1;
optional float ymin = 2;
optional float xmax = 3;
optional float ymax = 4;
optional int32 label = 5;
optional bool difficult = 6;
optional float score = 7;
optional float size = 8;
}
// Annotation for each object instance.
message Annotation {
optional int32 instance_id = 1 [default = 0];
optional NormalizedBBox bbox = 2;
}
// Group of annotations for a particular label.
message AnnotationGroup {
optional int32 group_label = 1;
repeated Annotation annotation = 2;
}
// An extension of Datum which contains "rich" annotations.
message AnnotatedDatum {
enum AnnotationType {
BBOX = 0;
}
optional Datum datum = 1;
// If there are "rich" annotations, specify the type of annotation.
// Currently it only supports bounding box.
// If there are no "rich" annotations, use label in datum instead.
optional AnnotationType type = 2;
// Each group contains annotation for a particular class.
repeated AnnotationGroup annotation_group = 3;
}
message FillerParameter {
// The filler type.
optional string type = 1 [default = 'constant'];
......@@ -98,7 +212,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 43 (last added: weights)
// SolverParameter next available ID: 44 (last added: plateau_winsize)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
......@@ -128,12 +242,24 @@ message SolverParameter {
// The states for the train/test nets. Must be unspecified or
// specified once per net.
//
// By default, train_state will have phase = TRAIN,
// By default, all states will have solver = true;
// train_state will have phase = TRAIN,
// and all test_state's will have phase = TEST.
// Other defaults are set according to the NetState defaults.
optional NetState train_state = 26;
repeated NetState test_state = 27;
// Evaluation type.
optional string eval_type = 41 [default = "classification"];
// ap_version: different ways of computing Average Precision.
// Check https://sanchom.wordpress.com/tag/average-precision/ for details.
// 11point: the 11-point interpolated average precision. Used in VOC2007.
// MaxIntegral: maximally interpolated AP. Used in VOC2012/ILSVRC.
// Integral: the natural integral of the precision-recall curve.
optional string ap_version = 42 [default = "Integral"];
// If true, display per class result.
optional bool show_per_class_result = 44 [default = false];
// The number of iterations for each test net.
repeated int32 test_iter = 3;
......@@ -165,6 +291,8 @@ message SolverParameter {
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
// - plateau: decreases lr
// if the minimum loss isn't updated for 'plateau_winsize' iters
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
......@@ -180,17 +308,15 @@ message SolverParameter {
optional int32 stepsize = 13;
// the stepsize for learning rate policy "multistep"
repeated int32 stepvalue = 34;
// the stepsize for learning rate policy "plateau"
repeated int32 plateau_winsize = 43;
// Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
// whenever their actual L2 norm is larger.
optional float clip_gradients = 35 [default = -1];
optional int32 snapshot = 14 [default = 0]; // The snapshot interval
// The prefix for the snapshot.
// If not set then is replaced by prototxt file path without extention.
// If is set to directory then is augmented by prototxt file name
// without extention.
optional string snapshot_prefix = 15;
optional string snapshot_prefix = 15; // The prefix for the snapshot.
// whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger.
optional bool snapshot_diff = 16 [default = false];
......@@ -242,19 +368,6 @@ message SolverParameter {
}
// DEPRECATED: use type instead of solver_type
optional SolverType solver_type = 30 [default = SGD];
// Overlap compute and communication for data parallel training
optional bool layer_wise_reduce = 41 [default = true];
// Path to caffemodel file(s) with pretrained weights to initialize finetuning.
// Tha same as command line --weights parameter for caffe train command.
// If command line --weights parameter if specified, it has higher priority
// and owerwrites this one(s).
// If --snapshot command line parameter is specified, this one(s) are ignored.
// If several model files are expected, they can be listed in a one
// weights parameter separated by ',' (like in a command string) or
// in repeated weights parameters separately.
repeated string weights = 42;
}
// A message that stores the solver snapshots
......@@ -263,6 +376,8 @@ message SolverState {
optional string learned_net = 2; // The file that stores the learned net.
repeated BlobProto history = 3; // The history for sgd solvers
optional int32 current_step = 4 [default = 0]; // The current step for learning rate
optional float minimum_loss = 5 [default = 1E38]; // Historical minimum loss
optional int32 iter_last_event = 6 [default = 0]; // The iteration when last lr-update or min_loss-update happend
}
enum Phase {
......@@ -375,6 +490,7 @@ message LayerParameter {
// engine parameter for selecting the implementation.
// The default for the engine is set by the ENGINE switch at compile-time.
optional AccuracyParameter accuracy_param = 102;
optional AnnotatedDataParameter annotated_data_param = 200;
optional ArgMaxParameter argmax_param = 103;
optional BatchNormParameter batch_norm_param = 139;
optional BiasParameter bias_param = 141;
......@@ -383,6 +499,8 @@ message LayerParameter {
optional ConvolutionParameter convolution_param = 106;
optional CropParameter crop_param = 144;
optional DataParameter data_param = 107;
optional DetectionEvaluateParameter detection_evaluate_param = 205;
optional DetectionOutputParameter detection_output_param = 204;
optional DropoutParameter dropout_param = 108;
optional DummyDataParameter dummy_data_param = 109;
optional EltwiseParameter eltwise_param = 110;
......@@ -400,13 +518,15 @@ message LayerParameter {
optional LogParameter log_param = 134;
optional LRNParameter lrn_param = 118;
optional MemoryDataParameter memory_data_param = 119;
optional MultiBoxLossParameter multibox_loss_param = 201;
optional MVNParameter mvn_param = 120;
optional NormalizeParameter norm_param = 206;
optional ParameterParameter parameter_param = 145;
optional PermuteParameter permute_param = 202;
optional PoolingParameter pooling_param = 121;
optional PowerParameter power_param = 122;
optional ProposalParameter proposal_param = 8266713;
optional PReLUParameter prelu_param = 131;
optional PSROIAlignParameter psroi_align_param = 1490;
optional PriorBoxParameter prior_box_param = 203;
optional PythonParameter python_param = 130;
optional RecurrentParameter recurrent_param = 146;
optional ReductionParameter reduction_param = 136;
......@@ -420,6 +540,7 @@ message LayerParameter {
optional TanHParameter tanh_param = 127;
optional ThresholdParameter threshold_param = 128;
optional TileParameter tile_param = 138;
optional VideoDataParameter video_data_param = 207;
optional WindowDataParameter window_data_param = 129;
optional ShuffleChannelParameter shuffle_channel_param = 164;
}
......@@ -435,9 +556,12 @@ message TransformationParameter {
optional bool mirror = 2 [default = false];
// Specify if we would like to randomly crop an image.
optional uint32 crop_size = 3 [default = 0];
optional uint32 crop_h = 11 [default = 0];
optional uint32 crop_w = 12 [default = 0];
// mean_file and mean_value cannot be specified at the same time
optional string mean_file = 4;
// if specified can be repeated once (would subtract it from all the channels)
// if specified can be repeated once (would substract it from all the channels)
// or can be repeated the same number of times as channels
// (would subtract them from the corresponding channel)
repeated float mean_value = 5;
......@@ -445,6 +569,141 @@ message TransformationParameter {
optional bool force_color = 6 [default = false];
// Force the decoded image to have 1 color channels.
optional bool force_gray = 7 [default = false];
// Resize policy
optional ResizeParameter resize_param = 8;
// Noise policy
optional NoiseParameter noise_param = 9;
// Distortion policy
optional DistortionParameter distort_param = 13;
// Expand policy
optional ExpansionParameter expand_param = 14;
// Constraint for emitting the annotation after transformation.
optional EmitConstraint emit_constraint = 10;
}
// Message that stores parameters used by data transformer for resize policy
message ResizeParameter {
//Probability of using this resize policy
optional float prob = 1 [default = 1];
enum Resize_mode {
WARP = 1;
FIT_SMALL_SIZE = 2;
FIT_LARGE_SIZE_AND_PAD = 3;
}
optional Resize_mode resize_mode = 2 [default = WARP];
optional uint32 height = 3 [default = 0];
optional uint32 width = 4 [default = 0];
// A parameter used to update bbox in FIT_SMALL_SIZE mode.
optional uint32 height_scale = 8 [default = 0];
optional uint32 width_scale = 9 [default = 0];
enum Pad_mode {
CONSTANT = 1;
MIRRORED = 2;
REPEAT_NEAREST = 3;
}
// Padding mode for BE_SMALL_SIZE_AND_PAD mode and object centering
optional Pad_mode pad_mode = 5 [default = CONSTANT];
// if specified can be repeated once (would fill all the channels)
// or can be repeated the same number of times as channels
// (would use it them to the corresponding channel)
repeated float pad_value = 6;
enum Interp_mode { //Same as in OpenCV
LINEAR = 1;
AREA = 2;
NEAREST = 3;
CUBIC = 4;
LANCZOS4 = 5;
}
//interpolation for for resizing
repeated Interp_mode interp_mode = 7;
}
message SaltPepperParameter {
//Percentage of pixels
optional float fraction = 1 [default = 0];
repeated float value = 2;
}
// Message that stores parameters used by data transformer for transformation
// policy
message NoiseParameter {
//Probability of using this resize policy
optional float prob = 1 [default = 0];
// Histogram equalized
optional bool hist_eq = 2 [default = false];
// Color inversion
optional bool inverse = 3 [default = false];
// Grayscale
optional bool decolorize = 4 [default = false];
// Gaussian blur
optional bool gauss_blur = 5 [default = false];
// JPEG compression quality (-1 = no compression)
optional float jpeg = 6 [default = -1];
// Posterization
optional bool posterize = 7 [default = false];
// Erosion
optional bool erode = 8 [default = false];
// Salt-and-pepper noise
optional bool saltpepper = 9 [default = false];
optional SaltPepperParameter saltpepper_param = 10;
// Local histogram equalization
optional bool clahe = 11 [default = false];
// Color space conversion
optional bool convert_to_hsv = 12 [default = false];
// Color space conversion
optional bool convert_to_lab = 13 [default = false];
}
// Message that stores parameters used by data transformer for distortion policy
message DistortionParameter {
// The probability of adjusting brightness.
optional float brightness_prob = 1 [default = 0.0];
// Amount to add to the pixel values within [-delta, delta].
// The possible value is within [0, 255]. Recommend 32.
optional float brightness_delta = 2 [default = 0.0];
// The probability of adjusting contrast.
optional float contrast_prob = 3 [default = 0.0];
// Lower bound for random contrast factor. Recommend 0.5.
optional float contrast_lower = 4 [default = 0.0];
// Upper bound for random contrast factor. Recommend 1.5.
optional float contrast_upper = 5 [default = 0.0];
// The probability of adjusting hue.
optional float hue_prob = 6 [default = 0.0];
// Amount to add to the hue channel within [-delta, delta].
// The possible value is within [0, 180]. Recommend 36.
optional float hue_delta = 7 [default = 0.0];
// The probability of adjusting saturation.
optional float saturation_prob = 8 [default = 0.0];
// Lower bound for the random saturation factor. Recommend 0.5.
optional float saturation_lower = 9 [default = 0.0];
// Upper bound for the random saturation factor. Recommend 1.5.
optional float saturation_upper = 10 [default = 0.0];
// The probability of randomly order the image channels.
optional float random_order_prob = 11 [default = 0.0];
}
// Message that stores parameters used by data transformer for expansion policy
message ExpansionParameter {
//Probability of using this expansion policy
optional float prob = 1 [default = 1];
// The ratio to expand the image.
optional float max_expand_ratio = 2 [default = 1.];
}
// Message that stores parameters shared by loss layers
......@@ -496,6 +755,16 @@ message AccuracyParameter {
optional int32 ignore_label = 3;
}
message AnnotatedDataParameter {
// Define the sampler.
repeated BatchSampler batch_sampler = 1;
// Store label name and label id in LabelMap format.
optional string label_map_file = 2;
// If provided, it will replace the AnnotationType stored in each
// AnnotatedDatum.
optional AnnotatedDatum.AnnotationType anno_type = 3;
}
message ArgMaxParameter {
// If true produce pairs (argmax, maxval)
optional bool out_max_val = 1 [default = false];
......@@ -519,21 +788,11 @@ message ConcatParameter {
}
message BatchNormParameter {
// If false, normalization is performed over the current mini-batch
// and global statistics are accumulated (but not yet used) by a moving
// average.
// If true, those accumulated mean and variance values are used for the
// normalization.
// By default, it is set to false when the network is in the training
// phase and true when the network is in the testing phase.
// If false, accumulate global mean/variance values via a moving average. If
// true, use those accumulated values instead of computing mean/variance
// across the batch.
optional bool use_global_stats = 1;
// What fraction of the moving average remains each iteration?
// Smaller values make the moving average decay faster, giving more
// weight to the recent values.
// Each iteration updates the moving average @f$S_{t-1}@f$ with the
// current mean @f$ Y_t @f$ by
// @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$
// is the moving_average_fraction parameter.
// How much does the moving average decay each iteration?
optional float moving_average_fraction = 2 [default = .999];
// Small value to add to the variance estimate so that we don't divide by
// zero.
......@@ -684,11 +943,100 @@ message DataParameter {
optional bool mirror = 6 [default = false];
// Force the encoded image to have 3 color channels
optional bool force_encoded_color = 9 [default = false];
// Prefetch queue (Increase if data feeding bandwidth varies, within the
// limit of device memory for GPU training)
// Prefetch queue (Number of batches to prefetch to host memory, increase if
// data access bandwidth varies).
optional uint32 prefetch = 10 [default = 4];
}
// Message that store parameters used by DetectionEvaluateLayer
message DetectionEvaluateParameter {
// Number of classes that are actually predicted. Required!
optional uint32 num_classes = 1;
// Label id for background class. Needed for sanity check so that
// background class is neither in the ground truth nor the detections.
optional uint32 background_label_id = 2 [default = 0];
// Threshold for deciding true/false positive.
optional float overlap_threshold = 3 [default = 0.5];
// If true, also consider difficult ground truth for evaluation.
optional bool evaluate_difficult_gt = 4 [default = true];
// A file which contains a list of names and sizes with same order
// of the input DB. The file is in the following format:
// name height width
// ...
// If provided, we will scale the prediction and ground truth NormalizedBBox
// for evaluation.
optional string name_size_file = 5;
// The resize parameter used in converting NormalizedBBox to original image.
optional ResizeParameter resize_param = 6;
}
message NonMaximumSuppressionParameter {
// Threshold to be used in nms.
optional float nms_threshold = 1 [default = 0.3];
// Maximum number of results to be kept.
optional int32 top_k = 2;
// Parameter for adaptive nms.
optional float eta = 3 [default = 1.0];
}
message SaveOutputParameter {
// Output directory. If not empty, we will save the results.
optional string output_directory = 1;
// Output name prefix.
optional string output_name_prefix = 2;
// Output format.
// VOC - PASCAL VOC output format.
// COCO - MS COCO output format.
optional string output_format = 3;
// If you want to output results, must also provide the following two files.
// Otherwise, we will ignore saving results.
// label map file.
optional string label_map_file = 4;
// A file which contains a list of names and sizes with same order
// of the input DB. The file is in the following format:
// name height width
// ...
optional string name_size_file = 5;
// Number of test images. It can be less than the lines specified in
// name_size_file. For example, when we only want to evaluate on part
// of the test images.
optional uint32 num_test_image = 6;
// The resize parameter used in saving the data.
optional ResizeParameter resize_param = 7;
}
// Message that store parameters used by DetectionOutputLayer
message DetectionOutputParameter {
// Number of classes to be predicted. Required!
optional uint32 num_classes = 1;
// If true, bounding box are shared among different classes.
optional bool share_location = 2 [default = true];
// Background label id. If there is no background class,
// set it as -1.
optional int32 background_label_id = 3 [default = 0];
// Parameters used for non maximum suppression.
optional NonMaximumSuppressionParameter nms_param = 4;
// Parameters used for saving detection results.
optional SaveOutputParameter save_output_param = 5;
// Type of coding method for bbox.
optional PriorBoxParameter.CodeType code_type = 6 [default = CORNER];
// If true, variance is encoded in target; otherwise we need to adjust the
// predicted offset accordingly.
optional bool variance_encoded_in_target = 8 [default = false];
// Number of total bboxes to be kept per image after nms step.
// -1 means keeping all bboxes after nms step.
optional int32 keep_top_k = 7 [default = -1];
// Only consider detections whose confidences are larger than a threshold.
// If not provided, consider all boxes.
optional float confidence_threshold = 9;
// If true, visualize the detection results.
optional bool visualize = 10 [default = false];
// The threshold used to visualize the detection results.
optional float visualize_threshold = 11;
// If provided, save outputs to video file.
optional string save_file = 12;
}
message DropoutParameter {
optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio
}
......@@ -832,7 +1180,6 @@ message ImageDataParameter {
message InfogainLossParameter {
// Specify the infogain matrix source.
optional string source = 1;
optional int32 axis = 2 [default = 1]; // axis of prob
}
message InnerProductParameter {
......@@ -896,6 +1243,78 @@ message MemoryDataParameter {
optional uint32 width = 4;
}
// Message that store parameters used by MultiBoxLossLayer
message MultiBoxLossParameter {
// Localization loss type.
enum LocLossType {
L2 = 0;
SMOOTH_L1 = 1;
}
optional LocLossType loc_loss_type = 1 [default = SMOOTH_L1];
// Confidence loss type.
enum ConfLossType {
SOFTMAX = 0;
LOGISTIC = 1;
}
optional ConfLossType conf_loss_type = 2 [default = SOFTMAX];
// Weight for localization loss.
optional float loc_weight = 3 [default = 1.0];
// Number of classes to be predicted. Required!
optional uint32 num_classes = 4;
// If true, bounding box are shared among different classes.
optional bool share_location = 5 [default = true];
// Matching method during training.
enum MatchType {
BIPARTITE = 0;
PER_PREDICTION = 1;
}
optional MatchType match_type = 6 [default = PER_PREDICTION];
// If match_type is PER_PREDICTION, use overlap_threshold to
// determine the extra matching bboxes.
optional float overlap_threshold = 7 [default = 0.5];
// Use prior for matching.
optional bool use_prior_for_matching = 8 [default = true];
// Background label id.
optional uint32 background_label_id = 9 [default = 0];
// If true, also consider difficult ground truth.
optional bool use_difficult_gt = 10 [default = true];
// If true, perform negative mining.
// DEPRECATED: use mining_type instead.
optional bool do_neg_mining = 11;
// The negative/positive ratio.
optional float neg_pos_ratio = 12 [default = 3.0];
// The negative overlap upperbound for the unmatched predictions.
optional float neg_overlap = 13 [default = 0.5];
// Type of coding method for bbox.
optional PriorBoxParameter.CodeType code_type = 14 [default = CORNER];
// If true, encode the variance of prior box in the loc loss target instead of
// in bbox.
optional bool encode_variance_in_target = 16 [default = false];
// If true, map all object classes to agnostic class. It is useful for learning
// objectness detector.
optional bool map_object_to_agnostic = 17 [default = false];
// If true, ignore cross boundary bbox during matching.
// Cross boundary bbox is a bbox who is outside of the image region.
optional bool ignore_cross_boundary_bbox = 18 [default = false];
// If true, only backpropagate on corners which are inside of the image
// region when encode_type is CORNER or CORNER_SIZE.
optional bool bp_inside = 19 [default = false];
// Mining type during training.
// NONE : use all negatives.
// MAX_NEGATIVE : select negatives based on the score.
// HARD_EXAMPLE : select hard examples based on "Training Region-based Object Detectors with Online Hard Example Mining", Shrivastava et.al.
enum MiningType {
NONE = 0;
MAX_NEGATIVE = 1;
HARD_EXAMPLE = 2;
}
optional MiningType mining_type = 20 [default = MAX_NEGATIVE];
// Parameters used for non maximum suppression durig hard example mining.
optional NonMaximumSuppressionParameter nms_param = 21;
optional int32 sample_size = 22 [default = 64];
optional bool use_prior_for_nms = 23 [default = false];
}
message MVNParameter {
// This parameter can be set to false to normalize mean only
optional bool normalize_variance = 1 [default = true];
......@@ -907,10 +1326,28 @@ message MVNParameter {
optional float eps = 3 [default = 1e-9];
}
// Message that stores parameters used by NormalizeLayer
message NormalizeParameter {
optional bool across_spatial = 1 [default = true];
// Initial value of scale. Default is 1.0 for all
optional FillerParameter scale_filler = 2;
// Whether or not scale parameters are shared across channels.
optional bool channel_shared = 3 [default = true];
// Epsilon for not dividing by zero while normalizing variance
optional float eps = 4 [default = 1e-10];
}
message ParameterParameter {
optional BlobShape shape = 1;
}
message PermuteParameter {
// The new orders of the axes of data. Notice it should be with
// in the same range as the input data, and it starts from 0.
// Do not provide repeated order.
repeated uint32 order = 1;
}
message PoolingParameter {
enum PoolMethod {
MAX = 0;
......@@ -947,17 +1384,46 @@ message PowerParameter {
optional float shift = 3 [default = 0.0];
}
// Message that stores parameters used by ProposalLayer
message ProposalParameter {
optional uint32 feat_stride = 1 [default = 16];
repeated uint32 scales = 2;
repeated float ratios = 3;
}
message PSROIAlignParameter {
required float spatial_scale = 1;
required int32 output_dim = 2; // output channel number
required int32 group_size = 3; // number of groups to encode position-sensitive score maps
// Message that store parameters used by PriorBoxLayer
message PriorBoxParameter {
// Encode/decode type.
enum CodeType {
CORNER = 1;
CENTER_SIZE = 2;
CORNER_SIZE = 3;
}
// Minimum box size (in pixels). Required!
repeated float min_size = 1;
// Maximum box size (in pixels). Required!
repeated float max_size = 2;
// Various of aspect ratios. Duplicate ratios will be ignored.
// If none is provided, we use default ratio 1.
repeated float aspect_ratio = 3;
// If true, will flip each aspect ratio.
// For example, if there is aspect ratio "r",
// we will generate aspect ratio "1.0/r" as well.
optional bool flip = 4 [default = true];
// If true, will clip the prior so that it is within [0, 1]
optional bool clip = 5 [default = false];
// Variance for adjusting the prior bboxes.
repeated float variance = 6;
// By default, we calculate img_height, img_width, step_x, step_y based on
// bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely
// provided.
// Explicitly provide the img_size.
optional uint32 img_size = 7;
// Either img_size or img_h/img_w should be specified; not both.
optional uint32 img_h = 8;
optional uint32 img_w = 9;
// Explicitly provide the step size.
optional float step = 10;
// Either step or step_h/step_w should be specified; not both.
optional float step_h = 11;
optional float step_w = 12;
// Offset to the top left corner of each cell.
optional float offset = 13 [default = 0.5];
}
message PythonParameter {
......@@ -968,7 +1434,9 @@ message PythonParameter {
// string, dictionary in Python dict format, JSON, etc. You may parse this
// string in `setup` method and use it in `forward` and `backward`.
optional string param_str = 3 [default = ''];
// DEPRECATED
// Whether this PythonLayer is shared among worker solvers during data parallelism.
// If true, each worker solver sequentially run forward from this layer.
// This value should be set true if you are using it as a data layer.
optional bool share_in_parallel = 4 [default = false];
}
......@@ -1195,6 +1663,18 @@ message ThresholdParameter {
optional float threshold = 1 [default = 0]; // Strictly positive values
}
message VideoDataParameter{
enum VideoType {
WEBCAM = 0;
VIDEO = 1;
}
optional VideoType video_type = 1 [default = WEBCAM];
optional int32 device_id = 2 [default = 0];
optional string video_file = 3;
// Number of frames to be skipped before processing a frame.
optional uint32 skip_frames = 4 [default = 0];
}
message WindowDataParameter {
// Specify the data source.
optional string source = 1;
......@@ -1437,7 +1917,7 @@ message PReLUParameter {
// Initial value of a_i. Default is a_i=0.25 for all i.
optional FillerParameter filler = 1;
// Whether or not slope parameters are shared across channels.
// Whether or not slope paramters are shared across channels.
optional bool channel_shared = 2 [default = false];
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册