提交 8911cffb 编写于 作者: L liutuo

add splice affine pnorm sumgroup for kaldi-model

support kaldi-onnx model convert
support validate output data from url
上级 f2f05c0d
...@@ -258,11 +258,12 @@ MaceStatus Gemv::Compute(const OpContext *context, ...@@ -258,11 +258,12 @@ MaceStatus Gemv::Compute(const OpContext *context,
++rhs_ptr; ++rhs_ptr;
} }
float32x4_t vbias = vdupq_n_f32(0);
if (bias) { if (bias) {
float32x4_t vbias = vdupq_n_f32(0);
vbias = vld1q_f32(bias_data + h_start); vbias = vld1q_f32(bias_data + h_start);
vo = vaddq_f32(vo, vbias);
} }
vo = vaddq_f32(vo, vbias);
vst1q_f32(ret_ptr, vo); vst1q_f32(ret_ptr, vo);
} else { // h_block_len < 4 } else { // h_block_len < 4
#endif // MACE_GEMV_UNROLL #endif // MACE_GEMV_UNROLL
......
...@@ -376,11 +376,11 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context, ...@@ -376,11 +376,11 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
++rhs_ptr; ++rhs_ptr;
} }
int32x4_t vbias = vdupq_n_s32(0);
if (bias) { if (bias) {
int32x4_t vbias = vdupq_n_s32(0);
vbias = vld1q_s32(bias_data + h_offset); vbias = vld1q_s32(bias_data + h_offset);
vo = vaddq_s32(vo, vbias);
} }
vo = vaddq_s32(vo, vbias);
if (is_output_type_uint8) { if (is_output_type_uint8) {
int32x4_t vo_mul = vqrdmulhq_s32(vo, voutput_multiplier); int32x4_t vo_mul = vqrdmulhq_s32(vo, voutput_multiplier);
......
...@@ -25,6 +25,7 @@ namespace ops { ...@@ -25,6 +25,7 @@ namespace ops {
template <DeviceType D, class T> template <DeviceType D, class T>
class LSTMCellOp; class LSTMCellOp;
#ifdef MACE_ENABLE_OPENCL
template <typename T> template <typename T>
class LSTMCellOp<DeviceType::GPU, T> : public Operation { class LSTMCellOp<DeviceType::GPU, T> : public Operation {
public: public:
...@@ -88,6 +89,7 @@ class LSTMCellOp<DeviceType::GPU, T> : public Operation { ...@@ -88,6 +89,7 @@ class LSTMCellOp<DeviceType::GPU, T> : public Operation {
MACE_OP_INPUT_TAGS(INPUT, PRE_OUTPUT, WEIGHT, BIAS, PRE_CELL); MACE_OP_INPUT_TAGS(INPUT, PRE_OUTPUT, WEIGHT, BIAS, PRE_CELL);
MACE_OP_OUTPUT_TAGS(CELL, OUTPUT); MACE_OP_OUTPUT_TAGS(CELL, OUTPUT);
}; };
#endif
void RegisterLSTMCell(OpRegistryBase *op_registry) { void RegisterLSTMCell(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "LSTMCell", LSTMCellOp, MACE_REGISTER_OP(op_registry, "LSTMCell", LSTMCellOp,
......
...@@ -43,6 +43,7 @@ extern void RegisterInferConv2dShape(OpRegistryBase *op_registry); ...@@ -43,6 +43,7 @@ extern void RegisterInferConv2dShape(OpRegistryBase *op_registry);
extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry); extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry);
extern void RegisterMatMul(OpRegistryBase *op_registry); extern void RegisterMatMul(OpRegistryBase *op_registry);
extern void RegisterPad(OpRegistryBase *op_registry); extern void RegisterPad(OpRegistryBase *op_registry);
extern void RegisterPNorm(OpRegistryBase *op_registry);
extern void RegisterPooling(OpRegistryBase *op_registry); extern void RegisterPooling(OpRegistryBase *op_registry);
extern void RegisterReduce(OpRegistryBase *op_registry); extern void RegisterReduce(OpRegistryBase *op_registry);
extern void RegisterPriorBox(OpRegistryBase *op_registry); extern void RegisterPriorBox(OpRegistryBase *op_registry);
...@@ -53,14 +54,19 @@ extern void RegisterResizeNearestNeighbor(OpRegistryBase *op_registry); ...@@ -53,14 +54,19 @@ extern void RegisterResizeNearestNeighbor(OpRegistryBase *op_registry);
extern void RegisterReverse(OpRegistryBase *op_registry); extern void RegisterReverse(OpRegistryBase *op_registry);
extern void RegisterScalarMath(OpRegistryBase *op_registry); extern void RegisterScalarMath(OpRegistryBase *op_registry);
extern void RegisterShape(OpRegistryBase *op_registry); extern void RegisterShape(OpRegistryBase *op_registry);
extern void RegisterSlice(OpRegistryBase *op_registry);
extern void RegisterSoftmax(OpRegistryBase *op_registry); extern void RegisterSoftmax(OpRegistryBase *op_registry);
extern void RegisterSpaceToBatchND(OpRegistryBase *op_registry); extern void RegisterSpaceToBatchND(OpRegistryBase *op_registry);
extern void RegisterSpaceToDepth(OpRegistryBase *op_registry); extern void RegisterSpaceToDepth(OpRegistryBase *op_registry);
extern void RegisterSplice(OpRegistryBase *op_registry);
extern void RegisterSplit(OpRegistryBase *op_registry); extern void RegisterSplit(OpRegistryBase *op_registry);
extern void RegisterSqrDiffMean(OpRegistryBase *op_registry); extern void RegisterSqrDiffMean(OpRegistryBase *op_registry);
extern void RegisterSqueeze(OpRegistryBase *op_registry); extern void RegisterSqueeze(OpRegistryBase *op_registry);
extern void RegisterStack(OpRegistryBase *op_registry); extern void RegisterStack(OpRegistryBase *op_registry);
extern void RegisterStridedSlice(OpRegistryBase *op_registry); extern void RegisterStridedSlice(OpRegistryBase *op_registry);
extern void RegisterSumGroup(OpRegistryBase *op_registry);
extern void RegisterTargetRMSNorm(OpRegistryBase *op_registry);
extern void RegisterTimeOffset(OpRegistryBase *op_registry);
extern void RegisterTranspose(OpRegistryBase *op_registry); extern void RegisterTranspose(OpRegistryBase *op_registry);
extern void RegisterUnstack(OpRegistryBase *op_registry); extern void RegisterUnstack(OpRegistryBase *op_registry);
...@@ -103,6 +109,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ...@@ -103,6 +109,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterLocalResponseNorm(this); ops::RegisterLocalResponseNorm(this);
ops::RegisterMatMul(this); ops::RegisterMatMul(this);
ops::RegisterPad(this); ops::RegisterPad(this);
ops::RegisterPNorm(this);
ops::RegisterPooling(this); ops::RegisterPooling(this);
ops::RegisterReduce(this); ops::RegisterReduce(this);
ops::RegisterPriorBox(this); ops::RegisterPriorBox(this);
...@@ -113,14 +120,19 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ...@@ -113,14 +120,19 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterReverse(this); ops::RegisterReverse(this);
ops::RegisterScalarMath(this); ops::RegisterScalarMath(this);
ops::RegisterShape(this); ops::RegisterShape(this);
ops::RegisterSlice(this);
ops::RegisterSoftmax(this); ops::RegisterSoftmax(this);
ops::RegisterSpaceToBatchND(this); ops::RegisterSpaceToBatchND(this);
ops::RegisterSpaceToDepth(this); ops::RegisterSpaceToDepth(this);
ops::RegisterSplice(this);
ops::RegisterSplit(this); ops::RegisterSplit(this);
ops::RegisterStack(this); ops::RegisterStack(this);
ops::RegisterStridedSlice(this); ops::RegisterStridedSlice(this);
ops::RegisterSqrDiffMean(this); ops::RegisterSqrDiffMean(this);
ops::RegisterSqueeze(this); ops::RegisterSqueeze(this);
ops::RegisterSumGroup(this);
ops::RegisterTargetRMSNorm(this);
ops::RegisterTimeOffset(this);
ops::RegisterTranspose(this); ops::RegisterTranspose(this);
ops::RegisterUnstack(this); ops::RegisterUnstack(this);
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for PNormComponent in Kaldi.
// The input-dim must be dividable by output-dim.
// The output will be divided to output-dim group,
// so input-dim should be dividable by output-dim.
// For each row:
// p is 0: output[i] = sum(abs(input[i*group + j]) > 0)
// p is 1: output[i] = sum(abs(input[i*group + j]))
// p is 2: output[i] = sqrt(sum(input[i * group + j] * input[i * group + j])),
// for j = (0 : group - 1)
// p's default value is 2.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class PNormOp;
template <typename T>
class PNormOp<DeviceType::CPU, T> : public Operation {
public:
explicit PNormOp(OpConstructContext *context)
: Operation(context),
p_(Operation::GetOptionalArg<int>("p", 2)),
output_dim_(Operation::GetOptionalArg<int>("output_dim", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
const std::vector<index_t> &input_shape = input->shape();
const index_t dim_size = input_shape.size();
MACE_CHECK(dim_size >= 1, "PNorm only supports input dim size >= 1");
std::vector<index_t> output_shape(input_shape);
const index_t input_dim = input_shape[dim_size -1];
MACE_CHECK(output_dim_ > 0,
"Output dim should be greater than zero.");
MACE_CHECK(input_dim % output_dim_ == 0 && output_dim_ < input_dim,
"PNorm's input dim should be a multiple of output dim.");
const index_t group_size = input_dim / output_dim_;
output_shape[dim_size - 1] = output_dim_;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
const index_t bh =
std::accumulate(input->shape().begin(), input->shape().end() - 1, 1,
std::multiplies<index_t>());
if (p_ == 0) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < bh; ++i) {
for (index_t j = 0; j < output_dim_; ++j) {
const T *in_base = input_data + i * input_dim + j * group_size;
T *out_base = output_data + i * output_dim_;
T temp_result = 0;
for (index_t g = 0; g < group_size; ++g) {
T value =
(std::fabs(in_base[g])
> std::numeric_limits<float>::epsilon()) ? 1.0f : 0.0f;
temp_result += value;
}
out_base[j] = temp_result;
}
}
} else if (p_ == 1) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < bh; ++i) {
for (index_t j = 0; j < output_dim_; ++j) {
const T *in_base = input_data + i * input_dim + j * group_size;
T *out_base = output_data + i * output_dim_;
T temp_result = 0;
for (index_t g = 0; g < group_size; ++g) {
temp_result += std::abs(in_base[g]);;
}
out_base[j] = temp_result;
}
}
} else if (p_ == 2) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < bh; ++i) {
for (index_t j = 0; j < output_dim_; ++j) {
const T *in_base = input_data + i * input_dim + j * group_size;
T *out_base = output_data + i * output_dim_;
T temp_result = 0;
for (index_t g = 0; g < group_size; ++g) {
temp_result += in_base[g] * in_base[g];
}
out_base[j] = std::sqrt(temp_result);
}
}
} else {
LOG(FATAL) << "PNorm's p should be 0, 1 or 2, here p is: " << p_;
}
return MaceStatus::MACE_SUCCESS;
}
private:
int p_;
int output_dim_;
};
void RegisterPNorm(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "PNorm", PNormOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void PNormBenchmark(int iters, int n, int h, int w, int p, int ow) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {n, h, w});
OpDefBuilder("PNorm", "PNormBM")
.Input("Input")
.AddIntArg("p", p)
.AddIntArg("output_dim", ow)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
} // namespace
#define MACE_BM_PNORM_MACRO(N, H, W, P, OW, TYPE, DEVICE) \
static void \
MACE_BM_PNORM_##N##_##H##_##W##_##P##_##OW##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
PNormBenchmark<DEVICE, TYPE>(iters, N, H, W, P, OW); \
} \
MACE_BENCHMARK( \
MACE_BM_PNORM_##N##_##H##_##W##_##P##_##OW##_##TYPE##_##DEVICE)
#define MACE_BM_PNORM(N, H, W, P, OW) \
MACE_BM_PNORM_MACRO(N, H, W, P, OW, float, CPU);
MACE_BM_PNORM(1, 10, 256, 0, 128);
MACE_BM_PNORM(1, 20, 128, 1, 64);
MACE_BM_PNORM(1, 10, 128, 2, 64);
MACE_BM_PNORM(1, 16, 256, 0, 128);
MACE_BM_PNORM(1, 32, 128, 1, 64);
MACE_BM_PNORM(1, 10, 512, 2, 256);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class PNormOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestPNorm(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const int p,
const int output_dim,
const std::vector<index_t> &output_shape,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
OpDefBuilder("PNorm", "PNormTest")
.Input("Input")
.AddIntArg("p", p)
.AddIntArg("output_dim", output_dim)
.Output("Output")
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", output_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(PNormOpTest, SimpleTest) {
TestPNorm<DeviceType::CPU, float>(
{1, 5, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18},
2, 5,
{1, 5, 5},
{2.236067977, 5, 7.810249676, 10.630145813, 13.453624047,
5, 7.810249676, 10.630145813, 13.453624047, 16.278820596,
7.810249676, 10.630145813, 13.453624047, 16.278820596, 19.104973175,
10.630145813, 13.453624047, 16.278820596, 19.104973175, 21.931712199,
13.453624047, 16.278820596, 19.104973175, 21.931712199, 24.758836806});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class SliceOp;
template <typename T>
class SliceOp<DeviceType::CPU, T> : public Operation {
public:
explicit SliceOp(OpConstructContext *context)
: Operation(context),
axes_(Operation::GetRepeatedArgs<int>("axes")),
starts_(Operation::GetRepeatedArgs<int>("starts")),
ends_(Operation::GetRepeatedArgs<int>("ends")) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
const index_t rank = input->dim_size();
MACE_CHECK(rank >= 1)
<< "The input dim size should >= 1";
MACE_CHECK(starts_.size() == 1 && ends_.size() == 1 && axes_.size() == 1,
"only support slicing at one axis.");
MACE_CHECK(axes_[0] == -1 || axes_[0] == rank - 1,
"only support slicing at the last axis.");
const index_t input_dim = input->dim(rank - 1);
const index_t offset = starts_[0];
const index_t output_dim = ends_[0] - starts_[0];
MACE_CHECK(output_dim >= 0, "output_dim should >= 0");
MACE_CHECK(starts_[0] < input_dim
&& output_dim <= input_dim
&& ends_[0] <= input_dim)
<< "The starts and ends caused over range error.";
const index_t frames =
std::accumulate(input->shape().begin(), input->shape().end() - 1, 1,
std::multiplies<index_t>());
std::vector<index_t> output_shape = input->shape();
output_shape[rank - 1] = output_dim;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < frames; ++i) {
const T *input_base =
input_data + i * input_dim + offset;
T *output_base =
output_data + i * output_dim;
memcpy(output_base, input_base, output_dim * sizeof(T));
}
return MaceStatus::MACE_SUCCESS;
}
private:
std::vector<int> axes_;
std::vector<int> starts_;
std::vector<int> ends_;
};
void RegisterSlice(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Slice", SliceOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class SliceOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestSlice(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const int offset,
const int output_dim,
const std::vector<index_t> &output_shape,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
OpDefBuilder("Slice", "SliceTest")
.Input("Input")
.Output("Output")
.AddIntsArg("axes", {-1})
.AddIntsArg("starts", {offset})
.AddIntsArg("ends", {offset + output_dim})
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", output_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(SliceOpTest, Simple2Dim) {
TestSlice<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
2, 3, {3, 3},
{3, 4, 5, 8, 9, 10, 13, 14, 15});
}
TEST_F(SliceOpTest, Simple3Dim) {
TestSlice<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
1, 2, {2, 3, 2},
{2, 3, 7, 8, 12, 13, 2, 3, 7, 8, 12, 13});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for SpliceComponent in Kaldi.
// It splices a context window of frames together [over time]
// (copy and append the frame whose time-index in in context_)
// The context_ values indicate which frame (over time) to splice.
// if context value is less than the first time-index,
// copy and append the first frame's dada,
// when context value is larger than frame's count,
// copy and append the last frame's data.
// i.e., give input data: [[1, 2, 3], [4, 5, 6]],
// with input-dim = 3, frame count = 2, context = [-1, 0, 1]
// Then, the output should be:
// [1, 2, 3, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 4, 5, 6]
// if const_component_dim_ != 0, const_dim_ will be used to determine which
// row of "in" we copy the last part of each row of "out" from (this part is
// not subject to splicing, it's assumed constant for each frame of "input".
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class SpliceOp;
template <typename T>
class SpliceOp<DeviceType::CPU, T> : public Operation {
public:
explicit SpliceOp(OpConstructContext *context)
: Operation(context),
context_(Operation::GetRepeatedArgs<int>("context")),
const_dim_(
Operation::GetOptionalArg<int>("const_component_dim", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
MACE_CHECK(context_.size() > 0)
<< "The context param should not be empty in Splice Op.";
Tensor *output = this->Output(0);
const std::vector<index_t> &input_shape = input->shape();
const index_t frames =
std::accumulate(input->shape().begin(), input->shape().end() - 1, 1,
std::multiplies<index_t>());
const index_t rank = input->dim_size();
const index_t input_dim = input_shape[rank - 1];
const index_t num_splice = static_cast<index_t>(context_.size());
const index_t dim = input_dim - const_dim_;
MACE_CHECK(input_dim > const_dim_,
"input dim should be greater than const dim.");
const index_t output_dim = dim * num_splice + const_dim_;
std::vector<index_t> output_shape = input->shape();
output_shape[rank - 1] = output_dim;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < frames; ++i) {
for (index_t c = 0; c < num_splice; ++c) {
const index_t offset =
Clamp<index_t>(context_[c] + i, 0, frames - 1);
T *output_base = output_data + i * output_dim + c * dim;
const T *input_base = input_data + offset * input_dim;
memcpy(output_base, input_base, dim * sizeof(T));
}
}
if (const_dim_ > 0) {
const index_t output_offset = output_dim - const_dim_;
const index_t input_offset = dim;
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < frames; ++i) {
index_t offset = i + context_[0] >= 0 ? i + context_[0] : 0;
T *output_base = output_data + i * output_dim;
const T *input_base = input_data + offset * input_dim;
memcpy(output_base + output_offset,
input_base + input_offset,
const_dim_ * sizeof(T));
}
}
return MaceStatus::MACE_SUCCESS;
}
private:
std::vector<int> context_;
int const_dim_;
};
void RegisterSplice(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Splice", SpliceOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template<DeviceType D, typename T>
void BMSpliceHelper(int iters,
const std::vector<index_t> &input_shape,
const index_t left_context,
const index_t right_context,
const int const_component_dim) {
mace::testing::StopTiming();
// Construct graph
OpsTestNet net;
const int num_splice = left_context + right_context + 1;
std::vector<int> contexts(num_splice);
for (int i = 0; i < num_splice; ++i) {
contexts[i] = left_context + i;
}
const index_t input_size = std::accumulate(input_shape.begin(),
input_shape.end(),
1,
std::multiplies<index_t>());
std::vector<float> input_data(input_size);
GenerateRandomRealTypeData(input_shape, &input_data);
net.AddInputFromArray<D, float>("Input", input_shape, input_data);
OpDefBuilder("Splice", "SpliceTest")
.Input("Input")
.Output("Output")
.AddIntsArg("context", contexts)
.AddIntArg("const_component_dim", const_component_dim)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
net.Sync();
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
} // namespace
#define MACE_BM_SPLICE_MACRO(N, H, W, L, R, C, TYPE, DEVICE) \
static void \
MACE_BM_SPLICE_##N##_##H##_##W##_##L##_##R##_##C##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W; \
mace::testing::MacsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BMSpliceHelper<DEVICE, TYPE>(iters, {N, H, W}, L, R, C); \
} \
MACE_BENCHMARK( \
MACE_BM_SPLICE_##N##_##H##_##W##_##L##_##R##_##C##_##TYPE##_##DEVICE)
#define MACE_BM_SPLICE(N, H, W, L, R, C) \
MACE_BM_SPLICE_MACRO(N, H, W, L, R, C, float, CPU);
MACE_BM_SPLICE(1, 32, 32, 5, 5, 10);
MACE_BM_SPLICE(1, 32, 32, 7, 7, 5);
MACE_BM_SPLICE(1, 32, 32, 3, 3, 20);
MACE_BM_SPLICE(1, 128, 128, 9, 9, 100);
MACE_BM_SPLICE(1, 128, 128, 7, 7, 100);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class SpliceOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestSplice(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const std::vector<int> &context,
const int const_dim,
const std::vector<index_t> &output_shape,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
OpDefBuilder("Splice", "SpliceTest")
.Input("Input")
.Output("Output")
.AddIntsArg("context", context)
.AddIntArg("const_component_dim", const_dim)
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", output_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(SpliceOpTest, WithoutConstDim) {
TestSplice<DeviceType::CPU, float>(
{1, 7, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
{-2, -1, 0, 1, 2}, 0,
{1, 7, 10},
{1, 2, 1, 2, 1, 2, 3, 4, 5, 6,
1, 2, 1, 2, 3, 4, 5, 6, 7, 8,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
7, 8, 9, 10, 11, 12, 13, 14, 13, 14,
9, 10, 11, 12, 13, 14, 13, 14, 13, 14});
}
TEST_F(SpliceOpTest, WithConstDim) {
TestSplice<DeviceType::CPU, float>(
{1, 5, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
{-2, -1, 0, 1, 2}, 7,
{1, 5, 22},
{1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10,
1, 2, 3, 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 4, 5, 6, 7, 8, 9, 10,
1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10,
2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 5, 6, 7, 5, 6, 7, 8, 9, 10, 11,
3, 4, 5, 4, 5, 6, 5, 6, 7, 5, 6, 7, 5, 6, 7, 6, 7, 8, 9, 10, 11, 12});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for SumGroupComponent in Kaldi.
// It's used to sum up groups of posteriors,
// and to introduce a kind of Gaussian-mixture-model-like
// idea into neural nets.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class SumGroupOp;
template <typename T>
class SumGroupOp<DeviceType::CPU, T> : public Operation {
public:
explicit SumGroupOp(OpConstructContext *context)
: Operation(context) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
MACE_CHECK(this->InputSize() >= 2,
"SumGroup should have at least 2 inputs.");
const Tensor *input = this->Input(0);
// Sizes-input gets a vector saying, for
// each output-dim, how many
// inputs data were summed over.
const Tensor *sizes = this->Input(1);
Tensor *output = this->Output(0);
MACE_CHECK(input->dim_size() >= 1,
"SumGroup's input's rank should be >= 1.");
MACE_CHECK(sizes->dim_size() == 1,
"SumGroup's sizes input should be a vector.");
const std::vector<index_t> &input_shape = input->shape();
const index_t bh =
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1,
std::multiplies<index_t>());
std::vector<index_t> output_shape(input_shape);
const index_t output_dim = sizes->dim(0);
const index_t dim_size = input->dim_size();
const index_t input_dim = input_shape[dim_size -1];
output_shape[dim_size - 1] = output_dim;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_sizes(sizes);
Tensor::MappingGuard guard_output(output);
const T *input_data = input->data<T>();
const int *sizes_data = sizes->data<int>();
T *output_data = output->mutable_data<T>();
std::vector<std::pair<int, int>>
sum_indexes(static_cast<size_t >(output_dim));
int cur_index = 0;
for (index_t i = 0; i < output_dim; ++i) {
int size_value = sizes_data[i];
MACE_CHECK(size_value > 0, "size value should be > 0");
sum_indexes[i].first = cur_index;
cur_index += size_value;
sum_indexes[i].second = cur_index;
MACE_CHECK(cur_index <= input_dim)
<< "size value over-ranged:" << cur_index << "<=" << input_dim;
}
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < bh; ++i) {
for (index_t j = 0; j < output_dim; ++j) {
int start_col = sum_indexes[j].first;
int end_col = sum_indexes[j].second;
T sum = 0;
for (int src_col = start_col; src_col < end_col; ++src_col) {
sum += input_data[i * input_dim + src_col];
}
output_data[i * output_dim + j] = sum;
}
}
return MaceStatus::MACE_SUCCESS;
}
};
void RegisterSumGroup(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "SumGroup", SumGroupOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void SumGroupBenchmark(int iters, int n, int h, int w) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {n, h, w});
net.AddRepeatedInput<D, int>("Sizes",
{w / 2},
2);
OpDefBuilder("SumGroup", "SumGroupBM")
.Input("Input")
.Input("Sizes")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
} // namespace
#define MACE_BM_SUMGROUP_MACRO(N, H, W, TYPE, DEVICE) \
static void \
MACE_BM_SUMGROUP_##N##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
SumGroupBenchmark<DEVICE, TYPE>(iters, N, H, W); \
} \
MACE_BENCHMARK( \
MACE_BM_SUMGROUP_##N##_##H##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_SUMGROUP(N, H, W) \
MACE_BM_SUMGROUP_MACRO(N, H, W, float, CPU);
MACE_BM_SUMGROUP(1, 10, 256);
MACE_BM_SUMGROUP(1, 20, 128);
MACE_BM_SUMGROUP(1, 10, 128);
MACE_BM_SUMGROUP(1, 20, 512);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class SumGroupOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestSumGroup(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const std::vector<int> &sizes,
const std::vector<index_t> &output_shape,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
const index_t output_dim = sizes.size();
net.AddInputFromArray<CPU, int>(MakeString("Sizes"),
{output_dim},
sizes);
OpDefBuilder("SumGroup", "SumGroupTest")
.Input("Input")
.Input("Sizes")
.Output("Output")
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", output_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(SumGroupOpTest, SimpleTest) {
TestSumGroup<DeviceType::CPU, float>(
{1, 5, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
{2, 1, 2, 3, 2},
{1, 5, 5},
{3, 3, 9, 21, 19,
5, 4, 11, 24, 21,
7, 5, 13, 27, 23,
9, 6, 15, 30, 25,
11, 7, 17, 33, 27});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This op is implemented for kaldi's NormalizeComponent.
// The output y_i = scale * x_i,
// and we want the RMS value of the y_i equals to target_rms,
// so y^t y = Dim * target_rms^2 (if y is one row of the input).
// Dim is the length of a row.
// we need the scale = 1.0 / sqrt(x^t x / (Dim * target_rms^2)).
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class TargetRMSNormOp;
template <typename T>
class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
public:
explicit TargetRMSNormOp(OpConstructContext *context)
: Operation(context),
target_rms_(Operation::GetOptionalArg<float>("target_rms", 1.0)) {}
// Calculate the square sum of an array
float SquareSum(const float *data, const index_t data_len) {
const int num_parts = 4;
float result = 0.0f;
if (data_len <= 2 * num_parts) {
for (index_t i = 0; i < data_len; ++i) {
result += data[i] * data[i];
}
} else {
const index_t part_len = data_len / num_parts;
const index_t left_len = data_len % num_parts;
float results[4] = {0.f, 0.f, 0.f, 0.f};
for (index_t i = 0; i < num_parts; ++i) {
for (index_t j = 0; j < part_len; ++j) {
results[i] += data[i * part_len + j] * data[i * part_len + j];
}
}
for (index_t k = 0; k < left_len; ++k) {
float d = data[num_parts * part_len + k];
results[3] += d * d;
}
for (index_t i = 0; i < num_parts; ++i) {
result += results[i];
}
}
return result;
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
const std::vector<index_t> &input_shape = input->shape();
const index_t dim_size = input->dim_size();
MACE_CHECK(dim_size >= 1,
"TargetRMSNorm's input dim size should be >= 1.");
const index_t dim = input_shape[dim_size -1];
MACE_CHECK(dim > 0 && target_rms_ > 0,
"Both input dim and target rms should be greater than zero.");
const index_t bh =
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1,
std::multiplies<index_t>());
const float d_scale = dim * target_rms_ * target_rms_;
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < bh; ++i) {
float scale = SquareSum(input_data + i * dim, dim);
scale = static_cast<float>(1.0 / std::sqrt(scale / d_scale));
for (index_t j = 0; j < dim; ++j) {
output_data[i * dim + j] = input_data[i * dim + j] * scale;
}
}
return MaceStatus::MACE_SUCCESS;
}
private:
float target_rms_;
};
void RegisterTargetRMSNorm(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "TargetRMSNorm", TargetRMSNormOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void TargetRMSNormBenchmark(int iters, int n, int h, int w, float target_rms) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {n, h, w});
OpDefBuilder("TargetRMSNorm", "TargetRMSNormBM")
.Input("Input")
.AddFloatArg("target_rms", target_rms)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
} // namespace
#define MACE_BM_TARGETRMSNORM_MACRO(N, H, W, RMS, TYPE, DEVICE) \
static void \
MACE_BM_TARGETRMSNORM_##N##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
TargetRMSNormBenchmark<DEVICE, TYPE>(iters, N, H, W, RMS); \
} \
MACE_BENCHMARK( \
MACE_BM_TARGETRMSNORM_##N##_##H##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_TARGETRMSNORM(N, H, W, RMS) \
MACE_BM_TARGETRMSNORM_MACRO(N, H, W, RMS, float, CPU);
MACE_BM_TARGETRMSNORM(1, 10, 256, 1.0);
MACE_BM_TARGETRMSNORM(1, 20, 128, 2.0);
MACE_BM_TARGETRMSNORM(1, 10, 128, 0.5);
MACE_BM_TARGETRMSNORM(1, 20, 512, 1.0);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class TargetRMSNormOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestTargetRMSNorm(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const float target_rms,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
OpDefBuilder("TargetRMSNorm", "TargetRMSNormTest")
.Input("Input")
.AddFloatArg("target_rms", target_rms)
.Output("Output")
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", input_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(TargetRMSNormOpTest, SimpleTest) {
TestTargetRMSNorm<DeviceType::CPU, float>(
{1, 3, 3},
{1, 2, 3,
2, 3, 4,
3, 4, 5},
1.0,
{0.46291, 0.92582, 1.38873,
0.64327, 0.9649, 1.28654,
0.734847, 0.979796, 1.224745});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for offset descriptor in Kaldi.
// It defines time offset.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class TimeOffsetOp;
template <typename T>
class TimeOffsetOp<DeviceType::CPU, T> : public Operation {
public:
explicit TimeOffsetOp(OpConstructContext *context)
: Operation(context),
offset_(Operation::GetOptionalArg<int>("offset", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
index_t rank = input->dim_size();
MACE_CHECK(rank >= 2, "input's rank should >= 2.");
const std::vector<index_t> &input_shape = input->shape();
const index_t batch =
std::accumulate(input_shape.begin(), input_shape.end() - 2, 1,
std::multiplies<index_t>());
const index_t frames = input_shape[rank - 2];
const index_t input_dim = input_shape[rank - 1];
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < batch; ++i) {
for (index_t j = 0; j < frames; ++j) {
index_t time_index = offset_ + j;
index_t index = Clamp<index_t>(time_index, 0, frames - 1);
T *output_base = output_data + (i * frames + j) * input_dim;
const T *input_base = input_data + (i * frames + index) * input_dim;
memcpy(output_base, input_base, input_dim * sizeof(T));
}
}
return MaceStatus::MACE_SUCCESS;
}
private:
int offset_;
};
void RegisterTimeOffset(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "TimeOffset", TimeOffsetOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template<DeviceType D, typename T>
void TimeOffsetBenchmark(int iters,
std::vector<index_t> shape,
int offset) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", shape);
OpDefBuilder("TimeOffset", "TimeOffsetBM")
.Input("Input")
.Output("Output")
.AddIntArg("offset", offset)
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_TIMEOFFSET2D_MACRO(H, W, TYPE, DEVICE) \
static void MACE_BM_TIMEOFFSET2D_##H##_##W##_##TYPE##_##DEVICE(\
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
TimeOffsetBenchmark<DEVICE, TYPE>(iters, {H, W}, 1); \
} \
MACE_BENCHMARK(MACE_BM_TIMEOFFSET2D_##H##_##W##_##TYPE##_##DEVICE) \
#define MACE_BM_TIMEOFFSET2D(H, W) \
MACE_BM_TIMEOFFSET2D_MACRO(H, W, float, CPU);
MACE_BM_TIMEOFFSET2D(20, 128);
MACE_BM_TIMEOFFSET2D(40, 512);
MACE_BM_TIMEOFFSET2D(1, 1024);
MACE_BM_TIMEOFFSET2D(20, 2048);
MACE_BM_TIMEOFFSET2D(20, 512);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class TimeOffsetOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestTimeOffset(const std::vector<index_t> &input_shape,
const std::vector<T> &input,
const int offset,
const std::vector<T> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, T>(MakeString("Input"),
input_shape,
input);
OpDefBuilder("TimeOffset", "TimeOffsetTest")
.Input("Input")
.Output("Output")
.AddIntArg("offset", offset)
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, T>("ExpectedOutput", input_shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(TimeOffsetOpTest, Simple2Dim) {
TestTimeOffset<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
-2,
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5});
TestTimeOffset<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
-1,
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
TestTimeOffset<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
0,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
TestTimeOffset<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
1,
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
TestTimeOffset<DeviceType::CPU, float>(
{3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
2,
{11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
}
TEST_F(TimeOffsetOpTest, Simple3Dim) {
TestTimeOffset<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
-2,
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5,
1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5});
TestTimeOffset<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
-1,
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
TestTimeOffset<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
0,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
TestTimeOffset<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
1,
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
TestTimeOffset<DeviceType::CPU, float>(
{2, 3, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
2,
{11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15,
11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -98,6 +98,7 @@ class FrameworkType(Enum): ...@@ -98,6 +98,7 @@ class FrameworkType(Enum):
MaceSupportedOps = [ MaceSupportedOps = [
'Activation', 'Activation',
'AddN', 'AddN',
'Affine',
'ArgMax', 'ArgMax',
'BatchNorm', 'BatchNorm',
'BatchToSpaceND', 'BatchToSpaceND',
...@@ -121,8 +122,10 @@ MaceSupportedOps = [ ...@@ -121,8 +122,10 @@ MaceSupportedOps = [
'InferConv2dShape', 'InferConv2dShape',
'LocalResponseNorm', 'LocalResponseNorm',
'LSTMCell', 'LSTMCell',
# 'LstmNonlinear',
'MatMul', 'MatMul',
'Pad', 'Pad',
'PNorm',
'Pooling', 'Pooling',
'PriorBox', 'PriorBox',
'Proposal', 'Proposal',
...@@ -134,6 +137,8 @@ MaceSupportedOps = [ ...@@ -134,6 +137,8 @@ MaceSupportedOps = [
'ResizeNearestNeighbor', 'ResizeNearestNeighbor',
'Reverse', 'Reverse',
'ScalarMath', 'ScalarMath',
'Slice',
'Splice',
'Split', 'Split',
'Shape', 'Shape',
'Squeeze', 'Squeeze',
...@@ -144,6 +149,9 @@ MaceSupportedOps = [ ...@@ -144,6 +149,9 @@ MaceSupportedOps = [
'SpaceToBatchND', 'SpaceToBatchND',
'SpaceToDepth', 'SpaceToDepth',
'SqrDiffMean', 'SqrDiffMean',
'SumGroup',
'TargetRMSNorm',
'TimeOffset',
'Transpose', 'Transpose',
'WinogradInverseTransform', 'WinogradInverseTransform',
'WinogradTransform', 'WinogradTransform',
...@@ -159,6 +167,7 @@ class MaceKeyword(object): ...@@ -159,6 +167,7 @@ class MaceKeyword(object):
mace_buffer_type = 'buffer_type' mace_buffer_type = 'buffer_type'
# arg related str # arg related str
mace_padding_str = 'padding' mace_padding_str = 'padding'
mace_padding_type_str = 'padding'
mace_padding_values_str = 'padding_values' mace_padding_values_str = 'padding_values'
mace_strides_str = 'strides' mace_strides_str = 'strides'
mace_dilations_str = 'dilations' mace_dilations_str = 'dilations'
...@@ -473,6 +482,7 @@ class ConverterOption(object): ...@@ -473,6 +482,7 @@ class ConverterOption(object):
# Model data format related transformation # Model data format related transformation
TransformerRule.TRANSPOSE_FILTERS, TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSPOSE_DATA_FORMAT, TransformerRule.TRANSPOSE_DATA_FORMAT,
TransformerRule.TRANSPOSE_MATMUL_WEIGHT,
# Add winograd argument # Add winograd argument
TransformerRule.ADD_WINOGRAD_ARG, TransformerRule.ADD_WINOGRAD_ARG,
# Mace model structure related transformation # Mace model structure related transformation
......
...@@ -33,21 +33,23 @@ from mace.python.tools.converter_tool.base_converter import MaceKeyword ...@@ -33,21 +33,23 @@ from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import ConverterUtil from mace.python.tools.converter_tool.base_converter import ConverterUtil
from mace.python.tools.convert_util import mace_check from mace.python.tools.convert_util import mace_check
import numpy as np
import onnx import onnx
import onnx.utils import onnx.utils
from onnx import helper, shape_inference, numpy_helper, optimizer from onnx import mapping, numpy_helper, TensorProto
import numpy as np
from onnx import mapping
from onnx import TensorProto
from numbers import Number from numbers import Number
IS_PYTHON3 = sys.version_info > (3,)
OnnxSupportedOps = [ OnnxSupportedOps = [
'Abs', 'Abs',
# 'Acos', # 'Acos',
# 'Acosh', # 'Acosh',
'Add', 'Add',
'Affine',
# 'And', # 'And',
'Append',
'ArgMax', 'ArgMax',
'ArgMin', 'ArgMin',
# 'Asin', # 'Asin',
...@@ -68,6 +70,7 @@ OnnxSupportedOps = [ ...@@ -68,6 +70,7 @@ OnnxSupportedOps = [
# 'Cos', # 'Cos',
# 'Cosh', # 'Cosh',
'DepthToSpace', 'DepthToSpace',
'DimRange',
'Div', 'Div',
'Dropout', 'Dropout',
'Elu', 'Elu',
...@@ -88,10 +91,12 @@ OnnxSupportedOps = [ ...@@ -88,10 +91,12 @@ OnnxSupportedOps = [
# 'Hardmax', # 'Hardmax',
'Identity', 'Identity',
# 'If', # 'If',
'IfDefined',
'ImageScaler', 'ImageScaler',
# 'InstanceNormalization', # 'InstanceNormalization',
# 'LRN', # 'LRN',
# 'LSTM', 'LSTM',
# 'LstmNonlinear',
'LeakyRelu', 'LeakyRelu',
# 'Less', # 'Less',
# 'Log', # 'Log',
...@@ -109,11 +114,15 @@ OnnxSupportedOps = [ ...@@ -109,11 +114,15 @@ OnnxSupportedOps = [
'Mul', 'Mul',
# 'Multinomial', # 'Multinomial',
'Neg', 'Neg',
'Normalize',
# 'Not', # 'Not',
'Offset',
# 'OneHot', # 'OneHot',
# 'Or', # 'Or',
'PRelu', 'PRelu',
'Pad', # 'Pad',
'Padding',
'PNorm',
'Pow', 'Pow',
# 'RNN', # 'RNN',
# 'RandomNormal', # 'RandomNormal',
...@@ -133,6 +142,7 @@ OnnxSupportedOps = [ ...@@ -133,6 +142,7 @@ OnnxSupportedOps = [
# 'ReduceSumSquare', # 'ReduceSumSquare',
'Relu', 'Relu',
'Reshape', 'Reshape',
'Scale',
# 'Scan', # 'Scan',
# 'Selu', # 'Selu',
'Shape', 'Shape',
...@@ -140,18 +150,21 @@ OnnxSupportedOps = [ ...@@ -140,18 +150,21 @@ OnnxSupportedOps = [
# 'Sin', # 'Sin',
# 'Sinh', # 'Sinh',
# 'Size', # 'Size',
# 'Slice', 'Slice',
'Softmax', 'Softmax',
# 'Softplus', # 'Softplus',
# 'Softsign', # 'Softsign',
'SpaceToDepth', 'SpaceToDepth',
'Splice',
'Split', 'Split',
'Sqrt', 'Sqrt',
'Squeeze', 'Squeeze',
'Sub', 'Sub',
'Sum', 'Sum',
'SumGroup',
# 'Tan', # 'Tan',
'Tanh', 'Tanh',
'TargetRMSNorm',
# 'Tile', # 'Tile',
# 'TopK', # 'TopK',
'Transpose', 'Transpose',
...@@ -188,7 +201,7 @@ def convert_onnx_attribute_proto(attr_proto): ...@@ -188,7 +201,7 @@ def convert_onnx_attribute_proto(attr_proto):
return attr_proto.i return attr_proto.i
elif attr_proto.HasField('s'): elif attr_proto.HasField('s'):
return str(attr_proto.s, 'utf-8')\ return str(attr_proto.s, 'utf-8')\
if sys.version_info.major == 3 else attr_proto.s if IS_PYTHON3 else attr_proto.s
elif attr_proto.HasField('t'): elif attr_proto.HasField('t'):
return attr_proto.t # this is a proto! return attr_proto.t # this is a proto!
elif attr_proto.floats: elif attr_proto.floats:
...@@ -273,6 +286,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -273,6 +286,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Equal.name: EltwiseType.EQUAL, OnnxOpType.Equal.name: EltwiseType.EQUAL,
OnnxOpType.Sqrt.name: EltwiseType.POW, OnnxOpType.Sqrt.name: EltwiseType.POW,
OnnxOpType.Reciprocal.name: EltwiseType.POW, OnnxOpType.Reciprocal.name: EltwiseType.POW,
OnnxOpType.Scale.name: EltwiseType.PROD,
} }
reduce_type = { reduce_type = {
...@@ -296,6 +310,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -296,6 +310,8 @@ class OnnxConverter(base_converter.ConverterInterface):
self._op_converters = { self._op_converters = {
OnnxOpType.Abs.name: self.convert_eltwise, OnnxOpType.Abs.name: self.convert_eltwise,
OnnxOpType.Add.name: self.convert_eltwise, OnnxOpType.Add.name: self.convert_eltwise,
OnnxOpType.Affine.name: self.convert_affine,
OnnxOpType.Append.name: self.convert_concat,
OnnxOpType.ArgMax.name: self.convert_argmax, OnnxOpType.ArgMax.name: self.convert_argmax,
OnnxOpType.ArgMin.name: self.convert_argmax, OnnxOpType.ArgMin.name: self.convert_argmax,
OnnxOpType.AveragePool.name: self.convert_pooling, OnnxOpType.AveragePool.name: self.convert_pooling,
...@@ -306,6 +322,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -306,6 +322,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.ConvTranspose.name: self.convert_deconv, OnnxOpType.ConvTranspose.name: self.convert_deconv,
OnnxOpType.DepthToSpace.name: self.convert_depth_space, OnnxOpType.DepthToSpace.name: self.convert_depth_space,
OnnxOpType.Dropout.name: self.convert_identity, OnnxOpType.Dropout.name: self.convert_identity,
OnnxOpType.DimRange.name: self.convert_dim_range,
OnnxOpType.Div.name: self.convert_eltwise, OnnxOpType.Div.name: self.convert_eltwise,
OnnxOpType.Equal.name: self.convert_eltwise, OnnxOpType.Equal.name: self.convert_eltwise,
OnnxOpType.Gather.name: self.convert_gather, OnnxOpType.Gather.name: self.convert_gather,
...@@ -313,47 +330,71 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -313,47 +330,71 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.GlobalAveragePool.name: self.convert_reduce, OnnxOpType.GlobalAveragePool.name: self.convert_reduce,
OnnxOpType.GlobalMaxPool.name: self.convert_reduce, OnnxOpType.GlobalMaxPool.name: self.convert_reduce,
OnnxOpType.Identity.name: self.convert_identity, OnnxOpType.Identity.name: self.convert_identity,
OnnxOpType.IfDefined.name: self.convert_identity,
OnnxOpType.ImageScaler.name: self.convert_imagescaler, OnnxOpType.ImageScaler.name: self.convert_imagescaler,
OnnxOpType.LeakyRelu.name: self.convert_activation, OnnxOpType.LeakyRelu.name: self.convert_activation,
# OnnxOpType.LogSoftmax.name: self.convert_softmax,
OnnxOpType.LSTM.name: self.convert_lstm,
# OnnxOpType.LstmNonlinear.name: self.convert_lstm_nonlinear,
OnnxOpType.Max.name: self.convert_eltwise, OnnxOpType.Max.name: self.convert_eltwise,
OnnxOpType.MaxPool.name: self.convert_pooling, OnnxOpType.MaxPool.name: self.convert_pooling,
OnnxOpType.MatMul.name: self.convert_matmul, OnnxOpType.MatMul.name: self.convert_matmul,
OnnxOpType.Min.name: self.convert_eltwise, OnnxOpType.Min.name: self.convert_eltwise,
OnnxOpType.Mul.name: self.convert_eltwise, OnnxOpType.Mul.name: self.convert_eltwise,
OnnxOpType.Neg.name: self.convert_eltwise, OnnxOpType.Neg.name: self.convert_eltwise,
OnnxOpType.Pad.name: self.convert_pad, OnnxOpType.Normalize: self.convert_normalize,
OnnxOpType.Offset.name: self.convert_timeoffset,
OnnxOpType.Padding.name: self.convert_identity,
OnnxOpType.PNorm.name: self.convert_pnorm,
OnnxOpType.Pow.name: self.convert_eltwise, OnnxOpType.Pow.name: self.convert_eltwise,
OnnxOpType.PRelu.name: self.convert_activation, OnnxOpType.PRelu.name: self.convert_activation,
OnnxOpType.Relu.name: self.convert_activation, OnnxOpType.Relu.name: self.convert_activation,
OnnxOpType.Reshape.name: self.convert_reshape, OnnxOpType.Reshape.name: self.convert_reshape,
OnnxOpType.Reciprocal.name: self.convert_eltwise, OnnxOpType.Reciprocal.name: self.convert_eltwise,
OnnxOpType.Scale.name: self.convert_eltwise,
OnnxOpType.Sigmoid.name: self.convert_activation, OnnxOpType.Sigmoid.name: self.convert_activation,
OnnxOpType.Slice.name: self.convert_slice,
OnnxOpType.Softmax.name: self.convert_softmax, OnnxOpType.Softmax.name: self.convert_softmax,
OnnxOpType.SpaceToDepth.name: self.convert_depth_space, OnnxOpType.SpaceToDepth.name: self.convert_depth_space,
OnnxOpType.Splice.name: self.convert_splice,
OnnxOpType.Split.name: self.convert_split, OnnxOpType.Split.name: self.convert_split,
OnnxOpType.Sqrt.name: self.convert_eltwise, OnnxOpType.Sqrt.name: self.convert_eltwise,
OnnxOpType.Squeeze.name: self.convert_squeeze, OnnxOpType.Squeeze.name: self.convert_squeeze,
OnnxOpType.Sub.name: self.convert_eltwise, OnnxOpType.Sub.name: self.convert_eltwise,
OnnxOpType.Sum.name: self.convert_eltwise, OnnxOpType.Sum.name: self.convert_eltwise,
OnnxOpType.SumGroup.name: self.convert_sum_group,
OnnxOpType.Tanh.name: self.convert_activation, OnnxOpType.Tanh.name: self.convert_activation,
OnnxOpType.TargetRMSNorm: self.convert_target_rms_norm,
OnnxOpType.Transpose.name: self.convert_transpose, OnnxOpType.Transpose.name: self.convert_transpose,
} }
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
self._data_format = DataFormat.NCHW
ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.OIHW) ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.OIHW)
onnx_model = onnx.load(src_model_file) onnx_model = onnx.load(src_model_file)
polished_model = onnx.utils.polish_model(onnx_model) ir_version = onnx_model.ir_version
opset_imp = onnx_model.opset_import
print "onnx model IR version: ", onnx_model.ir_version
print "onnx model opset import: ", onnx_model.opset_import polish_available = True
print "onnx model IR version: ", ir_version
self._onnx_model = shape_inference.infer_shapes(polished_model) for imp in opset_imp:
domain = imp.domain
version = imp.version
print "constains ops domain: ", domain, "version:", version
if 'kaldi2onnx' in domain:
polish_available = False
self._data_format = DataFormat.DF_NONE
if polish_available:
onnx_model = onnx.utils.polish_model(onnx_model)
self._onnx_model = onnx_model
self._graph_shapes_dict = {} self._graph_shapes_dict = {}
self._consts = {} self._consts = {}
self._replace_tensors = {} self._replace_tensors = {}
def print_graph_info(self, graph): @staticmethod
def print_graph_info(graph):
for value_info in graph.value_info: for value_info in graph.value_info:
print "value info:", value_info print "value info:", value_info
for value_info in graph.input: for value_info in graph.input:
...@@ -368,12 +409,12 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -368,12 +409,12 @@ class OnnxConverter(base_converter.ConverterInterface):
if t: if t:
shape_dict[value_info.name] = t shape_dict[value_info.name] = t
for value_info in graph.value_info: for vi in graph.value_info:
extract_value_info(self._graph_shapes_dict, value_info) extract_value_info(self._graph_shapes_dict, vi)
for value_info in graph.input: for vi in graph.input:
extract_value_info(self._graph_shapes_dict, value_info) extract_value_info(self._graph_shapes_dict, vi)
for value_info in graph.output: for vi in graph.output:
extract_value_info(self._graph_shapes_dict, value_info) extract_value_info(self._graph_shapes_dict, vi)
def add_tensor(self, name, shape, data_type, value): def add_tensor(self, name, shape, data_type, value):
tensor = self._mace_net_def.tensors.add() tensor = self._mace_net_def.tensors.add()
...@@ -387,11 +428,6 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -387,11 +428,6 @@ class OnnxConverter(base_converter.ConverterInterface):
self.extract_shape_info(graph_def) self.extract_shape_info(graph_def)
self.convert_tensors(graph_def) self.convert_tensors(graph_def)
self.convert_ops(graph_def) self.convert_ops(graph_def)
# self.print_graph_info(graph_def)
# shape_inferer = mace_shape_inference.ShapeInference(
# self._mace_net_def,
# self._option.input_nodes.values())
# shape_inferer.run()
return self._mace_net_def return self._mace_net_def
def add_stride_pad_kernel_arg(self, attrs, op_def): def add_stride_pad_kernel_arg(self, attrs, op_def):
...@@ -435,6 +471,32 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -435,6 +471,32 @@ class OnnxConverter(base_converter.ConverterInterface):
padding_arg.name = MaceKeyword.mace_padding_values_str padding_arg.name = MaceKeyword.mace_padding_values_str
padding_arg.ints.extend(pad) padding_arg.ints.extend(pad)
def remove_node(self, node):
input_name = node.inputs[0]
output_name = node.outputs[0]
self._replace_tensors[output_name] = input_name
@staticmethod
def squeeze_shape(shape, axis):
new_shape = []
if len(axis) > 0:
for i in range(len(shape)):
if i not in axis:
new_shape.append(shape[i])
else:
new_shape = shape
return new_shape
@staticmethod
def transpose_const(tensor):
shape = tensor.dims
mace_check(len(shape) == 2, "gemm only supports 2-dim input.")
tensor_data = np.array(tensor.float_data).reshape(
shape[0], shape[1])
tensor_data = tensor_data.transpose(1, 0)
tensor.float_data[:] = tensor_data.flat
tensor.dims[:] = tensor_data.shape
def convert_ops(self, graph_def): def convert_ops(self, graph_def):
for n in graph_def.node: for n in graph_def.node:
node = OnnxNode(n) node = OnnxNode(n)
...@@ -471,7 +533,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -471,7 +533,7 @@ class OnnxConverter(base_converter.ConverterInterface):
"Not supported tensor type: %s" % data_type) "Not supported tensor type: %s" % data_type)
self._consts[tensor.name] = tensor self._consts[tensor.name] = tensor
def convert_general_op(self, node): def convert_general_op(self, node, with_shape=True):
op = self._mace_net_def.op.add() op = self._mace_net_def.op.add()
op.name = node.name op.name = node.name
...@@ -481,9 +543,11 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -481,9 +543,11 @@ class OnnxConverter(base_converter.ConverterInterface):
op.input.append(input) op.input.append(input)
for output in node.outputs: for output in node.outputs:
op.output.append(output) op.output.append(output)
output_shape = op.output_shape.add() if with_shape:
shape_info = self._graph_shapes_dict[output] if output in self._graph_shapes_dict:
output_shape.dims.extend(shape_info) output_shape = op.output_shape.add()
shape_info = self._graph_shapes_dict[output]
output_shape.dims.extend(shape_info)
data_type_arg = op.arg.add() data_type_arg = op.arg.add()
data_type_arg.name = 'T' data_type_arg.name = 'T'
...@@ -493,91 +557,9 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -493,91 +557,9 @@ class OnnxConverter(base_converter.ConverterInterface):
framework_type_arg.name = MaceKeyword.mace_framework_type_str framework_type_arg.name = MaceKeyword.mace_framework_type_str
framework_type_arg.i = FrameworkType.ONNX.value framework_type_arg.i = FrameworkType.ONNX.value
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) ConverterUtil.add_data_format_arg(op, self._data_format)
return op return op
def convert_fused_batchnorm(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name
if "epsilon" in node.attrs:
epsilon_value = node.attrs["epsilon"]
else:
epsilon_value = 1e-5
mace_check(len(node.inputs) == 5, "batch norm should have 5 inputs.")
gamma_value = np.array(self._consts[node.inputs[1]].float_data)
beta_value = np.array(self._consts[node.inputs[2]].float_data)
mean_value = np.array(self._consts[node.inputs[3]].float_data)
var_value = np.array(self._consts[node.inputs[4]].float_data)
scale_name = node.name + 'scale'
offset_name = node.name + 'offset'
scale_value = (
(1.0 / np.sqrt(
var_value + epsilon_value)) * gamma_value)
offset_value = (-mean_value * scale_value) + beta_value
self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT,
scale_value)
self.add_tensor(offset_name, offset_value.shape, mace_pb2.DT_FLOAT,
offset_value)
del op.input[1:]
op.input.extend([scale_name, offset_name])
del op.output[1:]
del op.output_shape[1:]
def convert_conv2d(self, node):
op = self.convert_general_op(node)
self.add_stride_pad_kernel_arg(node.attrs, op)
group_arg = op.arg.add()
group_arg.name = MaceKeyword.mace_group_str
if 'group' in node.attrs:
group_val = node.attrs["group"]
else:
group_val = 1
group_arg.i = group_val
is_depthwise = False
if group_val > 1:
filter_shape = self._graph_shapes_dict[node.inputs[1]]
mace_check(group_val == filter_shape[0] and
filter_shape[1] == 1,
"Mace does not support group convolution yet")
filter_tensor = self._consts[node.inputs[1]]
new_shape = [filter_shape[1], filter_shape[0],
filter_shape[2], filter_shape[3]]
del filter_tensor.dims[:]
filter_tensor.dims.extend(new_shape)
is_depthwise = True
if is_depthwise:
op.type = MaceOp.DepthwiseConv2d.name
else:
op.type = MaceOp.Conv2D.name
dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str
if 'dilations' in node.attrs:
dilation_val = node.attrs["dilations"]
else:
dilation_val = [1, 1]
dilation_arg.ints.extend(dilation_val)
def convert_biasadd(self, node):
self.convert_general_op(node)
def convert_concat(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Concat.name
mace_check('axis' in node.attrs,
'Concat op should have axis attribute.')
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = node.attrs['axis']
axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i
mace_check(axis_arg.i == 1,
"only support concat at channel dimension")
def convert_activation(self, node): def convert_activation(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.Activation.name op.type = MaceOp.Activation.name
...@@ -597,100 +579,12 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -597,100 +579,12 @@ class OnnxConverter(base_converter.ConverterInterface):
alpha_arg.name = MaceKeyword.mace_activation_max_limit_str alpha_arg.name = MaceKeyword.mace_activation_max_limit_str
alpha_arg.f = alpha_value alpha_arg.f = alpha_value
def convert_pooling(self, node): def convert_affine(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Pooling.name
self.add_stride_pad_kernel_arg(node.attrs, op)
pooling_type_arg = op.arg.add()
pooling_type_arg.name = MaceKeyword.mace_pooling_type_str
pooling_type_arg.i = self.pooling_type_mode[node.op_type].value
round_mode_arg = op.arg.add()
round_mode_arg.name = MaceKeyword.mace_round_mode_str
round_mode_arg.i = RoundMode.FLOOR.value
def convert_reshape(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Reshape.name
def convert_flatten(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Reshape.name
def remove_node(self, node):
input_name = node.inputs[0]
output_name = node.outputs[0]
self._replace_tensors[output_name] = input_name
def convert_eltwise(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Eltwise.name
type_arg = op.arg.add()
type_arg.name = MaceKeyword.mace_element_type_str
type_arg.i = self.eltwise_type[node.op_type].value
if node.op_type == OnnxOpType.Sqrt.name:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = 0.5
elif node.op_type == OnnxOpType.Reciprocal.name:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = -1
def convert_reduce(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Reduce.name
reduce_type_arg = op.arg.add()
reduce_type_arg.name = MaceKeyword.mace_reduce_type_str
reduce_type_arg.i = self.reduce_type[node.op_type].value
if node.op_type in [OnnxOpType.GlobalAveragePool.name,
OnnxOpType.GlobalMaxPool.name]:
reduce_dims = [2, 3]
keep_dims = 1
else:
if 'axes' in node.attrs:
reduce_dims = node.attrs['axes']
else:
reduce_dims = []
if 'keepdims' in node.attrs:
keep_dims = node.attrs['keepdims']
else:
keep_dims = 1
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.ints.extend(reduce_dims)
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
keep_dims_arg.i = keep_dims
def convert_imagescaler(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name
scale = node.attrs['scale']
bias_value = np.array(node.attrs['bias'])
scale_value = scale * np.ones_like(bias_value)
scale_name = node.name + "_scale"
bias_name = node.name + "_bias"
self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT,
scale_value)
self.add_tensor(bias_name, bias_value.shape, mace_pb2.DT_FLOAT,
bias_value)
op.input.extend([scale_name, bias_name])
def convert_matmul(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.MatMul.name op.type = MaceOp.MatMul.name
transpose_b_arg = op.arg.add()
def convert_softmax(self, node): transpose_b_arg.name = MaceKeyword.mace_transpose_b_str
op = self.convert_general_op(node) transpose_b_arg.i = 1
op.type = MaceOp.Softmax.name
def convert_argmax(self, node): def convert_argmax(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
...@@ -717,6 +611,10 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -717,6 +611,10 @@ class OnnxConverter(base_converter.ConverterInterface):
min_arg.name = MaceKeyword.mace_argmin_str min_arg.name = MaceKeyword.mace_argmin_str
min_arg.i = 1 min_arg.i = 1
def convert_biasadd(self, node):
self.convert_general_op(node)
op.type = MaceOp.BiasAdd.name
def convert_cast(self, node): def convert_cast(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.Cast.name op.type = MaceOp.Cast.name
...@@ -732,41 +630,49 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -732,41 +630,49 @@ class OnnxConverter(base_converter.ConverterInterface):
else: else:
op.output_type.extend([self._option.data_type]) op.output_type.extend([self._option.data_type])
def convert_depth_space(self, node): def convert_concat(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
if op.type == OnnxOpType.DepthToSpace.name: op.type = MaceOp.Concat.name
op.type = MaceOp.DepthToSpace.name axis_value = 1
else: if node.op_type == OnnxOpType.Concat.name:
op.type = MaceOp.SpaceToDepth.name mace_check('axis' in node.attrs,
mace_check(('block_size' in node.attrs), 'Concat op should have axis attribute.')
"depth to space op should have block size attribute.") axis_value = node.attrs['axis']
block_size = node.attrs['block_size'] mace_check(axis_value == 1 or axis_value == -3,
size_arg = op.arg.add() "only support concat at channel dimension")
size_arg.name = MaceKeyword.mace_space_depth_block_size_str elif node.op_type == OnnxOpType.Append.name:
size_arg.i = block_size axis_value = 2
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 4 + axis_value if axis_value < 0 else axis_value
def convert_deconv(self, node): def convert_conv2d(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
self.add_stride_pad_kernel_arg(node.attrs, op) self.add_stride_pad_kernel_arg(node.attrs, op)
group_arg = op.arg.add()
group_arg.name = MaceKeyword.mace_group_str
if 'group' in node.attrs: if 'group' in node.attrs:
group_val = node.attrs["group"] group_val = node.attrs["group"]
else: else:
group_val = 1 group_val = 1
group_arg.i = group_val
is_depthwise = False
if group_val > 1: if group_val > 1:
op.type = MaceOp.DepthwiseDeconv2d.name
filter_shape = self._graph_shapes_dict[node.inputs[1]] filter_shape = self._graph_shapes_dict[node.inputs[1]]
mace_check(group_val == filter_shape[0] and
filter_shape[1] == 1,
"Mace does not support group convolution yet")
filter_tensor = self._consts[node.inputs[1]] filter_tensor = self._consts[node.inputs[1]]
new_shape = [filter_shape[1], filter_shape[0], new_shape = [filter_shape[1], filter_shape[0],
filter_shape[2], filter_shape[3]] filter_shape[2], filter_shape[3]]
del filter_tensor.dims[:] del filter_tensor.dims[:]
filter_tensor.dims.extend(new_shape) filter_tensor.dims.extend(new_shape)
is_depthwise = True
if is_depthwise:
op.type = MaceOp.DepthwiseConv2d.name
else: else:
op.type = MaceOp.Deconv2D.name op.type = MaceOp.Conv2D.name
group_arg = op.arg.add()
group_arg.name = MaceKeyword.mace_group_str
group_arg.i = group_val
dilation_arg = op.arg.add() dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str dilation_arg.name = MaceKeyword.mace_dilations_str
...@@ -775,16 +681,47 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -775,16 +681,47 @@ class OnnxConverter(base_converter.ConverterInterface):
else: else:
dilation_val = [1, 1] dilation_val = [1, 1]
dilation_arg.ints.extend(dilation_val) dilation_arg.ints.extend(dilation_val)
mace_check(dilation_val == [1, 1],
"not support convtranspose with dilation != 1 yet.")
mace_check('output_padding' not in node.attrs, def convert_deconv(self, node):
"not support convtranspose with output_padding yet.") op = self.convert_general_op(node)
mace_check('output_shape' not in node.attrs,
"not support convtranspose with output_shape yet.") self.add_stride_pad_kernel_arg(node.attrs, op)
# TODO: if output shape specified, calculate padding value
# if 'output_padding' in node.attrs: if 'group' in node.attrs:
# output_padding = node.attrs['output_padding'] group_val = node.attrs["group"]
else:
group_val = 1
if group_val > 1:
op.type = MaceOp.DepthwiseDeconv2d.name
filter_shape = self._graph_shapes_dict[node.inputs[1]]
filter_tensor = self._consts[node.inputs[1]]
new_shape = [filter_shape[1], filter_shape[0],
filter_shape[2], filter_shape[3]]
del filter_tensor.dims[:]
filter_tensor.dims.extend(new_shape)
else:
op.type = MaceOp.Deconv2D.name
group_arg = op.arg.add()
group_arg.name = MaceKeyword.mace_group_str
group_arg.i = group_val
dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str
if 'dilations' in node.attrs:
dilation_val = node.attrs["dilations"]
else:
dilation_val = [1, 1]
dilation_arg.ints.extend(dilation_val)
mace_check(dilation_val == [1, 1],
"not support convtranspose with dilation != 1 yet.")
mace_check('output_padding' not in node.attrs,
"not support convtranspose with output_padding yet.")
mace_check('output_shape' not in node.attrs,
"not support convtranspose with output_shape yet.")
# TODO: if output shape specified, calculate padding value
# if 'output_padding' in node.attrs:
# output_padding = node.attrs['output_padding']
# output_padding_arg = op.arg.add() # output_padding_arg = op.arg.add()
# output_padding_arg.name = MaceKeyword.mace_output_padding_str # output_padding_arg.name = MaceKeyword.mace_output_padding_str
# output_padding_arg.ints.extend(output_padding) # output_padding_arg.ints.extend(output_padding)
...@@ -794,43 +731,98 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -794,43 +731,98 @@ class OnnxConverter(base_converter.ConverterInterface):
# output_shape_arg.name = MaceKeyword.mace_output_shape_str # output_shape_arg.name = MaceKeyword.mace_output_shape_str
# output_shape_arg.ints.extend(output_shape) # output_shape_arg.ints.extend(output_shape)
def convert_nop(self, node): def convert_depth_space(self, node):
pass op = self.convert_general_op(node)
if op.type == OnnxOpType.DepthToSpace.name:
op.type = MaceOp.DepthToSpace.name
else:
op.type = MaceOp.SpaceToDepth.name
mace_check(('block_size' in node.attrs),
"depth to space op should have block size attribute.")
block_size = node.attrs['block_size']
size_arg = op.arg.add()
size_arg.name = MaceKeyword.mace_space_depth_block_size_str
size_arg.i = block_size
def convert_identity(self, node): def convert_dim_range(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.Identity.name op.type = MaceOp.Slice.name
mace_check('offset' in node.attrs,
"Attribute dim required!")
mace_check('output_dim' in node.attrs,
"Attribute output_dim required!")
offset = node.attrs['offset']
starts_arg = op.arg.add()
starts_arg.name = 'starts'
starts_arg.ints.append(offset)
output_dim = node.attrs['output_dim']
ends_arg = op.arg.add()
ends_arg.name = 'output_dim'
ends_arg.ints.append(output_dim)
axes_arg = op.arg.add()
axes_arg.name = 'axes'
axes_arg.ints.append(-1)
def convert_pad(self, node): def convert_eltwise(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.Pad.name op.type = MaceOp.Eltwise.name
type_arg = op.arg.add()
type_arg.name = MaceKeyword.mace_element_type_str
type_arg.i = self.eltwise_type[node.op_type].value
if 'pads' in node.attrs: if node.op_type == OnnxOpType.Sqrt.name:
paddings_arg = op.arg.add() value_arg = op.arg.add()
paddings_arg.name = MaceKeyword.mace_paddings_str value_arg.name = MaceKeyword.mace_scalar_input_str
paddings_value = node.attrs['pads'] value_arg.f = 0.5
paddings_arg.ints.extend(paddings_value) elif node.op_type == OnnxOpType.Reciprocal.name:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = -1
elif node.op_type == OnnxOpType.Scale.name and 'scale' in node.attrs:
value = node.attrs['scale']
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = value
if 'value' in node.attrs: def convert_flatten(self, node):
constant_value_arg = op.arg.add() op = self.convert_general_op(node)
constant_value_arg.name = MaceKeyword.mace_constant_value_str op.type = MaceOp.Reshape.name
constant_value_arg.i = node.attrs['value']
def convert_gather(self, node): def convert_fused_batchnorm(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.Gather.name op.type = MaceOp.BatchNorm.name
if 'axis' in node.attrs: if "epsilon" in node.attrs:
value = node.attrs['axis'] epsilon_value = node.attrs["epsilon"]
else: else:
value = 0 epsilon_value = 1e-5
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = value
def convert_split(self, node): mace_check(len(node.inputs) == 5, "batch norm should have 5 inputs.")
gamma_value = np.array(self._consts[node.inputs[1]].float_data)
beta_value = np.array(self._consts[node.inputs[2]].float_data)
mean_value = np.array(self._consts[node.inputs[3]].float_data)
var_value = np.array(self._consts[node.inputs[4]].float_data)
scale_name = node.name + 'scale'
offset_name = node.name + 'offset'
scale_value = (
(1.0 / np.sqrt(
var_value + epsilon_value)) * gamma_value)
offset_value = (-mean_value * scale_value) + beta_value
self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT,
scale_value)
self.add_tensor(offset_name, offset_value.shape, mace_pb2.DT_FLOAT,
offset_value)
del op.input[1:]
op.input.extend([scale_name, offset_name])
del op.output[1:]
del op.output_shape[1:]
def convert_gather(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.Split.name op.type = MaceOp.Gather.name
if 'axis' in node.attrs: if 'axis' in node.attrs:
value = node.attrs['axis'] value = node.attrs['axis']
...@@ -840,64 +832,6 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -840,64 +832,6 @@ class OnnxConverter(base_converter.ConverterInterface):
axis_arg.name = MaceKeyword.mace_axis_str axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = value axis_arg.i = value
def convert_transpose(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Transpose.name
if np.array_equal(perm, ordered_perm):
op.type = MaceOp.Identity.name
del op.input[1:]
if 'perm' in node.attrs:
perm = node.attrs['perm']
ordered_perm = np.sort(perm)
if np.array_equal(perm, ordered_perm):
op.type = MaceOp.Identity.name
else:
dims_arg = op.arg.add()
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend(perm)
@staticmethod
def squeeze_shape(shape, axis):
new_shape = []
if len(axis) > 0:
for i in range(len(shape)):
if i not in axis:
new_shape.append(shape[i])
else:
new_shape = shape
return new_shape
def convert_squeeze(self, node):
axis_value = node.attrs['axes']
if node.inputs[0] in self._consts:
tensor = self._consts[node.inputs[0]]
shape = tensor.dims
new_shape = self.squeeze_shape(shape, axis_value)
del tensor.dims[:]
tensor.dims.extend(new_shape)
self.remove_node(node)
else:
op = self.convert_general_op(node)
op.type = MaceOp.Squeeze.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
if 'axis' in node.attrs:
axis_value = node.attrs['axis']
else:
axis_value = []
axis_arg.ints.extend(axis_value)
@staticmethod
def transpose_const(tensor):
shape = tensor.dims
mace_check(len(shape) == 2, "gemm only supports 2-dim input.")
tensor_data = np.array(tensor.float_data).reshape(
shape[0], shape[1])
tensor_data = tensor_data.transpose(1, 0)
tensor.float_data[:] = tensor_data.flat
tensor.dims[:] = tensor_data.shape
def convert_gemm(self, node): def convert_gemm(self, node):
# only supports FullyConnected Style Gemm for now. # only supports FullyConnected Style Gemm for now.
trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0 trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0
...@@ -915,7 +849,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -915,7 +849,7 @@ class OnnxConverter(base_converter.ConverterInterface):
elif len(shape_b) == 2: elif len(shape_b) == 2:
tensor_b = self._consts[node.inputs[1]] tensor_b = self._consts[node.inputs[1]]
tensor_data = np.array(tensor_b.float_data).reshape( tensor_data = np.array(tensor_b.float_data).reshape(
shape_b[0], shape_b[1], 1, 1) shape_b[0], shape_b[1], 1, 1)
tensor_b.float_data[:] = tensor_data.flat tensor_b.float_data[:] = tensor_data.flat
tensor_b.dims[:] = tensor_data.shape tensor_b.dims[:] = tensor_data.shape
else: else:
...@@ -949,4 +883,224 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -949,4 +883,224 @@ class OnnxConverter(base_converter.ConverterInterface):
shape_info = [shape_info[0], shape_info[1], 1, 1] shape_info = [shape_info[0], shape_info[1], 1, 1]
output_shape.dims.extend(shape_info) output_shape.dims.extend(shape_info)
return op def convert_identity(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Identity.name
def convert_imagescaler(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name
scale = node.attrs['scale']
bias_value = np.array(node.attrs['bias'])
scale_value = scale * np.ones_like(bias_value)
scale_name = node.name + "_scale"
bias_name = node.name + "_bias"
self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT,
scale_value)
self.add_tensor(bias_name, bias_value.shape, mace_pb2.DT_FLOAT,
bias_value)
op.input.extend([scale_name, bias_name])
def convert_lstm(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.LSTMCell.name
def convert_lstm_nonlinear(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.LstmNonlinear.name
def convert_matmul(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.MatMul.name
def convert_nop(self, node):
pass
def convert_normalize(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name
def convert_pnorm(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.PNorm.name
if 'output_dim' in node.attrs:
output_dim_arg = op.arg.add()
output_dim_arg.name = 'output_dim'
output_dim_arg.i = node.attrs['output_dim']
if 'p' in node.attrs:
p_value = node.attrs['p']
mace_check((p_value >= 0) and (p_value <= 2),
"PNorm only supports p = 0, 1, 2")
p_arg = op.arg.add()
p_arg.name = 'p'
p_arg.i = p_value
def convert_pooling(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Pooling.name
self.add_stride_pad_kernel_arg(node.attrs, op)
pooling_type_arg = op.arg.add()
pooling_type_arg.name = MaceKeyword.mace_pooling_type_str
pooling_type_arg.i = self.pooling_type_mode[node.op_type].value
round_mode_arg = op.arg.add()
round_mode_arg.name = MaceKeyword.mace_round_mode_str
round_mode_arg.i = RoundMode.FLOOR.value
def convert_reduce(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Reduce.name
reduce_type_arg = op.arg.add()
reduce_type_arg.name = MaceKeyword.mace_reduce_type_str
reduce_type_arg.i = self.reduce_type[node.op_type].value
if node.op_type in [OnnxOpType.GlobalAveragePool.name,
OnnxOpType.GlobalMaxPool.name]:
reduce_dims = [2, 3]
keep_dims = 1
else:
if 'axes' in node.attrs:
reduce_dims = node.attrs['axes']
else:
reduce_dims = []
if 'keepdims' in node.attrs:
keep_dims = node.attrs['keepdims']
else:
keep_dims = 1
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.ints.extend(reduce_dims)
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
keep_dims_arg.i = keep_dims
def convert_reshape(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Reshape.name
def convert_slice(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Slice.name
mace_check('starts' in node.attrs, "Attribute starts required!")
mace_check('ends' in node.attrs, "Attribute ends required!")
starts = node.attrs['starts']
starts_arg = op.arg.add()
starts_arg.name = 'starts'
starts_arg.ints.extend(starts)
ends = node.attrs['ends']
ends_arg = op.arg.add()
ends_arg.name = 'ends'
ends_arg.ints.extend(ends)
if 'axes' in node.attrs:
axes = node.attrs['axes']
axes_arg = op.arg.add()
axes_arg.name = 'axes'
axes_arg.ints.extend(axes)
def convert_softmax(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Softmax.name
# TODO: add logsoftmax in softmax op
# if node.op_type == OnnxOpType.LogSoftmax.name:
# use_log_arg = op.arg.add()
# use_log_arg.name = 'use_log'
# use_log_arg.i = 1
def convert_splice(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Splice.name
if 'context' in node.attrs:
context = node.attrs['context']
else:
context = [0]
context_arg = op.arg.add()
context_arg.name = 'context'
context_arg.ints.extend(context)
if 'const_component_dim' in node.attrs:
const_dim = node.attrs['const_component_dim']
else:
const_dim = 0
const_dim_arg = op.arg.add()
const_dim_arg.name = 'const_component_dim'
const_dim_arg.i = const_dim
def convert_split(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Split.name
if 'axis' in node.attrs:
value = node.attrs['axis']
else:
value = 0
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = value
def convert_squeeze(self, node):
axis_value = node.attrs['axes']
if node.inputs[0] in self._consts:
tensor = self._consts[node.inputs[0]]
shape = tensor.dims
new_shape = self.squeeze_shape(shape, axis_value)
del tensor.dims[:]
tensor.dims.extend(new_shape)
self.remove_node(node)
else:
op = self.convert_general_op(node)
op.type = MaceOp.Squeeze.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
if 'axis' in node.attrs:
axis_value = node.attrs['axis']
else:
axis_value = []
axis_arg.ints.extend(axis_value)
def convert_sum_group(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.SumGroup.name
def convert_target_rms_norm(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.TargetRMSNorm.name
if 'target_rms' in node.attrs:
value = node.attrs['target_rms']
target_rms_arg = op.arg.add()
target_rms_arg.name = 'target_rms'
target_rms_arg.f = value
def convert_transpose(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Transpose.name
if 'perm' in node.attrs:
perm = node.attrs['perm']
ordered_perm = np.sort(perm)
if np.array_equal(perm, ordered_perm):
op.type = MaceOp.Identity.name
del op.input[1:]
else:
dims_arg = op.arg.add()
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend(perm)
def convert_timeoffset(self, node):
op = self.convert_general_op(node)
mace_check('offset' in node.attrs,
'Offset attribute required in Offset Node.')
offset = node.attrs['offset']
if offset == 0:
op.type = MaceOp.Identity.name
else:
op.type = MaceOp.TimeOffset.name
offset_arg = op.arg.add()
offset_arg.name = 'offset'
offset_arg.i = offset
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
#ifndef MACE_UTILS_UTILS_H_ #ifndef MACE_UTILS_UTILS_H_
#define MACE_UTILS_UTILS_H_ #define MACE_UTILS_UTILS_H_
#include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <cmath>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -69,6 +71,32 @@ Integer CeilQuotient(Integer a, Integer b) { ...@@ -69,6 +71,32 @@ Integer CeilQuotient(Integer a, Integer b) {
std::string ObfuscateString(const std::string &src, std::string ObfuscateString(const std::string &src,
const std::string &lookup_table); const std::string &lookup_table);
template <typename Integer>
inline Integer Clamp(Integer in, Integer low, Integer high) {
return std::max<Integer>(low, std::min<Integer>(in, high));
}
template <typename T>
inline T ScalarSigmoid(T in) {
if (in > static_cast<T>(0)) {
return static_cast<T>(1) / (static_cast<T>(1) + std::exp(-in));
} else {
T x = std::exp(in);
return x / (x + static_cast<T>(1));
}
}
template <typename T>
inline T ScalarTanh(T in) {
if (in > static_cast<T>(0)) {
T inv_expa = std::exp(-in);
return -static_cast<T>(1) +
static_cast<T>(2) / (static_cast<T>(1) + inv_expa * inv_expa);
} else {
T x = std::exp(in);
return x / (x + static_cast<T>(1));
}
}
std::string ObfuscateString(const std::string &src); std::string ObfuscateString(const std::string &src);
......
...@@ -401,6 +401,7 @@ class YAMLKeyword(object): ...@@ -401,6 +401,7 @@ class YAMLKeyword(object):
graph_optimize_options = 'graph_optimize_options' # internal use for now graph_optimize_options = 'graph_optimize_options' # internal use for now
cl_mem_type = 'cl_mem_type' cl_mem_type = 'cl_mem_type'
backend = 'backend' backend = 'backend'
validation_outputs_data = 'validation_outputs_data'
docker_image_tag = 'docker_image_tag' docker_image_tag = 'docker_image_tag'
dockerfile_path = 'dockerfile_path' dockerfile_path = 'dockerfile_path'
dockerfile_sha256_checksum = 'dockerfile_sha256_checksum' dockerfile_sha256_checksum = 'dockerfile_sha256_checksum'
......
...@@ -476,6 +476,14 @@ def format_model_config(flags): ...@@ -476,6 +476,14 @@ def format_model_config(flags):
onnx_backend = subgraph.get( onnx_backend = subgraph.get(
YAMLKeyword.backend, "tensorflow") YAMLKeyword.backend, "tensorflow")
subgraph[YAMLKeyword.backend] = onnx_backend subgraph[YAMLKeyword.backend] = onnx_backend
validation_outputs_data = subgraph.get(
YAMLKeyword.validation_outputs_data, [])
if not isinstance(validation_outputs_data, list):
subgraph[YAMLKeyword.validation_outputs_data] = [
validation_outputs_data]
else:
subgraph[YAMLKeyword.validation_outputs_data] = \
validation_outputs_data
input_ranges = subgraph.get( input_ranges = subgraph.get(
YAMLKeyword.input_ranges, []) YAMLKeyword.input_ranges, [])
if not isinstance(input_ranges, list): if not isinstance(input_ranges, list):
......
...@@ -660,6 +660,8 @@ class DeviceWrapper: ...@@ -660,6 +660,8 @@ class DeviceWrapper:
YAMLKeyword.validation_threshold][ YAMLKeyword.validation_threshold][
validate_type], validate_type],
backend=subgraphs[0][YAMLKeyword.backend], backend=subgraphs[0][YAMLKeyword.backend],
validation_outputs_data=subgraphs[0][
YAMLKeyword.validation_outputs_data],
log_file=log_file, log_file=log_file,
) )
if flags.report and flags.round > 0: if flags.report and flags.round > 0:
......
...@@ -656,9 +656,12 @@ def validate_model(abi, ...@@ -656,9 +656,12 @@ def validate_model(abi,
output_file_name="model_out", output_file_name="model_out",
validation_threshold=0.9, validation_threshold=0.9,
backend="tensorflow", backend="tensorflow",
log_file="", validation_outputs_data=[],
): log_file=""):
six.print_("* Validate with %s" % platform) if not validation_outputs_data:
six.print_("* Validate with %s" % platform)
else:
six.print_("* Validate with file: %s" % validation_outputs_data)
if abi != "host": if abi != "host":
for output_name in output_nodes: for output_name in output_nodes:
formatted_name = common.formatted_file_name( formatted_name = common.formatted_file_name(
...@@ -675,6 +678,7 @@ def validate_model(abi, ...@@ -675,6 +678,7 @@ def validate_model(abi,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes), ",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend, validation_threshold, ",".join(input_data_types), backend,
validation_outputs_data,
log_file) log_file)
elif platform == "onnx": elif platform == "onnx":
validate(platform, model_file_path, "", validate(platform, model_file_path, "",
...@@ -683,6 +687,7 @@ def validate_model(abi, ...@@ -683,6 +687,7 @@ def validate_model(abi,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes), ",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend, validation_threshold, ",".join(input_data_types), backend,
validation_outputs_data,
log_file) log_file)
elif platform == "caffe": elif platform == "caffe":
image_name = "mace-caffe:" + docker_image_tag image_name = "mace-caffe:" + docker_image_tag
...@@ -700,6 +705,7 @@ def validate_model(abi, ...@@ -700,6 +705,7 @@ def validate_model(abi,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes), ",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend, validation_threshold, ",".join(input_data_types), backend,
validation_outputs_data,
log_file) log_file)
elif caffe_env == common.CaffeEnvType.DOCKER: elif caffe_env == common.CaffeEnvType.DOCKER:
docker_image_id = sh.docker("images", "-q", image_name) docker_image_id = sh.docker("images", "-q", image_name)
...@@ -767,6 +773,8 @@ def validate_model(abi, ...@@ -767,6 +773,8 @@ def validate_model(abi,
"--validation_threshold=%f" % validation_threshold, "--validation_threshold=%f" % validation_threshold,
"--input_data_type=%s" % ",".join(input_data_types), "--input_data_type=%s" % ",".join(input_data_types),
"--backend=%s" % ",".join(backend), "--backend=%s" % ",".join(backend),
"--validation_outputs_data=%s" % ",".join(
validation_outputs_data),
"--log_file=%s" % log_file, "--log_file=%s" % log_file,
_fg=True) _fg=True)
......
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import os.path import os.path
import numpy as np import numpy as np
import re import re
import six
import common import common
...@@ -121,6 +122,32 @@ def normalize_tf_tensor_name(name): ...@@ -121,6 +122,32 @@ def normalize_tf_tensor_name(name):
return name return name
def validate_with_file(platform, device_type,
output_names, output_shapes,
mace_out_file, validation_outputs_data,
validation_threshold, log_file):
for i in range(len(output_names)):
if validation_outputs_data[i].startswith("http://") or \
validation_outputs_data[i].startswith("https://"):
validation_file_name = common.formatted_file_name(
mace_out_file, output_names[i] + '_validation')
six.moves.urllib.request.urlretrieve(validation_outputs_data[i],
validation_file_name)
else:
validation_file_name = validation_outputs_data[i]
value = load_data(validation_file_name)
out_shape = output_shapes[i]
if len(out_shape) == 4:
out_shape[1], out_shape[2], out_shape[3] = \
out_shape[3], out_shape[1], out_shape[2]
value = value.reshape(out_shape).transpose((0, 2, 3, 1))
output_file_name = common.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
compare_output(platform, device_type, output_names[i], mace_out_value,
value, validation_threshold, log_file)
def validate_tf_model(platform, device_type, model_file, input_file, def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, mace_out_file, input_names, input_shapes,
output_names, validation_threshold, input_data_types, output_names, validation_threshold, input_data_types,
...@@ -275,7 +302,8 @@ def validate_onnx_model(platform, device_type, model_file, input_file, ...@@ -275,7 +302,8 @@ def validate_onnx_model(platform, device_type, model_file, input_file,
def validate(platform, model_file, weight_file, input_file, mace_out_file, def validate(platform, model_file, weight_file, input_file, mace_out_file,
device_type, input_shape, output_shape, input_node, output_node, device_type, input_shape, output_shape, input_node, output_node,
validation_threshold, input_data_type, backend, log_file): validation_threshold, input_data_type, backend,
validation_outputs_data, log_file):
input_names = [name for name in input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')] input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')] input_shapes = [[int(x) for x in shape.split(',')]
...@@ -287,8 +315,21 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file, ...@@ -287,8 +315,21 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
input_data_types = ['float32'] * len(input_names) input_data_types = ['float32'] * len(input_names)
output_names = [name for name in output_node.split(',')] output_names = [name for name in output_node.split(',')]
assert len(input_names) == len(input_shapes) assert len(input_names) == len(input_shapes)
if not isinstance(validation_outputs_data, list):
if platform == 'tensorflow': if os.path.isfile(validation_outputs_data):
validation_outputs = [validation_outputs_data]
else:
validation_outputs = []
else:
validation_outputs = validation_outputs_data
if validation_outputs:
output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')]
for shape in output_shape_strs]
validate_with_file(platform, device_type, output_names, output_shapes,
mace_out_file, validation_outputs,
validation_threshold, log_file)
elif platform == 'tensorflow':
validate_tf_model(platform, device_type, model_file, input_file, validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, mace_out_file, input_names, input_shapes,
output_names, validation_threshold, input_data_types, output_names, validation_threshold, input_data_types,
...@@ -358,10 +399,10 @@ def parse_args(): ...@@ -358,10 +399,10 @@ def parse_args():
default="tensorflow", default="tensorflow",
help="onnx backend framwork") help="onnx backend framwork")
parser.add_argument( parser.add_argument(
"--log_file", "--validation_outputs_data", type=str,
type=str, default="", help="validation outputs data file path.")
default="", parser.add_argument(
help="log file") "--log_file", type=str, default="", help="log file.")
return parser.parse_known_args() return parser.parse_known_args()
...@@ -381,4 +422,5 @@ if __name__ == '__main__': ...@@ -381,4 +422,5 @@ if __name__ == '__main__':
FLAGS.validation_threshold, FLAGS.validation_threshold,
FLAGS.input_data_type, FLAGS.input_data_type,
FLAGS.backend, FLAGS.backend,
FLAGS.validation_outputs_data,
FLAGS.log_file) FLAGS.log_file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册