提交 c180e0b4 编写于 作者: L lichao18

Add SSD box predictor module

上级 05f3dc93
......@@ -45,6 +45,7 @@ extern void RegisterMatMul(OpRegistryBase *op_registry);
extern void RegisterPad(OpRegistryBase *op_registry);
extern void RegisterPooling(OpRegistryBase *op_registry);
extern void RegisterReduce(OpRegistryBase *op_registry);
extern void RegisterPriorBox(OpRegistryBase *op_registry);
extern void RegisterReshape(OpRegistryBase *op_registry);
extern void RegisterResizeBicubic(OpRegistryBase *op_registry);
extern void RegisterResizeBilinear(OpRegistryBase *op_registry);
......@@ -103,6 +104,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterPad(this);
ops::RegisterPooling(this);
ops::RegisterReduce(this);
ops::RegisterPriorBox(this);
ops::RegisterReshape(this);
ops::RegisterResizeBicubic(this);
ops::RegisterResizeBilinear(this);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template<DeviceType D, typename T>
class PriorBoxOp : public Operation {
public:
explicit PriorBoxOp(OpConstructContext *context)
: Operation(context),
min_size_(Operation::GetRepeatedArgs<float>("min_size")),
max_size_(Operation::GetRepeatedArgs<float>("max_size")),
aspect_ratio_(Operation::GetRepeatedArgs<float>("aspect_ratio")),
flip_(Operation::GetOptionalArg<bool>("flip", true)),
clip_(Operation::GetOptionalArg<bool>("clip", false)),
variance_(Operation::GetRepeatedArgs<float>("variance")),
offset_(Operation::GetOptionalArg<float>("offset", 0.5)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(INPUT);
const Tensor *data = this->Input(DATA);
Tensor *output = this->Output(OUTPUT);
const std::vector<index_t> &input_shape = input->shape();
const std::vector<index_t> &data_shape = data->shape();
const index_t input_w = input_shape[3];
const index_t input_h = input_shape[2];
const index_t image_w = data_shape[3];
const index_t image_h = data_shape[2];
float step_h = static_cast<float>(image_h) / static_cast<float>(input_h);
float step_w = static_cast<float>(image_w) / static_cast<float>(input_w);
if (Operation::GetOptionalArg<float>("step_h", 0) != 0 &&
Operation::GetOptionalArg<float>("step_w", 0) != 0) {
step_h = Operation::GetOptionalArg<float>("step_h", 0);
step_w = Operation::GetOptionalArg<float>("step_w", 0);
}
const index_t num_min_size = min_size_.size();
MACE_CHECK(num_min_size > 0, "min_size is required!");
const index_t num_max_size = max_size_.size();
const index_t num_aspect_ratio = aspect_ratio_.size();
MACE_CHECK(num_aspect_ratio > 0, "aspect_ratio is required!");
index_t num_prior = num_min_size * num_aspect_ratio +
num_min_size + num_max_size;
if (flip_)
num_prior += num_min_size * num_aspect_ratio;
index_t dim = 4 * input_w * input_h * num_prior;
std::vector<index_t> output_shape = {1, 2, dim};
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard output_guard(output);
T *output_data = output->mutable_data<T>();
float box_w, box_h;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < input_h; ++i) {
for (index_t j = 0; j < input_w; ++j) {
index_t idx = i * input_w * num_prior * 4;
float center_y = (offset_ + i) * step_h;
float center_x = (offset_ + j) * step_w;
for (index_t k = 0; k < num_min_size; ++k) {
float min_s = min_size_[k];
box_w = box_h = min_s * 0.5;
output_data[idx + 0] = (center_x - box_w) / image_w;
output_data[idx + 1] = (center_y - box_h) / image_h;
output_data[idx + 2] = (center_x + box_w) / image_w;
output_data[idx + 3] = (center_y + box_h) / image_h;
idx += 4;
if (num_max_size > 0) {
float max_s_ = max_size_[k];
box_w = box_h = sqrt(max_s_ * min_s) * 0.5f;
output_data[idx + 0] = (center_x - box_w) / image_w;
output_data[idx + 1] = (center_y - box_h) / image_h;
output_data[idx + 2] = (center_x + box_w) / image_w;
output_data[idx + 3] = (center_y + box_h) / image_h;
idx += 4;
}
for (int l = 0; l < num_aspect_ratio; ++l) {
float ar = aspect_ratio_[l];
box_w = min_s * sqrt(ar) * 0.5f;
box_h = min_s / sqrt(ar) * 0.5f;
output_data[idx + 0] = (center_x - box_w) / image_w;
output_data[idx + 1] = (center_y - box_h) / image_h;
output_data[idx + 2] = (center_x + box_w) / image_w;
output_data[idx + 3] = (center_y + box_h) / image_h;
idx += 4;
if (flip_) {
output_data[idx + 0] = (center_x - box_h) / image_w;
output_data[idx + 1] = (center_y - box_w) / image_h;
output_data[idx + 2] = (center_x + box_h) / image_w;
output_data[idx + 3] = (center_y + box_w) / image_h;
idx += 4;
}
}
}
}
}
if (clip_) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < dim; ++i) {
T min = 0;
T max = 1;
output_data[i] = std::min(std::max(output_data[i], min), max);
}
}
output_data += dim;
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < dim / 4; ++i) {
int index = i * 4;
output_data[0 + index] = variance_[0];
output_data[1 + index] = variance_[1];
output_data[2 + index] = variance_[2];
output_data[3 + index] = variance_[3];
}
return MaceStatus::MACE_SUCCESS;
}
private:
std::vector<float> min_size_;
std::vector<float> max_size_;
std::vector<float> aspect_ratio_;
bool flip_;
bool clip_;
std::vector<float> variance_;
const float offset_;
private:
MACE_OP_INPUT_TAGS(INPUT, DATA);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
void RegisterPriorBox(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "PriorBox", PriorBoxOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void PriorBox(
int iters, float min_size, float max_size, float aspect_ratio, int flip,
int clip, float variance0, float variance1, float offset, int h) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("INPUT", {1, h, 1, 1});
net.AddRandomInput<D, float>("DATA", {1, 3, 300, 300});
OpDefBuilder("PriorBox", "PriorBoxBM")
.Input("INPUT")
.Input("DATA")
.Output("OUTPUT")
.AddFloatsArg("min_size", {min_size})
.AddFloatsArg("max_size", {max_size})
.AddFloatsArg("aspect_ratio", {aspect_ratio})
.AddIntArg("flip", flip)
.AddIntArg("clip", clip)
.AddFloatsArg("variance", {variance0, variance0, variance1, variance1})
.AddFloatArg("offset", offset)
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
const int64_t tot = static_cast<int64_t>(iters) * (300 * 300 * 3 + h);
mace::testing::MaccProcessed(tot);
testing::BytesProcessed(tot * sizeof(T));
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
}
} // namespace
#define MACE_BM_PRIOR_BOX(MIN, MAX, AR, FLIP, CLIP, V0, V1, OFFSET, H) \
static void MACE_BM_PRIOR_BOX_##MIN##_##MAX##_##AR##_##FLIP##_##CLIP##_##V0##\
_##V1##_##OFFSET##_##H(int iters) { \
PriorBox<DeviceType::CPU, float>(iters, MIN, MAX, AR, FLIP, CLIP, V0, V1, \
OFFSET, H); \
} \
MACE_BENCHMARK(MACE_BM_PRIOR_BOX_##MIN##_##MAX##_##AR##_##FLIP##_##CLIP##_## \
V0##_##V1##_##OFFSET##_##H)
MACE_BM_PRIOR_BOX(285, 300, 2, 1, 0, 1, 2, 1, 128);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class PriorBoxOpTest : public OpsTestBase {};
TEST_F(PriorBoxOpTest, Simple) {
OpsTestNet net;
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("INPUT", {1, 128, 1, 1});
net.AddRandomInput<DeviceType::CPU, float>("DATA", {1, 3, 300, 300});
OpDefBuilder("PriorBox", "PriorBoxTest")
.Input("INPUT")
.Input("DATA")
.Output("OUTPUT")
.AddFloatsArg("min_size", {285})
.AddFloatsArg("max_size", {300})
.AddFloatsArg("aspect_ratio", {2, 3})
.AddIntArg("flip", 1)
.AddIntArg("clip", 0)
.AddFloatsArg("variance", {0.1, 0.1, 0.2, 0.2})
.AddFloatArg("offset", 0.5)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(DeviceType::CPU);
// Check
auto expected_tensor = net.CreateTensor<float>({1, 2, 24},
{0.025, 0.025, 0.975, 0.975,
0.012660282759551838, 0.012660282759551838,
0.9873397172404482, 0.9873397172404482,
-0.17175144212722018, 0.16412427893638995,
1.1717514421272204, 0.8358757210636101,
0.16412427893638995, -0.17175144212722018,
0.8358757210636101, 1.1717514421272204,
-0.3227241335952166, 0.22575862213492773,
1.3227241335952165, 0.7742413778650723,
0.22575862213492773, -0.3227241335952166,
0.7742413778650723, 1.3227241335952165,
0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2,
0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2});
ExpectTensorNear<float>(*expected_tensor, *net.GetTensor("OUTPUT"));
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -28,6 +28,11 @@ class ReshapeOp : public Operation {
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(INPUT);
const std::vector<index_t> &input_shape = input->shape();
int axis = Operation::GetOptionalArg<int>("reshape_axis", 0);
int num_axes = Operation::GetOptionalArg<int>("num_axes", -1);
MACE_CHECK(axis == 0 && num_axes == -1,
"Only support axis = 0 and num_axes = -1");
const Tensor *shape = this->Input(SHAPE);
const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0);
Tensor::MappingGuard shape_guard(shape);
......@@ -43,8 +48,12 @@ class ReshapeOp : public Operation {
MACE_CHECK(unknown_idx == -1, "Only one input size may be -1");
unknown_idx = i;
out_shape.push_back(1);
} else if (shape_data[i] == 0) {
MACE_CHECK(shape_data[i] == 0, "Shape should be 0");
out_shape.push_back(input_shape[i]);
product *= input_shape[i];
} else {
MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ",
MACE_CHECK(shape_data[i] > 0, "Shape must be non-negative: ",
shape_data[i]);
if (shape_data[i] == 0) {
MACE_CHECK(i < input->dim_size(),
......
......@@ -92,9 +92,17 @@ class SoftmaxOp<DeviceType::CPU, float> : public Operation {
}
} // k
} // b
} else if (input->dim_size() == 2) { // normal 2d softmax
const index_t class_size = input->dim(0);
const index_t class_count = input->dim(1);
} else if (input->dim_size() == 2 || input->dim_size() == 3) {
// normal 2d softmax and 3d softmax (dim(0) is batch)
index_t class_size = 0;
index_t class_count = 0;
if (input->dim_size() == 2) {
class_size = input->dim(0);
class_count = input->dim(1);
} else {
class_size = input->dim(0) * input->dim(1);
class_count = input->dim(2);
}
#pragma omp parallel for schedule(runtime)
for (index_t k = 0; k < class_size; ++k) {
const float *input_ptr = input_data + k * class_count;
......
......@@ -123,6 +123,7 @@ MaceSupportedOps = [
'MatMul',
'Pad',
'Pooling',
'PriorBox',
'Proposal',
'Quantize',
'Reduce',
......@@ -174,8 +175,11 @@ class MaceKeyword(object):
mace_space_batch_block_shape_str = 'block_shape'
mace_space_depth_block_size_str = 'block_size'
mace_constant_value_str = 'constant_value'
mace_dim_str = 'dim'
mace_dims_str = 'dims'
mace_axis_str = 'axis'
mace_end_axis_str = 'end_axis'
mace_num_axes_str = 'num_axes'
mace_num_split_str = 'num_split'
mace_keepdims_str = 'keepdims'
mace_shape_str = 'shape'
......@@ -205,6 +209,14 @@ class MaceKeyword(object):
mace_reduce_type_str = 'reduce_type'
mace_argmin_str = 'argmin'
mace_round_mode_str = 'round_mode'
mace_min_size_str = 'min_size'
mace_max_size_str = 'max_size'
mace_aspect_ratio_str = 'aspect_ratio'
mace_flip_str = 'flip'
mace_clip_str = 'clip'
mace_variance_str = 'variance'
mace_step_h_str = 'step_h'
mace_step_w_str = 'step_w'
class TransformerRule(Enum):
......@@ -243,6 +255,7 @@ class TransformerRule(Enum):
FOLD_SQRDIFF_MEAN = 33
TRANSPOSE_MATMUL_WEIGHT = 34
FOLD_EMBEDDING_LOOKUP = 35
TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN = 36
class ConverterInterface(object):
......@@ -433,6 +446,7 @@ class ConverterOption(object):
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
TransformerRule.TRANSFORM_BASIC_LSTMCELL,
TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN,
TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM,
......
......@@ -188,6 +188,10 @@ class CaffeConverter(base_converter.ConverterInterface):
'Crop': self.convert_crop,
'Scale': self.convert_scale,
'ShuffleChannel': self.convert_channel_shuffle,
'Permute': self.convert_permute,
'Flatten': self.convert_flatten,
'PriorBox': self.convert_prior_box,
'Reshape': self.convert_reshape,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -565,8 +569,6 @@ class CaffeConverter(base_converter.ConverterInterface):
axis_arg.i = param.axis
elif param.HasField('concat_dim'):
axis_arg.i = param.concat_dim
axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i
mace_check(axis_arg.i == 1, "only support concat at channel dimension")
def convert_slice(self, caffe_op):
op = self.convert_general_op(caffe_op)
......@@ -668,3 +670,107 @@ class CaffeConverter(base_converter.ConverterInterface):
group_arg.i = 1
if param.HasField('group'):
group_arg.i = param.group
def convert_permute(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.permute_param
op.type = MaceOp.Transpose.name
dims_arg = op.arg.add()
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend(list(param.order))
def convert_flatten(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.flatten_param
op.type = MaceOp.Reshape.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 1
if param.HasField('axis'):
axis_arg.i = param.axis
axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i
end_axis_arg = op.arg.add()
end_axis_arg.name = MaceKeyword.mace_end_axis_str
end_axis_arg.i = -1
if param.HasField('end_axis'):
end_axis_arg.i = param.end_axis
def convert_prior_box(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.prior_box_param
op.type = MaceOp.PriorBox.name
min_size_arg = op.arg.add()
min_size_arg.name = MaceKeyword.mace_min_size_str
min_size_arg.floats.extend(list(param.min_size))
max_size_arg = op.arg.add()
max_size_arg.name = MaceKeyword.mace_max_size_str
max_size_arg.floats.extend(list(param.max_size))
aspect_ratio_arg = op.arg.add()
aspect_ratio_arg.name = MaceKeyword.mace_aspect_ratio_str
aspect_ratio_arg.floats.extend(list(param.aspect_ratio))
flip_arg = op.arg.add()
flip_arg.name = MaceKeyword.mace_flip_str
flip_arg.i = 1
if param.HasField('flip'):
flip_arg.i = int(param.flip)
clip_arg = op.arg.add()
clip_arg.name = MaceKeyword.mace_clip_str
clip_arg.i = 0
if param.HasField('clip'):
clip_arg.i = int(param.clip)
variance_arg = op.arg.add()
variance_arg.name = MaceKeyword.mace_variance_str
variance_arg.floats.extend(list(param.variance))
offset_arg = op.arg.add()
offset_arg.name = MaceKeyword.mace_offset_str
offset_arg.f = 0.5
if param.HasField('offset'):
offset_arg.f = param.offset
step_h_arg = op.arg.add()
step_h_arg.name = MaceKeyword.mace_step_h_str
step_h_arg.f = 0
if param.HasField('step_h'):
mace_check(not param.HasField('step'),
"Either step or step_h/step_w should be specified; not both.") # noqa
step_h_arg.f = param.step_h
mace_check(step_h_arg.f > 0, "step_h should be larger than 0.")
step_w_arg = op.arg.add()
step_w_arg.name = MaceKeyword.mace_step_w_str
step_w_arg.f = 0
if param.HasField('step_w'):
mace_check(not param.HasField('step'),
"Either step or step_h/step_w should be specified; not both.") # noqa
step_w_arg.f = param.step_w
mace_check(step_w_arg.f > 0, "step_w should be larger than 0.")
if param.HasField('step'):
mace_check(not param.HasField('step_h') and not param.HasField('step_w'), # noqa
"Either step or step_h/step_w should be specified; not both.") # noqa
mace_check(param.step > 0, "step should be larger than 0.")
step_h_arg.f = param.step
step_w_arg.f = param.step
def convert_reshape(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.reshape_param
op.type = MaceOp.Reshape.name
dim_arg = op.arg.add()
dim_arg.name = MaceKeyword.mace_dim_str
dim_arg.ints.extend(list(param.shape.dim))
axis_arg = op.arg.add()
axis_arg.name = 'reshape_' + MaceKeyword.mace_axis_str
axis_arg.i = 0
if param.HasField('axis'):
axis_arg.i = param.axis
num_axes_arg = op.arg.add()
num_axes_arg.name = MaceKeyword.mace_num_axes_str
num_axes_arg.i = -1
if param.HasField('num_axes'):
num_axes_arg.i = param.num_axes
......@@ -49,6 +49,9 @@ class ShapeInference(object):
MaceOp.Crop.name: self.infer_shape_crop,
MaceOp.BiasAdd.name: self.infer_shape_general,
MaceOp.ChannelShuffle.name: self.infer_shape_channel_shuffle,
MaceOp.Transpose.name: self.infer_shape_permute,
MaceOp.PriorBox.name: self.infer_shape_prior_box,
MaceOp.Reshape.name: self.infer_shape_reshape,
}
self._net = net
......@@ -190,12 +193,14 @@ class ShapeInference(object):
self.add_output_shape(op, [output_shape])
def infer_shape_concat(self, op):
output_shape = self._output_shape_cache[op.input[0]]
output_shape = list(self._output_shape_cache[op.input[0]])
axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i
if axis < 0:
axis = len(output_shape) + axis
output_shape[axis] = 0
for input_node in op.input:
input_shape = self._output_shape_cache[input_node]
output_shape[axis] += input_shape[axis]
input_shape = list(self._output_shape_cache[input_node])
output_shape[axis] = output_shape[axis] + input_shape[axis]
self.add_output_shape(op, [output_shape])
def infer_shape_slice(self, op):
......@@ -225,3 +230,64 @@ class ShapeInference(object):
def infer_shape_channel_shuffle(self, op):
output_shape = self._output_shape_cache[op.input[0]]
self.add_output_shape(op, [output_shape])
def infer_shape_permute(self, op):
output_shape = list(self._output_shape_cache[op.input[0]])
dims = ConverterUtil.get_arg(op, MaceKeyword.mace_dims_str).ints
for i in xrange(len(dims)):
output_shape[i] = self._output_shape_cache[op.input[0]][dims[i]]
self.add_output_shape(op, [output_shape])
def infer_shape_prior_box(self, op):
output_shape = [1, 2, 1]
input_shape = list(self._output_shape_cache[op.input[0]])
input_w = input_shape[3]
input_h = input_shape[2]
min_size = ConverterUtil.get_arg(op, MaceKeyword.mace_min_size_str).floats # noqa
max_size = ConverterUtil.get_arg(op, MaceKeyword.mace_max_size_str).floats # noqa
aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa
flip = ConverterUtil.get_arg(op, MaceKeyword.mace_flip_str).i # noqa
num_prior = (len(min_size) * len(aspect_ratio) +
len(min_size) + len(max_size))
if flip:
num_prior = num_prior + len(min_size) * len(aspect_ratio)
output_shape[2] = num_prior * input_h * input_w * 4
self.add_output_shape(op, [output_shape])
def infer_shape_reshape(self, op):
if ConverterUtil.get_arg(op, MaceKeyword.mace_dim_str) is not None:
dim = ConverterUtil.get_arg(op, MaceKeyword.mace_dim_str).ints
output_shape = list(dim)
product = input_size = 1
idx = -1
for i in range(len(self._output_shape_cache[op.input[0]])):
input_size *= self._output_shape_cache[op.input[0]][i]
for i in range(len(dim)):
if dim[i] == 0:
output_shape[i] = self._output_shape_cache[op.input[0]][i]
product *= self._output_shape_cache[op.input[0]][i]
elif dim[i] == -1:
idx = i
output_shape[i] = 1
else:
output_shape[i] = dim[i]
product *= dim[i]
if idx != -1:
output_shape[idx] = input_size / product
self.add_output_shape(op, [output_shape])
else:
output_shape = list(self._output_shape_cache[op.input[0]])
axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i
end_axis = ConverterUtil.get_arg(op, MaceKeyword.mace_end_axis_str).i # noqa
if end_axis < 0:
end_axis = len(output_shape) + end_axis
dim = 1
for i in range(0, axis):
output_shape[i] = self._output_shape_cache[op.input[0]][i]
for i in range(axis, end_axis + 1):
dim *= self._output_shape_cache[op.input[0]][i]
output_shape[i] = 1
for i in range(end_axis + 1, len(output_shape)):
output_shape[i] = self._output_shape_cache[op.input[0]][i]
output_shape[axis] = dim
self.add_output_shape(op, [output_shape])
......@@ -95,6 +95,8 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution,
TransformerRule.CHECK_QUANTIZE_INFO:
self.check_quantize_info,
TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN:
self.transform_caffe_reshape_and_flatten,
}
self._option = option
......@@ -979,8 +981,9 @@ class Transformer(base_converter.ConverterInterface):
elif op.type == MaceOp.Concat.name or op.type == MaceOp.Split.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if ConverterUtil.data_format(op) == DataFormat.NCHW \
and self._target_data_format == DataFormat.NHWC: # noqa
if (ConverterUtil.data_format(op) == DataFormat.NCHW
and self._target_data_format == DataFormat.NHWC
and len(op.output_shape[0].dims) == 4):
print("Transpose concat/split args: %s(%s)"
% (op.name, op.type))
if arg.i == 1:
......@@ -1231,6 +1234,7 @@ class Transformer(base_converter.ConverterInterface):
# transform input(4D) -> reshape(2D) -> matmul to fc
# work for TensorFlow
if op.type == MaceOp.Reshape.name and \
len(op.input) == 2 and \
op.input[1] in self._consts and \
len(op.output_shape[0].dims) == 2 and \
filter_format == FilterFormat.HWIO and \
......@@ -1714,3 +1718,35 @@ class Transformer(base_converter.ConverterInterface):
output_info.name = check_node.name
output_info.dims.extend(check_node.output_shape[0].dims)
output_info.data_type = mace_pb2.DT_FLOAT
def transform_caffe_reshape_and_flatten(self):
net = self._model
for op in net.op:
if op.type == MaceOp.Reshape.name and \
len(op.input) == 1:
print("Transform Caffe Reshape")
if op.arg[3].name == 'dim':
shape_tensor = net.tensors.add()
shape_tensor.name = op.name + '_shape'
shape_tensor.dims.append(len(op.output_shape[0].dims))
shape_tensor.data_type = mace_pb2.DT_INT32
shape_tensor.int32_data.extend(op.arg[3].ints)
op.input.append(shape_tensor.name)
else:
axis = op.arg[3].i
dims = [1] * len(op.output_shape[0].dims)
end_axis = op.arg[4].i
end_axis = end_axis if end_axis >= 0 else end_axis + len(dims) # noqa
for i in range(0, axis):
dims[i] = 0
for i in range(axis + 1, end_axis + 1):
dims[i] = 1
for i in range(end_axis + 1, len(dims)):
dims[i] = 0
dims[axis] = -1
shape_tensor = net.tensors.add()
shape_tensor.name = op.name + '_shape'
shape_tensor.dims.append(len(dims))
shape_tensor.data_type = mace_pb2.DT_INT32
shape_tensor.int32_data.extend(dims)
op.input.append(shape_tensor.name)
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cmath>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "mace/utils/logging.h"
namespace mace {
struct BBox {
float xmin;
float ymin;
float xmax;
float ymax;
int label;
float confidence;
};
namespace {
inline float overlap(const BBox &a, const BBox &b) {
if (a.xmin > b.xmax || a.xmax < b.xmin ||
a.ymin > b.ymax || a.ymax < b.ymin) {
return 0.f;
}
float overlap_w = std::min(a.xmax, b.xmax) - std::max(a.xmin, b.xmin);
float overlap_h = std::min(a.ymax, b.ymax) - std::max(a.ymin, b.ymin);
return overlap_w * overlap_h;
}
void NmsSortedBboxes(const std::vector<BBox> &bboxes,
const float nms_threshold,
const int top_k,
std::vector<BBox> *sorted_boxes) {
const int n = std::min(top_k, static_cast<int>(bboxes.size()));
std::vector<int> picked;
std::vector<float> areas(n);
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < n; ++i) {
const BBox &r = bboxes[i];
float width = std::max(0.f, r.xmax - r.xmin);
float height = std::max(0.f, r.ymax - r.ymin);
areas[i] = width * height;
}
for (int i = 0; i < n; ++i) {
const BBox &a = bboxes[i];
int keep = 1;
for (size_t j = 0; j < picked.size(); ++j) {
const BBox &b = bboxes[picked[j]];
float inter_area = overlap(a, b);
float union_area = areas[i] + areas[picked[j]] - inter_area;
MACE_CHECK(union_area > 0, "union_area should be greater than 0");
if (inter_area / union_area > nms_threshold) {
keep = 0;
break;
}
}
if (keep) {
picked.push_back(i);
sorted_boxes->push_back(bboxes[i]);
}
}
}
inline bool cmp(const BBox &a, const BBox &b) {
return a.confidence > b.confidence;
}
} // namespace
int DetectionOutput(const float *loc_ptr,
const float *conf_ptr,
const float *pbox_ptr,
const int num_prior,
const int num_classes,
const float nms_threshold,
const int top_k,
const int keep_top_k,
const float confidence_threshold,
std::vector<BBox> *bbox_rects) {
MACE_CHECK(keep_top_k > 0, "keep_top_k should be greater than 0");
std::vector<float> bboxes(4 * num_prior);
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < num_prior; ++i) {
int index = i * 4;
const float *lc = loc_ptr + index;
const float *pb = pbox_ptr + index;
const float *var = pb + num_prior * 4;
float pb_w = pb[2] - pb[0];
float pb_h = pb[3] - pb[1];
float pb_cx = (pb[0] + pb[2]) * 0.5f;
float pb_cy = (pb[1] + pb[3]) * 0.5f;
float bbox_cx = var[0] * lc[0] * pb_w + pb_cx;
float bbox_cy = var[1] * lc[1] * pb_h + pb_cy;
float bbox_w = std::exp(var[2] * lc[2]) * pb_w;
float bbox_h = std::exp(var[3] * lc[3]) * pb_h;
bboxes[0 + index] = bbox_cx - bbox_w * 0.5f;
bboxes[1 + index] = bbox_cy - bbox_h * 0.5f;
bboxes[2 + index] = bbox_cx + bbox_w * 0.5f;
bboxes[3 + index] = bbox_cy + bbox_h * 0.5f;
}
// start from 1 to ignore background class
for (int i = 1; i < num_classes; ++i) {
// filter by confidence threshold
std::vector<BBox> class_bbox_rects;
for (int j = 0; j < num_prior; ++j) {
float confidence = conf_ptr[j * num_classes + i];
if (confidence > confidence_threshold) {
BBox c = {bboxes[0 + j * 4], bboxes[1 + j * 4], bboxes[2 + j * 4],
bboxes[3 + j * 4], i, confidence};
class_bbox_rects.push_back(c);
}
}
std::sort(class_bbox_rects.begin(), class_bbox_rects.end(), cmp);
// apply nms
std::vector<BBox> sorted_boxes;
NmsSortedBboxes(class_bbox_rects,
nms_threshold,
std::min(top_k,
static_cast<int>(class_bbox_rects.size())),
&sorted_boxes);
// gather
bbox_rects->insert(bbox_rects->end(), sorted_boxes.begin(),
sorted_boxes.end());
}
std::sort(bbox_rects->begin(), bbox_rects->end(), cmp);
// output
int num_detected = keep_top_k < static_cast<int>(bbox_rects->size()) ?
keep_top_k : static_cast<int>(bbox_rects->size());
bbox_rects->resize(num_detected);
return num_detected;
}
} // namespace mace
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册