提交 d4fb1af2 编写于 作者: 李滨

Merge branch 'mnmt' into 'master'

Support nlp model and ops

See merge request !602
...@@ -65,6 +65,8 @@ Configurations ...@@ -65,6 +65,8 @@ Configurations
- The shapes of the input tensors, in NHWC order. - The shapes of the input tensors, in NHWC order.
* - output_shapes * - output_shapes
- The shapes of the output tensors, in NHWC order. - The shapes of the output tensors, in NHWC order.
* - input_ranges
- The numerical range of the input tensors, default [-1, 1]. It is only for test.
* - validation_inputs_data * - validation_inputs_data
- [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used. - [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used.
* - runtime * - runtime
......
...@@ -217,9 +217,6 @@ MaceStatus MaceEngine::Impl::Run( ...@@ -217,9 +217,6 @@ MaceStatus MaceEngine::Impl::Run(
<< "' is not belong to model's inputs: " << "' is not belong to model's inputs: "
<< MakeString(MapKeys(input_info_map_)); << MakeString(MapKeys(input_info_map_));
} }
MACE_CHECK(input.second.shape().size() == 4,
"The Inputs' shape must be 4-dimension with NHWC format,"
" please use 1 to fill missing dimensions");
Tensor *input_tensor = Tensor *input_tensor =
ws_->GetTensor(MakeString("mace_input_node_", input.first)); ws_->GetTensor(MakeString("mace_input_node_", input.first));
MACE_RETURN_IF_ERROR(input_tensor->Resize(input.second.shape())); MACE_RETURN_IF_ERROR(input_tensor->Resize(input.second.shape()));
......
...@@ -63,7 +63,9 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) { ...@@ -63,7 +63,9 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) {
for (auto iter = operators_.begin(); iter != operators_.end(); ++iter) { for (auto iter = operators_.begin(); iter != operators_.end(); ++iter) {
auto &op = *iter; auto &op = *iter;
MACE_LATENCY_LOGGER(2, "Running operator ", op->debug_def().name(), "(", MACE_LATENCY_LOGGER(2, "Running operator ", op->debug_def().name(), "(",
op->debug_def().type(), ")"); op->debug_def().type(), "), mem_id: ",
MakeListString(op->debug_def().mem_id().data(),
op->debug_def().mem_id().size()));
bool future_wait = (device_type_ == DeviceType::GPU && bool future_wait = (device_type_ == DeviceType::GPU &&
(run_metadata != nullptr || (run_metadata != nullptr ||
std::distance(iter, operators_.end()) == 1)); std::distance(iter, operators_.end()) == 1));
......
...@@ -79,6 +79,7 @@ extern void Register_AddN(OperatorRegistry *op_registry); ...@@ -79,6 +79,7 @@ extern void Register_AddN(OperatorRegistry *op_registry);
extern void Register_BatchNorm(OperatorRegistry *op_registry); extern void Register_BatchNorm(OperatorRegistry *op_registry);
extern void Register_BatchToSpaceND(OperatorRegistry *op_registry); extern void Register_BatchToSpaceND(OperatorRegistry *op_registry);
extern void Register_BiasAdd(OperatorRegistry *op_registry); extern void Register_BiasAdd(OperatorRegistry *op_registry);
extern void Register_Cast(OperatorRegistry *op_registry);
extern void Register_ChannelShuffle(OperatorRegistry *op_registry); extern void Register_ChannelShuffle(OperatorRegistry *op_registry);
extern void Register_Concat(OperatorRegistry *op_registry); extern void Register_Concat(OperatorRegistry *op_registry);
extern void Register_Conv2D(OperatorRegistry *op_registry); extern void Register_Conv2D(OperatorRegistry *op_registry);
...@@ -127,6 +128,7 @@ OperatorRegistry::OperatorRegistry() { ...@@ -127,6 +128,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_BatchNorm(this); ops::Register_BatchNorm(this);
ops::Register_BatchToSpaceND(this); ops::Register_BatchToSpaceND(this);
ops::Register_BiasAdd(this); ops::Register_BiasAdd(this);
ops::Register_Cast(this);
ops::Register_ChannelShuffle(this); ops::Register_ChannelShuffle(this);
ops::Register_Concat(this); ops::Register_Concat(this);
ops::Register_Conv2D(this); ops::Register_Conv2D(this);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#ifndef MACE_KERNELS_BIAS_ADD_H_ #ifndef MACE_KERNELS_BIAS_ADD_H_
#define MACE_KERNELS_BIAS_ADD_H_ #define MACE_KERNELS_BIAS_ADD_H_
#include <functional>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -29,20 +30,27 @@ ...@@ -29,20 +30,27 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template<DeviceType D, typename T> struct BiasAddFunctorBase {
explicit BiasAddFunctorBase(const DataFormat data_format) {
data_format_ = data_format;
}
DataFormat data_format_;
};
template <DeviceType D, typename T>
struct BiasAddFunctor; struct BiasAddFunctor;
template<> template <>
struct BiasAddFunctor<DeviceType::CPU, float> { struct BiasAddFunctor<DeviceType::CPU, float> : BiasAddFunctorBase {
explicit BiasAddFunctor(const DataFormat data_format)
: BiasAddFunctorBase(data_format) {}
MaceStatus operator()(const Tensor *input, MaceStatus operator()(const Tensor *input,
const Tensor *bias, const Tensor *bias,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
MACE_UNUSED(future); MACE_UNUSED(future);
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t height = input->dim(2);
const index_t width = input->dim(3);
Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard bias_mapper(bias); Tensor::MappingGuard bias_mapper(bias);
...@@ -52,12 +60,31 @@ struct BiasAddFunctor<DeviceType::CPU, float> { ...@@ -52,12 +60,31 @@ struct BiasAddFunctor<DeviceType::CPU, float> {
const float *bias_ptr = bias->data<float>(); const float *bias_ptr = bias->data<float>();
float *output_ptr = output->mutable_data<float>(); float *output_ptr = output->mutable_data<float>();
if (input->dim_size() == 4 && data_format_ == NCHW) {
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t height_width = input->dim(2) * input->dim(3);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) { for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < channels; ++c) { for (index_t c = 0; c < channels; ++c) {
for (index_t hw = 0; hw < height * width; ++hw) { for (index_t hw = 0; hw < height_width; ++hw) {
index_t pos = (n * channels + c) * height * width + hw; index_t pos = (n * channels + c) * height_width + hw;
output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
}
}
}
} else {
const std::vector<index_t> &shape = input->shape();
const index_t fused_batch = std::accumulate(
shape.begin(), shape.end() - 1, 1, std::multiplies<index_t>());
const index_t channels = *shape.rbegin();
#pragma omp parallel for
for (index_t n = 0; n < fused_batch; ++n) {
index_t pos = n * channels;
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = input_ptr[pos] + bias_ptr[c]; output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
++pos;
} }
} }
} }
...@@ -67,12 +94,14 @@ struct BiasAddFunctor<DeviceType::CPU, float> { ...@@ -67,12 +94,14 @@ struct BiasAddFunctor<DeviceType::CPU, float> {
}; };
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
template<typename T> template <typename T>
struct BiasAddFunctor<DeviceType::GPU, T> { struct BiasAddFunctor<DeviceType::GPU, T> : BiasAddFunctorBase {
explicit BiasAddFunctor(const DataFormat data_format)
: BiasAddFunctorBase(data_format) {}
MaceStatus operator()(const Tensor *input, MaceStatus operator()(const Tensor *input,
const Tensor *bias, const Tensor *bias,
Tensor *output, Tensor *output,
StatsFuture *future); StatsFuture *future);
cl::Kernel kernel_; cl::Kernel kernel_;
uint32_t kwg_size_; uint32_t kwg_size_;
std::unique_ptr<BufferBase> kernel_error_; std::unique_ptr<BufferBase> kernel_error_;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <functional>
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
...@@ -44,6 +45,157 @@ enum EltwiseType { ...@@ -44,6 +45,157 @@ enum EltwiseType {
NONE = 10, NONE = 10,
}; };
inline index_t GetIndex(const std::vector<index_t> &shape,
const std::vector<index_t> &index) {
index_t idx = 0;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] > 1) {
idx = idx * shape[i] + index[i];
}
}
return idx;
}
inline void IncreaseIndex(const std::vector<index_t> &shape,
std::vector<index_t> *index) {
for (index_t i = static_cast<index_t>(shape.size()) - 1; i >= 0; --i) {
++(*index)[i];
if ((*index)[i] >= shape[i]) {
(*index)[i] -= shape[i];
} else {
break;
}
}
}
inline void TensorGeneralBroadcastEltwise(const EltwiseType type,
const float *input0,
const float *input1,
const std::vector<float> &coeff,
const bool swapped,
const std::vector<index_t>
&input0_shape,
const std::vector<index_t>
&input1_shape,
const std::vector<index_t>
&output_shape,
float *output) {
const index_t output_size = std::accumulate(output_shape.begin(),
output_shape.end(),
1,
std::multiplies<index_t>());
std::vector<index_t> out_index(output_shape.size(), 0);
switch (type) {
case SUM:
if (coeff.empty()) {
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = input0[idx0] + input1[idx1];
IncreaseIndex(output_shape, &out_index);
}
} else {
std::vector<float> coeff_copy = coeff;
if (swapped) {
std::swap(coeff_copy[0], coeff_copy[1]);
}
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] =
input0[idx0] * coeff_copy[0] + input1[idx1] * coeff_copy[1];
IncreaseIndex(output_shape, &out_index);
}
}
break;
case SUB:
if (!swapped) {
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = input0[idx0] - input1[idx1];
IncreaseIndex(output_shape, &out_index);
}
} else {
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = input1[idx1] - input0[idx0];
IncreaseIndex(output_shape, &out_index);
}
}
break;
case PROD:
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = input0[idx0] * input1[idx1];
IncreaseIndex(output_shape, &out_index);
}
break;
case DIV:
if (!swapped) {
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = input0[idx0] / input1[idx1];
IncreaseIndex(output_shape, &out_index);
}
} else {
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = input1[idx1] / input0[idx0];
IncreaseIndex(output_shape, &out_index);
}
}
break;
case MIN:
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = std::min(input1[idx1], input0[idx0]);
IncreaseIndex(output_shape, &out_index);
}
break;
case MAX:
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = std::max(input1[idx1], input0[idx0]);
IncreaseIndex(output_shape, &out_index);
}
break;
case SQR_DIFF:
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = std::pow(input1[idx1] - input0[idx0], 2.f);
IncreaseIndex(output_shape, &out_index);
}
break;
case POW:
if (!swapped) {
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = std::pow(input0[idx0], input1[idx1]);
IncreaseIndex(output_shape, &out_index);
}
} else {
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = std::pow(input1[idx1], input0[idx0]);
IncreaseIndex(output_shape, &out_index);
}
}
break;
default:
LOG(FATAL) << "Eltwise op not support type " << type;
}
}
inline void TensorBroadcastEltwise(const EltwiseType type, inline void TensorBroadcastEltwise(const EltwiseType type,
const float *input0, const float *input0,
const float *input1, const float *input1,
...@@ -662,40 +814,71 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase { ...@@ -662,40 +814,71 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase {
&& input1->dim(0) == input0->dim(1))), && input1->dim(0) == input0->dim(1))),
"only support broadcast channel dimension"); "only support broadcast channel dimension");
} else { } else {
if (rank_diff > 0 && rank_diff < input0->dim_size()) { for (uint32_t i = 0; i < input1->dim_size(); ++i) {
for (uint32_t i = 0; i < input1->dim_size(); ++i) { MACE_CHECK(input0->dim(rank_diff + i) == 1
MACE_CHECK(input0->dim(rank_diff + i) == input1->dim(i), || input1->dim(i) == 1
"Element-Wise op only support tail dimensions broadcast"); || input0->dim(rank_diff + i) == input1->dim(i),
} "Element-Wise op only support tail dimensions broadcast");
} }
} }
index_t common_size = input1->size();
index_t diff_size = input0->size() / common_size;
MACE_RETURN_IF_ERROR(output->ResizeLike(input0));
Tensor::MappingGuard input0_guard(input0); Tensor::MappingGuard input0_guard(input0);
Tensor::MappingGuard input1_guard(input1); Tensor::MappingGuard input1_guard(input1);
Tensor::MappingGuard output_guard(output);
const float *input0_ptr = input0->data<float>(); const float *input0_ptr = input0->data<float>();
const float *input1_ptr = input1->data<float>(); const float *input1_ptr = input1->data<float>();
float *output_ptr = output->mutable_data<float>();
if (data_format_ == NCHW && input1->dim_size() > 0 && if (data_format_ == NCHW && input1->dim_size() > 0 &&
input1->size() < input0->size()) { input1->size() < input0->size()) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input0));
Tensor::MappingGuard output_guard(output);
float *output_ptr = output->mutable_data<float>();
TensorEltwisePerChannel( TensorEltwisePerChannel(
type_, input0_ptr, input1_ptr, coeff_, input0->dim(0), type_, input0_ptr, input1_ptr, coeff_, input0->dim(0),
input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1), input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1),
input0->dim(2) * input0->dim(3), swapped, output_ptr); input0->dim(2) * input0->dim(3), swapped, output_ptr);
} else { } else {
if (input1->size() == input0->size()) { const std::vector<index_t> &input0_shape = input0->shape();
std::vector<index_t> input1_shape(rank_diff, 1);
input1_shape.insert(input1_shape.end(),
input1->shape().begin(),
input1->shape().end());
std::vector<index_t> output_shape(input0->dim_size(), 0);
for (unsigned int i = 0; i < input0_shape.size(); ++i) {
output_shape[i] = std::max(input0_shape[i], input1_shape[i]);
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard output_guard(output);
float *output_ptr = output->mutable_data<float>();
bool need_general_broadcast = false;
for (uint32_t i = 0; i < input1->dim_size(); ++i) {
if ((input0->dim(rank_diff + i) == 1 && input1->dim(i) > 1)
|| (input0->dim(rank_diff + i) > 1 && input1->dim(i) == 1)) {
need_general_broadcast = true;
break;
}
}
if (need_general_broadcast) {
TensorGeneralBroadcastEltwise(type_,
input0_ptr,
input1_ptr,
coeff_,
swapped,
input0_shape,
input1_shape,
output_shape,
output_ptr);
} else if (input1->size() == input0->size()) {
TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(), TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(),
swapped, output_ptr); swapped, output_ptr);
} else if (input1->size() < input0->size()) { } else if (input1->size() < input0->size()) {
if (input1->size() > 1) { if (input1->size() > 1) {
index_t common_size = input1->size();
index_t diff_size = input0->size() / common_size;
TensorBroadcastEltwise(type_, input0_ptr, input1_ptr, coeff_, TensorBroadcastEltwise(type_, input0_ptr, input1_ptr, coeff_,
diff_size, common_size, swapped, output_ptr); diff_size, common_size, swapped, output_ptr);
} else { } else {
......
...@@ -26,6 +26,9 @@ MaceStatus BiasAddFunctor<DeviceType::GPU, T>::operator()(const Tensor *input, ...@@ -26,6 +26,9 @@ MaceStatus BiasAddFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
const Tensor *bias, const Tensor *bias,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
MACE_CHECK(input->dim_size() == 4 && data_format_ == NHWC,
"gpu only support biasadd for 4-dimensional NHWC format tensor");
const index_t batch = input->dim(0); const index_t batch = input->dim(0);
const index_t height = input->dim(1); const index_t height = input->dim(1);
const index_t width = input->dim(2); const index_t width = input->dim(2);
......
...@@ -25,14 +25,14 @@ template <DeviceType D, class T> ...@@ -25,14 +25,14 @@ template <DeviceType D, class T>
class BiasAddOp : public Operator<D, T> { class BiasAddOp : public Operator<D, T> {
public: public:
BiasAddOp(const OperatorDef &operator_def, Workspace *ws) BiasAddOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), functor_() {} : Operator<D, T>(operator_def, ws),
functor_(static_cast<DataFormat>(OperatorBase::GetOptionalArg<int>(
"data_format", NHWC))) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
const Tensor *bias = this->Input(BIAS); const Tensor *bias = this->Input(BIAS);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size());
MACE_CHECK(bias->dim_size() == 1, "bias must be 1-dimensional. ", MACE_CHECK(bias->dim_size() == 1, "bias must be 1-dimensional. ",
bias->dim_size()); bias->dim_size());
......
...@@ -42,6 +42,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) { ...@@ -42,6 +42,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) {
OpDefBuilder("BiasAdd", "BiasAddBM") OpDefBuilder("BiasAdd", "BiasAddBM")
.Input("Input") .Input("Input")
.Input("Bias") .Input("Bias")
.AddIntArg("data_format", NCHW)
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
......
...@@ -37,6 +37,7 @@ void BiasAddSimple() { ...@@ -37,6 +37,7 @@ void BiasAddSimple() {
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Bias") .Input("Bias")
.AddIntArg("data_format", NCHW)
.Output("OutputNCHW") .Output("OutputNCHW")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -99,6 +100,7 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) { ...@@ -99,6 +100,7 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) {
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Bias") .Input("Bias")
.AddIntArg("data_format", NCHW)
.Output("OutputNCHW") .Output("OutputNCHW")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
...@@ -155,6 +157,7 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) { ...@@ -155,6 +157,7 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) {
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Bias") .Input("Bias")
.AddIntArg("data_format", NCHW)
.Output("OutputNCHW") .Output("OutputNCHW")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/cast.h"
namespace mace {
namespace ops {
void Register_Cast(OperatorRegistry *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Cast")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
CastOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Cast")
.Device(DeviceType::CPU)
.TypeConstraint<int32_t>("T")
.Build(),
CastOp<DeviceType::CPU, int32_t>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_CAST_H_
#define MACE_OPS_CAST_H_
#include <vector>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename DT>
class CastOp : public Operator<D, DT> {
public:
CastOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, DT>(op_def, ws) {}
MaceStatus Run(StatsFuture *future) override {
MACE_UNUSED(future);
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input))
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
auto src_dtype = input->dtype();
auto output_data = output->mutable_data<DT>();
#define MACE_CAST_COPY \
auto input_data = input->data<T>(); \
for (index_t i = 0; i < output->size(); ++i) { \
output_data[i] = static_cast<DT>(input_data[i]); \
}
MACE_RUN_WITH_TYPE_ENUM(src_dtype, MACE_CAST_COPY);
return MACE_SUCCESS;
}
private:
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_CAST_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class CastOpTest : public OpsTestBase {};
namespace {
template <typename SrcType, typename DstType>
void TestCast(const std::vector<index_t> &input_shape,
const std::vector<SrcType> &input) {
// Construct graph
OpsTestNet net;
OpDefBuilder("Cast", "CastTest")
.Input("Input")
.Output("Output")
.AddIntArg("T", DataTypeToEnum<DstType>::v())
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, SrcType>("Input", input_shape, input);
// Run
net.RunOp();
auto input_tensor = net.GetTensor("Input");
auto output_tensor = net.GetTensor("Output");
EXPECT_THAT(output_tensor->shape(), ::testing::ContainerEq(input_shape));
const int size = output_tensor->size();
for (int i = 0; i < size; ++i) {
Expector<SrcType, DstType, true>::Near(
*input_tensor, *output_tensor, 1e-5, 1.f);
}
}
} // namespace
TEST_F(CastOpTest, TestCastFromFloatToInt32) {
TestCast<float, int32_t>({1, 2, 3}, {1.1, 2.2, 3.3, 4.4, 5.5, 6.6});
}
TEST_F(CastOpTest, TestCastFromInt32ToFloat) {
TestCast<int32_t, float>({1, 2, 3}, {1, 2, 3, 4, 5, 6});
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -23,7 +23,11 @@ void Register_Concat(OperatorRegistry *op_registry) { ...@@ -23,7 +23,11 @@ void Register_Concat(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
ConcatOp<DeviceType::CPU, float>); ConcatOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat")
.Device(DeviceType::CPU)
.TypeConstraint<int32_t>("T")
.Build(),
ConcatOp<DeviceType::CPU, int32_t>);
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat") MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat")
.Device(DeviceType::GPU) .Device(DeviceType::GPU)
......
...@@ -135,6 +135,41 @@ void SimpleTensorEltwise(const kernels::EltwiseType type, ...@@ -135,6 +135,41 @@ void SimpleTensorEltwise(const kernels::EltwiseType type,
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
} }
template<DeviceType D, typename T>
void TensorGeneralBroadcastEltwise(const kernels::EltwiseType type,
const std::vector<index_t> &shape0,
const std::vector<float> &input0,
const std::vector<index_t> &shape1,
const std::vector<float> &input1,
const std::vector<index_t> &output_shape,
const std::vector<float> &output,
const std::vector<float> &coeff = {}) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input0", shape0, input0);
net.AddInputFromArray<D, float>("Input1", shape1, input1);
if (D == DeviceType::CPU) {
auto op_builder = OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input0")
.Input("Input1")
.AddIntArg("type", static_cast<int>(type))
.AddFloatsArg("coeff", coeff)
.Output("Output");
op_builder.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else {
MACE_NOT_IMPLEMENTED;
}
auto expected = CreateTensor<float>(output_shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
} // namespace } // namespace
TEST_F(EltwiseOpTest, CPUSimpleTensorScalar) { TEST_F(EltwiseOpTest, CPUSimpleTensorScalar) {
...@@ -559,6 +594,30 @@ TEST_F(EltwiseOpTest, RandomTensorTensorHalf) { ...@@ -559,6 +594,30 @@ TEST_F(EltwiseOpTest, RandomTensorTensorHalf) {
{3, 31, 37, 17}); {3, 31, 37, 17});
} }
TEST_F(EltwiseOpTest, TensorGeneralBroadcast) {
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>(
kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1},
{1, 2}, {1, 1, 2, 3}, {2, 3, 4, 6, 7, 8});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>(
kernels::EltwiseType::SUB, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1},
{1, 2}, {1, 1, 2, 3}, {0, 1, 2, 2, 3, 4});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>(
kernels::EltwiseType::PROD, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6},
{1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {1, 2, 3, 8, 10, 12});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>(
kernels::EltwiseType::DIV, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1},
{1, 2}, {1, 1, 2, 3}, {1, 2, 3, 2, 2.5, 3});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>(
kernels::EltwiseType::MIN, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1},
{1, 2}, {1, 1, 2, 3}, {1, 1, 1, 2, 2, 2});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>(
kernels::EltwiseType::MAX, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1},
{1, 2}, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>(
kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6},
{1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {0, 1, 4, 4, 9, 16});
}
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -23,6 +23,12 @@ void Register_Reshape(OperatorRegistry *op_registry) { ...@@ -23,6 +23,12 @@ void Register_Reshape(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
ReshapeOp<DeviceType::CPU, float>); ReshapeOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape")
.Device(DeviceType::CPU)
.TypeConstraint<int32_t>("T")
.Build(),
ReshapeOp<DeviceType::CPU, int32_t>);
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape") MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape")
......
...@@ -23,6 +23,11 @@ void Register_Stack(OperatorRegistry *op_registry) { ...@@ -23,6 +23,11 @@ void Register_Stack(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
StackOp<DeviceType::CPU, float>); StackOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Stack")
.Device(DeviceType::CPU)
.TypeConstraint<int32_t>("T")
.Build(),
StackOp<DeviceType::CPU, int32_t>);
} }
} // namespace ops } // namespace ops
......
...@@ -23,6 +23,11 @@ void Register_StridedSlice(OperatorRegistry *op_registry) { ...@@ -23,6 +23,11 @@ void Register_StridedSlice(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
StridedSliceOp<DeviceType::CPU, float>); StridedSliceOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("StridedSlice")
.Device(DeviceType::CPU)
.TypeConstraint<int32_t>("T")
.Build(),
StridedSliceOp<DeviceType::CPU, int32_t>);
} }
} // namespace ops } // namespace ops
......
...@@ -108,7 +108,10 @@ def main(unused_args): ...@@ -108,7 +108,10 @@ def main(unused_args):
print("%s does not support dsp runtime yet." % FLAGS.platform) print("%s does not support dsp runtime yet." % FLAGS.platform)
sys.exit(-1) sys.exit(-1)
else: else:
option = cvt.ConverterOption() if FLAGS.transformers:
option = cvt.ConverterOption(FLAGS.transformers.split(','))
else:
option = cvt.ConverterOption()
option.winograd_enabled = bool(FLAGS.winograd) option.winograd_enabled = bool(FLAGS.winograd)
input_node_names = FLAGS.input_node.split(',') input_node_names = FLAGS.input_node.split(',')
...@@ -285,6 +288,11 @@ def parse_args(): ...@@ -285,6 +288,11 @@ def parse_args():
type=str, type=str,
default="fp16_fp32", default="fp16_fp32",
help="fp16_fp32/fp32_fp32") help="fp16_fp32/fp32_fp32")
parser.add_argument(
"--transformers",
type=str,
default="",
help="model transformers")
return parser.parse_known_args() return parser.parse_known_args()
......
...@@ -74,6 +74,7 @@ MaceSupportedOps = [ ...@@ -74,6 +74,7 @@ MaceSupportedOps = [
'BatchNorm', 'BatchNorm',
'BatchToSpaceND', 'BatchToSpaceND',
'BiasAdd', 'BiasAdd',
'Cast',
'ChannelShuffle', 'ChannelShuffle',
'Concat', 'Concat',
'Conv2D', 'Conv2D',
...@@ -177,9 +178,10 @@ class TransformerRule(Enum): ...@@ -177,9 +178,10 @@ class TransformerRule(Enum):
TRANSPOSE_DATA_FORMAT = 15 TRANSPOSE_DATA_FORMAT = 15
TRANSFORM_GLOBAL_CONV_TO_FC = 16 TRANSFORM_GLOBAL_CONV_TO_FC = 16
TRANSFORM_BUFFER_IMAGE = 17 TRANSFORM_BUFFER_IMAGE = 17
ADD_DEVICE_AND_DATA_TYPE = 18 ADD_DEVICE = 18
SORT_BY_EXECUTION = 19 SORT_BY_EXECUTION = 19
ADD_IN_OUT_TENSOR_INFO = 20 ADD_IN_OUT_TENSOR_INFO = 20
ADD_MACE_INPUT_AND_OUTPUT_NODES = 21
class ConverterInterface(object): class ConverterInterface(object):
...@@ -219,34 +221,39 @@ class NodeInfo(object): ...@@ -219,34 +221,39 @@ class NodeInfo(object):
class ConverterOption(object): class ConverterOption(object):
"""A class for specifying options passed to converter tool""" """A class for specifying options passed to converter tool"""
def __init__(self): def __init__(self, transformers=None):
self._input_nodes = {} self._input_nodes = {}
self._output_nodes = {} self._output_nodes = {}
self._data_type = mace_pb2.DT_FLOAT self._data_type = mace_pb2.DT_FLOAT
self._device = DeviceType.CPU.value self._device = DeviceType.CPU.value
self._winograd_enabled = False self._winograd_enabled = False
self._transformer_option = [ if transformers:
TransformerRule.REMOVE_IDENTITY_OP, self._transformer_option = [TransformerRule[transformer]
TransformerRule.TRANSFORM_GLOBAL_POOLING, for transformer in transformers]
TransformerRule.FOLD_RESHAPE, else:
TransformerRule.TRANSFORM_MATMUL_TO_FC, self._transformer_option = [
TransformerRule.FOLD_BATCHNORM, TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.FOLD_CONV_AND_BN, TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN, TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_GPU_WINOGRAD, TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.TRANSFORM_ADD_TO_BIASADD, TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_BIASADD, TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FLATTEN_ATROUS_CONV, TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
TransformerRule.FOLD_ACTIVATION, TransformerRule.TRANSFORM_GPU_WINOGRAD,
TransformerRule.TRANSPOSE_FILTERS, TransformerRule.TRANSFORM_ADD_TO_BIASADD,
TransformerRule.TRANSPOSE_DATA_FORMAT, TransformerRule.FOLD_BIASADD,
TransformerRule.ADD_IN_OUT_TENSOR_INFO, TransformerRule.FLATTEN_ATROUS_CONV,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, TransformerRule.FOLD_ACTIVATION,
TransformerRule.RESHAPE_FC_WEIGHT, TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSFORM_BUFFER_IMAGE, TransformerRule.TRANSPOSE_DATA_FORMAT,
TransformerRule.ADD_DEVICE_AND_DATA_TYPE, TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.SORT_BY_EXECUTION, TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
] TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.TRANSFORM_BUFFER_IMAGE,
TransformerRule.ADD_DEVICE,
TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES,
TransformerRule.SORT_BY_EXECUTION,
]
@property @property
def input_nodes(self): def input_nodes(self):
......
...@@ -341,6 +341,10 @@ class CaffeConverter(base_converter.ConverterInterface): ...@@ -341,6 +341,10 @@ class CaffeConverter(base_converter.ConverterInterface):
op.input.extend(caffe_op.layer.bottom) op.input.extend(caffe_op.layer.bottom)
op.output.extend(caffe_op.layer.top) op.output.extend(caffe_op.layer.top)
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
return op return op
......
...@@ -92,6 +92,7 @@ TFSupportedOps = [ ...@@ -92,6 +92,7 @@ TFSupportedOps = [
'Slice', 'Slice',
'Stack', 'Stack',
'Pack', 'Pack',
'Cast',
] ]
TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
...@@ -181,7 +182,8 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -181,7 +182,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.StridedSlice.name: self.convert_stridedslice, TFOpType.StridedSlice.name: self.convert_stridedslice,
TFOpType.Slice.name: self.convert_slice, TFOpType.Slice.name: self.convert_slice,
TFOpType.Pack.name: self.convert_stack, TFOpType.Pack.name: self.convert_stack,
TFOpType.Stack.name: self.convert_stack TFOpType.Stack.name: self.convert_stack,
TFOpType.Cast.name: self.convert_cast
} }
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
...@@ -300,6 +302,19 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -300,6 +302,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
output_shape = op.output_shape.add() output_shape = op.output_shape.add()
output_shape.dims.extend(self.infer_tensor_shape(tf_output)) output_shape.dims.extend(self.infer_tensor_shape(tf_output))
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
try:
dtype = tf_op.get_attr('T')
if dtype == tf.int32:
data_type_arg.i = mace_pb2.DT_INT32
elif dtype == tf.float32:
data_type_arg.i = self._option.data_type
else:
mace_check(False, "data type %s not supported" % dtype)
except ValueError:
data_type_arg.i = self._option.data_type
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC) ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
return op return op
...@@ -367,7 +382,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -367,7 +382,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
value_arg.f = tf_op.inputs[0].eval().astype(np.float32) value_arg.f = tf_op.inputs[0].eval().astype(np.float32)
self._skip_tensor.add(tf_op.inputs[0].name) self._skip_tensor.add(tf_op.inputs[0].name)
del op.input[0] del op.input[0]
elif len(tf_op.inputs[1].shape) == 0: elif len(tf_op.inputs) > 1 and len(tf_op.inputs[1].shape) == 0:
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_value_str value_arg.name = MaceKeyword.mace_value_str
value_arg.f = tf_op.inputs[1].eval().astype(np.float32) value_arg.f = tf_op.inputs[1].eval().astype(np.float32)
...@@ -655,6 +670,9 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -655,6 +670,9 @@ class TensorflowConverter(base_converter.ConverterInterface):
def convert_slice(self, tf_op): def convert_slice(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.StridedSlice.name op.type = MaceOp.StridedSlice.name
arg = op.arg.add()
arg.name = 'slice'
arg.i = 1
def convert_stack(self, tf_op): def convert_stack(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
...@@ -666,3 +684,19 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -666,3 +684,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
axis_arg.i = tf_op.get_attr(MaceKeyword.mace_axis_str) axis_arg.i = tf_op.get_attr(MaceKeyword.mace_axis_str)
except ValueError: except ValueError:
axis_arg.i = 0 axis_arg.i = 0
def convert_cast(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Cast.name
data_type_arg = ConverterUtil.get_arg(op, 'T')
try:
dtype = tf_op.get_attr('DstT')
if dtype == tf.int32:
data_type_arg.i = mace_pb2.DT_INT32
elif dtype == tf.float32:
data_type_arg.i = self._option.data_type
else:
mace_check(False, "data type %s not supported" % dtype)
except ValueError:
data_type_arg.i = self._option.data_type
...@@ -53,30 +53,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -53,30 +53,6 @@ class Transformer(base_converter.ConverterInterface):
""" """
def __init__(self, option, model): def __init__(self, option, model):
# DO NOT reorder the following transformers' order
self._registered_transformers_order = [
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
TransformerRule.TRANSFORM_GPU_WINOGRAD,
TransformerRule.TRANSFORM_ADD_TO_BIASADD,
TransformerRule.FOLD_BIASADD,
TransformerRule.FOLD_ACTIVATION,
TransformerRule.FLATTEN_ATROUS_CONV,
TransformerRule.FOLD_ACTIVATION,
TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSPOSE_DATA_FORMAT,
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.TRANSFORM_BUFFER_IMAGE,
TransformerRule.ADD_DEVICE_AND_DATA_TYPE,
TransformerRule.SORT_BY_EXECUTION,
]
self._registered_transformers = { self._registered_transformers = {
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op, TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
TransformerRule.TRANSFORM_GLOBAL_POOLING: TransformerRule.TRANSFORM_GLOBAL_POOLING:
...@@ -105,8 +81,10 @@ class Transformer(base_converter.ConverterInterface): ...@@ -105,8 +81,10 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight, TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight,
TransformerRule.TRANSFORM_BUFFER_IMAGE: TransformerRule.TRANSFORM_BUFFER_IMAGE:
self.transform_buffer_image, self.transform_buffer_image,
TransformerRule.ADD_DEVICE_AND_DATA_TYPE: TransformerRule.ADD_DEVICE:
self.add_device_and_data_type, self.add_device,
TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES:
self.add_mace_input_and_output_nodes,
TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution, TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution,
} }
...@@ -119,18 +97,18 @@ class Transformer(base_converter.ConverterInterface): ...@@ -119,18 +97,18 @@ class Transformer(base_converter.ConverterInterface):
self._consumers = {} self._consumers = {}
self._producer = {} self._producer = {}
self._target_data_format = DataFormat.NHWC self._target_data_format = DataFormat.NHWC
self._input_output_added = False
if self._option.device == DeviceType.CPU.value: if self._option.device == DeviceType.CPU.value:
self._target_data_format = DataFormat.NCHW self._target_data_format = DataFormat.NCHW
def run(self): def run(self):
for key in self._registered_transformers_order: for key in self._option.transformer_option:
if key in self._option.transformer_option: transformer = self._registered_transformers[key]
transformer = self._registered_transformers[key] while True:
while True: self.construct_ops_and_consumers()
self.construct_ops_and_consumers() changed = transformer()
changed = transformer() if not changed:
if not changed:
break break
return self._model return self._model
...@@ -900,6 +878,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -900,6 +878,8 @@ class Transformer(base_converter.ConverterInterface):
else: else:
op.type = MaceOp.Identity.name op.type = MaceOp.Identity.name
self._input_output_added = True
return False return False
def transpose_filters(self): def transpose_filters(self):
...@@ -1060,6 +1040,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1060,6 +1040,8 @@ class Transformer(base_converter.ConverterInterface):
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC) ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
self._input_output_added = True
return False return False
def fold_reshape(self): def fold_reshape(self):
...@@ -1164,16 +1146,13 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1164,16 +1146,13 @@ class Transformer(base_converter.ConverterInterface):
in_channels * filter_width in_channels * filter_width
* filter_height][:] * filter_height][:]
def add_device_and_data_type(self): def add_device(self):
# TODO(liuqi) add device definition in OperatorDef # TODO(liuqi) add device definition in OperatorDef
net = self._model net = self._model
for op in net.op: for op in net.op:
arg = op.arg.add() arg = op.arg.add()
arg.name = MaceKeyword.mace_device arg.name = MaceKeyword.mace_device
arg.i = self._option.device arg.i = self._option.device
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
return False return False
...@@ -1188,6 +1167,37 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1188,6 +1167,37 @@ class Transformer(base_converter.ConverterInterface):
self.sort_dfs(producer_op, visited, sorted_nodes) self.sort_dfs(producer_op, visited, sorted_nodes)
sorted_nodes.append(op) sorted_nodes.append(op)
def add_mace_input_and_output_nodes(self):
if self._input_output_added:
return
print("add mace input and output nodes")
for input_node in self._option.input_nodes.values():
new_input_name = MaceKeyword.mace_input_node_name \
+ '_' + input_node.name
op_def = self._model.op.add()
op_def.name = self.normalize_op_name(input_node.name)
op_def.type = MaceOp.Identity.name
op_def.input.extend([new_input_name])
op_def.output.extend([input_node.name])
output_shape = op_def.output_shape.add()
output_shape.dims.extend(input_node.shape)
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
for output_node in self._option.output_nodes.values():
output_name = MaceKeyword.mace_output_node_name \
+ '_' + output_node.name
op_def = self._model.op.add()
op_def.name = self.normalize_op_name(output_name)
op_def.type = MaceOp.Identity.name
op_def.input.extend([output_node.name])
op_def.output.extend([output_name])
output_shape = op_def.output_shape.add()
output_shape.dims.extend(
self._producer[output_node.name].output_shape[0].dims)
def sort_by_execution(self): def sort_by_execution(self):
print("Sort by execution") print("Sort by execution")
net = self._model net = self._model
......
...@@ -24,7 +24,8 @@ class MemoryOptimizer(object): ...@@ -24,7 +24,8 @@ class MemoryOptimizer(object):
self.op_mem = {} # op_name->mem_id self.op_mem = {} # op_name->mem_id
self.mem_block = {} # mem_id->[size] or mem_id->[x, y] self.mem_block = {} # mem_id->[size] or mem_id->[x, y]
self.total_mem_count = 0 self.total_mem_count = 0
self.ref_counter = {} self.input_ref_counter = {}
self.mem_ref_counter = {}
consumers = {} consumers = {}
for op in net_def.op: for op in net_def.op:
...@@ -41,9 +42,10 @@ class MemoryOptimizer(object): ...@@ -41,9 +42,10 @@ class MemoryOptimizer(object):
for output in op.output: for output in op.output:
tensor_name = output tensor_name = output
if tensor_name in consumers: if tensor_name in consumers:
self.ref_counter[tensor_name] = len(consumers[tensor_name]) self.input_ref_counter[tensor_name] = \
len(consumers[tensor_name])
else: else:
self.ref_counter[tensor_name] = 0 self.input_ref_counter[tensor_name] = 0
def op_need_optimize_memory(self, op): def op_need_optimize_memory(self, op):
return True return True
...@@ -93,8 +95,8 @@ class MemoryOptimizer(object): ...@@ -93,8 +95,8 @@ class MemoryOptimizer(object):
if not self.op_need_optimize_memory(op): if not self.op_need_optimize_memory(op):
continue continue
if not op.output_shape: if not op.output_shape:
print('WARNING: There is no output shape information to ' print("WARNING: There is no output shape information to "
'do memory optimization.') "do memory optimization. %s (%s)" % (op.name, op.type))
return return
if len(op.output_shape) != len(op.output): if len(op.output_shape) != len(op.output):
print('WARNING: the number of output shape is not equal to ' print('WARNING: the number of output shape is not equal to '
...@@ -146,16 +148,23 @@ class MemoryOptimizer(object): ...@@ -146,16 +148,23 @@ class MemoryOptimizer(object):
if mem_id != -1: if mem_id != -1:
op.mem_id.extend([mem_id]) op.mem_id.extend([mem_id])
self.op_mem[op.output[i]] = mem_id self.op_mem[op.output[i]] = mem_id
if mem_id not in self.mem_ref_counter:
self.mem_ref_counter[mem_id] = 1
else:
self.mem_ref_counter[mem_id] += 1
# de-ref input tensor mem # de-ref input tensor mem
for idx in xrange(len(op.input)): for idx in xrange(len(op.input)):
ipt = op.input[idx] ipt = op.input[idx]
if ipt in self.ref_counter: if ipt in self.input_ref_counter:
self.ref_counter[ipt] -= 1 self.input_ref_counter[ipt] -= 1
if self.ref_counter[ipt] == 0 and \ if self.input_ref_counter[ipt] == 0 \
(idx > 0 or not self.is_memory_reuse_op(op)): and ipt in self.op_mem:
self.idle_mem.add(self.op_mem[ipt]) mem_id = self.op_mem[ipt]
elif self.ref_counter[ipt] < 0: self.mem_ref_counter[mem_id] -= 1
if self.mem_ref_counter[mem_id] == 0:
self.idle_mem.add(self.op_mem[ipt])
elif self.input_ref_counter[ipt] < 0:
raise Exception('ref count is less than 0') raise Exception('ref count is less than 0')
self.add_net_mem_blocks() self.add_net_mem_blocks()
......
...@@ -54,10 +54,9 @@ std::string MakeString(const Args &... args) { ...@@ -54,10 +54,9 @@ std::string MakeString(const Args &... args) {
} }
template <typename T> template <typename T>
std::string MakeString(const std::vector<T> &args) { std::string MakeListString(const T *args, size_t size) {
std::stringstream ss; std::stringstream ss;
ss << "["; ss << "[";
const size_t size = args.size();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
ss << args[i]; ss << args[i];
if (i < size - 1) { if (i < size - 1) {
...@@ -68,6 +67,11 @@ std::string MakeString(const std::vector<T> &args) { ...@@ -68,6 +67,11 @@ std::string MakeString(const std::vector<T> &args) {
return ss.str(); return ss.str();
} }
template <typename T>
std::string MakeString(const std::vector<T> &args) {
return MakeListString(args.data(), args.size());
}
// Specializations for already-a-string types. // Specializations for already-a-string types.
template <> template <>
inline std::string MakeString(const std::string &str) { inline std::string MakeString(const std::string &str) {
......
...@@ -136,6 +136,7 @@ class YAMLKeyword(object): ...@@ -136,6 +136,7 @@ class YAMLKeyword(object):
subgraphs = 'subgraphs' subgraphs = 'subgraphs'
input_tensors = 'input_tensors' input_tensors = 'input_tensors'
input_shapes = 'input_shapes' input_shapes = 'input_shapes'
input_ranges = 'input_ranges'
output_tensors = 'output_tensors' output_tensors = 'output_tensors'
output_shapes = 'output_shapes' output_shapes = 'output_shapes'
runtime = 'runtime' runtime = 'runtime'
...@@ -145,6 +146,7 @@ class YAMLKeyword(object): ...@@ -145,6 +146,7 @@ class YAMLKeyword(object):
obfuscate = 'obfuscate' obfuscate = 'obfuscate'
winograd = 'winograd' winograd = 'winograd'
validation_inputs_data = 'validation_inputs_data' validation_inputs_data = 'validation_inputs_data'
transformers = 'transformers' # keep it private for now
class ModuleName(object): class ModuleName(object):
...@@ -640,7 +642,8 @@ def convert_model(configs): ...@@ -640,7 +642,8 @@ def convert_model(configs):
model_config[YAMLKeyword.winograd], model_config[YAMLKeyword.winograd],
model_config[YAMLKeyword.obfuscate], model_config[YAMLKeyword.obfuscate],
configs[YAMLKeyword.build_type], configs[YAMLKeyword.build_type],
data_type) data_type,
",".join(model_config.get(YAMLKeyword.transformers, [])))
if configs[YAMLKeyword.build_type] == BuildType.proto: if configs[YAMLKeyword.build_type] == BuildType.proto:
sh.mv("-f", sh.mv("-f",
...@@ -732,7 +735,8 @@ def build_specific_lib(target_abi, target_soc, serial_num, ...@@ -732,7 +735,8 @@ def build_specific_lib(target_abi, target_soc, serial_num,
model_output_dir, model_output_dir,
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data]) subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0].get(YAMLKeyword.input_ranges, None))
device_type = parse_device_type(RuntimeType.gpu) device_type = parse_device_type(RuntimeType.gpu)
sh_commands.tuning_run( sh_commands.tuning_run(
...@@ -975,7 +979,8 @@ def run_specific_target(flags, configs, target_abi, ...@@ -975,7 +979,8 @@ def run_specific_target(flags, configs, target_abi,
model_output_dir, model_output_dir,
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data]) subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0].get(YAMLKeyword.input_ranges, None))
runtime_list = [] runtime_list = []
if target_abi == ABIType.host: if target_abi == ABIType.host:
runtime_list.extend([RuntimeType.cpu]) runtime_list.extend([RuntimeType.cpu])
...@@ -1123,7 +1128,8 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num): ...@@ -1123,7 +1128,8 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num):
model_output_dir, model_output_dir,
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data]) subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0].get(YAMLKeyword.input_ranges, None))
runtime_list = [] runtime_list = []
if target_abi == ABIType.host: if target_abi == ABIType.host:
runtime_list.extend([RuntimeType.cpu]) runtime_list.extend([RuntimeType.cpu])
......
...@@ -23,25 +23,34 @@ import common ...@@ -23,25 +23,34 @@ import common
# python generate_data.py \ # python generate_data.py \
# --input_node input_node \ # --input_node input_node \
# --input_shape 1,64,64,3 \ # --input_shape 1,64,64,3 \
# --input_file input_file # --input_file input_file \
# # --input_ranges -1,1
def generate_data(name, shape, input_file): def generate_data(name, shape, input_file, tensor_range):
np.random.seed() np.random.seed()
data = np.random.random(shape) * 2 - 1 data = np.random.random(shape) * (tensor_range[1] - tensor_range[0]) \
+ tensor_range[0]
input_file_name = common.formatted_file_name(input_file, name) input_file_name = common.formatted_file_name(input_file, name)
print 'Generate input file: ', input_file_name print 'Generate input file: ', input_file_name
data.astype(np.float32).tofile(input_file_name) data.astype(np.float32).tofile(input_file_name)
def generate_input_data(input_file, input_node, input_shape): def generate_input_data(input_file, input_node, input_shape, input_ranges):
input_names = [name for name in input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shapes = [shape for shape in input_shape.split(':')] input_shapes = [shape for shape in input_shape.split(':')]
if input_ranges:
input_ranges = [r for r in input_ranges.split(':')]
else:
input_ranges = None
assert len(input_names) == len(input_shapes) assert len(input_names) == len(input_shapes)
for i in range(len(input_names)): for i in range(len(input_names)):
shape = [int(x) for x in input_shapes[i].split(',')] shape = [int(x) for x in input_shapes[i].split(',')]
generate_data(input_names[i], shape, input_file) if input_ranges:
input_range = [float(x) for x in input_ranges[i].split(',')]
else:
input_range = [-1, 1]
generate_data(input_names[i], shape, input_file, input_range)
print "Generate input file done." print "Generate input file done."
...@@ -55,10 +64,13 @@ def parse_args(): ...@@ -55,10 +64,13 @@ def parse_args():
"--input_node", type=str, default="input_node", help="input node") "--input_node", type=str, default="input_node", help="input node")
parser.add_argument( parser.add_argument(
"--input_shape", type=str, default="1,64,64,3", help="input shape.") "--input_shape", type=str, default="1,64,64,3", help="input shape.")
parser.add_argument(
"--input_ranges", type=str, default="-1,1", help="input range.")
return parser.parse_known_args() return parser.parse_known_args()
if __name__ == '__main__': if __name__ == '__main__':
FLAGS, unparsed = parse_args() FLAGS, unparsed = parse_args()
generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape) generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape,
FLAGS.input_ranges)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册