提交 f3de19de 编写于 作者: Y yejianwu

support bidirection lstm for cpu, fix memory overflow when inferring output shape

上级 ac79ea41
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_EXPAND_DIMS_H_
#define MACE_KERNELS_EXPAND_DIMS_H_
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace kernels {
struct ExpandDimsBase {
explicit ExpandDimsBase(int axis) : axis_(axis) {}
int axis_;
};
template <DeviceType D, typename T>
struct ExpandDimsFunctor;
template <typename T>
struct ExpandDimsFunctor<DeviceType::CPU, T> : ExpandDimsBase {
explicit ExpandDimsFunctor(int axis) : ExpandDimsBase(axis) {}
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
index_t input_dims_size = input->dim_size();
if ( axis_ < 0 ) {
axis_ += input_dims_size + 1;
}
MACE_CHECK(axis_ >= 0 && axis_ <= input_dims_size,
"axis is out of bound: ", axis_);
const std::vector<index_t> input_shape = input->shape();
std::vector<index_t> output_shape;
output_shape.insert(output_shape.end(), input_shape.begin(),
input_shape.begin() + axis_);
output_shape.insert(output_shape.end(), 1);
output_shape.insert(output_shape.end(), input_shape.begin() + axis_,
input_shape.end());
output->ReuseTensorBuffer(*input);
output->Reshape(output_shape);
return MACE_SUCCESS;
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_EXPAND_DIMS_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_REVERSE_H_
#define MACE_KERNELS_REVERSE_H_
#include <functional>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct ReverseFunctor;
template <typename T>
struct ReverseFunctor<DeviceType::CPU, T> {
MaceStatus operator()(const Tensor *input,
const Tensor *axis,
Tensor *output,
StatsFuture *future) {
MACE_CHECK(axis->dim_size() == 1, "Only support reverse in one axis now");
const int32_t *axis_data = axis->data<int32_t>();
const index_t reverse_dim = *axis_data >= 0 ?
*axis_data : *axis_data + input->dim_size();
MACE_CHECK(reverse_dim >= 0 && reverse_dim < input->dim_size(),
"axis must be in the range [-rank(input), rank(input))");
const std::vector<index_t> input_shape = input->shape();
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
index_t high_dim_elem_size =
std::accumulate(input_shape.begin(), input_shape.begin() + reverse_dim,
1, std::multiplies<index_t>());
index_t low_dim_elem_size =
std::accumulate(input_shape.begin() + reverse_dim + 1,
input_shape.end(), 1, std::multiplies<index_t>());
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
index_t reverse_size = input_shape[reverse_dim] * low_dim_elem_size;
for (index_t h = 0; h < high_dim_elem_size; ++h) {
int input_idx = h * reverse_size;
int output_idx = input_idx + reverse_size;
for (index_t i = 0; i < input_shape[reverse_dim]; ++i) {
output_idx -= low_dim_elem_size;
memcpy(output_data + output_idx, input_data + input_idx,
sizeof(T) * low_dim_elem_size);
input_idx += low_dim_elem_size;
}
}
SetFutureDefaultWaitFn(future);
return MACE_SUCCESS;
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_REVERSE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/expand_dims.h"
namespace mace {
namespace ops {
void Register_ExpandDims(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ExpandDims")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ExpandDimsOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ExpandDims")
.Device(DeviceType::CPU)
.TypeConstraint<int32_t>("T")
.Build(),
ExpandDimsOp<DeviceType::CPU, int32_t>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_EXPAND_DIMS_H_
#define MACE_OPS_EXPAND_DIMS_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/expand_dims.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class ExpandDimsOp : public Operator<D, T> {
public:
ExpandDimsOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetOptionalArg<int>("axis", 0)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
return functor_(input, output, future);
}
private:
kernels::ExpandDimsFunctor<D, T> functor_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_EXPAND_DIMS_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ExpandDimsTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestExpandDims(const std::vector<index_t> &input_shape,
const int &axis,
const std::vector<index_t> &output_shape) {
// Construct graph
OpsTestNet net;
OpDefBuilder("ExpandDims", "ExpandDimsTest")
.Input("Input")
.AddIntArg("axis", static_cast<int>(axis))
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, T>("Input", input_shape);
// Run
net.RunOp();
auto input = net.GetTensor("Input");
auto output = net.GetTensor("Output");
EXPECT_THAT(output->shape(), ::testing::ContainerEq(output_shape));
const T *input_ptr = input->data<T>();
const T *output_ptr = output->data<T>();
const int size = output->size();
for (int i = 0; i < size; ++i) {
ASSERT_EQ(input_ptr[i], output_ptr[i]);
}
}
} // namespace
TEST_F(ExpandDimsTest, SimpleCPU) {
TestExpandDims<DeviceType::CPU, float>({3, 2, 1}, 1, {3, 1, 2, 1});
TestExpandDims<DeviceType::CPU, float>({1, 2, 3}, -1, {1, 2, 3, 1});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry);
extern void Register_Dequantize(OperatorRegistryBase *op_registry);
extern void Register_Eltwise(OperatorRegistryBase *op_registry);
extern void Register_ExpandDims(OperatorRegistryBase *op_registry);
extern void Register_Fill(OperatorRegistryBase *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry);
extern void Register_FullyConnected(OperatorRegistryBase *op_registry);
......@@ -50,6 +51,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry);
extern void Register_Reshape(OperatorRegistryBase *op_registry);
extern void Register_ResizeBicubic(OperatorRegistryBase *op_registry);
extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry);
extern void Register_Reverse(OperatorRegistryBase *op_registry);
extern void Register_ScalarMath(OperatorRegistryBase *op_registry);
extern void Register_Shape(OperatorRegistryBase *op_registry);
extern void Register_Split(OperatorRegistryBase *op_registry);
......@@ -90,6 +92,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this);
ops::Register_Eltwise(this);
ops::Register_ExpandDims(this);
ops::Register_Fill(this);
ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this);
......@@ -106,6 +109,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_Reshape(this);
ops::Register_ResizeBicubic(this);
ops::Register_ResizeBilinear(this);
ops::Register_Reverse(this);
ops::Register_ScalarMath(this);
ops::Register_Shape(this);
ops::Register_Split(this);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/reverse.h"
namespace mace {
namespace ops {
void Register_Reverse(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reverse")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ReverseOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_REVERSE_H_
#define MACE_OPS_REVERSE_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/reverse.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class ReverseOp : public Operator<D, T> {
public:
ReverseOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *axis = this->Input(AXIS);
Tensor *output = this->Output(OUTPUT);
return functor_(input, axis, output, future);
}
private:
kernels::ReverseFunctor<D, T> functor_;
protected:
MACE_OP_INPUT_TAGS(INPUT, AXIS);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REVERSE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ReverseOpTest : public OpsTestBase {};
namespace {
void TestReverse(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<index_t> &axis_shape,
const std::vector<int32_t> &axis,
const std::vector<float> &outputs) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>("Axis", axis_shape, axis);
OpDefBuilder("Reverse", "ReverseOpTest")
.Input("Input")
.Input("Axis")
.Output("Output")
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, float>("ExpectedOutput", input_shape,
outputs);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(ReverseOpTest, SimpleCPU) {
TestReverse({2, 3, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1}, {0},
{6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5});
TestReverse({2, 3, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1}, {1},
{4, 5, 2, 3, 0, 1, 10, 11, 8, 9, 6, 7});
TestReverse({2, 3, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1}, {2},
{1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -87,6 +87,7 @@ MaceSupportedOps = [
'DepthwiseConv2d',
'Dequantize',
'Eltwise',
'ExpandDims',
'FoldedBatchNorm',
'Fill',
'FullyConnected',
......@@ -104,6 +105,7 @@ MaceSupportedOps = [
'Reshape',
'ResizeBicubic',
'ResizeBilinear',
'Reverse',
'ScalarMath',
'Slice',
'Split',
......
......@@ -73,6 +73,7 @@ TFSupportedOps = [
'FusedBatchNorm',
'AvgPool',
'MaxPool',
'ExpandDims',
'Squeeze',
'MatMul',
'BatchMatMul',
......@@ -95,6 +96,7 @@ TFSupportedOps = [
'Gather',
'StridedSlice',
'Slice',
'ReverseV2',
'Stack',
'Pack',
'Unstack',
......@@ -181,6 +183,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Identity.name: self.convert_identity,
TFOpType.Reshape.name: self.convert_reshape,
TFOpType.Shape.name: self.convert_shape,
TFOpType.ExpandDims.name: self.convert_expand_dims,
TFOpType.Squeeze.name: self.convert_squeeze,
TFOpType.Transpose.name: self.convert_transpose,
TFOpType.Softmax.name: self.convert_softmax,
......@@ -198,6 +201,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Gather.name: self.convert_gather,
TFOpType.StridedSlice.name: self.convert_stridedslice,
TFOpType.Slice.name: self.convert_slice,
TFOpType.ReverseV2.name: self.convert_reverse,
TFOpType.Pack.name: self.convert_stack,
TFOpType.Stack.name: self.convert_stack,
TFOpType.Unpack.name: self.convert_unstack,
......@@ -225,9 +229,12 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._skip_tensor = set()
self._output_shape_list = []
self._output_shape_op_list = []
def run(self):
with tf.Session() as session:
self.convert_ops()
self.convert_ops(session)
self.replace_input_output_tensor_name()
return self._mace_net_def
......@@ -267,12 +274,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
else:
return tensor_name[:idx]
def convert_ops(self):
def update_output_shapes(self, sess):
output_shapes = sess.run(self._output_shape_op_list,
feed_dict=self._placeholders)
for i in range(len(self._output_shape_list)):
self._output_shape_list[i].dims.extend(output_shapes[i])
def convert_ops(self, sess):
for tf_op in self._tf_graph.get_operations():
mace_check(tf_op.type in self._op_converters,
"Mace does not support tensorflow op type %s yet"
% tf_op.type)
self._op_converters[tf_op.type](tf_op)
self.update_output_shapes(sess)
self.convert_tensors()
def convert_tensors(self):
......@@ -306,7 +320,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
# this function tries to infer tensor shape, but some dimension shape
# may be undefined due to variance of input length
def infer_tensor_shape(self, tensor):
def infer_tensor_shape(self, output_shape, tensor):
inferred_tensor_shape = tensor.shape.as_list()
inferred_success = True
for _, dim in enumerate(inferred_tensor_shape):
......@@ -314,10 +328,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
inferred_success = False
break
if inferred_success:
return inferred_tensor_shape
tensor_shape = tf.shape(tensor).eval(feed_dict=self._placeholders)
return tensor_shape
output_shape.dims.extend(inferred_tensor_shape)
else:
self._output_shape_list.append(output_shape)
self._output_shape_op_list.append(tf.shape(tensor))
def convert_nop(self, tf_op):
pass
......@@ -330,7 +344,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
op.output.extend([tf_output.name for tf_output in tf_op.outputs])
for tf_output in tf_op.outputs:
output_shape = op.output_shape.add()
output_shape.dims.extend(self.infer_tensor_shape(tf_output))
self.infer_tensor_shape(output_shape, tf_output)
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
......@@ -678,6 +692,21 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Reshape.name
def convert_expand_dims(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.ExpandDims.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
try:
axis_value = tf_op.get_attr('dim')
except ValueError:
try:
axis_value = tf_op.get_attr('axis')
except ValueError:
axis_value = 0
axis_arg.i = axis_value
def convert_squeeze(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Squeeze.name
......@@ -783,6 +812,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
arg.name = 'slice'
arg.i = 1
def convert_reverse(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Reverse.name
def convert_stack(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Stack.name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册