提交 34c487cd 编写于 作者: 李寅

Support reshape with shape as a varable tensor

上级 6a154ca8
......@@ -237,11 +237,6 @@ MaceStatus MaceEngine::Impl::Run(
<< "' is not belong to model's outputs: "
<< MakeString(MapKeys(output_info_map_));
}
if (device_type_ == DeviceType::GPU) {
MACE_CHECK(output.second.shape().size() == 4,
"The outputs' shape must be 4-dimension with NHWC format,"
" please use 1 to fill missing dimensions");
}
Tensor *output_tensor =
ws_->GetTensor(MakeString("mace_output_node_", output.first));
output_tensors.push_back(output_tensor);
......
......@@ -90,6 +90,7 @@ extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry);
extern void Register_Gather(OperatorRegistry *op_registry);
extern void Register_Identity(OperatorRegistry *op_registry);
extern void Register_LocalResponseNorm(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_Pad(OperatorRegistry *op_registry);
......@@ -107,6 +108,7 @@ extern void Register_Stack(OperatorRegistry *op_registry);
extern void Register_StridedSlice(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_SpaceToDepth(OperatorRegistry *op_registry);
extern void Register_Squeeze(OperatorRegistry *op_registry);
extern void Register_Transpose(OperatorRegistry *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
extern void Register_WinogradTransform(OperatorRegistry *op_registry);
......@@ -135,6 +137,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this);
ops::Register_Gather(this);
ops::Register_Identity(this);
ops::Register_LocalResponseNorm(this);
ops::Register_MatMul(this);
ops::Register_Pad(this);
......@@ -152,6 +155,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_StridedSlice(this);
ops::Register_SpaceToBatchND(this);
ops::Register_SpaceToDepth(this);
ops::Register_Squeeze(this);
ops::Register_Transpose(this);
ops::Register_WinogradInverseTransform(this);
ops::Register_WinogradTransform(this);
......
......@@ -14,6 +14,7 @@
#include <string>
#include <vector>
#include <unordered_set>
#include <utility>
#include "mace/core/arg_helper.h"
......@@ -22,6 +23,15 @@
namespace mace {
namespace {
bool ShouldPreallocateMemoryForOp(const OperatorDef &op) {
static const std::unordered_set<std::string> reuse_buffer_ops {
"Reshape", "Identity", "Squeeze"
};
return reuse_buffer_ops.find(op.type()) == reuse_buffer_ops.end();
}
} // namespace
Workspace::Workspace() : host_scratch_buffer_(new ScratchBuffer(
GetDeviceAllocator(DeviceType::CPU))) {}
......@@ -177,7 +187,7 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()
&& op.type() != "Reshape") {
&& ShouldPreallocateMemoryForOp(op)) {
auto mem_ids = op.mem_id();
int count = mem_ids.size();
for (int i = 0; i < count; ++i) {
......
......@@ -136,7 +136,19 @@ MaceStatus BufferToImageFunctor<DeviceType::GPU, T>::operator()(
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
} else if (type == ARGUMENT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
} else {
} else if (type == IN_OUT_CHANNEL) {
if (buffer->dim_size() == 4) { // NHWC
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
} else if (buffer->dim_size() == 2) { // NC
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
} else {
MACE_NOT_IMPLEMENTED;
}
} else if (type == IN_OUT_WIDTH || type == IN_OUT_HEIGHT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
if (buffer->dim_size() < 4) {
......@@ -144,6 +156,10 @@ MaceStatus BufferToImageFunctor<DeviceType::GPU, T>::operator()(
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
}
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
}
b2f_kernel.setArg(idx++, *(image->opencl_image()));
......
......@@ -28,10 +28,15 @@ namespace {
// [(C + 3) / 4 * W, N * H]
void CalInOutputImageShape(const std::vector<index_t> &shape, /* NHWC */
std::vector<size_t> *image_shape) {
MACE_CHECK(shape.size() == 4);
MACE_CHECK(shape.size() == 4 || shape.size() == 2);
image_shape->resize(2);
(*image_shape)[0] = RoundUpDiv4(shape[3]) * shape[2];
(*image_shape)[1] = shape[0] * shape[1];
if (shape.size() == 4) {
(*image_shape)[0] = RoundUpDiv4(shape[3]) * shape[2];
(*image_shape)[1] = shape[0] * shape[1];
} else if (shape.size() == 2) {
(*image_shape)[0] = RoundUpDiv4(shape[1]);
(*image_shape)[1] = shape[0];
}
}
// [Ic, H * W * (Oc + 3) / 4]
......
......@@ -123,7 +123,19 @@ MaceStatus ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
} else {
} else if (type == IN_OUT_CHANNEL) {
if (buffer->dim_size() == 4) { // NHWC
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
} else if (buffer->dim_size() == 2) { // NC
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
} else {
MACE_NOT_IMPLEMENTED;
}
} else if (type == IN_OUT_WIDTH || type == IN_OUT_HEIGHT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
if (buffer->dim_size() < 4) {
......@@ -131,6 +143,10 @@ MaceStatus ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
}
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
}
b2f_kernel.setArg(idx++, *(image->opencl_image()));
......
......@@ -45,10 +45,26 @@ template <typename T>
MaceStatus SoftmaxFunctor<DeviceType::GPU, T>::operator()(const Tensor *logits,
Tensor *output,
StatsFuture *future) {
const index_t batch = logits->dim(0);
const index_t height = logits->dim(1);
const index_t width = logits->dim(2);
const index_t channels = logits->dim(3);
index_t batch = 0;
index_t height = 0;
index_t width = 0;
index_t channels = 0;
if (logits->dim_size() == 2) {
batch = logits->dim(0);
height = 1;
width = 1;
channels = logits->dim(1);
} else if (logits->dim_size() == 4) {
batch = logits->dim(0);
height = logits->dim(1);
width = logits->dim(2);
channels = logits->dim(3);
} else {
MACE_NOT_IMPLEMENTED;
}
const index_t channel_blocks = RoundUpDiv4(channels);
const int remain_channels = channel_blocks * 4 - channels;
......@@ -103,8 +119,7 @@ MaceStatus SoftmaxFunctor<DeviceType::GPU, T>::operator()(const Tensor *logits,
std::vector<uint32_t> lws = LocalWS(gws, kwg_size_);
std::string tuning_key =
Concat("softmax_opencl_kernel", output->dim(0), output->dim(1),
output->dim(2), output->dim(3));
Concat("softmax_opencl_kernel", batch, height, width, channels);
TuningOrRun3DKernel(kernel_, tuning_key, gws, lws, future);
if (runtime->IsOutOfRangeCheckEnabled()) {
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/identity.h"
namespace mace {
namespace ops {
void Register_Identity(OperatorRegistry *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
IdentityOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
.Build(),
IdentityOp<DeviceType::GPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity")
.Device(DeviceType::GPU)
.TypeConstraint<half>("T")
.Build(),
IdentityOp<DeviceType::GPU, half>);
#endif
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_IDENTITY_H_
#define MACE_OPS_IDENTITY_H_
#include <vector>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class IdentityOp : public Operator<D, T> {
public:
IdentityOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
output->ReuseTensorBuffer(*input);
return MACE_SUCCESS;
}
private:
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_IDENTITY_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class IdentityOpTest : public OpsTestBase {};
namespace {
void TestIdentity(const std::vector<index_t> &shape) {
// Construct graph
OpsTestNet net;
OpDefBuilder("Identity", "IdentityTest")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", shape);
// Run
net.RunOp();
auto input = net.GetTensor("Input");
auto output = net.GetTensor("Output");
EXPECT_THAT(output->shape(), ::testing::ContainerEq(shape));
const float *input_ptr = input->data<float>();
const float *output_ptr = output->data<float>();
const int size = output->size();
for (int i = 0; i < size; ++i) {
ASSERT_EQ(input_ptr[i], output_ptr[i]);
}
}
} // namespace
TEST_F(IdentityOpTest, TestIdentity) {
TestIdentity({1, 2, 3, 4});
TestIdentity({1});
TestIdentity({});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -23,6 +23,20 @@ void Register_Reshape(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T")
.Build(),
ReshapeOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
.Build(),
ReshapeOp<DeviceType::GPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape")
.Device(DeviceType::GPU)
.TypeConstraint<half>("T")
.Build(),
ReshapeOp<DeviceType::GPU, half>);
#endif
}
} // namespace ops
......
......@@ -27,26 +27,29 @@ template <DeviceType D, typename T>
class ReshapeOp : public Operator<D, T> {
public:
ReshapeOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
shape_(OperatorBase::GetRepeatedArgs<int64_t>("shape")) {}
: Operator<D, T>(op_def, ws) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const index_t num_dims = shape_.size();
const Tensor *shape = this->Input(SHAPE);
const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0);
Tensor::MappingGuard shape_guard(shape);
const int32_t *shape_data = shape->data<int32_t>();
int unknown_idx = -1;
index_t product = 1;
std::vector<index_t> out_shape;
for (int i = 0; i < num_dims; ++i) {
if (shape_[i] == -1) {
if (shape_data[i] == -1) {
MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1";
unknown_idx = i;
out_shape.push_back(1);
} else {
MACE_CHECK(shape_[i] >= 0) << "Shape must be non-negative: "
<< shape_[i];
out_shape.push_back(shape_[i]);
product *= shape_[i];
MACE_CHECK(shape_data[i] >= 0) << "Shape must be non-negative: "
<< shape_data[i];
out_shape.push_back(shape_data[i]);
product *= shape_data[i];
}
}
......@@ -65,11 +68,10 @@ class ReshapeOp : public Operator<D, T> {
}
private:
std::vector<int64_t> shape_;
kernels::ReshapeFunctor<D, T> functor_;
private:
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_INPUT_TAGS(INPUT, SHAPE);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
......
......@@ -30,12 +30,15 @@ void TestReshape(const std::vector<index_t> &org_shape,
OpsTestNet net;
OpDefBuilder("Reshape", "ReshapeTest")
.Input("Input")
.Input("Shape")
.Output("Output")
.AddIntsArg("shape", output_shape)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", org_shape);
net.AddInputFromArray<DeviceType::CPU, int32_t>("Shape",
{output_shape.size()},
output_shape);
// Run
net.RunOp();
......
......@@ -93,17 +93,25 @@ void Complex(const std::vector<index_t> &logits_shape) {
// Add input data
net.AddRandomInput<D, float>("Input", logits_shape);
net.TransformDataFormat<CPU, float>("Input", NHWC, "InputNCHW", NCHW);
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
if (logits_shape.size() == 4) {
net.TransformDataFormat<CPU, float>("Input", NHWC, "InputNCHW", NCHW);
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Run on cpu
net.RunOp();
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC);
if (logits_shape.size() == 4) {
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC);
}
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
......@@ -142,6 +150,11 @@ TEST_F(SoftmaxOpTest, OPENCLUnAligned) {
Complex<DeviceType::GPU>({5, 211, 107, 1});
}
TEST_F(SoftmaxOpTest, OPENCLAlignedRank2) {
Complex<DeviceType::GPU>({1, 1001});
Complex<DeviceType::GPU>({3, 1001});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/squeeze.h"
namespace mace {
namespace ops {
void Register_Squeeze(OperatorRegistry *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
SqueezeOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
.Build(),
SqueezeOp<DeviceType::GPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze")
.Device(DeviceType::GPU)
.TypeConstraint<half>("T")
.Build(),
SqueezeOp<DeviceType::GPU, half>);
#endif
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_SQUEEZE_H_
#define MACE_OPS_SQUEEZE_H_
#include <vector>
#include <unordered_set>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template<DeviceType D, typename T>
class SqueezeOp : public Operator<D, T> {
public:
SqueezeOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
axis_(OperatorBase::GetRepeatedArgs<int>("axis", {})) {}
MaceStatus Run(StatsFuture *future) override {
MACE_UNUSED(future);
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
std::vector<index_t> output_shape;
std::unordered_set<int> axis_set(axis_.begin(), axis_.end());
for (int i = 0; i < input->dim_size(); ++i) {
if (input->dim(i) > 1
|| (!axis_set.empty() && axis_set.find(i) == axis_set.end())) {
output_shape.push_back(input->dim(i));
}
}
output->ReuseTensorBuffer(*input);
output->Reshape(output_shape);
return MACE_SUCCESS;
}
private:
std::vector<int> axis_;
private:
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_SQUEEZE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class SqueezeOpTest : public OpsTestBase {};
namespace {
void TestSqueeze(const std::vector<index_t> &org_shape,
const std::vector<int> &axis,
const std::vector<index_t> &res_shape) {
// Construct graph
OpsTestNet net;
OpDefBuilder("Squeeze", "SqueezeTest")
.Input("Input")
.AddIntsArg("axis", axis)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", org_shape);
// Run
net.RunOp();
auto input = net.GetTensor("Input");
auto output = net.GetTensor("Output");
EXPECT_THAT(output->shape(), ::testing::ContainerEq(res_shape));
const float *input_ptr = input->data<float>();
const float *output_ptr = output->data<float>();
const int size = output->size();
for (int i = 0; i < size; ++i) {
ASSERT_EQ(input_ptr[i], output_ptr[i]);
}
}
} // namespace
TEST_F(SqueezeOpTest, TestSqueeze) {
TestSqueeze({1, 2, 1, 4}, {}, {2, 4});
TestSqueeze({1, 2, 1, 4}, {1}, {1, 2, 1, 4});
TestSqueeze({1, 2, 1, 4}, {2}, {1, 2, 4});
TestSqueeze({1}, {}, {});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -84,6 +84,8 @@ MaceSupportedOps = [
'Eltwise',
'FoldedBatchNorm',
'FullyConnected',
'Gather',
'Identity',
'LocalResponseNorm',
'MatMul',
'Pad',
......@@ -95,6 +97,10 @@ MaceSupportedOps = [
'Reshape',
'ResizeBilinear',
'Slice',
'Shape',
'Squeeze',
'Stack',
'StridedSlice',
'Softmax',
'SpaceToBatchND',
'SpaceToDepth',
......@@ -144,7 +150,6 @@ class MaceKeyword(object):
class TransformerRule(Enum):
REMOVE_USELESS_RESHAPE_OP = 0
REMOVE_IDENTITY_OP = 1
TRANSFORM_GLOBAL_POOLING = 2
FOLD_RESHAPE = 3
......@@ -212,7 +217,6 @@ class ConverterOption(object):
self._winograd_enabled = False
self._transformer_option = [
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.REMOVE_USELESS_RESHAPE_OP,
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_RESHAPE,
......
......@@ -41,6 +41,8 @@ tf_kernel_str = 'ksize'
tf_epsilon_str = 'epsilon'
tf_align_corners = 'align_corners'
tf_block_size = 'block_size'
tf_squeeze_dims = 'squeeze_dims'
tf_axis = 'axis'
TFSupportedOps = [
'Conv2D',
......@@ -149,7 +151,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.MatMul.name: self.convert_matmul,
TFOpType.Identity.name: self.convert_identity,
TFOpType.Reshape.name: self.convert_reshape,
TFOpType.Shape.name: self.convert_nop,
TFOpType.Shape.name: self.convert_shape,
TFOpType.Squeeze.name: self.convert_squeeze,
TFOpType.Transpose.name: self.convert_transpose,
TFOpType.Softmax.name: self.convert_softmax,
TFOpType.ResizeBilinear.name: self.convert_resize_bilinear,
......@@ -257,6 +260,16 @@ class TensorflowConverter(base_converter.ConverterInterface):
tensor.data_type = data_type
tensor.float_data.extend(value.flat)
# this function tries to infer tensor shape, but some dimension shape
# may be undefined due to variance of input length
@staticmethod
def infer_tensor_shape(tensor):
shape = tensor.shape.as_list()
def normalize_func(dim):
return dim if dim else - 1
return [normalize_func(dim) for dim in shape]
def convert_nop(self, tf_op):
pass
......@@ -268,7 +281,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
op.output.extend([tf_output.name for tf_output in tf_op.outputs])
for tf_output in tf_op.outputs:
output_shape = op.output_shape.add()
output_shape.dims.extend(tf_output.shape.as_list())
output_shape.dims.extend(self.infer_tensor_shape(tf_output))
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
......@@ -481,24 +494,29 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op)
op.type = MaceOp.MatMul.name
def convert_shape(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Shape.name
op.output_type.extend([mace_pb2.DT_INT32])
def convert_reshape(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Reshape.name
del op.input[1:]
shape_arg = op.arg.add()
shape_arg.name = MaceKeyword.mace_shape_str
shape_value = []
if tf_op.inputs[1].op.type == TFOpType.Const.name:
shape_value = list(tf_op.inputs[1].eval().astype(np.int32))
for i in xrange(len(shape_value)):
if shape_value[i] == -1:
shape_value[i] = 1
self._skip_tensor.add(tf_op.inputs[-1].name)
elif tf_op.inputs[1].op.type == TFOpType.Shape.name:
shape_value = list(tf_op.inputs[1].op.inputs[0].shape.as_list())
shape_arg.ints.extend(shape_value)
def convert_squeeze(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Squeeze.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
try:
axis_value = tf_op.get_attr('squeeze_dims')
except ValueError:
try:
axis_value = tf_op.get_attr('axis')
except ValueError:
axis_value = []
axis_arg.ints.extend(axis_value)
def convert_transpose(self, tf_op):
perm = tf_op.inputs[1].eval().astype(np.int32)
......
......@@ -56,7 +56,6 @@ class Transformer(base_converter.ConverterInterface):
# DO NOT reorder the following transformers' order
self._registered_transformers_order = [
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.REMOVE_USELESS_RESHAPE_OP,
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_RESHAPE,
......@@ -81,8 +80,6 @@ class Transformer(base_converter.ConverterInterface):
self._registered_transformers = {
TransformerRule.ADD_IN_OUT_TENSOR_INFO:
self.add_in_out_tensor_info,
TransformerRule.REMOVE_USELESS_RESHAPE_OP:
self.remove_useless_reshape_op,
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
TransformerRule.TRANSFORM_GLOBAL_POOLING:
self.transform_global_pooling,
......@@ -289,19 +286,6 @@ class Transformer(base_converter.ConverterInterface):
return False
def remove_useless_reshape_op(self):
net = self._model
for op in net.op:
if op.type == MaceOp.Reshape.name:
shape = list(ConverterUtil.get_arg(
op, MaceKeyword.mace_shape_str).ints)
if shape == self.get_tensor_shape(op.input[0]):
print("Remove useless reshape: %s(%s)"
% (op.name, op.type))
op.type = 'Identity'
return False
def remove_identity_op(self):
net = self._model
for op in net.op:
......@@ -791,6 +775,26 @@ class Transformer(base_converter.ConverterInterface):
"channel dimension")
arg.i = 3
elif op.type == MaceOp.Squeeze.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if ConverterUtil.data_format(
op) == DataFormat.NHWC \
and self._target_data_format == DataFormat.NCHW: # noqa
print("Transpose squeeze args: %s(%s)"
% (op.name, op.type))
mace_check(list(arg.ints) == [1, 2],
'only support squeeze at at [1, 2]')
arg.ints[:] = [2, 3]
elif ConverterUtil.data_format(
op) == DataFormat.NCHW \
and self._target_data_format == DataFormat.NHWC: # noqa
print("Transpose squeeze args: %s(%s)"
% (op.name, op.type))
mace_check(list(arg.ints) == [2, 3],
'only support squeeze at at [2, 3]')
arg.ints[:] = [1, 2]
# transpose op output shape
data_format = ConverterUtil.data_format(op)
if data_format is not None \
......@@ -818,16 +822,19 @@ class Transformer(base_converter.ConverterInterface):
+ '_' + input_node.name
op = net.op.add()
op.name = self.normalize_op_name(input_node.name)
op.type = MaceOp.Transpose.name
op.input.extend([new_input_name])
op.output.extend([input_node.name])
output_shape = op.output_shape.add()
output_shape.dims.extend(input_node.shape)
self.transpose_shape(output_shape.dims, [0, 3, 1, 2])
if len(output_shape.dims) == 4:
op.type = MaceOp.Transpose.name
self.transpose_shape(output_shape.dims, [0, 3, 1, 2])
dims_arg = op.arg.add()
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 3, 1, 2])
dims_arg = op.arg.add()
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 3, 1, 2])
else:
op.type = MaceOp.Identity.name
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
......@@ -836,19 +843,22 @@ class Transformer(base_converter.ConverterInterface):
+ '_' + output_node.name
op = self._model.op.add()
op.name = self.normalize_op_name(output_name)
op.type = MaceOp.Transpose.name
op.input.extend([output_node.name])
op.output.extend([output_name])
output_shape = op.output_shape.add()
output_shape.dims.extend(
self._producer[output_node.name].output_shape[0].dims)
self.transpose_shape(output_shape.dims, [0, 2, 3, 1])
if len(output_shape.dims) == 4:
op.type = MaceOp.Transpose.name
self.transpose_shape(output_shape.dims, [0, 2, 3, 1])
dims_arg = op.arg.add()
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 2, 3, 1])
dims_arg = op.arg.add()
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 2, 3, 1])
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
else:
op.type = MaceOp.Identity.name
return False
......@@ -1003,36 +1013,51 @@ class Transformer(base_converter.ConverterInterface):
return False
def fold_reshape(self):
changed = False
net = self._model
for op in net.op:
if op.type == MaceOp.Softmax.name or op.type == MaceOp.MatMul.name:
print("Fold reshape: %s(%s)" % (op.name, op.type))
if self.consumer_count(op.output[0]) == 1:
if op.type == MaceOp.Softmax.name:
# see if possible to fold
# Reshape(xd->2d) + Softmax(2d) + Reshape(xd) to Softmax(xd)
should_fold = False
if op.input[0] in self._producer \
and self._producer[op.input[0]].type \
== MaceOp.Reshape.name \
and len(op.output_shape[0].dims) == 2 \
and self.consumer_count(op.output[0]) == 1:
producer = self._producer[op.input[0]]
consumer = self._consumers[op.output[0]][0]
if consumer.type == MaceOp.Reshape.name:
shape = ConverterUtil.get_arg(consumer,
MaceKeyword.mace_shape_str).ints # noqa
del op.output_shape[0].dims[:]
op.output_shape[0].dims.extend(shape)
self.safe_remove_node(consumer, op)
changed = True
if (consumer.type == MaceOp.Reshape.name
and op.output_shape[0].dims[-1]
== consumer.output_shape[0].dims[-1]
and op.output_shape[0].dims[-1] != -1
and self.get_tensor_shape(producer.input[0])
== consumer.output_shape[0].dims):
should_fold = True
if should_fold:
print(
"Fold reshape and softmax: %s(%s)"
% (op.name, op.type))
producer = self._producer[op.input[0]]
if producer.type == MaceOp.Reshape.name:
self.safe_remove_node(producer,
self._producer[
producer.input[0]])
changed = True
if len(op.output_shape[0].dims) < 4:
shape = ([1, 1, 1, 1] + list(op.output_shape[0].dims))[-4:]
op.output_shape[0].dims[:] = shape[:]
changed = True
if changed:
return True
consumer = self._consumers[op.output[0]][0]
op.output_shape[0].dims[:] = self.get_tensor_shape(
producer.input[0])
# if there is a shape op, remove it too
if (consumer.input[1] in self._producer
and self._producer[consumer.input[1]].type
== 'Shape'):
self.safe_remove_node(
self._producer[consumer.input[1]], None)
# remove consumer reshape
self.safe_remove_node(consumer, op)
# remove producer reshape
self.safe_remove_node(producer,
self._producer.get(producer.input[0],
None))
return True
return False
def transform_matmul_to_fc(self):
......
......@@ -83,6 +83,11 @@ class MemoryOptimizer(object):
optimized_mem_size += self.mem_size(self.mem_block[mem])
return optimized_mem_size
@staticmethod
def is_memory_reuse_op(op):
return op.type == 'Reshape' or op.type == 'Identity' \
or op.type == 'Squeeze'
def optimize(self):
for op in self.net_def.op:
if not self.op_need_optimize_memory(op):
......@@ -96,51 +101,59 @@ class MemoryOptimizer(object):
'the number of output.')
return
for i in range(len(op.output)):
op_mem_block = self.get_op_mem_block(op.type,
op.output_shape[i].dims)
mem_id = -1
if len(self.idle_mem) > 0:
best_mem_add_size = sys.maxint
best_mem_waste_size = sys.maxint
for mid in self.idle_mem:
old_mem_block = self.mem_block[mid]
new_mem_block = self.resize_mem_block(
old_mem_block, op_mem_block)
add_mem_size = self.sub_mem_block(new_mem_block,
old_mem_block)
waste_mem_size = self.sub_mem_block(new_mem_block,
op_mem_block)
# minimize add_mem_size; if best_mem_add_size is 0,
# then minimize waste_mem_size
if (best_mem_add_size > 0 and
add_mem_size < best_mem_add_size) \
or (best_mem_add_size == 0 and
waste_mem_size < best_mem_waste_size):
best_mem_id = mid
best_mem_add_size = add_mem_size
best_mem_waste_size = waste_mem_size
best_mem_block = new_mem_block
# if add mem size < op mem size, then reuse it
if best_mem_add_size <= self.mem_size(op_mem_block):
self.mem_block[best_mem_id] = best_mem_block
mem_id = best_mem_id
self.idle_mem.remove(mem_id)
if mem_id == -1:
mem_id = self.mem_id_base() + self.total_mem_count
self.total_mem_count += 1
self.mem_block[mem_id] = op_mem_block
op.mem_id.extend([mem_id])
self.op_mem[op.output[i]] = mem_id
if self.is_memory_reuse_op(op):
# make these ops reuse memory of input tensor
mem_id = self.op_mem.get(op.input[0], -1)
else:
op_mem_block = self.get_op_mem_block(
op.type,
op.output_shape[i].dims)
mem_id = -1
if len(self.idle_mem) > 0:
best_mem_add_size = sys.maxint
best_mem_waste_size = sys.maxint
for mid in self.idle_mem:
old_mem_block = self.mem_block[mid]
new_mem_block = self.resize_mem_block(
old_mem_block, op_mem_block)
add_mem_size = self.sub_mem_block(new_mem_block,
old_mem_block)
waste_mem_size = self.sub_mem_block(new_mem_block,
op_mem_block)
# minimize add_mem_size; if best_mem_add_size is 0,
# then minimize waste_mem_size
if (best_mem_add_size > 0 and
add_mem_size < best_mem_add_size) \
or (best_mem_add_size == 0 and
waste_mem_size < best_mem_waste_size):
best_mem_id = mid
best_mem_add_size = add_mem_size
best_mem_waste_size = waste_mem_size
best_mem_block = new_mem_block
# if add mem size < op mem size, then reuse it
if best_mem_add_size <= self.mem_size(op_mem_block):
self.mem_block[best_mem_id] = best_mem_block
mem_id = best_mem_id
self.idle_mem.remove(mem_id)
if mem_id == -1:
mem_id = self.mem_id_base() + self.total_mem_count
self.total_mem_count += 1
self.mem_block[mem_id] = op_mem_block
if mem_id != -1:
op.mem_id.extend([mem_id])
self.op_mem[op.output[i]] = mem_id
# de-ref input tensor mem
for ipt in op.input:
for idx in xrange(len(op.input)):
ipt = op.input[idx]
if ipt in self.ref_counter:
self.ref_counter[ipt] -= 1
if self.ref_counter[ipt] == 0:
if self.ref_counter[ipt] == 0 and \
(idx > 0 or not self.is_memory_reuse_op(op)):
self.idle_mem.add(self.op_mem[ipt])
elif self.ref_counter[ipt] < 0:
raise Exception('ref count is less than 0')
......@@ -170,8 +183,10 @@ class GPUMemoryOptimizer(MemoryOptimizer):
mem_block[0] = output_shape[2]
mem_block[1] = output_shape[0] * int((output_shape[1] + 3) / 4)
else:
mem_block[0] = output_shape[2] * int((output_shape[3] + 3) / 4)
mem_block[1] = output_shape[0] * output_shape[1]
padded_output_shape = ([1, 1, 1, 1] + list(output_shape))[-4:]
mem_block[0] = padded_output_shape[2] * int(
(padded_output_shape[3] + 3) / 4)
mem_block[1] = padded_output_shape[0] * padded_output_shape[1]
return mem_block
def mem_size(self, memory_block):
......
......@@ -482,6 +482,7 @@ def print_configuration(flags, configs):
def download_model_files(model_file_path,
model_output_dir,
weight_file_path=""):
MaceLogger.info("Downloading model, please wait ...")
if model_file_path.startswith("http://") or \
model_file_path.startswith("https://"):
model_file = model_output_dir + "/model.pb"
......@@ -491,6 +492,7 @@ def download_model_files(model_file_path,
weight_file_path.startswith("https://"):
weight_file = model_output_dir + "/model.caffemodel"
urllib.urlretrieve(weight_file_path, weight_file)
MaceLogger.info("Model downloaded successfully.")
def get_model_files_path(model_file_path,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册