提交 04ecb90e 编写于 作者: L luxuhui

support BatchMatMulV2 & Select ops for tensorflow

N/A
Signed-off-by: NLuxuhui <luxuhui@xiaomi.com>
上级 0f37ee96
......@@ -289,7 +289,7 @@ void OpRegistryBase::GetInOutMemoryTypes(
const std::string &op_type,
OpConditionContext *context) const {
MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered.");
op_type, " operation is not registered. op_type=", op_type);
return registry_.at(op_type)->memory_type_setter(context);
}
......
......@@ -64,6 +64,7 @@ extern void RegisterResizeBilinear(OpRegistryBase *op_registry);
extern void RegisterResizeNearestNeighbor(OpRegistryBase *op_registry);
extern void RegisterReverse(OpRegistryBase *op_registry);
extern void RegisterScalarMath(OpRegistryBase *op_registry);
extern void RegisterSelect(OpRegistryBase *op_registry);
extern void RegisterShape(OpRegistryBase *op_registry);
extern void RegisterSlice(OpRegistryBase *op_registry);
extern void RegisterSoftmax(OpRegistryBase *op_registry);
......@@ -143,6 +144,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterResizeNearestNeighbor(this);
ops::RegisterReverse(this);
ops::RegisterScalarMath(this);
ops::RegisterSelect(this);
ops::RegisterShape(this);
ops::RegisterSlice(this);
ops::RegisterSoftmax(this);
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/core/tensor.h"
namespace mace {
namespace ops {
template<DeviceType D, typename T>
class SelectOp;
template<>
class SelectOp<DeviceType::CPU, float> : public Operation {
public:
explicit SelectOp(OpConstructContext *context)
: Operation(context) {}
MaceStatus Run(OpContext *context) override {
if (this->InputSize() == 1) {
return RunWithNoData(context);
} else {
return RunWithData(context);
}
}
MaceStatus RunWithNoData(OpContext *context) {
const Tensor *condition = this->Input(CONDITION);
Tensor *output = this->Output(OUTPUT);
const index_t condition_rank = condition->dim_size();
MACE_RETURN_IF_ERROR(output->Resize({condition->size(), condition_rank}));
float *output_data = output->mutable_data<float>();
const bool *condition_data = condition->data<bool>();
index_t i = 0;
if (condition_rank == 1) {
const index_t channel = condition->dim(0);
for (index_t c = 0; c < channel; ++c) {
if (condition_data[c]) {
output_data[i++] = c;
}
}
} else if (condition_rank == 2) {
const index_t width = condition->dim(0);
const index_t channel = condition->dim(1);
for (index_t w = 0; w < width; ++w) {
index_t w_base = w * channel;
for (index_t c = 0; c < channel; ++c) {
if (condition_data[w_base + c]) {
output_data[i++] = w;
output_data[i++] = c;
}
}
}
} else if (condition_rank == 3) {
const index_t height = condition->dim(0);
const index_t width = condition->dim(1);
const index_t channel = condition->dim(2);
for (index_t h = 0; h < height; ++h) {
index_t h_base = h * width;
for (index_t w = 0; w < width; ++w) {
index_t w_base = (w + h_base) * channel;
for (index_t c = 0; c < channel; ++c) {
if (condition_data[w_base + c]) {
output_data[i++] = h;
output_data[i++] = w;
output_data[i++] = c;
}
}
}
}
} else if (condition_rank == 4) {
const index_t batch = condition->dim(0);
const index_t height = condition->dim(1);
const index_t width = condition->dim(2);
const index_t channel = condition->dim(3);
for (index_t b = 0; b < batch; ++b) {
index_t b_base = b * height;
for (index_t h = 0; h < height; ++h) {
index_t h_base = (b_base + h) * width;
for (index_t w = 0; w < width; ++w) {
index_t w_base = (w + h_base) * channel;
for (index_t c = 0; c < channel; ++c) {
if (condition_data[w_base + c]) {
output_data[i++] = b;
output_data[i++] = h;
output_data[i++] = w;
output_data[i++] = c;
}
}
}
}
}
} else {
const index_t condition_size = condition->size();
const index_t condition_rank = condition->dim_size();
auto div_buffer = context->device()->scratch_buffer();
div_buffer->Rewind();
MACE_RETURN_IF_ERROR(div_buffer->GrowSize(
condition_rank * sizeof(index_t)));
index_t *div_ptr = div_buffer->mutable_data<index_t>();
div_ptr[condition_rank - 1] = 1;
for (index_t dim = condition_rank - 1; dim > 0; --dim) {
div_ptr[dim - 1] = div_ptr[dim] * condition->dim(dim);
}
for (index_t c = 0; c < condition_size; ++c) {
if (condition_data[c]) {
auto remainder = c;
for (index_t dim = 0; dim < condition_rank; ++dim) {
output_data[i++] = remainder / div_ptr[dim];
remainder = remainder % div_ptr[dim];
}
}
}
}
MACE_RETURN_IF_ERROR(output->Resize({i / condition_rank, condition_rank}));
return MaceStatus::MACE_SUCCESS;
}
bool CheckDataValid(const Tensor *condition,
const Tensor *x, const Tensor *y) {
const index_t x_rank = x->dim_size();
const index_t y_rank = y->dim_size();
const index_t condition_rank = condition->dim_size();
MACE_CHECK(condition_rank <= x_rank && x_rank == y_rank);
for (index_t i = 0; i < condition_rank; ++i) {
MACE_CHECK(condition->dim(i) == x->dim(i),
"dimensions are not equal: ",
MakeString(condition->shape()),
" vs. ",
MakeString(x->shape()));
}
for (index_t i = 0; i < x_rank; ++i) {
MACE_CHECK(y->dim(i) == x->dim(i), "dimensions are not equal: ",
MakeString(y->shape()), " vs. ", MakeString(x->shape()));
}
return true;
}
MaceStatus RunWithData(OpContext *context) {
const Tensor *condition = this->Input(CONDITION);
const Tensor *x = this->Input(X);
const Tensor *y = this->Input(Y);
MACE_ASSERT(CheckDataValid(condition, x, y));
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->Resize(x->shape()));
float *output_data = output->mutable_data<float>();
const bool *condition_data = condition->data<bool>();
const float *x_data = x->data<float>();
const float *y_data = y->data<float>();
const index_t condition_size = condition->size();
const index_t x_size = x->size();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
if (condition_size == x_size) {
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t k = start; k < end; k += step) {
// LOG(INFO) << "condition_data[" << k << "] = " << condition_data[k];
output_data[k] = condition_data[k] ? x_data[k] : y_data[k];
}
}, 0, x_size, 1);
} else if (x_size > condition_size) { // broadcast
const auto block_size = x_size / condition_size;
MACE_ASSERT(
block_size > 1 && x_size % condition_size == 0,
"x_size should be a multiple of condition_size and greater than 1");
const auto raw_block_size = block_size * sizeof(float);
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t k = start; k < end; k += step) {
auto offset = block_size * k;
if (condition_data[k]) {
memcpy(output_data + offset, x_data + offset, raw_block_size);
} else {
memcpy(output_data + offset, y_data + offset, raw_block_size);
}
}
}, 0, condition_size, 1);
} else {
MACE_CHECK(false, "x_size should be bigger than condition_size");
}
return MaceStatus::MACE_SUCCESS;
}
private:
MACE_OP_INPUT_TAGS(CONDITION, X, Y);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
void RegisterSelect(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Select", SelectOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class SelectOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestSelect(const std::vector<index_t> &input_shape,
const std::vector<uint8_t> &input,
const std::vector<index_t> &x_shape,
const std::vector<T> &x,
const std::vector<index_t> &y_shape,
const std::vector<T> &y,
const std::vector<index_t> &output_shape,
const std::vector<T> &output) {
// Construct graph
OpsTestNet net;
OpDefBuilder builder("Select", "SelectTest");
builder.Input("Input");
if (x.size() > 0) {
builder.Input("X").Input("Y");
}
builder.Output("Output").Finalize(net.NewOperatorDef());
net.AddInputFromArray<D, uint8_t>(MakeString("Input"), input_shape, input);
if (x.size() > 0) {
net.AddInputFromArray<D, T>(MakeString("X"), x_shape, x);
net.AddInputFromArray<D, T>(MakeString("Y"), y_shape, y);
}
// Run
net.RunOp();
net.AddInputFromArray<D, T>("ExpectedOutput", output_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(SelectOpTest, SimpleTestWithData) {
TestSelect<DeviceType::CPU, float>(
{2, 3},
{true, false, false, false, true, true},
{2, 3},
{3.0, 2.0, 3.0, 4.0, 5.0, 6.0},
{2, 3},
{3.0, -1.0, -2.0, -3.0, 8.0, 9.0},
{2, 3},
{3.0, -1.0, -2.0, -3.0, 5.0, 6.0});
}
TEST_F(SelectOpTest, SimpleTestWithDataBroadcast) {
TestSelect<DeviceType::CPU, float>(
{2},
{true, false},
{2, 3},
{3.0, 2.0, 3.0, 4.0, 5.0, 6.0},
{2, 3},
{3.0, -1.0, -2.0, -3.0, 8.0, 9.0},
{2, 3},
{3, 2, 3, -3, 8, 9});
}
TEST_F(SelectOpTest, SimpleTestWithNoDataBroadcast1) {
TestSelect<DeviceType::CPU, float>(
{2},
{true, false},
{}, {}, {}, {},
{1, 1},
{0});
}
TEST_F(SelectOpTest, SimpleTestWithNoDataBroadcast2) {
TestSelect<DeviceType::CPU, float>(
{2, 3},
{true, false, false, false, true, true},
{}, {}, {}, {},
{3, 2},
{0, 0, 1, 1, 1, 2});
}
TEST_F(SelectOpTest, SimpleTestWithNoDataBroadcast3) {
TestSelect<DeviceType::CPU, float>(
{2, 2, 3},
{true, false, false, false, true, true,
true, false, false, false, true, true},
{}, {}, {}, {},
{6, 3},
{0, 0, 0, 0, 1, 1, 0, 1, 2,
1, 0, 0, 1, 1, 1, 1, 1, 2});
}
TEST_F(SelectOpTest, SimpleTestWithNoDataBroadcast4) {
TestSelect<DeviceType::CPU, float>(
{2, 2, 2, 3},
{true, false, false, false, true, true,
true, false, false, false, true, true,
true, false, false, false, true, true,
true, false, false, false, true, true},
{}, {}, {}, {},
{12, 4},
{0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 2,
0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 2,
1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 2,
1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2});
}
TEST_F(SelectOpTest, SimpleTestWithNoDataBroadcast5) {
TestSelect<DeviceType::CPU, float>(
{2, 2, 2, 2, 3},
{true, false, false, false, true, true,
true, false, false, false, true, true,
true, false, false, false, true, true,
true, false, false, false, true, true,
true, false, false, false, true, true,
true, false, false, false, true, true,
true, false, false, false, true, true,
true, false, false, false, true, true},
{}, {}, {}, {},
{24, 5},
{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 2,
0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 2,
0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 2,
0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 2,
1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 2,
1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 2,
1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 2,
1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -136,6 +136,7 @@ MaceSupportedOps = [
'ResizeNearestNeighbor',
'Reverse',
'ScalarMath',
'Select',
'Slice',
'Splice',
'Split',
......@@ -280,7 +281,7 @@ class MaceKeyword(object):
class TransformerRule(Enum):
REMOVE_IDENTITY_OP = 1
REMOVE_USELESS_OP = 1
TRANSFORM_GLOBAL_POOLING = 2
FOLD_RESHAPE = 3
TRANSFORM_MATMUL_TO_FC = 4
......@@ -526,9 +527,9 @@ class ConverterOption(object):
else:
self._transformer_option = [
# Model structure related transformation
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.REMOVE_USELESS_OP,
TransformerRule.TRANSFORM_FAKE_QUANTIZE,
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.REMOVE_USELESS_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
TransformerRule.TRANSFORM_BASIC_LSTMCELL,
......
......@@ -57,6 +57,7 @@ TFSupportedOps = [
'ArgMax',
'AvgPool',
'BatchMatMul',
'BatchMatMulV2',
'BatchToSpaceND',
'BiasAdd',
'Cast',
......@@ -105,6 +106,7 @@ TFSupportedOps = [
'ResizeNearestNeighbor',
'ReverseV2',
'Rsqrt',
'Select',
'Shape',
'Sigmoid',
'Sign',
......@@ -134,7 +136,7 @@ TFSupportedOps = [six.b(op) for op in TFSupportedOps]
TFTransformGraphOptions = [
'strip_unused_nodes',
'remove_nodes(op=Identity, op=CheckNumerics)',
'remove_nodes(op=Identity, op=CheckNumerics, op=StopGradient)',
'fold_constants(ignore_errors=true)',
'fold_batch_norms',
'fold_old_batch_norms',
......@@ -211,6 +213,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.ArgMax.name: self.convert_argmax,
TFOpType.AvgPool.name: self.convert_pooling,
TFOpType.BatchMatMul.name: self.convert_matmul,
TFOpType.BatchMatMulV2.name: self.convert_matmul,
TFOpType.BatchToSpaceND.name: self.convert_space_batch,
TFOpType.BiasAdd.name: self.convert_biasadd,
TFOpType.Cast.name: self.convert_cast,
......@@ -263,6 +266,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.ResizeBilinear.name: self.convert_resize_bilinear,
TFOpType.ResizeNearestNeighbor.name: self.convert_resize_nearest_neighbor, # noqa
TFOpType.ReverseV2.name: self.convert_reverse,
TFOpType.Select.name: self.convert_select,
TFOpType.Shape.name: self.convert_shape,
TFOpType.Sigmoid.name: self.convert_activation,
TFOpType.Sign.name: self.convert_elementwise,
......@@ -993,6 +997,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Reverse.name
def convert_select(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Select.name
def convert_stack(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Stack.name
......
......@@ -48,7 +48,7 @@ class Transformer(base_converter.ConverterInterface):
self._registered_transformers = {
TransformerRule.TRANSFORM_FAKE_QUANTIZE:
self.transform_fake_quantize,
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
TransformerRule.REMOVE_USELESS_OP: self.remove_useless_op,
TransformerRule.TRANSFORM_GLOBAL_POOLING:
self.transform_global_pooling,
TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE:
......@@ -347,15 +347,15 @@ class Transformer(base_converter.ConverterInterface):
return False
def remove_identity_op(self):
def remove_useless_op(self):
net = self._model
for op in net.op:
if op.type == 'Identity':
print("Remove identity: %s(%s)" % (op.name, op.type))
print("Remove useless op: %s(%s)" % (op.name, op.type))
self.safe_remove_node(op,
self._producer.get(op.input[0], None))
return True
if op.type == 'Reshape' and \
elif op.type == 'Reshape' and \
op.output_shape[0].dims == \
self.get_tensor_shape(op.input[0]):
print("Remove useless reshape: %s(%s)" % (op.name, op.type))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册