提交 021e1e25 编写于 作者: 叶剑武

support tf.tile

fix type error

fix type warning

add TileOp to MaceTransposableDataFormatOps and reduce memcpy in TileOp
上级 107b956a
......@@ -74,6 +74,7 @@ extern void RegisterStack(OpRegistryBase *op_registry);
extern void RegisterStridedSlice(OpRegistryBase *op_registry);
extern void RegisterSumGroup(OpRegistryBase *op_registry);
extern void RegisterTargetRMSNorm(OpRegistryBase *op_registry);
extern void RegisterTile(OpRegistryBase *op_registry);
extern void RegisterTranspose(OpRegistryBase *op_registry);
extern void RegisterUnstack(OpRegistryBase *op_registry);
extern void RegisterUnsqueeze(OpRegistryBase *op_registry);
......@@ -148,6 +149,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterSqueeze(this);
ops::RegisterSumGroup(this);
ops::RegisterTargetRMSNorm(this);
ops::RegisterTile(this);
ops::RegisterTranspose(this);
ops::RegisterUnstack(this);
ops::RegisterUnsqueeze(this);
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <memory>
#include <vector>
#include "mace/core/operator.h"
#include "mace/utils/memory.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class TileOp : public Operation {
public:
explicit TileOp(OpConstructContext *context)
: Operation(context),
has_data_format_(Operation::GetOptionalArg<int>("has_data_format", 0)) {
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
const Tensor *multiples = this->Input(1);
const index_t input_dims = input->dim_size();
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard multiples_guard(multiples);
const T *input_data = input->data<T>();
const int32_t *multiples_data = multiples->data<int32_t>();
MACE_CHECK(multiples->dim_size() == 1, "multiples must be 1-dimensional. ",
multiples->dim_size());
MACE_CHECK(input_dims == multiples->size(),
"multiples length must be the same as the dim_size of input",
input_dims, " vs. ", multiples->size());
std::vector<int32_t> multiples_vec(multiples_data,
multiples_data + multiples->size());
if (has_data_format_ && input_dims) {
int32_t h = multiples_vec[1];
int32_t w = multiples_vec[2];
int32_t c = multiples_vec[3];
multiples_vec[1] = c;
multiples_vec[2] = h;
multiples_vec[3] = w;
}
Tensor *output = this->Output(0);
std::vector<index_t> output_shape = {};
for (index_t i = 0; i < input_dims; ++i) {
output_shape.push_back(input->dim(i) * multiples_vec[i]);
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard output_guard(output);
T *output_data = output->mutable_data<T>();
const index_t output_byte_size = PadAlignSize(output->size() * sizeof(T));
ScratchBuffer *scratch = context->device()->scratch_buffer();
scratch->Rewind();
scratch->GrowSize(output_byte_size);
Tensor fake_input(scratch->Scratch(output_byte_size),
DataTypeToEnum<T>::value);
T *fake_input_data = fake_input.mutable_data<T>();
std::memcpy(fake_input_data, input_data, input->size() * sizeof(T));
index_t inner_dim = 1;
index_t outer_dim = input->size();
index_t acc_multiples = 1;
const index_t total_multiples =
std::accumulate(multiples_vec.begin(), multiples_vec.end(), 1,
std::multiplies<index_t>());
for (int64_t i = input_dims - 1; ; --i) {
inner_dim *= input->dim(i);
outer_dim /= input->dim(i);
for (int64_t o = 0; o < outer_dim; ++o) {
for (int64_t m = 0; m < multiples_vec[i]; ++m) {
std::memcpy(output_data + (o * multiples_vec[i] + m) * inner_dim,
fake_input_data + o * inner_dim, inner_dim * sizeof(T));
}
}
acc_multiples *= multiples_vec[i];
if (acc_multiples == total_multiples) {
break;
}
std::memcpy(fake_input_data, output_data,
input->size() * acc_multiples * sizeof(T));
inner_dim *= multiples_vec[i];
}
return MaceStatus::MACE_SUCCESS;
}
private:
int has_data_format_;
};
void RegisterTile(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Tile", TileOp, DeviceType::CPU, float);
MACE_REGISTER_OP_CONDITION(
op_registry, OpConditionBuilder("Tile").SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return {DeviceType::CPU};
}
return {DeviceType::CPU};
}));
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/benchmark_utils/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void BMTileHelper(int iters, const std::vector<index_t> &input_shape) {
mace::testing::StopTiming();
// Construct graph
OpsTestNet net;
net.AddRandomInput<D, T>("Input", input_shape);
std::vector<int32_t> multiples = {};
for (size_t i = 0; i < input_shape.size(); ++i) {
multiples.push_back(2);
}
net.AddInputFromArray<D, int32_t>(
"Multiples", {static_cast<int64_t>(multiples.size())}, multiples);
OpDefBuilder("Tile", "TileBM")
.Input("Input")
.Input("Multiples")
.Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
net.Sync();
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
} // namespace
#define MACE_BM_TILE_MACRO(N, H, W, C, TYPE, DEVICE) \
static void MACE_BM_TILE_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W * C; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BMTileHelper<DEVICE, TYPE>(iters, {N, H, W, C}); \
} \
MACE_BENCHMARK(MACE_BM_TILE_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE)
#define MACE_BM_TILE(N, H, W, C) MACE_BM_TILE_MACRO(N, H, W, C, float, CPU);
MACE_BM_TILE(1, 32, 32, 5);
MACE_BM_TILE(1, 32, 32, 7);
MACE_BM_TILE(1, 32, 32, 3);
MACE_BM_TILE(1, 128, 128, 9);
MACE_BM_TILE(1, 128, 128, 7);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class TileOpTest : public OpsTestBase {};
namespace {
void TestTile(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &multiples,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>(
"Multiples", {static_cast<int32_t>(multiples.size())}, multiples);
OpDefBuilder("Tile", "TileOpTest")
.Input("Input")
.Input("Multiples")
.Output("Output")
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
void TestTileWithDataFormat(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &multiples,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>(
"Multiples", {static_cast<int32_t>(multiples.size())}, multiples);
net.TransformDataFormat<DeviceType::CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Tile", "TileOpTest")
.Input("InputNCHW")
.Input("Multiples")
.Output("OutputNCHW")
.AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef());
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(TileOpTest, SimpleTest) {
TestTile({2, 3}, {0, 1, 2, 3, 4, 5}, {2, 3}, {4, 9},
{0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5,
0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5});
TestTile({2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 1, 2},
{2, 2, 6}, {0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5,
6, 7, 8, 6, 7, 8, 9, 10, 11, 9, 10, 11});
TestTile({2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {2, 1, 2},
{4, 2, 6}, {0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6,
7, 8, 9, 10, 11, 9, 10, 11, 0, 1, 2, 0, 1, 2, 3, 4,
5, 3, 4, 5, 6, 7, 8, 6, 7, 8, 9, 10, 11, 9, 10, 11});
TestTile({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 1, 1, 2}, {2, 2, 2, 6},
{0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6,
7, 8, 9, 10, 11, 9, 10, 11, 12, 13, 14, 12, 13, 14, 15, 16,
17, 15, 16, 17, 18, 19, 20, 18, 19, 20, 21, 22, 23, 21, 22, 23});
TestTile({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 2, 2, 1}, {2, 4, 4, 3},
{0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 0, 1,
2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 18, 19, 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 18, 19, 20, 21, 22, 23});
}
TEST_F(TileOpTest, TestTileWithDataFormat) {
TestTileWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 1, 1, 2}, {2, 2, 2, 6},
{0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6,
7, 8, 9, 10, 11, 9, 10, 11, 12, 13, 14, 12, 13, 14, 15, 16,
17, 15, 16, 17, 18, 19, 20, 18, 19, 20, 21, 22, 23, 21, 22, 23});
TestTileWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 2, 2, 1}, {2, 4, 4, 3},
{0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 0, 1,
2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 18, 19, 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 18, 19, 20, 21, 22, 23});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -153,6 +153,7 @@ MaceSupportedOps = [
'TargetRMSNorm',
'Transpose',
'Cumsum',
'Tile',
]
MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)
......@@ -185,7 +186,8 @@ MaceTransposableDataFormatOps = [MaceOp.Activation,
MaceOp.Softmax,
MaceOp.Split,
MaceOp.Squeeze,
MaceOp.SqrDiffMean]
MaceOp.SqrDiffMean,
MaceOp.Tile]
class MaceKeyword(object):
......
......@@ -126,6 +126,7 @@ TFSupportedOps = [
'MirrorPad',
'Cumsum',
'OneHot',
'Tile',
]
TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
......@@ -276,6 +277,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Cumsum.name: self.convert_cumsum,
TFOpType.OneHot.name: self.convert_one_hot,
TFOpType.Sum.name: self.convert_reduce,
TFOpType.Tile.name: self.convert_tile,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -1050,6 +1052,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
del op.input[0]
self._skip_tensor.add(tf_op.inputs[0].name)
def convert_tile(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Tile.name
def convert_fake_quantize(self, tf_op):
op = self.convert_general_op(tf_op)
min_arg = op.arg.add()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册