提交 5fc5fbb2 编写于 作者: 李寅

Merge branch 'feature_wuch' into 'master'

fix caffe converter for neon

See merge request !355
......@@ -89,18 +89,22 @@ void DoActivation(const T *input_ptr,
template <typename T>
void PReLUActivation(const T *input_ptr,
const index_t size,
const index_t outer_size,
const index_t input_chan,
const index_t inner_size,
const T *alpha_ptr,
T *output_ptr) {
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
const index_t chan_idx = i % input_chan;
T in = input_ptr[i];
if (in < 0) {
output_ptr[i] = in * alpha_ptr[chan_idx];
#pragma omp parallel for collapse(3)
for (index_t i = 0; i < outer_size; ++i) {
for (index_t chan_idx = 0; chan_idx < input_chan; ++chan_idx) {
for (index_t j = 0; j < inner_size; ++j) {
index_t idx = i * input_chan * inner_size + chan_idx * inner_size + j;
if (input_ptr[idx] < 0) {
output_ptr[idx] = input_ptr[idx] * alpha_ptr[chan_idx];
} else {
output_ptr[i] = in;
output_ptr[idx] = input_ptr[idx];
}
}
}
}
}
......@@ -120,7 +124,9 @@ class ActivationFunctor {
if (activation_ == PRELU) {
MACE_CHECK_NOTNULL(alpha);
const T *alpha_ptr = alpha->data<T>();
PReLUActivation(input_ptr, output->size(), input->dim(3), alpha_ptr,
const index_t outer_size = output->dim(0) * output->dim(1)
* output->dim(2);
PReLUActivation(input_ptr, outer_size, input->dim(3), 1, alpha_ptr,
output_ptr);
} else {
DoActivation(input_ptr, output_ptr, output->size(), activation_,
......
......@@ -17,7 +17,9 @@ void ActivationFunctor<DeviceType::NEON, float>::operator()(
if (activation_ == PRELU) {
MACE_CHECK_NOTNULL(alpha);
const float *alpha_ptr = alpha->data<float>();
PReLUActivation(input_ptr, output->size(), input->dim(1), alpha_ptr,
const index_t outer_size = output->dim(0);
const index_t inner_size = output->dim(2) * output->dim(3);
PReLUActivation(input_ptr, outer_size, input->dim(1), inner_size, alpha_ptr,
output_ptr);
} else {
DoActivation(input_ptr, output_ptr, output->size(), activation_,
......
......@@ -5,6 +5,7 @@
#ifndef MACE_KERNELS_SLICE_H_
#define MACE_KERNELS_SLICE_H_
#include <functional>
#include <vector>
#include "mace/core/future.h"
......@@ -16,20 +17,34 @@
namespace mace {
namespace kernels {
struct SliceFunctorBase {
explicit SliceFunctorBase(const int32_t axis) : axis_(axis) {}
int32_t axis_;
};
template<DeviceType D, typename T>
struct SliceFunctor {
struct SliceFunctor : SliceFunctorBase {
explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {}
void operator()(const Tensor *input,
const std::vector<Tensor *> &output_list,
StatsFuture *future) {
const index_t outer_size = input->dim(0) * input->dim(1) * input->dim(2);
const index_t input_channels = input->dim(3);
const index_t input_channels = input->dim(axis_);
const size_t outputs_count = output_list.size();
const index_t output_channels = input_channels / outputs_count;
std::vector<T *> output_ptrs(output_list.size(), nullptr);
std::vector<index_t> output_shape(input->shape());
output_shape[axis_] = output_channels;
std::vector<index_t> output_shape({input->dim(0), input->dim(1),
input->dim(2), output_channels});
const index_t outer_size = std::accumulate(output_shape.begin(),
output_shape.begin() + axis_,
1,
std::multiplies<index_t>());
const index_t inner_size = std::accumulate(output_shape.begin() + axis_ + 1,
output_shape.end(),
1,
std::multiplies<index_t>());
for (size_t i= 0; i < outputs_count; ++i) {
output_list[i]->Resize(output_shape);
output_ptrs[i] = output_list[i]->mutable_data<T>();
......@@ -38,25 +53,27 @@ struct SliceFunctor {
#pragma omp parallel for
for (int outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
int input_idx = outer_idx * input_channels;
int output_idx = outer_idx * output_channels;
int input_idx = outer_idx * input_channels * inner_size;
int output_idx = outer_idx * output_channels * inner_size;
for (size_t i = 0; i < outputs_count; ++i) {
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
memcpy(output_ptrs[i]+output_idx, input_ptr+input_idx,
output_channels * sizeof(T));
output_channels * inner_size * sizeof(T));
} else {
for (index_t k = 0; k < output_channels; ++k) {
for (index_t k = 0; k < output_channels * inner_size; ++k) {
*(output_ptrs[i] + output_idx + k) = *(input_ptr + input_idx + k);
}
}
input_idx += output_channels;
input_idx += output_channels * inner_size;
}
}
}
};
template<typename T>
struct SliceFunctor<DeviceType::OPENCL, T> {
struct SliceFunctor<DeviceType::OPENCL, T> : SliceFunctorBase {
explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {}
void operator()(const Tensor *input,
const std::vector<Tensor *> &output_list,
StatsFuture *future);
......
......@@ -249,14 +249,26 @@ void TestSimplePrelu() {
net.RunOp(D);
}
if (D == DeviceType::NEON) {
auto expected = CreateTensor<float>(
{2, 2, 2, 2},
{-14, 7, -12, 6, -15, -15, -12, -12, -6, 3, -4, 2, -3, -3, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
} else {
auto expected = CreateTensor<float>(
{2, 2, 2, 2},
{-14, 7, -12, 6, -10, -15, -8, -12, -6, 3, -4, 2, -2, -3, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
}
TEST_F(ActivationOpTest, CPUSimplePrelu) {
TestSimplePrelu<DeviceType::CPU>();
}
TEST_F(ActivationOpTest, CPUSimplePrelu) { TestSimplePrelu<DeviceType::CPU>(); }
TEST_F(ActivationOpTest, NEONSimplePrelu) {
TestSimplePrelu<DeviceType::NEON>();
}
TEST_F(ActivationOpTest, OPENCLSimplePrelu) {
TestSimplePrelu<DeviceType::OPENCL>();
......
......@@ -24,6 +24,11 @@ void Register_Slice(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
SliceOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
SliceOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -17,14 +17,16 @@ template <DeviceType D, typename T>
class SliceOp : public Operator<D, T> {
public:
SliceOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws) {}
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("axis", 3)) {}
bool Run(StatsFuture *future) override {
MACE_CHECK(this->OutputSize() >= 2)
<< "There must be at least two outputs for slicing";
const Tensor *input = this->Input(INPUT);
const std::vector<Tensor *> output_list = this->Outputs();
MACE_CHECK((input->dim(3) % this->OutputSize()) == 0)
const int32_t slice_axis = OperatorBase::GetSingleArgument<int>("axis", 3);
MACE_CHECK((input->dim(slice_axis) % this->OutputSize()) == 0)
<< "Outputs do not split input equally.";
functor_(input, output_list, future);
......
......@@ -16,7 +16,7 @@ namespace test {
class SliceOpTest : public OpsTestBase {};
template<DeviceType D, typename T>
void RandomTest(const int num_outputs) {
void RandomTest(const int num_outputs, const int axis) {
static unsigned int seed = time(NULL);
const index_t output_channels = 4 * (1 + rand_r(&seed) % 10);
const index_t input_channels = num_outputs * output_channels;
......@@ -27,7 +27,11 @@ void RandomTest(const int num_outputs) {
// Construct graph
OpsTestNet net;
std::vector<index_t> input_shape({batch, height, width, input_channels});
std::vector<index_t> input_shape;
if (axis == 1)
input_shape = {batch, input_channels, height, width};
else if (axis == 3)
input_shape = {batch, height, width, input_channels};
const index_t input_size = std::accumulate(input_shape.begin(),
input_shape.end(),
1,
......@@ -49,7 +53,7 @@ void RandomTest(const int num_outputs) {
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
auto builder = OpDefBuilder("Slice", "SliceTest");
auto builder = OpDefBuilder("Slice", "SliceTest").AddIntArg("axis", axis);
builder.Input("Input");
for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("Output", i));
......@@ -70,9 +74,17 @@ void RandomTest(const int num_outputs) {
}
// Check
std::vector<index_t> expected_shape({batch, height, width, output_channels});
std::vector<index_t> expected_shape;
if (axis == 1)
expected_shape = {batch, output_channels, height, width};
else if (axis == 3)
expected_shape = {batch, height, width, output_channels};
const index_t outer_size = std::accumulate(expected_shape.begin(),
expected_shape.end() - 1,
expected_shape.begin() + axis,
1,
std::multiplies<index_t>());
const index_t inner_size = std::accumulate(expected_shape.begin() + axis + 1,
expected_shape.end(),
1,
std::multiplies<index_t>());
const float *input_ptr = input_data.data();
......@@ -83,8 +95,9 @@ void RandomTest(const int num_outputs) {
Tensor::MappingGuard output_mapper(output);
output_ptr = output->data<float>();
for (int outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
const int idx = outer_idx * input_channels + i * output_channels;
for (int j = 0; j < output_channels; ++j) {
const int idx = (outer_idx * input_channels + i * output_channels)
* inner_size;
for (int j = 0; j < output_channels * inner_size; ++j) {
ASSERT_NEAR(*output_ptr++, input_ptr[idx + j], 1e-2) << "with output "
<< i << " index " << idx + j;
}
......@@ -93,21 +106,27 @@ void RandomTest(const int num_outputs) {
}
TEST_F(SliceOpTest, CPU) {
RandomTest<DeviceType::CPU, float>(2);
RandomTest<DeviceType::CPU, float>(4);
RandomTest<DeviceType::CPU, float>(11);
RandomTest<DeviceType::CPU, float>(2, 3);
RandomTest<DeviceType::CPU, float>(4, 3);
RandomTest<DeviceType::CPU, float>(11, 3);
}
TEST_F(SliceOpTest, CPUAxis1) {
RandomTest<DeviceType::CPU, float>(2, 1);
RandomTest<DeviceType::CPU, float>(4, 1);
RandomTest<DeviceType::CPU, float>(11, 1);
}
TEST_F(SliceOpTest, OPENCLFloat) {
RandomTest<DeviceType::OPENCL, float>(2);
RandomTest<DeviceType::OPENCL, float>(4);
RandomTest<DeviceType::OPENCL, float>(11);
RandomTest<DeviceType::OPENCL, float>(2, 3);
RandomTest<DeviceType::OPENCL, float>(4, 3);
RandomTest<DeviceType::OPENCL, float>(11, 3);
}
TEST_F(SliceOpTest, OPENCLHalf) {
RandomTest<DeviceType::OPENCL, half>(2);
RandomTest<DeviceType::OPENCL, half>(4);
RandomTest<DeviceType::OPENCL, half>(11);
RandomTest<DeviceType::OPENCL, half>(2, 3);
RandomTest<DeviceType::OPENCL, half>(4, 3);
RandomTest<DeviceType::OPENCL, half>(11, 3);
}
} // namespace test
......
......@@ -68,14 +68,26 @@ def BlobToNPArray(blob):
class Shapes(object):
@staticmethod
def conv_pool_shape(input_shape, filter_shape, paddings, strides, dilations, round_func):
def conv_pool_shape(input_shape, filter_shape, paddings, strides, dilations, round_func, input_format='NHWC'):
output_shape = np.zeros_like(input_shape)
output_shape[0] = input_shape[0]
if input_format == 'NHWC':
# input format: NHWC, filter format: HWOI
output_shape[1] = int(round_func((input_shape[1] + paddings[0] - filter_shape[0]
- (filter_shape[0] - 1) * (dilations[0] - 1)) / float(strides[0]))) + 1
output_shape[2] = int(round_func((input_shape[2] + paddings[1] - filter_shape[1]
- (filter_shape[1] - 1) * (dilations[1] - 1)) / float(strides[1]))) + 1
output_shape[3] = filter_shape[2]
elif input_format == 'NCHW':
# input format: NCHW, filter format: OIHW
output_shape[1] = filter_shape[0]
output_shape[2] = int(round_func((input_shape[2] + paddings[0] - filter_shape[2]
- (filter_shape[2] - 1) * (dilations[0] - 1)) / float(strides[0]))) + 1
output_shape[3] = int(round_func((input_shape[3] + paddings[1] - filter_shape[3]
- (filter_shape[3] - 1) * (dilations[1] - 1)) / float(strides[1]))) + 1
else:
raise Exception("format %s is not supported" % input_format)
return output_shape
@staticmethod
......@@ -93,8 +105,13 @@ class Shapes(object):
return output_shape
@staticmethod
def slice_shape(input_shape, num_output):
def slice_shape(input_shape, num_output, input_format='NHWC'):
if input_format == 'NHWC':
return [input_shape[0], input_shape[1], input_shape[2], input_shape[3]/num_output]
elif input_format == 'NCHW':
return [input_shape[0], input_shape[1]/num_output, input_shape[2], input_shape[3]]
else:
raise Exception("format %s is not supported" % input_format)
# outputs' name is [op.name + '_' + #]
class CaffeConverter(object):
......@@ -168,6 +185,9 @@ class CaffeConverter(object):
arg.i = self.dt
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
if self.device == 'neon':
data_format_arg.s = 'NCHW'
else:
data_format_arg.s = 'NHWC'
op_def.name = op.name
op_def.type = mace_type
......@@ -342,6 +362,10 @@ class CaffeConverter(object):
# Add filter
weight_tensor_name = op.name + '_weight:0'
if self.device == 'neon':
weight_data = op.data[0]
else:
# OIHW -> HWOI
weight_data = op.data[0].transpose((2, 3, 0, 1))
self.add_tensor(weight_tensor_name, weight_data)
......@@ -376,10 +400,11 @@ class CaffeConverter(object):
final_op = op
self.resolved_ops.add(op.name)
input_format = 'NCHW' if self.device == 'neon' else 'NHWC'
output_shape = Shapes.conv_pool_shape(op.get_single_parent().output_shape_map[op.layer.bottom[0]],
weight_data.shape,
paddings, strides, dilations,
math.floor)
math.floor, input_format)
op.output_shape_map[op.layer.top[0]] = output_shape
if len(self.ops_map[final_op.name].children) == 1 \
......@@ -399,9 +424,13 @@ class CaffeConverter(object):
self.net_def.op.extend([op_def])
def check_winograd_conv(self, op):
# TODO: support winograd conv on neon
if self.device == 'neon':
return False
param = op.layer.convolution_param
filter_shape = np.asarray(op.data[0].shape)
filter_shape = filter_shape[[2, 3, 0, 1]]
if self.device != 'neon':
filter_shape = filter_shape[[2, 3, 0, 1]] # OIHW -> HWOI
paddings, strides, _ = self.add_stride_pad_kernel_arg(param, None)
dilations = [1, 1]
......@@ -411,17 +440,21 @@ class CaffeConverter(object):
elif len(param.dilation) == 2:
dilations = [param.dilation[0], param.dilation[1]]
input_format = 'NCHW' if self.device == 'neon' else 'NHWC'
output_shape = Shapes.conv_pool_shape(
op.get_single_parent().output_shape_map[op.layer.bottom[0]],
filter_shape, paddings, strides, dilations, math.floor)
filter_shape, paddings, strides, dilations, math.floor, input_format)
width = output_shape[0] * ((output_shape[1] + 1)/2) * ((output_shape[2]+1)/2)
return self.winograd and self.device == 'gpu' and \
filter_shape[0] == 3 and (filter_shape[0] == filter_shape[1]) and \
dilations[0] == 1 and (dilations[0] == dilations[1]) and \
(strides[0] == 1) and (strides[0] == strides[1]) and \
if self.winograd and dilations[0] == 1 and (dilations[0] == dilations[1]) and \
(strides[0] == 1) and (strides[0] == strides[1]):
if self.device == 'gpu':
return filter_shape[0] == 3 and (filter_shape[0] == filter_shape[1]) and \
(16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \
(16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \
(width < OPENCL_IMAGE_MAX_SIZE)
elif self.device == 'neon':
return filter_shape[2] == 3 and (filter_shape[2] == filter_shape[3])
return False
def convert_winograd_conv(self, op):
# Add filter
......@@ -435,11 +468,13 @@ class CaffeConverter(object):
paddings, strides, _ = self.add_stride_pad_kernel_arg(param, None)
filter_shape = np.asarray(op.data[0].shape)
filter_shape = filter_shape[[2, 3, 0, 1]]
if self.device != 'neon':
filter_shape = filter_shape[[2, 3, 0, 1]] # OIHW -> HWOI
input_format = 'NCHW' if self.device == 'neon' else 'NHWC'
output_shape = Shapes.conv_pool_shape(
op.get_single_parent().output_shape_map[op.layer.bottom[0]],
filter_shape, paddings, strides, [1, 1], math.floor)
filter_shape, paddings, strides, [1, 1], math.floor, input_format)
# Input transform
wt_op = mace_pb2.OperatorDef()
......@@ -455,8 +490,12 @@ class CaffeConverter(object):
wt_output_name = wt_op.name + ":0"
wt_op.output.extend([wt_output_name])
wt_output_shape = mace_pb2.OutputShape()
if self.device != 'neon':
wt_output_width = output_shape[0] * ((output_shape[1] + 1)/2) * ((output_shape[2]+1)/2)
wt_output_shape.dims.extend([16, filter_shape[3], wt_output_width, 1])
else:
wt_output_width = output_shape[0] * ((output_shape[2] + 1)/2) * ((output_shape[3]+1)/2)
wt_output_shape.dims.extend([16, filter_shape[1], wt_output_width, 1])
wt_op.output_shape.extend([wt_output_shape])
# MatMul
......@@ -470,7 +509,10 @@ class CaffeConverter(object):
matmul_output_name = matmul_op.name + ":0"
matmul_op.output.extend([matmul_output_name])
matmul_output_shape = mace_pb2.OutputShape()
if self.device != 'neon':
matmul_output_shape.dims.extend([16, filter_shape[2], wt_output_width, 1])
else:
matmul_output_shape.dims.extend([16, filter_shape[0], wt_output_width, 1])
matmul_op.output_shape.extend([matmul_output_shape])
# Inverse transform
......@@ -483,10 +525,10 @@ class CaffeConverter(object):
batch_arg.i = output_shape[0]
height_arg = iwt_op.arg.add()
height_arg.name = 'height'
height_arg.i = output_shape[1]
height_arg.i = output_shape[1] if self.device != 'neon' else output_shape[2]
width_arg = iwt_op.arg.add()
width_arg.name = 'width'
width_arg.i = output_shape[2]
width_arg.i = output_shape[2] if self.device != 'neon' else output_shape[3]
iwt_op.name = op.name + '_inverse_transform'
iwt_op.type = 'WinogradInverseTransform'
iwt_op.input.extend([matmul_output_name])
......@@ -589,6 +631,7 @@ class CaffeConverter(object):
weight_data = op.data[0].reshape(-1, op.data[0].shape[-1])
assert weight_data.shape[1] == (input_shape[1] * input_shape[2] * input_shape[3])
if self.device != 'neon':
weight_data = weight_data.reshape(-1, input_shape[3], input_shape[1], input_shape[2])
weight_data = weight_data.transpose((0, 2, 3, 1)).reshape(weight_data.shape[0], -1)
self.add_tensor(weight_tensor_name, weight_data)
......@@ -665,9 +708,12 @@ class CaffeConverter(object):
kernel_arg.name = 'kernels'
kernel_arg.ints.extend(kernels)
filter_shape = [kernels[0], kernels[1], input_shape[3], input_shape[3]]
filter_shape = [kernels[0], kernels[1], input_shape[3], input_shape[3]] \
if self.device != 'neon' else \
[input_shape[1], input_shape[1], kernels[0], kernels[1]]
input_format = 'NCHW' if self.device == 'neon' else 'NHWC'
output_shape = Shapes.conv_pool_shape(input_shape, filter_shape,
paddings, strides, [1, 1], math.ceil)
paddings, strides, [1, 1], math.ceil, input_format)
op.output_shape_map[op.layer.top[0]] = output_shape
op_def.output.extend([op.name + ':0'])
......@@ -720,7 +766,7 @@ class CaffeConverter(object):
op_def = self.CommonConvert(op, 'Concat')
axis_arg = op_def.arg.add()
axis_arg.name = 'axis'
axis_arg.i = 3
axis_arg.i = 3 if self.device != 'neon' else 1
try:
if op.layer.concat_param.HasFeild('axis'):
axis_arg.i = op.concat_param.axis
......@@ -766,13 +812,19 @@ class CaffeConverter(object):
if len(param.slice_point) > 0:
raise Exception('Mace do not support slice with slice_point')
axis_arg = op_def.arg.add()
axis_arg.name = 'axis'
axis_arg.i = 3 if self.device != 'neon' else 1
input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]]
num_outputs = len(op.layer.top)
if (input_shape[3] % num_outputs) != 0 or \
(self.device == 'gpu' and ((input_shape[3] / num_outputs) % 4 != 0)) :
input_channels = input_shape[axis_arg.i]
if (input_channels % num_outputs) != 0 or \
(self.device == 'gpu' and ((input_channels / num_outputs) % 4 != 0)):
raise Exception('Mace do not support slice with input shape '
+ str(input_shape) + ' and number of output ' + str(num_outputs))
output_shape = Shapes.slice_shape(input_shape, num_outputs)
input_format = 'NCHW' if self.device == 'neon' else 'NHWC'
output_shape = Shapes.slice_shape(input_shape, num_outputs, input_format)
for i in range(len(op.layer.top)):
op.output_shape_map[op.layer.top[i]] = output_shape
self.add_output_shape(op_def, output_shape)
......@@ -790,10 +842,15 @@ class CaffeConverter(object):
self.resolved_ops.add(op.name)
def convert_reshape(self, op):
if self.device == 'neon':
op_def = self.CommonConvert(op, 'Reshape')
else:
op_def = self.CommonConvert(op, 'ReOrganize')
input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]]
output_shape = input_shape
shape_param = np.asarray(op.layer.reshape_param.shape.dim)[[0, 3, 2, 1]]
shape_param = np.asarray(op.layer.reshape_param.shape.dim)
if self.device != 'neon':
shape_param = shape_param[[0, 3, 1, 2]]
for i in range(len(shape_param)):
if shape_param[i] != 0:
output_shape[i] = shape_param[i]
......@@ -867,15 +924,50 @@ class CaffeConverter(object):
assert len(input_nodes) == len(input_shapes)
for i in range(len(input_nodes)):
input_op = self.ops_map[input_nodes[i]]
input_shape = input_shapes[i] if self.device != 'neon' else \
[input_shapes[i][0], input_shapes[i][3], input_shapes[i][1], input_shapes[i][2]]
if input_op.layer is not None:
input_op.output_shape_map[input_op.layer.top[0]] = input_shapes[i]
input_op.output_shape_map[input_op.layer.top[0]] = input_shape
else:
input_op.output_shape_map[input_op.name] = input_shapes[i]
input_op.output_shape_map[input_op.name] = input_shape
def add_neon_input_transform(self, names):
for name in names:
new_input_name = MACE_INPUT_NODE_NAME + '_' + name + ":0"
op_def = self.net_def.op.add()
op_def.name = name
op_def.type = 'Transpose'
op_def.input.extend([new_input_name])
op_def.output.extend([name+':0'])
dims_arg = op_def.arg.add()
dims_arg.name = 'dims'
dims_arg.ints.extend([0, 3, 1, 2]) # NHWC -> NCHW
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
def add_neon_output_transform(self, names):
for name in names:
output_name = MACE_OUTPUT_NODE_NAME + '_' + name + ":0"
op_def = self.net_def.op.add()
op_def.name = output_name[:-2]
op_def.type = 'Transpose'
op_def.input.extend([name+':0'])
op_def.output.extend([output_name])
dims_arg = op_def.arg.add()
dims_arg.name = 'dims'
dims_arg.ints.extend([0, 2, 3, 1]) # NCHW -> NHWC
def convert(self, input_nodes, input_shapes, output_nodes):
if self.device == 'gpu':
self.add_input_transform(input_nodes)
if self.device == 'neon':
self.add_neon_input_transform(input_nodes)
assert self.ops[0].type == 'Input'
self.add_input_op_shape(input_nodes, input_shapes)
......@@ -924,6 +1016,9 @@ class CaffeConverter(object):
if self.device == 'cpu':
self.replace_in_out_name(input_nodes, output_nodes)
if self.device == 'neon':
self.add_neon_output_transform(output_nodes)
for op in self.ops:
if op.name not in self.resolved_ops:
print 'Unresolve Op: %s with type %s' % (op.name, op.type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册