提交 9ffd51d9 编写于 作者: 李寅

Merge branch 'support_bilstm' into 'master'

support bidirection lstm for cpu, fix memory overflow when inferring output shape

See merge request !785
...@@ -21,6 +21,8 @@ Operator lists ...@@ -21,6 +21,8 @@ Operator lists
"DEQUANTIZE","Y","Model quantization will be supported later." "DEQUANTIZE","Y","Model quantization will be supported later."
"ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL" "ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL"
"EMBEDDING_LOOKUP","Y","" "EMBEDDING_LOOKUP","Y",""
"EXPANDDIMS","Y","Only CPU and TensorFlow is supported."
"FILL","Y","Only CPU and TensorFlow is supported."
"FULLY_CONNECTED","Y","" "FULLY_CONNECTED","Y",""
"GROUP_CONV_2D","","Caffe model with group count = channel count is supported." "GROUP_CONV_2D","","Caffe model with group count = channel count is supported."
"IDENTITY","Y","Only TensorFlow model is supported." "IDENTITY","Y","Only TensorFlow model is supported."
...@@ -39,6 +41,7 @@ Operator lists ...@@ -39,6 +41,7 @@ Operator lists
"RELUX","Y","" "RELUX","Y",""
"RESHAPE","Y","Limited support: GPU only supports softmax-like usage, CPU only supports the usage which not change the storage format." "RESHAPE","Y","Limited support: GPU only supports softmax-like usage, CPU only supports the usage which not change the storage format."
"RESIZE_BILINEAR","Y","" "RESIZE_BILINEAR","Y",""
"REVERSE","Y","Only CPU and Tensorflow is supported"
"RNN","","" "RNN","",""
"RPN_PROPOSAL_LAYER","Y","" "RPN_PROPOSAL_LAYER","Y",""
"SHAPE","Y","Only CPU and TensorFlow is supported." "SHAPE","Y","Only CPU and TensorFlow is supported."
...@@ -48,6 +51,7 @@ Operator lists ...@@ -48,6 +51,7 @@ Operator lists
"SOFTMAX","Y","" "SOFTMAX","Y",""
"SPACE_TO_BATCH_ND", "Y","" "SPACE_TO_BATCH_ND", "Y",""
"SPACE_TO_DEPTH","Y","" "SPACE_TO_DEPTH","Y",""
"SQEEZE","Y","Only CPU and TensorFlow is supported." "SQUEEZE","Y","Only CPU and TensorFlow is supported."
"TANH","Y","" "TANH","Y",""
"TRANSPOSE","Y","Only CPU and TensorFlow is supported." "TRANSPOSE","Y","Only CPU and TensorFlow is supported."
"UNSTACK","Y","Only CPU and TensorFlow is supported."
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_EXPAND_DIMS_H_
#define MACE_KERNELS_EXPAND_DIMS_H_
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/kernels/kernel.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct ExpandDimsFunctor;
template <typename T>
struct ExpandDimsFunctor<DeviceType::CPU, T> : OpKernel {
explicit ExpandDimsFunctor(OpKernelContext *context, int axis)
: OpKernel(context), axis_(axis) {}
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
index_t input_dims_size = input->dim_size();
if ( axis_ < 0 ) {
axis_ += input_dims_size + 1;
}
MACE_CHECK(axis_ >= 0 && axis_ <= input_dims_size,
"axis is out of bound: ", axis_);
const std::vector<index_t> input_shape = input->shape();
std::vector<index_t> output_shape;
output_shape.insert(output_shape.end(), input_shape.begin(),
input_shape.begin() + axis_);
output_shape.insert(output_shape.end(), 1);
output_shape.insert(output_shape.end(), input_shape.begin() + axis_,
input_shape.end());
output->ReuseTensorBuffer(*input);
output->Reshape(output_shape);
return MACE_SUCCESS;
}
int axis_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_EXPAND_DIMS_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_REVERSE_H_
#define MACE_KERNELS_REVERSE_H_
#include <functional>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/kernels/kernel.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct ReverseFunctor;
template <typename T>
struct ReverseFunctor<DeviceType::CPU, T> : OpKernel {
explicit ReverseFunctor(OpKernelContext *context) : OpKernel(context) {}
MaceStatus operator()(const Tensor *input,
const Tensor *axis,
Tensor *output,
StatsFuture *future) {
MACE_CHECK(axis->dim_size() == 1, "Only support reverse in one axis now");
const int32_t *axis_data = axis->data<int32_t>();
const index_t reverse_dim = *axis_data >= 0 ?
*axis_data : *axis_data + input->dim_size();
MACE_CHECK(reverse_dim >= 0 && reverse_dim < input->dim_size(),
"axis must be in the range [-rank(input), rank(input))");
const std::vector<index_t> input_shape = input->shape();
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
index_t high_dim_elem_size =
std::accumulate(input_shape.begin(), input_shape.begin() + reverse_dim,
1, std::multiplies<index_t>());
index_t low_dim_elem_size =
std::accumulate(input_shape.begin() + reverse_dim + 1,
input_shape.end(), 1, std::multiplies<index_t>());
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
index_t reverse_size = input_shape[reverse_dim] * low_dim_elem_size;
for (index_t h = 0; h < high_dim_elem_size; ++h) {
int input_idx = h * reverse_size;
int output_idx = input_idx + reverse_size;
for (index_t i = 0; i < input_shape[reverse_dim]; ++i) {
output_idx -= low_dim_elem_size;
memcpy(output_data + output_idx, input_data + input_idx,
sizeof(T) * low_dim_elem_size);
input_idx += low_dim_elem_size;
}
}
SetFutureDefaultWaitFn(future);
return MACE_SUCCESS;
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_REVERSE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/expand_dims.h"
namespace mace {
namespace ops {
void Register_ExpandDims(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ExpandDims")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ExpandDimsOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ExpandDims")
.Device(DeviceType::CPU)
.TypeConstraint<int32_t>("T")
.Build(),
ExpandDimsOp<DeviceType::CPU, int32_t>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ExpandDims")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
ExpandDimsOp<DeviceType::CPU, uint8_t>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_EXPAND_DIMS_H_
#define MACE_OPS_EXPAND_DIMS_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/expand_dims.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class ExpandDimsOp : public Operator<D, T> {
public:
ExpandDimsOp(const OperatorDef &op_def, OpKernelContext *context)
: Operator<D, T>(op_def, context),
functor_(context, OperatorBase::GetOptionalArg<int>("axis", 0)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
return functor_(input, output, future);
}
private:
kernels::ExpandDimsFunctor<D, T> functor_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_EXPAND_DIMS_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ExpandDimsTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestExpandDims(const std::vector<index_t> &input_shape,
const int &axis,
const std::vector<index_t> &output_shape) {
// Construct graph
OpsTestNet net;
OpDefBuilder("ExpandDims", "ExpandDimsTest")
.Input("Input")
.AddIntArg("axis", static_cast<int>(axis))
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, T>("Input", input_shape);
// Run
net.RunOp();
auto input = net.GetTensor("Input");
auto output = net.GetTensor("Output");
EXPECT_THAT(output->shape(), ::testing::ContainerEq(output_shape));
const T *input_ptr = input->data<T>();
const T *output_ptr = output->data<T>();
const int size = output->size();
for (int i = 0; i < size; ++i) {
ASSERT_EQ(input_ptr[i], output_ptr[i]);
}
}
} // namespace
TEST_F(ExpandDimsTest, SimpleCPU) {
TestExpandDims<DeviceType::CPU, float>({3, 2, 1}, 1, {3, 1, 2, 1});
TestExpandDims<DeviceType::CPU, float>({1, 2, 3}, -1, {1, 2, 3, 1});
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry); ...@@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry);
extern void Register_Dequantize(OperatorRegistryBase *op_registry); extern void Register_Dequantize(OperatorRegistryBase *op_registry);
extern void Register_Eltwise(OperatorRegistryBase *op_registry); extern void Register_Eltwise(OperatorRegistryBase *op_registry);
extern void Register_ExpandDims(OperatorRegistryBase *op_registry);
extern void Register_Fill(OperatorRegistryBase *op_registry); extern void Register_Fill(OperatorRegistryBase *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry);
extern void Register_FullyConnected(OperatorRegistryBase *op_registry); extern void Register_FullyConnected(OperatorRegistryBase *op_registry);
...@@ -50,6 +51,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry); ...@@ -50,6 +51,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry);
extern void Register_Reshape(OperatorRegistryBase *op_registry); extern void Register_Reshape(OperatorRegistryBase *op_registry);
extern void Register_ResizeBicubic(OperatorRegistryBase *op_registry); extern void Register_ResizeBicubic(OperatorRegistryBase *op_registry);
extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry); extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry);
extern void Register_Reverse(OperatorRegistryBase *op_registry);
extern void Register_ScalarMath(OperatorRegistryBase *op_registry); extern void Register_ScalarMath(OperatorRegistryBase *op_registry);
extern void Register_Shape(OperatorRegistryBase *op_registry); extern void Register_Shape(OperatorRegistryBase *op_registry);
extern void Register_Split(OperatorRegistryBase *op_registry); extern void Register_Split(OperatorRegistryBase *op_registry);
...@@ -90,6 +92,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ...@@ -90,6 +92,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_DepthwiseConv2d(this); ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this); ops::Register_Dequantize(this);
ops::Register_Eltwise(this); ops::Register_Eltwise(this);
ops::Register_ExpandDims(this);
ops::Register_Fill(this); ops::Register_Fill(this);
ops::Register_FoldedBatchNorm(this); ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this); ops::Register_FullyConnected(this);
...@@ -106,6 +109,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ...@@ -106,6 +109,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_Reshape(this); ops::Register_Reshape(this);
ops::Register_ResizeBicubic(this); ops::Register_ResizeBicubic(this);
ops::Register_ResizeBilinear(this); ops::Register_ResizeBilinear(this);
ops::Register_Reverse(this);
ops::Register_ScalarMath(this); ops::Register_ScalarMath(this);
ops::Register_Shape(this); ops::Register_Shape(this);
ops::Register_Split(this); ops::Register_Split(this);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/reverse.h"
namespace mace {
namespace ops {
void Register_Reverse(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reverse")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ReverseOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_REVERSE_H_
#define MACE_OPS_REVERSE_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/reverse.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class ReverseOp : public Operator<D, T> {
public:
ReverseOp(const OperatorDef &operator_def, OpKernelContext *context)
: Operator<D, T>(operator_def, context), functor_(context) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *axis = this->Input(AXIS);
Tensor *output = this->Output(OUTPUT);
return functor_(input, axis, output, future);
}
private:
kernels::ReverseFunctor<D, T> functor_;
protected:
MACE_OP_INPUT_TAGS(INPUT, AXIS);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REVERSE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void Reverse(int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
net.AddRandomInput<D, int32_t>("Axis", {1});
OpDefBuilder("Reverse", "ReverseOpTest")
.Input("Input")
.Input("Axis")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_REVERSE_MACRO(N, C, H, W, TYPE, DEVICE) \
static void MACE_BM_REVERSE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t macc = \
static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Reverse<DEVICE, TYPE>(iters, N, C, H, W); \
} \
MACE_BENCHMARK(MACE_BM_REVERSE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_REVERSE(N, C, H, W) \
MACE_BM_REVERSE_MACRO(N, C, H, W, float, CPU);
MACE_BM_REVERSE(1, 1, 99, 256);
MACE_BM_REVERSE(1, 30, 99, 256);
MACE_BM_REVERSE(1, 50, 99, 256);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ReverseOpTest : public OpsTestBase {};
namespace {
void TestReverse(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<index_t> &axis_shape,
const std::vector<int32_t> &axis,
const std::vector<float> &outputs) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>("Axis", axis_shape, axis);
OpDefBuilder("Reverse", "ReverseOpTest")
.Input("Input")
.Input("Axis")
.Output("Output")
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, float>("ExpectedOutput", input_shape,
outputs);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(ReverseOpTest, SimpleCPU) {
TestReverse({2, 3, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1}, {0},
{6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5});
TestReverse({2, 3, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1}, {1},
{4, 5, 2, 3, 0, 1, 10, 11, 8, 9, 6, 7});
TestReverse({2, 3, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1}, {2},
{1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10});
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -87,6 +87,7 @@ MaceSupportedOps = [ ...@@ -87,6 +87,7 @@ MaceSupportedOps = [
'DepthwiseConv2d', 'DepthwiseConv2d',
'Dequantize', 'Dequantize',
'Eltwise', 'Eltwise',
'ExpandDims',
'FoldedBatchNorm', 'FoldedBatchNorm',
'Fill', 'Fill',
'FullyConnected', 'FullyConnected',
...@@ -104,6 +105,7 @@ MaceSupportedOps = [ ...@@ -104,6 +105,7 @@ MaceSupportedOps = [
'Reshape', 'Reshape',
'ResizeBicubic', 'ResizeBicubic',
'ResizeBilinear', 'ResizeBilinear',
'Reverse',
'ScalarMath', 'ScalarMath',
'Slice', 'Slice',
'Split', 'Split',
......
...@@ -73,6 +73,7 @@ TFSupportedOps = [ ...@@ -73,6 +73,7 @@ TFSupportedOps = [
'FusedBatchNorm', 'FusedBatchNorm',
'AvgPool', 'AvgPool',
'MaxPool', 'MaxPool',
'ExpandDims',
'Squeeze', 'Squeeze',
'MatMul', 'MatMul',
'BatchMatMul', 'BatchMatMul',
...@@ -95,6 +96,7 @@ TFSupportedOps = [ ...@@ -95,6 +96,7 @@ TFSupportedOps = [
'Gather', 'Gather',
'StridedSlice', 'StridedSlice',
'Slice', 'Slice',
'ReverseV2',
'Stack', 'Stack',
'Pack', 'Pack',
'Unstack', 'Unstack',
...@@ -181,6 +183,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -181,6 +183,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Identity.name: self.convert_identity, TFOpType.Identity.name: self.convert_identity,
TFOpType.Reshape.name: self.convert_reshape, TFOpType.Reshape.name: self.convert_reshape,
TFOpType.Shape.name: self.convert_shape, TFOpType.Shape.name: self.convert_shape,
TFOpType.ExpandDims.name: self.convert_expand_dims,
TFOpType.Squeeze.name: self.convert_squeeze, TFOpType.Squeeze.name: self.convert_squeeze,
TFOpType.Transpose.name: self.convert_transpose, TFOpType.Transpose.name: self.convert_transpose,
TFOpType.Softmax.name: self.convert_softmax, TFOpType.Softmax.name: self.convert_softmax,
...@@ -198,6 +201,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -198,6 +201,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Gather.name: self.convert_gather, TFOpType.Gather.name: self.convert_gather,
TFOpType.StridedSlice.name: self.convert_stridedslice, TFOpType.StridedSlice.name: self.convert_stridedslice,
TFOpType.Slice.name: self.convert_slice, TFOpType.Slice.name: self.convert_slice,
TFOpType.ReverseV2.name: self.convert_reverse,
TFOpType.Pack.name: self.convert_stack, TFOpType.Pack.name: self.convert_stack,
TFOpType.Stack.name: self.convert_stack, TFOpType.Stack.name: self.convert_stack,
TFOpType.Unpack.name: self.convert_unstack, TFOpType.Unpack.name: self.convert_unstack,
...@@ -225,9 +229,12 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -225,9 +229,12 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._skip_tensor = set() self._skip_tensor = set()
self._output_shape_list = []
self._output_shape_op_list = []
def run(self): def run(self):
with tf.Session() as session: with tf.Session() as session:
self.convert_ops() self.convert_ops(session)
self.replace_input_output_tensor_name() self.replace_input_output_tensor_name()
return self._mace_net_def return self._mace_net_def
...@@ -267,12 +274,19 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -267,12 +274,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
else: else:
return tensor_name[:idx] return tensor_name[:idx]
def convert_ops(self): def update_output_shapes(self, sess):
output_shapes = sess.run(self._output_shape_op_list,
feed_dict=self._placeholders)
for i in range(len(self._output_shape_list)):
self._output_shape_list[i].dims.extend(output_shapes[i])
def convert_ops(self, sess):
for tf_op in self._tf_graph.get_operations(): for tf_op in self._tf_graph.get_operations():
mace_check(tf_op.type in self._op_converters, mace_check(tf_op.type in self._op_converters,
"Mace does not support tensorflow op type %s yet" "Mace does not support tensorflow op type %s yet"
% tf_op.type) % tf_op.type)
self._op_converters[tf_op.type](tf_op) self._op_converters[tf_op.type](tf_op)
self.update_output_shapes(sess)
self.convert_tensors() self.convert_tensors()
def convert_tensors(self): def convert_tensors(self):
...@@ -306,7 +320,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -306,7 +320,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
# this function tries to infer tensor shape, but some dimension shape # this function tries to infer tensor shape, but some dimension shape
# may be undefined due to variance of input length # may be undefined due to variance of input length
def infer_tensor_shape(self, tensor): def infer_tensor_shape(self, output_shape, tensor):
inferred_tensor_shape = tensor.shape.as_list() inferred_tensor_shape = tensor.shape.as_list()
inferred_success = True inferred_success = True
for _, dim in enumerate(inferred_tensor_shape): for _, dim in enumerate(inferred_tensor_shape):
...@@ -314,10 +328,10 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -314,10 +328,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
inferred_success = False inferred_success = False
break break
if inferred_success: if inferred_success:
return inferred_tensor_shape output_shape.dims.extend(inferred_tensor_shape)
else:
tensor_shape = tf.shape(tensor).eval(feed_dict=self._placeholders) self._output_shape_list.append(output_shape)
return tensor_shape self._output_shape_op_list.append(tf.shape(tensor))
def convert_nop(self, tf_op): def convert_nop(self, tf_op):
pass pass
...@@ -330,7 +344,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -330,7 +344,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
op.output.extend([tf_output.name for tf_output in tf_op.outputs]) op.output.extend([tf_output.name for tf_output in tf_op.outputs])
for tf_output in tf_op.outputs: for tf_output in tf_op.outputs:
output_shape = op.output_shape.add() output_shape = op.output_shape.add()
output_shape.dims.extend(self.infer_tensor_shape(tf_output)) self.infer_tensor_shape(output_shape, tf_output)
data_type_arg = op.arg.add() data_type_arg = op.arg.add()
data_type_arg.name = 'T' data_type_arg.name = 'T'
...@@ -678,6 +692,21 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -678,6 +692,21 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.Reshape.name op.type = MaceOp.Reshape.name
def convert_expand_dims(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.ExpandDims.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
try:
axis_value = tf_op.get_attr('dim')
except ValueError:
try:
axis_value = tf_op.get_attr('axis')
except ValueError:
axis_value = 0
axis_arg.i = axis_value
def convert_squeeze(self, tf_op): def convert_squeeze(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.Squeeze.name op.type = MaceOp.Squeeze.name
...@@ -783,6 +812,10 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -783,6 +812,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
arg.name = 'slice' arg.name = 'slice'
arg.i = 1 arg.i = 1
def convert_reverse(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Reverse.name
def convert_stack(self, tf_op): def convert_stack(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.Stack.name op.type = MaceOp.Stack.name
......
...@@ -124,7 +124,7 @@ class MemoryOptimizer(object): ...@@ -124,7 +124,7 @@ class MemoryOptimizer(object):
@staticmethod @staticmethod
def is_memory_reuse_op(op): def is_memory_reuse_op(op):
return op.type == 'Reshape' or op.type == 'Identity' \ return op.type == 'Reshape' or op.type == 'Identity' \
or op.type == 'Squeeze' or op.type == 'Squeeze' or op.type == 'ExpandDims'
def optimize(self): def optimize(self):
for op in self.net_def.op: for op in self.net_def.op:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册