From 8911cffba3886b266b1606ebe374e1b6954d7c57 Mon Sep 17 00:00:00 2001 From: liutuo Date: Mon, 17 Dec 2018 13:45:06 +0800 Subject: [PATCH] add splice affine pnorm sumgroup for kaldi-model support kaldi-onnx model convert support validate output data from url --- mace/ops/arm/fp32/gemv.cc | 5 +- mace/ops/arm/q8/gemv.cc | 4 +- mace/ops/lstm_cell.cc | 2 + mace/ops/ops_registry.cc | 12 + mace/ops/pnorm.cc | 133 +++ mace/ops/pnorm_benchmark.cc | 77 ++ mace/ops/pnorm_test.cc | 70 ++ mace/ops/slice.cc | 94 +++ mace/ops/slice_test.cc | 71 ++ mace/ops/splice.cc | 120 +++ mace/ops/splice_benchmark.cc | 92 ++ mace/ops/splice_test.cc | 84 ++ mace/ops/sum_group.cc | 107 +++ mace/ops/sum_group_benchmark.cc | 75 ++ mace/ops/sum_group_test.cc | 71 ++ mace/ops/target_rms_norm.cc | 116 +++ mace/ops/target_rms_norm_benchmark.cc | 74 ++ mace/ops/target_rms_norm_test.cc | 62 ++ mace/ops/time_offset.cc | 80 ++ mace/ops/time_offset_benchmark.cc | 78 ++ mace/ops/time_offset_test.cc | 125 +++ .../tools/converter_tool/base_converter.py | 10 + .../tools/converter_tool/onnx_converter.py | 792 +++++++++++------- mace/utils/utils.h | 28 + tools/common.py | 1 + tools/converter.py | 8 + tools/device.py | 2 + tools/sh_commands.py | 14 +- tools/validate.py | 56 +- 29 files changed, 2130 insertions(+), 333 deletions(-) create mode 100644 mace/ops/pnorm.cc create mode 100644 mace/ops/pnorm_benchmark.cc create mode 100644 mace/ops/pnorm_test.cc create mode 100644 mace/ops/slice.cc create mode 100644 mace/ops/slice_test.cc create mode 100644 mace/ops/splice.cc create mode 100644 mace/ops/splice_benchmark.cc create mode 100644 mace/ops/splice_test.cc create mode 100644 mace/ops/sum_group.cc create mode 100644 mace/ops/sum_group_benchmark.cc create mode 100644 mace/ops/sum_group_test.cc create mode 100644 mace/ops/target_rms_norm.cc create mode 100644 mace/ops/target_rms_norm_benchmark.cc create mode 100644 mace/ops/target_rms_norm_test.cc create mode 100644 mace/ops/time_offset.cc create mode 100644 mace/ops/time_offset_benchmark.cc create mode 100644 mace/ops/time_offset_test.cc diff --git a/mace/ops/arm/fp32/gemv.cc b/mace/ops/arm/fp32/gemv.cc index 703e3944..89474fbb 100644 --- a/mace/ops/arm/fp32/gemv.cc +++ b/mace/ops/arm/fp32/gemv.cc @@ -258,11 +258,12 @@ MaceStatus Gemv::Compute(const OpContext *context, ++rhs_ptr; } - float32x4_t vbias = vdupq_n_f32(0); if (bias) { + float32x4_t vbias = vdupq_n_f32(0); vbias = vld1q_f32(bias_data + h_start); + vo = vaddq_f32(vo, vbias); } - vo = vaddq_f32(vo, vbias); + vst1q_f32(ret_ptr, vo); } else { // h_block_len < 4 #endif // MACE_GEMV_UNROLL diff --git a/mace/ops/arm/q8/gemv.cc b/mace/ops/arm/q8/gemv.cc index 790a1448..f61062f4 100644 --- a/mace/ops/arm/q8/gemv.cc +++ b/mace/ops/arm/q8/gemv.cc @@ -376,11 +376,11 @@ MaceStatus Gemv::Compute(const OpContext *context, ++rhs_ptr; } - int32x4_t vbias = vdupq_n_s32(0); if (bias) { + int32x4_t vbias = vdupq_n_s32(0); vbias = vld1q_s32(bias_data + h_offset); + vo = vaddq_s32(vo, vbias); } - vo = vaddq_s32(vo, vbias); if (is_output_type_uint8) { int32x4_t vo_mul = vqrdmulhq_s32(vo, voutput_multiplier); diff --git a/mace/ops/lstm_cell.cc b/mace/ops/lstm_cell.cc index a342cef8..bc5af8f5 100644 --- a/mace/ops/lstm_cell.cc +++ b/mace/ops/lstm_cell.cc @@ -25,6 +25,7 @@ namespace ops { template class LSTMCellOp; +#ifdef MACE_ENABLE_OPENCL template class LSTMCellOp : public Operation { public: @@ -88,6 +89,7 @@ class LSTMCellOp : public Operation { MACE_OP_INPUT_TAGS(INPUT, PRE_OUTPUT, WEIGHT, BIAS, PRE_CELL); MACE_OP_OUTPUT_TAGS(CELL, OUTPUT); }; +#endif void RegisterLSTMCell(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "LSTMCell", LSTMCellOp, diff --git a/mace/ops/ops_registry.cc b/mace/ops/ops_registry.cc index 5780483a..52f22e4a 100644 --- a/mace/ops/ops_registry.cc +++ b/mace/ops/ops_registry.cc @@ -43,6 +43,7 @@ extern void RegisterInferConv2dShape(OpRegistryBase *op_registry); extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry); extern void RegisterMatMul(OpRegistryBase *op_registry); extern void RegisterPad(OpRegistryBase *op_registry); +extern void RegisterPNorm(OpRegistryBase *op_registry); extern void RegisterPooling(OpRegistryBase *op_registry); extern void RegisterReduce(OpRegistryBase *op_registry); extern void RegisterPriorBox(OpRegistryBase *op_registry); @@ -53,14 +54,19 @@ extern void RegisterResizeNearestNeighbor(OpRegistryBase *op_registry); extern void RegisterReverse(OpRegistryBase *op_registry); extern void RegisterScalarMath(OpRegistryBase *op_registry); extern void RegisterShape(OpRegistryBase *op_registry); +extern void RegisterSlice(OpRegistryBase *op_registry); extern void RegisterSoftmax(OpRegistryBase *op_registry); extern void RegisterSpaceToBatchND(OpRegistryBase *op_registry); extern void RegisterSpaceToDepth(OpRegistryBase *op_registry); +extern void RegisterSplice(OpRegistryBase *op_registry); extern void RegisterSplit(OpRegistryBase *op_registry); extern void RegisterSqrDiffMean(OpRegistryBase *op_registry); extern void RegisterSqueeze(OpRegistryBase *op_registry); extern void RegisterStack(OpRegistryBase *op_registry); extern void RegisterStridedSlice(OpRegistryBase *op_registry); +extern void RegisterSumGroup(OpRegistryBase *op_registry); +extern void RegisterTargetRMSNorm(OpRegistryBase *op_registry); +extern void RegisterTimeOffset(OpRegistryBase *op_registry); extern void RegisterTranspose(OpRegistryBase *op_registry); extern void RegisterUnstack(OpRegistryBase *op_registry); @@ -103,6 +109,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ops::RegisterLocalResponseNorm(this); ops::RegisterMatMul(this); ops::RegisterPad(this); + ops::RegisterPNorm(this); ops::RegisterPooling(this); ops::RegisterReduce(this); ops::RegisterPriorBox(this); @@ -113,14 +120,19 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ops::RegisterReverse(this); ops::RegisterScalarMath(this); ops::RegisterShape(this); + ops::RegisterSlice(this); ops::RegisterSoftmax(this); ops::RegisterSpaceToBatchND(this); ops::RegisterSpaceToDepth(this); + ops::RegisterSplice(this); ops::RegisterSplit(this); ops::RegisterStack(this); ops::RegisterStridedSlice(this); ops::RegisterSqrDiffMean(this); ops::RegisterSqueeze(this); + ops::RegisterSumGroup(this); + ops::RegisterTargetRMSNorm(this); + ops::RegisterTimeOffset(this); ops::RegisterTranspose(this); ops::RegisterUnstack(this); diff --git a/mace/ops/pnorm.cc b/mace/ops/pnorm.cc new file mode 100644 index 00000000..8742a3b4 --- /dev/null +++ b/mace/ops/pnorm.cc @@ -0,0 +1,133 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This Op is for PNormComponent in Kaldi. +// The input-dim must be dividable by output-dim. +// The output will be divided to output-dim group, +// so input-dim should be dividable by output-dim. +// For each row: +// p is 0: output[i] = sum(abs(input[i*group + j]) > 0) +// p is 1: output[i] = sum(abs(input[i*group + j])) +// p is 2: output[i] = sqrt(sum(input[i * group + j] * input[i * group + j])), +// for j = (0 : group - 1) +// p's default value is 2. + +#include +#include + +#include "mace/core/operator.h" + + +namespace mace { +namespace ops { + +template +class PNormOp; + +template +class PNormOp : public Operation { + public: + explicit PNormOp(OpConstructContext *context) + : Operation(context), + p_(Operation::GetOptionalArg("p", 2)), + output_dim_(Operation::GetOptionalArg("output_dim", 0)) {} + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(0); + Tensor *output = this->Output(0); + + + + const std::vector &input_shape = input->shape(); + const index_t dim_size = input_shape.size(); + MACE_CHECK(dim_size >= 1, "PNorm only supports input dim size >= 1"); + std::vector output_shape(input_shape); + const index_t input_dim = input_shape[dim_size -1]; + MACE_CHECK(output_dim_ > 0, + "Output dim should be greater than zero."); + MACE_CHECK(input_dim % output_dim_ == 0 && output_dim_ < input_dim, + "PNorm's input dim should be a multiple of output dim."); + const index_t group_size = input_dim / output_dim_; + output_shape[dim_size - 1] = output_dim_; + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + + Tensor::MappingGuard guard_input(input); + Tensor::MappingGuard guard_output(output); + + const T *input_data = input->data(); + T *output_data = output->mutable_data(); + const index_t bh = + std::accumulate(input->shape().begin(), input->shape().end() - 1, 1, + std::multiplies()); + if (p_ == 0) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (index_t i = 0; i < bh; ++i) { + for (index_t j = 0; j < output_dim_; ++j) { + const T *in_base = input_data + i * input_dim + j * group_size; + T *out_base = output_data + i * output_dim_; + T temp_result = 0; + for (index_t g = 0; g < group_size; ++g) { + T value = + (std::fabs(in_base[g]) + > std::numeric_limits::epsilon()) ? 1.0f : 0.0f; + temp_result += value; + } + out_base[j] = temp_result; + } + } + } else if (p_ == 1) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (index_t i = 0; i < bh; ++i) { + for (index_t j = 0; j < output_dim_; ++j) { + const T *in_base = input_data + i * input_dim + j * group_size; + T *out_base = output_data + i * output_dim_; + T temp_result = 0; + for (index_t g = 0; g < group_size; ++g) { + temp_result += std::abs(in_base[g]);; + } + out_base[j] = temp_result; + } + } + } else if (p_ == 2) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (index_t i = 0; i < bh; ++i) { + for (index_t j = 0; j < output_dim_; ++j) { + const T *in_base = input_data + i * input_dim + j * group_size; + T *out_base = output_data + i * output_dim_; + T temp_result = 0; + for (index_t g = 0; g < group_size; ++g) { + temp_result += in_base[g] * in_base[g]; + } + out_base[j] = std::sqrt(temp_result); + } + } + } else { + LOG(FATAL) << "PNorm's p should be 0, 1 or 2, here p is: " << p_; + } + return MaceStatus::MACE_SUCCESS; + } + + private: + int p_; + int output_dim_; +}; + +void RegisterPNorm(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "PNorm", PNormOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/pnorm_benchmark.cc b/mace/ops/pnorm_benchmark.cc new file mode 100644 index 00000000..e3af765c --- /dev/null +++ b/mace/ops/pnorm_benchmark.cc @@ -0,0 +1,77 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void PNormBenchmark(int iters, int n, int h, int w, int p, int ow) { + mace::testing::StopTiming(); + + OpsTestNet net; + // Add input data + net.AddRandomInput("Input", {n, h, w}); + + OpDefBuilder("PNorm", "PNormBM") + .Input("Input") + .AddIntArg("p", p) + .AddIntArg("output_dim", ow) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + net.Sync(); + } +} +} // namespace + +#define MACE_BM_PNORM_MACRO(N, H, W, P, OW, TYPE, DEVICE) \ + static void \ + MACE_BM_PNORM_##N##_##H##_##W##_##P##_##OW##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * H * W; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + PNormBenchmark(iters, N, H, W, P, OW); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_PNORM_##N##_##H##_##W##_##P##_##OW##_##TYPE##_##DEVICE) + +#define MACE_BM_PNORM(N, H, W, P, OW) \ + MACE_BM_PNORM_MACRO(N, H, W, P, OW, float, CPU); + +MACE_BM_PNORM(1, 10, 256, 0, 128); +MACE_BM_PNORM(1, 20, 128, 1, 64); +MACE_BM_PNORM(1, 10, 128, 2, 64); +MACE_BM_PNORM(1, 16, 256, 0, 128); +MACE_BM_PNORM(1, 32, 128, 1, 64); +MACE_BM_PNORM(1, 10, 512, 2, 256); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/pnorm_test.cc b/mace/ops/pnorm_test.cc new file mode 100644 index 00000000..35108682 --- /dev/null +++ b/mace/ops/pnorm_test.cc @@ -0,0 +1,70 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class PNormOpTest : public OpsTestBase {}; + +namespace { +template +void TestPNorm(const std::vector &input_shape, + const std::vector &input, + const int p, + const int output_dim, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray(MakeString("Input"), + input_shape, + input); + + OpDefBuilder("PNorm", "PNormTest") + .Input("Input") + .AddIntArg("p", p) + .AddIntArg("output_dim", output_dim) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} +} // namespace + +TEST_F(PNormOpTest, SimpleTest) { + TestPNorm( + {1, 5, 10}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + 2, 5, + {1, 5, 5}, + {2.236067977, 5, 7.810249676, 10.630145813, 13.453624047, + 5, 7.810249676, 10.630145813, 13.453624047, 16.278820596, + 7.810249676, 10.630145813, 13.453624047, 16.278820596, 19.104973175, + 10.630145813, 13.453624047, 16.278820596, 19.104973175, 21.931712199, + 13.453624047, 16.278820596, 19.104973175, 21.931712199, 24.758836806}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/slice.cc b/mace/ops/slice.cc new file mode 100644 index 00000000..f38a2a32 --- /dev/null +++ b/mace/ops/slice.cc @@ -0,0 +1,94 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class SliceOp; + +template +class SliceOp : public Operation { + public: + explicit SliceOp(OpConstructContext *context) + : Operation(context), + axes_(Operation::GetRepeatedArgs("axes")), + starts_(Operation::GetRepeatedArgs("starts")), + ends_(Operation::GetRepeatedArgs("ends")) {} + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(0); + Tensor *output = this->Output(0); + + const index_t rank = input->dim_size(); + MACE_CHECK(rank >= 1) + << "The input dim size should >= 1"; + MACE_CHECK(starts_.size() == 1 && ends_.size() == 1 && axes_.size() == 1, + "only support slicing at one axis."); + MACE_CHECK(axes_[0] == -1 || axes_[0] == rank - 1, + "only support slicing at the last axis."); + const index_t input_dim = input->dim(rank - 1); + const index_t offset = starts_[0]; + const index_t output_dim = ends_[0] - starts_[0]; + + MACE_CHECK(output_dim >= 0, "output_dim should >= 0"); + MACE_CHECK(starts_[0] < input_dim + && output_dim <= input_dim + && ends_[0] <= input_dim) + << "The starts and ends caused over range error."; + + const index_t frames = + std::accumulate(input->shape().begin(), input->shape().end() - 1, 1, + std::multiplies()); + + std::vector output_shape = input->shape(); + output_shape[rank - 1] = output_dim; + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + const T *input_data = input->data(); + T *output_data = output->mutable_data(); + +#pragma omp parallel for schedule(runtime) + for (index_t i = 0; i < frames; ++i) { + const T *input_base = + input_data + i * input_dim + offset; + T *output_base = + output_data + i * output_dim; + memcpy(output_base, input_base, output_dim * sizeof(T)); + } + + return MaceStatus::MACE_SUCCESS; + } + + private: + std::vector axes_; + std::vector starts_; + std::vector ends_; +}; + +void RegisterSlice(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "Slice", SliceOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/slice_test.cc b/mace/ops/slice_test.cc new file mode 100644 index 00000000..a5f82cc1 --- /dev/null +++ b/mace/ops/slice_test.cc @@ -0,0 +1,71 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class SliceOpTest : public OpsTestBase {}; + +namespace { +template +void TestSlice(const std::vector &input_shape, + const std::vector &input, + const int offset, + const int output_dim, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray(MakeString("Input"), + input_shape, + input); + + OpDefBuilder("Slice", "SliceTest") + .Input("Input") + .Output("Output") + .AddIntsArg("axes", {-1}) + .AddIntsArg("starts", {offset}) + .AddIntsArg("ends", {offset + output_dim}) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} +} // namespace + +TEST_F(SliceOpTest, Simple2Dim) { + TestSlice( + {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + 2, 3, {3, 3}, + {3, 4, 5, 8, 9, 10, 13, 14, 15}); +} + +TEST_F(SliceOpTest, Simple3Dim) { + TestSlice( + {2, 3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + 1, 2, {2, 3, 2}, + {2, 3, 7, 8, 12, 13, 2, 3, 7, 8, 12, 13}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/splice.cc b/mace/ops/splice.cc new file mode 100644 index 00000000..9093b8aa --- /dev/null +++ b/mace/ops/splice.cc @@ -0,0 +1,120 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This Op is for SpliceComponent in Kaldi. +// It splices a context window of frames together [over time] +// (copy and append the frame whose time-index in in context_) +// The context_ values indicate which frame (over time) to splice. +// if context value is less than the first time-index, +// copy and append the first frame's dada, +// when context value is larger than frame's count, +// copy and append the last frame's data. +// i.e., give input data: [[1, 2, 3], [4, 5, 6]], +// with input-dim = 3, frame count = 2, context = [-1, 0, 1] +// Then, the output should be: +// [1, 2, 3, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 4, 5, 6] +// if const_component_dim_ != 0, const_dim_ will be used to determine which +// row of "in" we copy the last part of each row of "out" from (this part is +// not subject to splicing, it's assumed constant for each frame of "input". + +#include +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class SpliceOp; + +template +class SpliceOp : public Operation { + public: + explicit SpliceOp(OpConstructContext *context) + : Operation(context), + context_(Operation::GetRepeatedArgs("context")), + const_dim_( + Operation::GetOptionalArg("const_component_dim", 0)) {} + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(0); + MACE_CHECK(context_.size() > 0) + << "The context param should not be empty in Splice Op."; + + Tensor *output = this->Output(0); + const std::vector &input_shape = input->shape(); + + const index_t frames = + std::accumulate(input->shape().begin(), input->shape().end() - 1, 1, + std::multiplies()); + + const index_t rank = input->dim_size(); + const index_t input_dim = input_shape[rank - 1]; + + const index_t num_splice = static_cast(context_.size()); + const index_t dim = input_dim - const_dim_; + MACE_CHECK(input_dim > const_dim_, + "input dim should be greater than const dim."); + const index_t output_dim = dim * num_splice + const_dim_; + + std::vector output_shape = input->shape(); + output_shape[rank - 1] = output_dim; + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + const T *input_data = input->data(); + T *output_data = output->mutable_data(); + +#pragma omp parallel for collapse(2) schedule(runtime) + for (index_t i = 0; i < frames; ++i) { + for (index_t c = 0; c < num_splice; ++c) { + const index_t offset = + Clamp(context_[c] + i, 0, frames - 1); + T *output_base = output_data + i * output_dim + c * dim; + const T *input_base = input_data + offset * input_dim; + memcpy(output_base, input_base, dim * sizeof(T)); + } + } + + if (const_dim_ > 0) { + const index_t output_offset = output_dim - const_dim_; + const index_t input_offset = dim; +#pragma omp parallel for schedule(runtime) + for (index_t i = 0; i < frames; ++i) { + index_t offset = i + context_[0] >= 0 ? i + context_[0] : 0; + T *output_base = output_data + i * output_dim; + const T *input_base = input_data + offset * input_dim; + memcpy(output_base + output_offset, + input_base + input_offset, + const_dim_ * sizeof(T)); + } + } + return MaceStatus::MACE_SUCCESS; + } + + private: + std::vector context_; + int const_dim_; +}; + +void RegisterSplice(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "Splice", SpliceOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/splice_benchmark.cc b/mace/ops/splice_benchmark.cc new file mode 100644 index 00000000..253808b8 --- /dev/null +++ b/mace/ops/splice_benchmark.cc @@ -0,0 +1,92 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void BMSpliceHelper(int iters, + const std::vector &input_shape, + const index_t left_context, + const index_t right_context, + const int const_component_dim) { + mace::testing::StopTiming(); + + // Construct graph + OpsTestNet net; + + const int num_splice = left_context + right_context + 1; + std::vector contexts(num_splice); + for (int i = 0; i < num_splice; ++i) { + contexts[i] = left_context + i; + } + const index_t input_size = std::accumulate(input_shape.begin(), + input_shape.end(), + 1, + std::multiplies()); + std::vector input_data(input_size); + GenerateRandomRealTypeData(input_shape, &input_data); + net.AddInputFromArray("Input", input_shape, input_data); + + OpDefBuilder("Splice", "SpliceTest") + .Input("Input") + .Output("Output") + .AddIntsArg("context", contexts) + .AddIntArg("const_component_dim", const_component_dim) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + net.Sync(); + } + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + net.Sync(); + } +} +} // namespace + +#define MACE_BM_SPLICE_MACRO(N, H, W, L, R, C, TYPE, DEVICE) \ + static void \ + MACE_BM_SPLICE_##N##_##H##_##W##_##L##_##R##_##C##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * H * W; \ + mace::testing::MacsProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + BMSpliceHelper(iters, {N, H, W}, L, R, C); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_SPLICE_##N##_##H##_##W##_##L##_##R##_##C##_##TYPE##_##DEVICE) + +#define MACE_BM_SPLICE(N, H, W, L, R, C) \ + MACE_BM_SPLICE_MACRO(N, H, W, L, R, C, float, CPU); + +MACE_BM_SPLICE(1, 32, 32, 5, 5, 10); +MACE_BM_SPLICE(1, 32, 32, 7, 7, 5); +MACE_BM_SPLICE(1, 32, 32, 3, 3, 20); +MACE_BM_SPLICE(1, 128, 128, 9, 9, 100); +MACE_BM_SPLICE(1, 128, 128, 7, 7, 100); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/splice_test.cc b/mace/ops/splice_test.cc new file mode 100644 index 00000000..60e1652a --- /dev/null +++ b/mace/ops/splice_test.cc @@ -0,0 +1,84 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class SpliceOpTest : public OpsTestBase {}; + +namespace { +template +void TestSplice(const std::vector &input_shape, + const std::vector &input, + const std::vector &context, + const int const_dim, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray(MakeString("Input"), + input_shape, + input); + + OpDefBuilder("Splice", "SpliceTest") + .Input("Input") + .Output("Output") + .AddIntsArg("context", context) + .AddIntArg("const_component_dim", const_dim) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} +} // namespace + +TEST_F(SpliceOpTest, WithoutConstDim) { + TestSplice( + {1, 7, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}, + {-2, -1, 0, 1, 2}, 0, + {1, 7, 10}, + {1, 2, 1, 2, 1, 2, 3, 4, 5, 6, + 1, 2, 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 7, 8, 9, 10, 11, 12, 13, 14, 13, 14, + 9, 10, 11, 12, 13, 14, 13, 14, 13, 14}); +} + +TEST_F(SpliceOpTest, WithConstDim) { + TestSplice( + {1, 5, 10}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}, + {-2, -1, 0, 1, 2}, 7, + {1, 5, 22}, + {1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10, + 1, 2, 3, 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 4, 5, 6, 7, 8, 9, 10, + 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10, + 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 5, 6, 7, 5, 6, 7, 8, 9, 10, 11, + 3, 4, 5, 4, 5, 6, 5, 6, 7, 5, 6, 7, 5, 6, 7, 6, 7, 8, 9, 10, 11, 12}); +} +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/sum_group.cc b/mace/ops/sum_group.cc new file mode 100644 index 00000000..21c83b68 --- /dev/null +++ b/mace/ops/sum_group.cc @@ -0,0 +1,107 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This Op is for SumGroupComponent in Kaldi. +// It's used to sum up groups of posteriors, +// and to introduce a kind of Gaussian-mixture-model-like +// idea into neural nets. + +#include +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class SumGroupOp; + +template +class SumGroupOp : public Operation { + public: + explicit SumGroupOp(OpConstructContext *context) + : Operation(context) {} + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + MACE_CHECK(this->InputSize() >= 2, + "SumGroup should have at least 2 inputs."); + const Tensor *input = this->Input(0); + // Sizes-input gets a vector saying, for + // each output-dim, how many + // inputs data were summed over. + const Tensor *sizes = this->Input(1); + Tensor *output = this->Output(0); + MACE_CHECK(input->dim_size() >= 1, + "SumGroup's input's rank should be >= 1."); + MACE_CHECK(sizes->dim_size() == 1, + "SumGroup's sizes input should be a vector."); + + const std::vector &input_shape = input->shape(); + const index_t bh = + std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, + std::multiplies()); + std::vector output_shape(input_shape); + const index_t output_dim = sizes->dim(0); + const index_t dim_size = input->dim_size(); + const index_t input_dim = input_shape[dim_size -1]; + output_shape[dim_size - 1] = output_dim; + + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + Tensor::MappingGuard guard_input(input); + Tensor::MappingGuard guard_sizes(sizes); + Tensor::MappingGuard guard_output(output); + const T *input_data = input->data(); + const int *sizes_data = sizes->data(); + T *output_data = output->mutable_data(); + + std::vector> + sum_indexes(static_cast(output_dim)); + + int cur_index = 0; + for (index_t i = 0; i < output_dim; ++i) { + int size_value = sizes_data[i]; + MACE_CHECK(size_value > 0, "size value should be > 0"); + sum_indexes[i].first = cur_index; + cur_index += size_value; + sum_indexes[i].second = cur_index; + MACE_CHECK(cur_index <= input_dim) + << "size value over-ranged:" << cur_index << "<=" << input_dim; + } + +#pragma omp parallel for collapse(2) schedule(runtime) + for (index_t i = 0; i < bh; ++i) { + for (index_t j = 0; j < output_dim; ++j) { + int start_col = sum_indexes[j].first; + int end_col = sum_indexes[j].second; + T sum = 0; + for (int src_col = start_col; src_col < end_col; ++src_col) { + sum += input_data[i * input_dim + src_col]; + } + output_data[i * output_dim + j] = sum; + } + } + + return MaceStatus::MACE_SUCCESS; + } +}; + +void RegisterSumGroup(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "SumGroup", SumGroupOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/sum_group_benchmark.cc b/mace/ops/sum_group_benchmark.cc new file mode 100644 index 00000000..bb3b20e8 --- /dev/null +++ b/mace/ops/sum_group_benchmark.cc @@ -0,0 +1,75 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void SumGroupBenchmark(int iters, int n, int h, int w) { + mace::testing::StopTiming(); + OpsTestNet net; + // Add input data + net.AddRandomInput("Input", {n, h, w}); + net.AddRepeatedInput("Sizes", + {w / 2}, + 2); + OpDefBuilder("SumGroup", "SumGroupBM") + .Input("Input") + .Input("Sizes") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + net.Sync(); + } +} +} // namespace + +#define MACE_BM_SUMGROUP_MACRO(N, H, W, TYPE, DEVICE) \ + static void \ + MACE_BM_SUMGROUP_##N##_##H##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * H * W; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + SumGroupBenchmark(iters, N, H, W); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_SUMGROUP_##N##_##H##_##W##_##TYPE##_##DEVICE) + +#define MACE_BM_SUMGROUP(N, H, W) \ + MACE_BM_SUMGROUP_MACRO(N, H, W, float, CPU); + +MACE_BM_SUMGROUP(1, 10, 256); +MACE_BM_SUMGROUP(1, 20, 128); +MACE_BM_SUMGROUP(1, 10, 128); +MACE_BM_SUMGROUP(1, 20, 512); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/sum_group_test.cc b/mace/ops/sum_group_test.cc new file mode 100644 index 00000000..e5a4ef90 --- /dev/null +++ b/mace/ops/sum_group_test.cc @@ -0,0 +1,71 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class SumGroupOpTest : public OpsTestBase {}; + +namespace { +template +void TestSumGroup(const std::vector &input_shape, + const std::vector &input, + const std::vector &sizes, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray(MakeString("Input"), + input_shape, + input); + const index_t output_dim = sizes.size(); + net.AddInputFromArray(MakeString("Sizes"), + {output_dim}, + sizes); + + OpDefBuilder("SumGroup", "SumGroupTest") + .Input("Input") + .Input("Sizes") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} +} // namespace + +TEST_F(SumGroupOpTest, SimpleTest) { + TestSumGroup( + {1, 5, 10}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}, + {2, 1, 2, 3, 2}, + {1, 5, 5}, + {3, 3, 9, 21, 19, + 5, 4, 11, 24, 21, + 7, 5, 13, 27, 23, + 9, 6, 15, 30, 25, + 11, 7, 17, 33, 27}); +} +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/target_rms_norm.cc b/mace/ops/target_rms_norm.cc new file mode 100644 index 00000000..7b769fe7 --- /dev/null +++ b/mace/ops/target_rms_norm.cc @@ -0,0 +1,116 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This op is implemented for kaldi's NormalizeComponent. +// The output y_i = scale * x_i, +// and we want the RMS value of the y_i equals to target_rms, +// so y^t y = Dim * target_rms^2 (if y is one row of the input). +// Dim is the length of a row. +// we need the scale = 1.0 / sqrt(x^t x / (Dim * target_rms^2)). + +#include +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class TargetRMSNormOp; + +template +class TargetRMSNormOp : public Operation { + public: + explicit TargetRMSNormOp(OpConstructContext *context) + : Operation(context), + target_rms_(Operation::GetOptionalArg("target_rms", 1.0)) {} + + // Calculate the square sum of an array + float SquareSum(const float *data, const index_t data_len) { + const int num_parts = 4; + float result = 0.0f; + if (data_len <= 2 * num_parts) { + for (index_t i = 0; i < data_len; ++i) { + result += data[i] * data[i]; + } + } else { + const index_t part_len = data_len / num_parts; + const index_t left_len = data_len % num_parts; + float results[4] = {0.f, 0.f, 0.f, 0.f}; + for (index_t i = 0; i < num_parts; ++i) { + for (index_t j = 0; j < part_len; ++j) { + results[i] += data[i * part_len + j] * data[i * part_len + j]; + } + } + for (index_t k = 0; k < left_len; ++k) { + float d = data[num_parts * part_len + k]; + results[3] += d * d; + } + + for (index_t i = 0; i < num_parts; ++i) { + result += results[i]; + } + } + + return result; + } + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(0); + Tensor *output = this->Output(0); + const std::vector &input_shape = input->shape(); + const index_t dim_size = input->dim_size(); + MACE_CHECK(dim_size >= 1, + "TargetRMSNorm's input dim size should be >= 1."); + const index_t dim = input_shape[dim_size -1]; + MACE_CHECK(dim > 0 && target_rms_ > 0, + "Both input dim and target rms should be greater than zero."); + const index_t bh = + std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, + std::multiplies()); + const float d_scale = dim * target_rms_ * target_rms_; + + MACE_RETURN_IF_ERROR(output->ResizeLike(input)); + + Tensor::MappingGuard guard_input(input); + Tensor::MappingGuard guard_output(output); + + const float *input_data = input->data(); + float *output_data = output->mutable_data(); + +#pragma omp parallel for schedule(runtime) + for (index_t i = 0; i < bh; ++i) { + float scale = SquareSum(input_data + i * dim, dim); + scale = static_cast(1.0 / std::sqrt(scale / d_scale)); + for (index_t j = 0; j < dim; ++j) { + output_data[i * dim + j] = input_data[i * dim + j] * scale; + } + } + + return MaceStatus::MACE_SUCCESS; + } + + private: + float target_rms_; +}; + +void RegisterTargetRMSNorm(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "TargetRMSNorm", TargetRMSNormOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/target_rms_norm_benchmark.cc b/mace/ops/target_rms_norm_benchmark.cc new file mode 100644 index 00000000..d496bb8b --- /dev/null +++ b/mace/ops/target_rms_norm_benchmark.cc @@ -0,0 +1,74 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void TargetRMSNormBenchmark(int iters, int n, int h, int w, float target_rms) { + mace::testing::StopTiming(); + + OpsTestNet net; + // Add input data + net.AddRandomInput("Input", {n, h, w}); + + OpDefBuilder("TargetRMSNorm", "TargetRMSNormBM") + .Input("Input") + .AddFloatArg("target_rms", target_rms) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + net.Sync(); + } +} +} // namespace + +#define MACE_BM_TARGETRMSNORM_MACRO(N, H, W, RMS, TYPE, DEVICE) \ + static void \ + MACE_BM_TARGETRMSNORM_##N##_##H##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * H * W; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + TargetRMSNormBenchmark(iters, N, H, W, RMS); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_TARGETRMSNORM_##N##_##H##_##W##_##TYPE##_##DEVICE) + +#define MACE_BM_TARGETRMSNORM(N, H, W, RMS) \ + MACE_BM_TARGETRMSNORM_MACRO(N, H, W, RMS, float, CPU); + +MACE_BM_TARGETRMSNORM(1, 10, 256, 1.0); +MACE_BM_TARGETRMSNORM(1, 20, 128, 2.0); +MACE_BM_TARGETRMSNORM(1, 10, 128, 0.5); +MACE_BM_TARGETRMSNORM(1, 20, 512, 1.0); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/target_rms_norm_test.cc b/mace/ops/target_rms_norm_test.cc new file mode 100644 index 00000000..95082447 --- /dev/null +++ b/mace/ops/target_rms_norm_test.cc @@ -0,0 +1,62 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class TargetRMSNormOpTest : public OpsTestBase {}; + +namespace { +template +void TestTargetRMSNorm(const std::vector &input_shape, + const std::vector &input, + const float target_rms, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray(MakeString("Input"), + input_shape, + input); + + OpDefBuilder("TargetRMSNorm", "TargetRMSNormTest") + .Input("Input") + .AddFloatArg("target_rms", target_rms) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", input_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} +} // namespace + +TEST_F(TargetRMSNormOpTest, SimpleTest) { + TestTargetRMSNorm( + {1, 3, 3}, + {1, 2, 3, + 2, 3, 4, + 3, 4, 5}, + 1.0, + {0.46291, 0.92582, 1.38873, + 0.64327, 0.9649, 1.28654, + 0.734847, 0.979796, 1.224745}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/time_offset.cc b/mace/ops/time_offset.cc new file mode 100644 index 00000000..2fff53c1 --- /dev/null +++ b/mace/ops/time_offset.cc @@ -0,0 +1,80 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This Op is for offset descriptor in Kaldi. +// It defines time offset. + +#include +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class TimeOffsetOp; + +template +class TimeOffsetOp : public Operation { + public: + explicit TimeOffsetOp(OpConstructContext *context) + : Operation(context), + offset_(Operation::GetOptionalArg("offset", 0)) {} + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(0); + Tensor *output = this->Output(0); + + index_t rank = input->dim_size(); + MACE_CHECK(rank >= 2, "input's rank should >= 2."); + const std::vector &input_shape = input->shape(); + const index_t batch = + std::accumulate(input_shape.begin(), input_shape.end() - 2, 1, + std::multiplies()); + const index_t frames = input_shape[rank - 2]; + const index_t input_dim = input_shape[rank - 1]; + MACE_RETURN_IF_ERROR(output->ResizeLike(input)); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + const T *input_data = input->data(); + T *output_data = output->mutable_data(); + +#pragma omp parallel for collapse(2) schedule(runtime) + for (index_t i = 0; i < batch; ++i) { + for (index_t j = 0; j < frames; ++j) { + index_t time_index = offset_ + j; + index_t index = Clamp(time_index, 0, frames - 1); + T *output_base = output_data + (i * frames + j) * input_dim; + const T *input_base = input_data + (i * frames + index) * input_dim; + memcpy(output_base, input_base, input_dim * sizeof(T)); + } + } + + return MaceStatus::MACE_SUCCESS; + } + + private: + int offset_; +}; + +void RegisterTimeOffset(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "TimeOffset", TimeOffsetOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/time_offset_benchmark.cc b/mace/ops/time_offset_benchmark.cc new file mode 100644 index 00000000..82ea9967 --- /dev/null +++ b/mace/ops/time_offset_benchmark.cc @@ -0,0 +1,78 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void TimeOffsetBenchmark(int iters, + std::vector shape, + int offset) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", shape); + + OpDefBuilder("TimeOffset", "TimeOffsetBM") + .Input("Input") + .Output("Output") + .AddIntArg("offset", offset) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define MACE_BM_TIMEOFFSET2D_MACRO(H, W, TYPE, DEVICE) \ + static void MACE_BM_TIMEOFFSET2D_##H##_##W##_##TYPE##_##DEVICE(\ + int iters) { \ + const int64_t tot = static_cast(iters) * H * W; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + TimeOffsetBenchmark(iters, {H, W}, 1); \ + } \ + MACE_BENCHMARK(MACE_BM_TIMEOFFSET2D_##H##_##W##_##TYPE##_##DEVICE) \ + +#define MACE_BM_TIMEOFFSET2D(H, W) \ + MACE_BM_TIMEOFFSET2D_MACRO(H, W, float, CPU); + + +MACE_BM_TIMEOFFSET2D(20, 128); +MACE_BM_TIMEOFFSET2D(40, 512); +MACE_BM_TIMEOFFSET2D(1, 1024); +MACE_BM_TIMEOFFSET2D(20, 2048); +MACE_BM_TIMEOFFSET2D(20, 512); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/time_offset_test.cc b/mace/ops/time_offset_test.cc new file mode 100644 index 00000000..b32b8c52 --- /dev/null +++ b/mace/ops/time_offset_test.cc @@ -0,0 +1,125 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class TimeOffsetOpTest : public OpsTestBase {}; + +namespace { +template +void TestTimeOffset(const std::vector &input_shape, + const std::vector &input, + const int offset, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray(MakeString("Input"), + input_shape, + input); + + OpDefBuilder("TimeOffset", "TimeOffsetTest") + .Input("Input") + .Output("Output") + .AddIntArg("offset", offset) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", input_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} +} // namespace + +TEST_F(TimeOffsetOpTest, Simple2Dim) { + TestTimeOffset( + {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + -2, + {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); + + TestTimeOffset( + {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + -1, + {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + TestTimeOffset( + {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + 0, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + + TestTimeOffset( + {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + 1, + {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}); + + TestTimeOffset( + {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + 2, + {11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}); +} + + +TEST_F(TimeOffsetOpTest, Simple3Dim) { + TestTimeOffset( + {2, 3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + -2, + {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); + + TestTimeOffset( + {2, 3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + -1, + {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + TestTimeOffset( + {2, 3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + 0, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + + TestTimeOffset( + {2, 3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + 1, + {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}); + + TestTimeOffset( + {2, 3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + 2, + {11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15, + 11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index d4d326ef..c6552100 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -98,6 +98,7 @@ class FrameworkType(Enum): MaceSupportedOps = [ 'Activation', 'AddN', + 'Affine', 'ArgMax', 'BatchNorm', 'BatchToSpaceND', @@ -121,8 +122,10 @@ MaceSupportedOps = [ 'InferConv2dShape', 'LocalResponseNorm', 'LSTMCell', + # 'LstmNonlinear', 'MatMul', 'Pad', + 'PNorm', 'Pooling', 'PriorBox', 'Proposal', @@ -134,6 +137,8 @@ MaceSupportedOps = [ 'ResizeNearestNeighbor', 'Reverse', 'ScalarMath', + 'Slice', + 'Splice', 'Split', 'Shape', 'Squeeze', @@ -144,6 +149,9 @@ MaceSupportedOps = [ 'SpaceToBatchND', 'SpaceToDepth', 'SqrDiffMean', + 'SumGroup', + 'TargetRMSNorm', + 'TimeOffset', 'Transpose', 'WinogradInverseTransform', 'WinogradTransform', @@ -159,6 +167,7 @@ class MaceKeyword(object): mace_buffer_type = 'buffer_type' # arg related str mace_padding_str = 'padding' + mace_padding_type_str = 'padding' mace_padding_values_str = 'padding_values' mace_strides_str = 'strides' mace_dilations_str = 'dilations' @@ -473,6 +482,7 @@ class ConverterOption(object): # Model data format related transformation TransformerRule.TRANSPOSE_FILTERS, TransformerRule.TRANSPOSE_DATA_FORMAT, + TransformerRule.TRANSPOSE_MATMUL_WEIGHT, # Add winograd argument TransformerRule.ADD_WINOGRAD_ARG, # Mace model structure related transformation diff --git a/mace/python/tools/converter_tool/onnx_converter.py b/mace/python/tools/converter_tool/onnx_converter.py index 2f3570d5..6befa478 100644 --- a/mace/python/tools/converter_tool/onnx_converter.py +++ b/mace/python/tools/converter_tool/onnx_converter.py @@ -33,21 +33,23 @@ from mace.python.tools.converter_tool.base_converter import MaceKeyword from mace.python.tools.converter_tool.base_converter import ConverterUtil from mace.python.tools.convert_util import mace_check +import numpy as np + import onnx import onnx.utils -from onnx import helper, shape_inference, numpy_helper, optimizer -import numpy as np -from onnx import mapping -from onnx import TensorProto +from onnx import mapping, numpy_helper, TensorProto from numbers import Number +IS_PYTHON3 = sys.version_info > (3,) OnnxSupportedOps = [ 'Abs', # 'Acos', # 'Acosh', 'Add', + 'Affine', # 'And', + 'Append', 'ArgMax', 'ArgMin', # 'Asin', @@ -68,6 +70,7 @@ OnnxSupportedOps = [ # 'Cos', # 'Cosh', 'DepthToSpace', + 'DimRange', 'Div', 'Dropout', 'Elu', @@ -88,10 +91,12 @@ OnnxSupportedOps = [ # 'Hardmax', 'Identity', # 'If', + 'IfDefined', 'ImageScaler', # 'InstanceNormalization', # 'LRN', - # 'LSTM', + 'LSTM', + # 'LstmNonlinear', 'LeakyRelu', # 'Less', # 'Log', @@ -109,11 +114,15 @@ OnnxSupportedOps = [ 'Mul', # 'Multinomial', 'Neg', + 'Normalize', # 'Not', + 'Offset', # 'OneHot', # 'Or', 'PRelu', - 'Pad', + # 'Pad', + 'Padding', + 'PNorm', 'Pow', # 'RNN', # 'RandomNormal', @@ -133,6 +142,7 @@ OnnxSupportedOps = [ # 'ReduceSumSquare', 'Relu', 'Reshape', + 'Scale', # 'Scan', # 'Selu', 'Shape', @@ -140,18 +150,21 @@ OnnxSupportedOps = [ # 'Sin', # 'Sinh', # 'Size', - # 'Slice', + 'Slice', 'Softmax', # 'Softplus', # 'Softsign', 'SpaceToDepth', + 'Splice', 'Split', 'Sqrt', 'Squeeze', 'Sub', 'Sum', + 'SumGroup', # 'Tan', 'Tanh', + 'TargetRMSNorm', # 'Tile', # 'TopK', 'Transpose', @@ -188,7 +201,7 @@ def convert_onnx_attribute_proto(attr_proto): return attr_proto.i elif attr_proto.HasField('s'): return str(attr_proto.s, 'utf-8')\ - if sys.version_info.major == 3 else attr_proto.s + if IS_PYTHON3 else attr_proto.s elif attr_proto.HasField('t'): return attr_proto.t # this is a proto! elif attr_proto.floats: @@ -273,6 +286,7 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.Equal.name: EltwiseType.EQUAL, OnnxOpType.Sqrt.name: EltwiseType.POW, OnnxOpType.Reciprocal.name: EltwiseType.POW, + OnnxOpType.Scale.name: EltwiseType.PROD, } reduce_type = { @@ -296,6 +310,8 @@ class OnnxConverter(base_converter.ConverterInterface): self._op_converters = { OnnxOpType.Abs.name: self.convert_eltwise, OnnxOpType.Add.name: self.convert_eltwise, + OnnxOpType.Affine.name: self.convert_affine, + OnnxOpType.Append.name: self.convert_concat, OnnxOpType.ArgMax.name: self.convert_argmax, OnnxOpType.ArgMin.name: self.convert_argmax, OnnxOpType.AveragePool.name: self.convert_pooling, @@ -306,6 +322,7 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.ConvTranspose.name: self.convert_deconv, OnnxOpType.DepthToSpace.name: self.convert_depth_space, OnnxOpType.Dropout.name: self.convert_identity, + OnnxOpType.DimRange.name: self.convert_dim_range, OnnxOpType.Div.name: self.convert_eltwise, OnnxOpType.Equal.name: self.convert_eltwise, OnnxOpType.Gather.name: self.convert_gather, @@ -313,47 +330,71 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.GlobalAveragePool.name: self.convert_reduce, OnnxOpType.GlobalMaxPool.name: self.convert_reduce, OnnxOpType.Identity.name: self.convert_identity, + OnnxOpType.IfDefined.name: self.convert_identity, OnnxOpType.ImageScaler.name: self.convert_imagescaler, OnnxOpType.LeakyRelu.name: self.convert_activation, + # OnnxOpType.LogSoftmax.name: self.convert_softmax, + OnnxOpType.LSTM.name: self.convert_lstm, + # OnnxOpType.LstmNonlinear.name: self.convert_lstm_nonlinear, OnnxOpType.Max.name: self.convert_eltwise, OnnxOpType.MaxPool.name: self.convert_pooling, OnnxOpType.MatMul.name: self.convert_matmul, OnnxOpType.Min.name: self.convert_eltwise, OnnxOpType.Mul.name: self.convert_eltwise, OnnxOpType.Neg.name: self.convert_eltwise, - OnnxOpType.Pad.name: self.convert_pad, + OnnxOpType.Normalize: self.convert_normalize, + OnnxOpType.Offset.name: self.convert_timeoffset, + OnnxOpType.Padding.name: self.convert_identity, + OnnxOpType.PNorm.name: self.convert_pnorm, OnnxOpType.Pow.name: self.convert_eltwise, OnnxOpType.PRelu.name: self.convert_activation, OnnxOpType.Relu.name: self.convert_activation, OnnxOpType.Reshape.name: self.convert_reshape, OnnxOpType.Reciprocal.name: self.convert_eltwise, + OnnxOpType.Scale.name: self.convert_eltwise, OnnxOpType.Sigmoid.name: self.convert_activation, + OnnxOpType.Slice.name: self.convert_slice, OnnxOpType.Softmax.name: self.convert_softmax, OnnxOpType.SpaceToDepth.name: self.convert_depth_space, + OnnxOpType.Splice.name: self.convert_splice, OnnxOpType.Split.name: self.convert_split, OnnxOpType.Sqrt.name: self.convert_eltwise, OnnxOpType.Squeeze.name: self.convert_squeeze, OnnxOpType.Sub.name: self.convert_eltwise, OnnxOpType.Sum.name: self.convert_eltwise, + OnnxOpType.SumGroup.name: self.convert_sum_group, OnnxOpType.Tanh.name: self.convert_activation, + OnnxOpType.TargetRMSNorm: self.convert_target_rms_norm, OnnxOpType.Transpose.name: self.convert_transpose, } self._option = option self._mace_net_def = mace_pb2.NetDef() + self._data_format = DataFormat.NCHW ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.OIHW) onnx_model = onnx.load(src_model_file) - polished_model = onnx.utils.polish_model(onnx_model) - - print "onnx model IR version: ", onnx_model.ir_version - print "onnx model opset import: ", onnx_model.opset_import - - self._onnx_model = shape_inference.infer_shapes(polished_model) + ir_version = onnx_model.ir_version + opset_imp = onnx_model.opset_import + + polish_available = True + print "onnx model IR version: ", ir_version + for imp in opset_imp: + domain = imp.domain + version = imp.version + print "constains ops domain: ", domain, "version:", version + if 'kaldi2onnx' in domain: + polish_available = False + self._data_format = DataFormat.DF_NONE + if polish_available: + onnx_model = onnx.utils.polish_model(onnx_model) + + self._onnx_model = onnx_model self._graph_shapes_dict = {} self._consts = {} self._replace_tensors = {} - def print_graph_info(self, graph): + @staticmethod + def print_graph_info(graph): for value_info in graph.value_info: print "value info:", value_info for value_info in graph.input: @@ -368,12 +409,12 @@ class OnnxConverter(base_converter.ConverterInterface): if t: shape_dict[value_info.name] = t - for value_info in graph.value_info: - extract_value_info(self._graph_shapes_dict, value_info) - for value_info in graph.input: - extract_value_info(self._graph_shapes_dict, value_info) - for value_info in graph.output: - extract_value_info(self._graph_shapes_dict, value_info) + for vi in graph.value_info: + extract_value_info(self._graph_shapes_dict, vi) + for vi in graph.input: + extract_value_info(self._graph_shapes_dict, vi) + for vi in graph.output: + extract_value_info(self._graph_shapes_dict, vi) def add_tensor(self, name, shape, data_type, value): tensor = self._mace_net_def.tensors.add() @@ -387,11 +428,6 @@ class OnnxConverter(base_converter.ConverterInterface): self.extract_shape_info(graph_def) self.convert_tensors(graph_def) self.convert_ops(graph_def) - # self.print_graph_info(graph_def) - # shape_inferer = mace_shape_inference.ShapeInference( - # self._mace_net_def, - # self._option.input_nodes.values()) - # shape_inferer.run() return self._mace_net_def def add_stride_pad_kernel_arg(self, attrs, op_def): @@ -435,6 +471,32 @@ class OnnxConverter(base_converter.ConverterInterface): padding_arg.name = MaceKeyword.mace_padding_values_str padding_arg.ints.extend(pad) + def remove_node(self, node): + input_name = node.inputs[0] + output_name = node.outputs[0] + self._replace_tensors[output_name] = input_name + + @staticmethod + def squeeze_shape(shape, axis): + new_shape = [] + if len(axis) > 0: + for i in range(len(shape)): + if i not in axis: + new_shape.append(shape[i]) + else: + new_shape = shape + return new_shape + + @staticmethod + def transpose_const(tensor): + shape = tensor.dims + mace_check(len(shape) == 2, "gemm only supports 2-dim input.") + tensor_data = np.array(tensor.float_data).reshape( + shape[0], shape[1]) + tensor_data = tensor_data.transpose(1, 0) + tensor.float_data[:] = tensor_data.flat + tensor.dims[:] = tensor_data.shape + def convert_ops(self, graph_def): for n in graph_def.node: node = OnnxNode(n) @@ -471,7 +533,7 @@ class OnnxConverter(base_converter.ConverterInterface): "Not supported tensor type: %s" % data_type) self._consts[tensor.name] = tensor - def convert_general_op(self, node): + def convert_general_op(self, node, with_shape=True): op = self._mace_net_def.op.add() op.name = node.name @@ -481,9 +543,11 @@ class OnnxConverter(base_converter.ConverterInterface): op.input.append(input) for output in node.outputs: op.output.append(output) - output_shape = op.output_shape.add() - shape_info = self._graph_shapes_dict[output] - output_shape.dims.extend(shape_info) + if with_shape: + if output in self._graph_shapes_dict: + output_shape = op.output_shape.add() + shape_info = self._graph_shapes_dict[output] + output_shape.dims.extend(shape_info) data_type_arg = op.arg.add() data_type_arg.name = 'T' @@ -493,91 +557,9 @@ class OnnxConverter(base_converter.ConverterInterface): framework_type_arg.name = MaceKeyword.mace_framework_type_str framework_type_arg.i = FrameworkType.ONNX.value - ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) + ConverterUtil.add_data_format_arg(op, self._data_format) return op - def convert_fused_batchnorm(self, node): - op = self.convert_general_op(node) - op.type = MaceOp.BatchNorm.name - - if "epsilon" in node.attrs: - epsilon_value = node.attrs["epsilon"] - else: - epsilon_value = 1e-5 - - mace_check(len(node.inputs) == 5, "batch norm should have 5 inputs.") - - gamma_value = np.array(self._consts[node.inputs[1]].float_data) - beta_value = np.array(self._consts[node.inputs[2]].float_data) - mean_value = np.array(self._consts[node.inputs[3]].float_data) - var_value = np.array(self._consts[node.inputs[4]].float_data) - - scale_name = node.name + 'scale' - offset_name = node.name + 'offset' - scale_value = ( - (1.0 / np.sqrt( - var_value + epsilon_value)) * gamma_value) - offset_value = (-mean_value * scale_value) + beta_value - self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT, - scale_value) - self.add_tensor(offset_name, offset_value.shape, mace_pb2.DT_FLOAT, - offset_value) - del op.input[1:] - op.input.extend([scale_name, offset_name]) - del op.output[1:] - del op.output_shape[1:] - - def convert_conv2d(self, node): - op = self.convert_general_op(node) - self.add_stride_pad_kernel_arg(node.attrs, op) - group_arg = op.arg.add() - group_arg.name = MaceKeyword.mace_group_str - if 'group' in node.attrs: - group_val = node.attrs["group"] - else: - group_val = 1 - group_arg.i = group_val - - is_depthwise = False - if group_val > 1: - filter_shape = self._graph_shapes_dict[node.inputs[1]] - mace_check(group_val == filter_shape[0] and - filter_shape[1] == 1, - "Mace does not support group convolution yet") - filter_tensor = self._consts[node.inputs[1]] - new_shape = [filter_shape[1], filter_shape[0], - filter_shape[2], filter_shape[3]] - del filter_tensor.dims[:] - filter_tensor.dims.extend(new_shape) - is_depthwise = True - if is_depthwise: - op.type = MaceOp.DepthwiseConv2d.name - else: - op.type = MaceOp.Conv2D.name - - dilation_arg = op.arg.add() - dilation_arg.name = MaceKeyword.mace_dilations_str - if 'dilations' in node.attrs: - dilation_val = node.attrs["dilations"] - else: - dilation_val = [1, 1] - dilation_arg.ints.extend(dilation_val) - - def convert_biasadd(self, node): - self.convert_general_op(node) - - def convert_concat(self, node): - op = self.convert_general_op(node) - op.type = MaceOp.Concat.name - mace_check('axis' in node.attrs, - 'Concat op should have axis attribute.') - axis_arg = op.arg.add() - axis_arg.name = MaceKeyword.mace_axis_str - axis_arg.i = node.attrs['axis'] - axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i - mace_check(axis_arg.i == 1, - "only support concat at channel dimension") - def convert_activation(self, node): op = self.convert_general_op(node) op.type = MaceOp.Activation.name @@ -597,100 +579,12 @@ class OnnxConverter(base_converter.ConverterInterface): alpha_arg.name = MaceKeyword.mace_activation_max_limit_str alpha_arg.f = alpha_value - def convert_pooling(self, node): - op = self.convert_general_op(node) - - op.type = MaceOp.Pooling.name - self.add_stride_pad_kernel_arg(node.attrs, op) - pooling_type_arg = op.arg.add() - pooling_type_arg.name = MaceKeyword.mace_pooling_type_str - pooling_type_arg.i = self.pooling_type_mode[node.op_type].value - - round_mode_arg = op.arg.add() - round_mode_arg.name = MaceKeyword.mace_round_mode_str - round_mode_arg.i = RoundMode.FLOOR.value - - def convert_reshape(self, node): - op = self.convert_general_op(node) - op.type = MaceOp.Reshape.name - - def convert_flatten(self, node): - op = self.convert_general_op(node) - op.type = MaceOp.Reshape.name - - def remove_node(self, node): - input_name = node.inputs[0] - output_name = node.outputs[0] - self._replace_tensors[output_name] = input_name - - def convert_eltwise(self, node): - op = self.convert_general_op(node) - op.type = MaceOp.Eltwise.name - type_arg = op.arg.add() - type_arg.name = MaceKeyword.mace_element_type_str - type_arg.i = self.eltwise_type[node.op_type].value - - if node.op_type == OnnxOpType.Sqrt.name: - value_arg = op.arg.add() - value_arg.name = MaceKeyword.mace_scalar_input_str - value_arg.f = 0.5 - elif node.op_type == OnnxOpType.Reciprocal.name: - value_arg = op.arg.add() - value_arg.name = MaceKeyword.mace_scalar_input_str - value_arg.f = -1 - - def convert_reduce(self, node): - op = self.convert_general_op(node) - op.type = MaceOp.Reduce.name - - reduce_type_arg = op.arg.add() - reduce_type_arg.name = MaceKeyword.mace_reduce_type_str - reduce_type_arg.i = self.reduce_type[node.op_type].value - - if node.op_type in [OnnxOpType.GlobalAveragePool.name, - OnnxOpType.GlobalMaxPool.name]: - reduce_dims = [2, 3] - keep_dims = 1 - else: - if 'axes' in node.attrs: - reduce_dims = node.attrs['axes'] - else: - reduce_dims = [] - if 'keepdims' in node.attrs: - keep_dims = node.attrs['keepdims'] - else: - keep_dims = 1 - axis_arg = op.arg.add() - axis_arg.name = MaceKeyword.mace_axis_str - axis_arg.ints.extend(reduce_dims) - - keep_dims_arg = op.arg.add() - keep_dims_arg.name = MaceKeyword.mace_keepdims_str - keep_dims_arg.i = keep_dims - - def convert_imagescaler(self, node): - op = self.convert_general_op(node) - op.type = MaceOp.BatchNorm.name - - scale = node.attrs['scale'] - bias_value = np.array(node.attrs['bias']) - scale_value = scale * np.ones_like(bias_value) - - scale_name = node.name + "_scale" - bias_name = node.name + "_bias" - self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT, - scale_value) - self.add_tensor(bias_name, bias_value.shape, mace_pb2.DT_FLOAT, - bias_value) - op.input.extend([scale_name, bias_name]) - - def convert_matmul(self, node): + def convert_affine(self, node): op = self.convert_general_op(node) op.type = MaceOp.MatMul.name - - def convert_softmax(self, node): - op = self.convert_general_op(node) - op.type = MaceOp.Softmax.name + transpose_b_arg = op.arg.add() + transpose_b_arg.name = MaceKeyword.mace_transpose_b_str + transpose_b_arg.i = 1 def convert_argmax(self, node): op = self.convert_general_op(node) @@ -717,6 +611,10 @@ class OnnxConverter(base_converter.ConverterInterface): min_arg.name = MaceKeyword.mace_argmin_str min_arg.i = 1 + def convert_biasadd(self, node): + self.convert_general_op(node) + op.type = MaceOp.BiasAdd.name + def convert_cast(self, node): op = self.convert_general_op(node) op.type = MaceOp.Cast.name @@ -732,41 +630,49 @@ class OnnxConverter(base_converter.ConverterInterface): else: op.output_type.extend([self._option.data_type]) - def convert_depth_space(self, node): + def convert_concat(self, node): op = self.convert_general_op(node) - if op.type == OnnxOpType.DepthToSpace.name: - op.type = MaceOp.DepthToSpace.name - else: - op.type = MaceOp.SpaceToDepth.name - mace_check(('block_size' in node.attrs), - "depth to space op should have block size attribute.") - block_size = node.attrs['block_size'] - size_arg = op.arg.add() - size_arg.name = MaceKeyword.mace_space_depth_block_size_str - size_arg.i = block_size + op.type = MaceOp.Concat.name + axis_value = 1 + if node.op_type == OnnxOpType.Concat.name: + mace_check('axis' in node.attrs, + 'Concat op should have axis attribute.') + axis_value = node.attrs['axis'] + mace_check(axis_value == 1 or axis_value == -3, + "only support concat at channel dimension") + elif node.op_type == OnnxOpType.Append.name: + axis_value = 2 + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.i = 4 + axis_value if axis_value < 0 else axis_value - def convert_deconv(self, node): + def convert_conv2d(self, node): op = self.convert_general_op(node) - self.add_stride_pad_kernel_arg(node.attrs, op) - + group_arg = op.arg.add() + group_arg.name = MaceKeyword.mace_group_str if 'group' in node.attrs: group_val = node.attrs["group"] else: group_val = 1 + group_arg.i = group_val + + is_depthwise = False if group_val > 1: - op.type = MaceOp.DepthwiseDeconv2d.name filter_shape = self._graph_shapes_dict[node.inputs[1]] + mace_check(group_val == filter_shape[0] and + filter_shape[1] == 1, + "Mace does not support group convolution yet") filter_tensor = self._consts[node.inputs[1]] new_shape = [filter_shape[1], filter_shape[0], filter_shape[2], filter_shape[3]] del filter_tensor.dims[:] filter_tensor.dims.extend(new_shape) + is_depthwise = True + if is_depthwise: + op.type = MaceOp.DepthwiseConv2d.name else: - op.type = MaceOp.Deconv2D.name - group_arg = op.arg.add() - group_arg.name = MaceKeyword.mace_group_str - group_arg.i = group_val + op.type = MaceOp.Conv2D.name dilation_arg = op.arg.add() dilation_arg.name = MaceKeyword.mace_dilations_str @@ -775,16 +681,47 @@ class OnnxConverter(base_converter.ConverterInterface): else: dilation_val = [1, 1] dilation_arg.ints.extend(dilation_val) - mace_check(dilation_val == [1, 1], - "not support convtranspose with dilation != 1 yet.") - mace_check('output_padding' not in node.attrs, - "not support convtranspose with output_padding yet.") - mace_check('output_shape' not in node.attrs, - "not support convtranspose with output_shape yet.") - # TODO: if output shape specified, calculate padding value - # if 'output_padding' in node.attrs: - # output_padding = node.attrs['output_padding'] + def convert_deconv(self, node): + op = self.convert_general_op(node) + + self.add_stride_pad_kernel_arg(node.attrs, op) + + if 'group' in node.attrs: + group_val = node.attrs["group"] + else: + group_val = 1 + if group_val > 1: + op.type = MaceOp.DepthwiseDeconv2d.name + filter_shape = self._graph_shapes_dict[node.inputs[1]] + filter_tensor = self._consts[node.inputs[1]] + new_shape = [filter_shape[1], filter_shape[0], + filter_shape[2], filter_shape[3]] + del filter_tensor.dims[:] + filter_tensor.dims.extend(new_shape) + else: + op.type = MaceOp.Deconv2D.name + group_arg = op.arg.add() + group_arg.name = MaceKeyword.mace_group_str + group_arg.i = group_val + + dilation_arg = op.arg.add() + dilation_arg.name = MaceKeyword.mace_dilations_str + if 'dilations' in node.attrs: + dilation_val = node.attrs["dilations"] + else: + dilation_val = [1, 1] + dilation_arg.ints.extend(dilation_val) + mace_check(dilation_val == [1, 1], + "not support convtranspose with dilation != 1 yet.") + + mace_check('output_padding' not in node.attrs, + "not support convtranspose with output_padding yet.") + mace_check('output_shape' not in node.attrs, + "not support convtranspose with output_shape yet.") + # TODO: if output shape specified, calculate padding value + # if 'output_padding' in node.attrs: + # output_padding = node.attrs['output_padding'] # output_padding_arg = op.arg.add() # output_padding_arg.name = MaceKeyword.mace_output_padding_str # output_padding_arg.ints.extend(output_padding) @@ -794,43 +731,98 @@ class OnnxConverter(base_converter.ConverterInterface): # output_shape_arg.name = MaceKeyword.mace_output_shape_str # output_shape_arg.ints.extend(output_shape) - def convert_nop(self, node): - pass + def convert_depth_space(self, node): + op = self.convert_general_op(node) + if op.type == OnnxOpType.DepthToSpace.name: + op.type = MaceOp.DepthToSpace.name + else: + op.type = MaceOp.SpaceToDepth.name + mace_check(('block_size' in node.attrs), + "depth to space op should have block size attribute.") + block_size = node.attrs['block_size'] + size_arg = op.arg.add() + size_arg.name = MaceKeyword.mace_space_depth_block_size_str + size_arg.i = block_size - def convert_identity(self, node): + def convert_dim_range(self, node): op = self.convert_general_op(node) - op.type = MaceOp.Identity.name + op.type = MaceOp.Slice.name + + mace_check('offset' in node.attrs, + "Attribute dim required!") + mace_check('output_dim' in node.attrs, + "Attribute output_dim required!") + offset = node.attrs['offset'] + starts_arg = op.arg.add() + starts_arg.name = 'starts' + starts_arg.ints.append(offset) + output_dim = node.attrs['output_dim'] + ends_arg = op.arg.add() + ends_arg.name = 'output_dim' + ends_arg.ints.append(output_dim) + axes_arg = op.arg.add() + axes_arg.name = 'axes' + axes_arg.ints.append(-1) - def convert_pad(self, node): + def convert_eltwise(self, node): op = self.convert_general_op(node) - op.type = MaceOp.Pad.name + op.type = MaceOp.Eltwise.name + type_arg = op.arg.add() + type_arg.name = MaceKeyword.mace_element_type_str + type_arg.i = self.eltwise_type[node.op_type].value - if 'pads' in node.attrs: - paddings_arg = op.arg.add() - paddings_arg.name = MaceKeyword.mace_paddings_str - paddings_value = node.attrs['pads'] - paddings_arg.ints.extend(paddings_value) + if node.op_type == OnnxOpType.Sqrt.name: + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_scalar_input_str + value_arg.f = 0.5 + elif node.op_type == OnnxOpType.Reciprocal.name: + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_scalar_input_str + value_arg.f = -1 + elif node.op_type == OnnxOpType.Scale.name and 'scale' in node.attrs: + value = node.attrs['scale'] + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_scalar_input_str + value_arg.f = value - if 'value' in node.attrs: - constant_value_arg = op.arg.add() - constant_value_arg.name = MaceKeyword.mace_constant_value_str - constant_value_arg.i = node.attrs['value'] + def convert_flatten(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Reshape.name - def convert_gather(self, node): + def convert_fused_batchnorm(self, node): op = self.convert_general_op(node) - op.type = MaceOp.Gather.name + op.type = MaceOp.BatchNorm.name - if 'axis' in node.attrs: - value = node.attrs['axis'] + if "epsilon" in node.attrs: + epsilon_value = node.attrs["epsilon"] else: - value = 0 - axis_arg = op.arg.add() - axis_arg.name = MaceKeyword.mace_axis_str - axis_arg.i = value + epsilon_value = 1e-5 - def convert_split(self, node): + mace_check(len(node.inputs) == 5, "batch norm should have 5 inputs.") + + gamma_value = np.array(self._consts[node.inputs[1]].float_data) + beta_value = np.array(self._consts[node.inputs[2]].float_data) + mean_value = np.array(self._consts[node.inputs[3]].float_data) + var_value = np.array(self._consts[node.inputs[4]].float_data) + + scale_name = node.name + 'scale' + offset_name = node.name + 'offset' + scale_value = ( + (1.0 / np.sqrt( + var_value + epsilon_value)) * gamma_value) + offset_value = (-mean_value * scale_value) + beta_value + self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT, + scale_value) + self.add_tensor(offset_name, offset_value.shape, mace_pb2.DT_FLOAT, + offset_value) + del op.input[1:] + op.input.extend([scale_name, offset_name]) + del op.output[1:] + del op.output_shape[1:] + + def convert_gather(self, node): op = self.convert_general_op(node) - op.type = MaceOp.Split.name + op.type = MaceOp.Gather.name if 'axis' in node.attrs: value = node.attrs['axis'] @@ -840,64 +832,6 @@ class OnnxConverter(base_converter.ConverterInterface): axis_arg.name = MaceKeyword.mace_axis_str axis_arg.i = value - def convert_transpose(self, node): - op = self.convert_general_op(node) - op.type = MaceOp.Transpose.name - - if np.array_equal(perm, ordered_perm): - op.type = MaceOp.Identity.name - del op.input[1:] - if 'perm' in node.attrs: - perm = node.attrs['perm'] - ordered_perm = np.sort(perm) - if np.array_equal(perm, ordered_perm): - op.type = MaceOp.Identity.name - else: - dims_arg = op.arg.add() - dims_arg.name = MaceKeyword.mace_dims_str - dims_arg.ints.extend(perm) - - @staticmethod - def squeeze_shape(shape, axis): - new_shape = [] - if len(axis) > 0: - for i in range(len(shape)): - if i not in axis: - new_shape.append(shape[i]) - else: - new_shape = shape - return new_shape - - def convert_squeeze(self, node): - axis_value = node.attrs['axes'] - if node.inputs[0] in self._consts: - tensor = self._consts[node.inputs[0]] - shape = tensor.dims - new_shape = self.squeeze_shape(shape, axis_value) - del tensor.dims[:] - tensor.dims.extend(new_shape) - self.remove_node(node) - else: - op = self.convert_general_op(node) - op.type = MaceOp.Squeeze.name - axis_arg = op.arg.add() - axis_arg.name = MaceKeyword.mace_axis_str - if 'axis' in node.attrs: - axis_value = node.attrs['axis'] - else: - axis_value = [] - axis_arg.ints.extend(axis_value) - - @staticmethod - def transpose_const(tensor): - shape = tensor.dims - mace_check(len(shape) == 2, "gemm only supports 2-dim input.") - tensor_data = np.array(tensor.float_data).reshape( - shape[0], shape[1]) - tensor_data = tensor_data.transpose(1, 0) - tensor.float_data[:] = tensor_data.flat - tensor.dims[:] = tensor_data.shape - def convert_gemm(self, node): # only supports FullyConnected Style Gemm for now. trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0 @@ -915,7 +849,7 @@ class OnnxConverter(base_converter.ConverterInterface): elif len(shape_b) == 2: tensor_b = self._consts[node.inputs[1]] tensor_data = np.array(tensor_b.float_data).reshape( - shape_b[0], shape_b[1], 1, 1) + shape_b[0], shape_b[1], 1, 1) tensor_b.float_data[:] = tensor_data.flat tensor_b.dims[:] = tensor_data.shape else: @@ -949,4 +883,224 @@ class OnnxConverter(base_converter.ConverterInterface): shape_info = [shape_info[0], shape_info[1], 1, 1] output_shape.dims.extend(shape_info) - return op + def convert_identity(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Identity.name + + def convert_imagescaler(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.BatchNorm.name + + scale = node.attrs['scale'] + bias_value = np.array(node.attrs['bias']) + scale_value = scale * np.ones_like(bias_value) + + scale_name = node.name + "_scale" + bias_name = node.name + "_bias" + self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT, + scale_value) + self.add_tensor(bias_name, bias_value.shape, mace_pb2.DT_FLOAT, + bias_value) + op.input.extend([scale_name, bias_name]) + + def convert_lstm(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.LSTMCell.name + + def convert_lstm_nonlinear(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.LstmNonlinear.name + + def convert_matmul(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.MatMul.name + + def convert_nop(self, node): + pass + + def convert_normalize(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.BatchNorm.name + + def convert_pnorm(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.PNorm.name + if 'output_dim' in node.attrs: + output_dim_arg = op.arg.add() + output_dim_arg.name = 'output_dim' + output_dim_arg.i = node.attrs['output_dim'] + if 'p' in node.attrs: + p_value = node.attrs['p'] + mace_check((p_value >= 0) and (p_value <= 2), + "PNorm only supports p = 0, 1, 2") + p_arg = op.arg.add() + p_arg.name = 'p' + p_arg.i = p_value + + def convert_pooling(self, node): + op = self.convert_general_op(node) + + op.type = MaceOp.Pooling.name + self.add_stride_pad_kernel_arg(node.attrs, op) + pooling_type_arg = op.arg.add() + pooling_type_arg.name = MaceKeyword.mace_pooling_type_str + pooling_type_arg.i = self.pooling_type_mode[node.op_type].value + + round_mode_arg = op.arg.add() + round_mode_arg.name = MaceKeyword.mace_round_mode_str + round_mode_arg.i = RoundMode.FLOOR.value + + def convert_reduce(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Reduce.name + + reduce_type_arg = op.arg.add() + reduce_type_arg.name = MaceKeyword.mace_reduce_type_str + reduce_type_arg.i = self.reduce_type[node.op_type].value + + if node.op_type in [OnnxOpType.GlobalAveragePool.name, + OnnxOpType.GlobalMaxPool.name]: + reduce_dims = [2, 3] + keep_dims = 1 + else: + if 'axes' in node.attrs: + reduce_dims = node.attrs['axes'] + else: + reduce_dims = [] + if 'keepdims' in node.attrs: + keep_dims = node.attrs['keepdims'] + else: + keep_dims = 1 + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.ints.extend(reduce_dims) + + keep_dims_arg = op.arg.add() + keep_dims_arg.name = MaceKeyword.mace_keepdims_str + keep_dims_arg.i = keep_dims + + def convert_reshape(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Reshape.name + + def convert_slice(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Slice.name + + mace_check('starts' in node.attrs, "Attribute starts required!") + mace_check('ends' in node.attrs, "Attribute ends required!") + starts = node.attrs['starts'] + starts_arg = op.arg.add() + starts_arg.name = 'starts' + starts_arg.ints.extend(starts) + ends = node.attrs['ends'] + ends_arg = op.arg.add() + ends_arg.name = 'ends' + ends_arg.ints.extend(ends) + if 'axes' in node.attrs: + axes = node.attrs['axes'] + axes_arg = op.arg.add() + axes_arg.name = 'axes' + axes_arg.ints.extend(axes) + + def convert_softmax(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Softmax.name + # TODO: add logsoftmax in softmax op + # if node.op_type == OnnxOpType.LogSoftmax.name: + # use_log_arg = op.arg.add() + # use_log_arg.name = 'use_log' + # use_log_arg.i = 1 + + def convert_splice(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Splice.name + if 'context' in node.attrs: + context = node.attrs['context'] + else: + context = [0] + context_arg = op.arg.add() + context_arg.name = 'context' + context_arg.ints.extend(context) + if 'const_component_dim' in node.attrs: + const_dim = node.attrs['const_component_dim'] + else: + const_dim = 0 + const_dim_arg = op.arg.add() + const_dim_arg.name = 'const_component_dim' + const_dim_arg.i = const_dim + + def convert_split(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Split.name + + if 'axis' in node.attrs: + value = node.attrs['axis'] + else: + value = 0 + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.i = value + + def convert_squeeze(self, node): + axis_value = node.attrs['axes'] + if node.inputs[0] in self._consts: + tensor = self._consts[node.inputs[0]] + shape = tensor.dims + new_shape = self.squeeze_shape(shape, axis_value) + del tensor.dims[:] + tensor.dims.extend(new_shape) + self.remove_node(node) + else: + op = self.convert_general_op(node) + op.type = MaceOp.Squeeze.name + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + if 'axis' in node.attrs: + axis_value = node.attrs['axis'] + else: + axis_value = [] + axis_arg.ints.extend(axis_value) + + def convert_sum_group(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.SumGroup.name + + def convert_target_rms_norm(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.TargetRMSNorm.name + + if 'target_rms' in node.attrs: + value = node.attrs['target_rms'] + target_rms_arg = op.arg.add() + target_rms_arg.name = 'target_rms' + target_rms_arg.f = value + + def convert_transpose(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Transpose.name + + if 'perm' in node.attrs: + perm = node.attrs['perm'] + ordered_perm = np.sort(perm) + if np.array_equal(perm, ordered_perm): + op.type = MaceOp.Identity.name + del op.input[1:] + else: + dims_arg = op.arg.add() + dims_arg.name = MaceKeyword.mace_dims_str + dims_arg.ints.extend(perm) + + def convert_timeoffset(self, node): + op = self.convert_general_op(node) + mace_check('offset' in node.attrs, + 'Offset attribute required in Offset Node.') + offset = node.attrs['offset'] + if offset == 0: + op.type = MaceOp.Identity.name + else: + op.type = MaceOp.TimeOffset.name + + offset_arg = op.arg.add() + offset_arg.name = 'offset' + offset_arg.i = offset diff --git a/mace/utils/utils.h b/mace/utils/utils.h index 0b1a6992..7ebdd221 100644 --- a/mace/utils/utils.h +++ b/mace/utils/utils.h @@ -15,7 +15,9 @@ #ifndef MACE_UTILS_UTILS_H_ #define MACE_UTILS_UTILS_H_ +#include #include +#include #include #include #include @@ -69,6 +71,32 @@ Integer CeilQuotient(Integer a, Integer b) { std::string ObfuscateString(const std::string &src, const std::string &lookup_table); +template +inline Integer Clamp(Integer in, Integer low, Integer high) { + return std::max(low, std::min(in, high)); +} + +template +inline T ScalarSigmoid(T in) { + if (in > static_cast(0)) { + return static_cast(1) / (static_cast(1) + std::exp(-in)); + } else { + T x = std::exp(in); + return x / (x + static_cast(1)); + } +} + +template +inline T ScalarTanh(T in) { + if (in > static_cast(0)) { + T inv_expa = std::exp(-in); + return -static_cast(1) + + static_cast(2) / (static_cast(1) + inv_expa * inv_expa); + } else { + T x = std::exp(in); + return x / (x + static_cast(1)); + } +} std::string ObfuscateString(const std::string &src); diff --git a/tools/common.py b/tools/common.py index 8e69ed8e..6aa7d632 100644 --- a/tools/common.py +++ b/tools/common.py @@ -401,6 +401,7 @@ class YAMLKeyword(object): graph_optimize_options = 'graph_optimize_options' # internal use for now cl_mem_type = 'cl_mem_type' backend = 'backend' + validation_outputs_data = 'validation_outputs_data' docker_image_tag = 'docker_image_tag' dockerfile_path = 'dockerfile_path' dockerfile_sha256_checksum = 'dockerfile_sha256_checksum' diff --git a/tools/converter.py b/tools/converter.py index b3a65696..36e90c81 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -476,6 +476,14 @@ def format_model_config(flags): onnx_backend = subgraph.get( YAMLKeyword.backend, "tensorflow") subgraph[YAMLKeyword.backend] = onnx_backend + validation_outputs_data = subgraph.get( + YAMLKeyword.validation_outputs_data, []) + if not isinstance(validation_outputs_data, list): + subgraph[YAMLKeyword.validation_outputs_data] = [ + validation_outputs_data] + else: + subgraph[YAMLKeyword.validation_outputs_data] = \ + validation_outputs_data input_ranges = subgraph.get( YAMLKeyword.input_ranges, []) if not isinstance(input_ranges, list): diff --git a/tools/device.py b/tools/device.py index 07e92878..5bc788f5 100644 --- a/tools/device.py +++ b/tools/device.py @@ -660,6 +660,8 @@ class DeviceWrapper: YAMLKeyword.validation_threshold][ validate_type], backend=subgraphs[0][YAMLKeyword.backend], + validation_outputs_data=subgraphs[0][ + YAMLKeyword.validation_outputs_data], log_file=log_file, ) if flags.report and flags.round > 0: diff --git a/tools/sh_commands.py b/tools/sh_commands.py index 035348ff..399bb3d4 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -656,9 +656,12 @@ def validate_model(abi, output_file_name="model_out", validation_threshold=0.9, backend="tensorflow", - log_file="", - ): - six.print_("* Validate with %s" % platform) + validation_outputs_data=[], + log_file=""): + if not validation_outputs_data: + six.print_("* Validate with %s" % platform) + else: + six.print_("* Validate with file: %s" % validation_outputs_data) if abi != "host": for output_name in output_nodes: formatted_name = common.formatted_file_name( @@ -675,6 +678,7 @@ def validate_model(abi, ":".join(input_shapes), ":".join(output_shapes), ",".join(input_nodes), ",".join(output_nodes), validation_threshold, ",".join(input_data_types), backend, + validation_outputs_data, log_file) elif platform == "onnx": validate(platform, model_file_path, "", @@ -683,6 +687,7 @@ def validate_model(abi, ":".join(input_shapes), ":".join(output_shapes), ",".join(input_nodes), ",".join(output_nodes), validation_threshold, ",".join(input_data_types), backend, + validation_outputs_data, log_file) elif platform == "caffe": image_name = "mace-caffe:" + docker_image_tag @@ -700,6 +705,7 @@ def validate_model(abi, ":".join(input_shapes), ":".join(output_shapes), ",".join(input_nodes), ",".join(output_nodes), validation_threshold, ",".join(input_data_types), backend, + validation_outputs_data, log_file) elif caffe_env == common.CaffeEnvType.DOCKER: docker_image_id = sh.docker("images", "-q", image_name) @@ -767,6 +773,8 @@ def validate_model(abi, "--validation_threshold=%f" % validation_threshold, "--input_data_type=%s" % ",".join(input_data_types), "--backend=%s" % ",".join(backend), + "--validation_outputs_data=%s" % ",".join( + validation_outputs_data), "--log_file=%s" % log_file, _fg=True) diff --git a/tools/validate.py b/tools/validate.py index 2ea8fed2..7b2703c4 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -18,6 +18,7 @@ import os import os.path import numpy as np import re +import six import common @@ -121,6 +122,32 @@ def normalize_tf_tensor_name(name): return name +def validate_with_file(platform, device_type, + output_names, output_shapes, + mace_out_file, validation_outputs_data, + validation_threshold, log_file): + for i in range(len(output_names)): + if validation_outputs_data[i].startswith("http://") or \ + validation_outputs_data[i].startswith("https://"): + validation_file_name = common.formatted_file_name( + mace_out_file, output_names[i] + '_validation') + six.moves.urllib.request.urlretrieve(validation_outputs_data[i], + validation_file_name) + else: + validation_file_name = validation_outputs_data[i] + value = load_data(validation_file_name) + out_shape = output_shapes[i] + if len(out_shape) == 4: + out_shape[1], out_shape[2], out_shape[3] = \ + out_shape[3], out_shape[1], out_shape[2] + value = value.reshape(out_shape).transpose((0, 2, 3, 1)) + output_file_name = common.formatted_file_name( + mace_out_file, output_names[i]) + mace_out_value = load_data(output_file_name) + compare_output(platform, device_type, output_names[i], mace_out_value, + value, validation_threshold, log_file) + + def validate_tf_model(platform, device_type, model_file, input_file, mace_out_file, input_names, input_shapes, output_names, validation_threshold, input_data_types, @@ -275,7 +302,8 @@ def validate_onnx_model(platform, device_type, model_file, input_file, def validate(platform, model_file, weight_file, input_file, mace_out_file, device_type, input_shape, output_shape, input_node, output_node, - validation_threshold, input_data_type, backend, log_file): + validation_threshold, input_data_type, backend, + validation_outputs_data, log_file): input_names = [name for name in input_node.split(',')] input_shape_strs = [shape for shape in input_shape.split(':')] input_shapes = [[int(x) for x in shape.split(',')] @@ -287,8 +315,21 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file, input_data_types = ['float32'] * len(input_names) output_names = [name for name in output_node.split(',')] assert len(input_names) == len(input_shapes) - - if platform == 'tensorflow': + if not isinstance(validation_outputs_data, list): + if os.path.isfile(validation_outputs_data): + validation_outputs = [validation_outputs_data] + else: + validation_outputs = [] + else: + validation_outputs = validation_outputs_data + if validation_outputs: + output_shape_strs = [shape for shape in output_shape.split(':')] + output_shapes = [[int(x) for x in shape.split(',')] + for shape in output_shape_strs] + validate_with_file(platform, device_type, output_names, output_shapes, + mace_out_file, validation_outputs, + validation_threshold, log_file) + elif platform == 'tensorflow': validate_tf_model(platform, device_type, model_file, input_file, mace_out_file, input_names, input_shapes, output_names, validation_threshold, input_data_types, @@ -358,10 +399,10 @@ def parse_args(): default="tensorflow", help="onnx backend framwork") parser.add_argument( - "--log_file", - type=str, - default="", - help="log file") + "--validation_outputs_data", type=str, + default="", help="validation outputs data file path.") + parser.add_argument( + "--log_file", type=str, default="", help="log file.") return parser.parse_known_args() @@ -381,4 +422,5 @@ if __name__ == '__main__': FLAGS.validation_threshold, FLAGS.input_data_type, FLAGS.backend, + FLAGS.validation_outputs_data, FLAGS.log_file) -- GitLab