From a7afda0ed7eb408e8d2585403924515d3e3605b0 Mon Sep 17 00:00:00 2001 From: liutuo Date: Thu, 21 Mar 2019 16:01:04 +0800 Subject: [PATCH] add kaldi op IfDefined add extract pooling and kaldi batchnorm add time index info in extract pooling add benchmar& test for kaldi components add batch support for kaldi change ifdefined to delay --- mace/ops/concat.cc | 3 +- mace/ops/delay.cc | 87 +++++ mace/ops/delay_benchmark.cc | 75 +++++ mace/ops/dynamic_lstm.cc | 10 +- mace/ops/extract_pooling.cc | 215 +++++++++++++ mace/ops/extract_pooling_benchmark.cc | 103 ++++++ mace/ops/extract_pooling_test.cc | 187 +++++++++++ mace/ops/kaldi_batch_norm.cc | 176 ++++++++++ mace/ops/kaldi_batch_norm_benchmark.cc | 94 ++++++ mace/ops/kaldi_batch_norm_test.cc | 136 ++++++++ mace/ops/ops_registry.cc | 6 + mace/ops/pad_context.cc | 2 +- mace/ops/reduce_test.cc | 300 ------------------ mace/ops/splice.cc | 5 +- mace/ops/sum_group.cc | 24 +- mace/ops/target_rms_norm.cc | 67 +++- .../tools/converter_tool/base_converter.py | 3 + .../tools/converter_tool/onnx_converter.py | 214 +++++++++++-- .../tools/converter_tool/transformer.py | 5 +- 19 files changed, 1359 insertions(+), 353 deletions(-) create mode 100644 mace/ops/delay.cc create mode 100644 mace/ops/delay_benchmark.cc create mode 100644 mace/ops/extract_pooling.cc create mode 100644 mace/ops/extract_pooling_benchmark.cc create mode 100644 mace/ops/extract_pooling_test.cc create mode 100644 mace/ops/kaldi_batch_norm.cc create mode 100644 mace/ops/kaldi_batch_norm_benchmark.cc create mode 100644 mace/ops/kaldi_batch_norm_test.cc diff --git a/mace/ops/concat.cc b/mace/ops/concat.cc index 1254c643..9fa45feb 100644 --- a/mace/ops/concat.cc +++ b/mace/ops/concat.cc @@ -86,7 +86,8 @@ class ConcatOp : public ConcatOpBase { continue; } MACE_CHECK(input->dim(j) == input0->dim(j), - "Dimensions of inputs should equal except axis."); + "Dimensions of inputs should equal except axis: ", + input->dim(j), "!=", input0->dim(j)); } outer_sizes[i] = input->size() / inner_size; output_shape[axis] += input->dim(axis); diff --git a/mace/ops/delay.cc b/mace/ops/delay.cc new file mode 100644 index 00000000..db99723d --- /dev/null +++ b/mace/ops/delay.cc @@ -0,0 +1,87 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This Op is for IfDefined descriptor in Kaldi. +// It defines time offset. +// If time index <= offset, using zeros as output. + +#include +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class DelayOp; + +template +class DelayOp : public Operation { + public: + explicit DelayOp(OpConstructContext *context) + : Operation(context), + offset_(Operation::GetOptionalArg("offset", 0)) {} + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(0); + Tensor *output = this->Output(0); + MACE_CHECK(offset_ < 0, "offset param should be negative."); + + index_t rank = input->dim_size(); + MACE_CHECK(rank >= 2, "input's rank should >= 2."); + const std::vector &input_shape = input->shape(); + const index_t batch = + std::accumulate(input_shape.begin(), input_shape.end() - 2, 1, + std::multiplies()); + const index_t chunk = input_shape[rank - 2]; + const index_t dim = input_shape[rank - 1]; + MACE_RETURN_IF_ERROR(output->ResizeLike(input)); + output->Clear(); + + if (chunk <= -offset_) + return MaceStatus::MACE_SUCCESS; + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + const T *input_data = input->data(); + T *output_data = output->mutable_data(); + utils::ThreadPool + &thread_pool = context->device()->cpu_runtime()->thread_pool(); + thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (index_t i = start0; i < end0; i += step0) { + for (index_t j = start1; j < end1; j += step1) { + memcpy(output_data + (i * chunk + j - offset_) * dim, + input_data + (i * chunk + j) * dim, + dim * sizeof(T)); + } + } + }, 0, batch, 1, 0, chunk + offset_, 1); + + return MaceStatus::MACE_SUCCESS; + } + + private: + int offset_; +}; + +void RegisterDelay(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "Delay", DelayOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/delay_benchmark.cc b/mace/ops/delay_benchmark.cc new file mode 100644 index 00000000..89218a3e --- /dev/null +++ b/mace/ops/delay_benchmark.cc @@ -0,0 +1,75 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +static void Delay(int iters, + int batch, + int chunk, + int dim, + int offset) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {batch, chunk, dim}); + + OpDefBuilder("Delay", "DelayTest") + .Input("Input") + .Output("Output") + .AddIntArg("offset", -offset) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.Run(); + } + net.Sync(); +} + +#define MACE_BM_DELAY_MACRO(N, C, D, OFFSET, TYPE, DEVICE) \ + static void MACE_BM_DELAY_##N##_##C##_##D##_##OFFSET##_##TYPE##_##DEVICE(\ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * D; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + Delay(iters, N, C, D, OFFSET); \ + } \ + MACE_BENCHMARK(MACE_BM_DELAY_##N##_##C##_##D##_##OFFSET##_##TYPE\ +##_##DEVICE) + +#define MACE_BM_DELAY(N, C, D, OFFSET) \ + MACE_BM_DELAY_MACRO(N, C, D, OFFSET, float, CPU); + +MACE_BM_DELAY(8, 40, 512, 2); +MACE_BM_DELAY(16, 80, 100, 3); +MACE_BM_DELAY(32, 60, 200, 5); + + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/dynamic_lstm.cc b/mace/ops/dynamic_lstm.cc index 7d7014d5..36d24ada 100644 --- a/mace/ops/dynamic_lstm.cc +++ b/mace/ops/dynamic_lstm.cc @@ -214,7 +214,7 @@ class DynamicLSTMOp : public Operation { Tensor *output = this->Output(OUTPUT); std::vector output_shape = input->shape(); - output_shape[1] = output_dim; + output_shape[input_rank - 1] = output_dim; MACE_RETURN_IF_ERROR(output->Resize(output_shape)); @@ -235,8 +235,10 @@ class DynamicLSTMOp : public Operation { affine_b_in.Clear(); affine_b_out.Clear(); for (int i = 0; i < chunk; ++i) { + const float *input_ptr = input_data + (b * chunk + i) * input_dim; + float *output_ptr = output_data + (b * chunk + i) * output_dim; // Append - memcpy(affine_a_in_data, input_data, input_dim * sizeof(float)); + memcpy(affine_a_in_data, input_ptr, input_dim * sizeof(float)); if (prev_out_idx >= 0) { memcpy(affine_a_in_data + input_dim, prev_out_data + prev_out_idx % out_buf_chunk * prev_out_dim_, @@ -283,7 +285,7 @@ class DynamicLSTMOp : public Operation { false, &affine_b_out); // Output - memcpy(output_data, + memcpy(output_ptr, affine_b_out_data, output_dim * sizeof(float)); // Update @@ -292,8 +294,6 @@ class DynamicLSTMOp : public Operation { prev_out_dim_, scale_, curr_out_ptr); - input_data += input_dim; - output_data += output_dim; prev_out_idx++; prev_cell_idx++; } diff --git a/mace/ops/extract_pooling.cc b/mace/ops/extract_pooling.cc new file mode 100644 index 00000000..3908ceaf --- /dev/null +++ b/mace/ops/extract_pooling.cc @@ -0,0 +1,215 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This Op is for fused StatisticsExtraction, StatisticsPooling and +// Round Components in Kaldi. +// This op is used to extract moving-average mean and standard-deviation +// statistics of input data. +// 'input_indexes' indicates which frames will be used for extract statistics. +// 'output_indexes' indicates which frames of outputs will be used to +// save statistics results. +// 'modulus' will be used for extent results to all frames. +// 'start_index' and 'end_index' indicate time indexes of output frames. +// 'forward_indexes' and 'count' were from precomputed index in kaldi. +// Reference to +// http://kaldi-asr.org/doc/nnet-general-component_8h_source.html#l00158 + +#include +#include + +#include "mace/core/operator.h" + + +namespace mace { +namespace ops { + +template +class ExtractPoolingOp; + +template +class ExtractPoolingOp : public Operation { + public: + explicit ExtractPoolingOp(OpConstructContext *context) + : Operation(context), + modulus_(Operation::GetOptionalArg("modulus", 1)), + include_variance_( + static_cast( + Operation::GetOptionalArg("include_variance", 0))), + num_log_count_( + Operation::GetOptionalArg("num_log_count", 0)), + variance_floor_( + Operation::GetOptionalArg("variance_floor", 1.0e-10)), + input_indexes_(Operation::GetRepeatedArgs("input_indexes")), + output_indexes_(Operation::GetRepeatedArgs("output_indexes")), + forward_indexes_(Operation::GetRepeatedArgs("forward_indexes")), + counts_(Operation::GetRepeatedArgs("counts")), + input_time_range_(Operation::GetRepeatedArgs("input_time_range")), + output_time_range_( + Operation::GetRepeatedArgs("output_time_range")) {} + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(0); + Tensor *output = this->Output(0); + + const std::vector &input_shape = input->shape(); + const index_t dim_size = input_shape.size(); + MACE_CHECK(dim_size >= 2, + "ExtractPooling only supports input dim size >= 2"); + MACE_CHECK(modulus_ >= 1, + "ExtractPooling's pooling size should be greater than zero."); + MACE_CHECK(input_time_range_.size() == 2 && output_time_range_.size() == 2 + && counts_.size() * 2 == forward_indexes_.size() + && counts_.size() == output_indexes_.size()); + int in_start_index = input_time_range_[0]; + int out_start_index = output_time_range_[0]; + int out_end_index = output_time_range_[1]; + MACE_CHECK(out_end_index >= out_start_index + && input_time_range_[1] >= input_time_range_[0], + "end index should be greater than start index."); + const index_t output_chunk = out_end_index - out_start_index + 1; + const index_t input_dim = input_shape[dim_size - 1]; + const index_t chunk = input_shape[dim_size - 2]; + MACE_CHECK(chunk == input_time_range_[1] - input_time_range_[0] + 1, + "input chunk should be equal to end - start + 1."); + const index_t batch = + std::accumulate(input_shape.begin(), input_shape.end() - 2, 1, + std::multiplies()); + + index_t output_dim = include_variance_ ? 2 * input_dim : input_dim; + output_dim += num_log_count_; + std::vector output_shape(input_shape); + output_shape[dim_size - 1] = output_dim; + output_shape[dim_size - 2] = output_chunk; + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + + const index_t num_input_indexes = input_indexes_.size(); + const index_t num_output_indexes = output_indexes_.size(); + MACE_CHECK(num_input_indexes > 0 && num_output_indexes > 0, + "ExtractPooling's input_indexes or output_indexes is empty."); + const index_t extract_out_size = PadAlignSize(output_dim * sizeof(float)); + ScratchBuffer *scratch = context->device()->scratch_buffer(); + scratch->Rewind(); + scratch->GrowSize(extract_out_size); + + Tensor extract_out(scratch->Scratch(extract_out_size), DT_FLOAT); + extract_out.Reshape({1, output_dim}); + extract_out.Clear(); + float *extract_out_data = extract_out.mutable_data(); + + Tensor::MappingGuard guard_input(input); + Tensor::MappingGuard guard_output(output); + const T *input_data = input->data(); + T *output_data = output->mutable_data(); + + utils::ThreadPool + &thread_pool = context->device()->cpu_runtime()->thread_pool(); + + for (index_t b = 0; b < batch; ++b) { + for (index_t i = 0; i < num_output_indexes; ++i) { + int start = forward_indexes_[2 * i]; + int end = forward_indexes_[2 * i + 1]; + float count = counts_[i]; + float mean_scale = 1.f / count; + float log_count = std::log(count); + thread_pool.Compute1D([=](index_t start0, + index_t end0, + index_t step0) { + for (index_t n = start0; n < end0; n += step0) { + extract_out_data[n] = log_count; + } + }, 0, num_log_count_, 1); + if (include_variance_) { + thread_pool.Compute1D([=](index_t start0, + index_t end0, + index_t step0) { + for (index_t d = start0; d < end0; d += step0) { + float mean = 0.f; + float variance = 0.f; + for (int t = start; t < end; ++t) { + index_t input_index = + (b * chunk + input_indexes_[t] - in_start_index) + * input_dim; + float x = input_data[input_index + d]; + mean += x; + variance += x * x; + } + mean *= mean_scale; + variance *= mean_scale; + extract_out_data[d + num_log_count_] = mean; + variance = variance - mean * mean; + extract_out_data[d + input_dim + num_log_count_] = + variance < variance_floor_ ? + std::sqrt(variance_floor_) : + std::sqrt(variance); + } + }, 0, input_dim, 1); + } else { + thread_pool.Compute1D([=](index_t start0, + index_t end0, + index_t step0) { + for (index_t d = start0; d < end0; d += step0) { + float mean = 0.f; + for (int t = start; t < end; ++t) { + index_t input_index = + (b * chunk + input_indexes_[t] - in_start_index) + * input_dim; + mean += input_data[input_index + d]; + } + extract_out_data[d + num_log_count_] = mean * mean_scale; + } + }, 0, input_dim, 1); + } + + int output_start = output_indexes_[i] < out_start_index ? + out_start_index : output_indexes_[i]; + int output_end = output_indexes_[i] + modulus_; + output_end = output_end > out_end_index ? + out_end_index + 1 : + output_end; + thread_pool.Compute1D([=](index_t start0, + index_t end0, + index_t step0) { + for (index_t idx = start0; idx < end0; idx += step0) { + memcpy(output_data + (b * output_chunk + idx - out_start_index) + * output_dim, + extract_out_data, output_dim * sizeof(float)); + } + }, output_start, output_end, 1); + } + } + + return MaceStatus::MACE_SUCCESS; + } + + private: + int modulus_; + bool include_variance_; + int num_log_count_; + float variance_floor_; + std::vector input_indexes_; + std::vector output_indexes_; + std::vector forward_indexes_; + std::vector counts_; + std::vector input_time_range_; + std::vector output_time_range_; +}; + +void RegisterExtractPooling(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "ExtractPooling", ExtractPoolingOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/extract_pooling_benchmark.cc b/mace/ops/extract_pooling_benchmark.cc new file mode 100644 index 00000000..69477727 --- /dev/null +++ b/mace/ops/extract_pooling_benchmark.cc @@ -0,0 +1,103 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +static void ExtractPooling(int iters, + int batch, + int chunk, + int dim, + int input_period, + int modulus) { + mace::testing::StopTiming(); + + OpsTestNet net; + + size_t num_input_indexes = static_cast(chunk / input_period); + std::vector input_indexes(num_input_indexes, 0); + + for (size_t i = 0; i < num_input_indexes; ++i) { + input_indexes[i] = static_cast(i * input_period); + } + + size_t num_output_indexes = static_cast(chunk / modulus); + std::vector output_indexes(num_output_indexes, 0); + std::vector forward_indexes(num_output_indexes * 2, 0); + std::vector counts(num_output_indexes, 0.f); + for (size_t i = 0; i < num_output_indexes; ++i) { + output_indexes[i] = static_cast(i * modulus); + forward_indexes[2 * i] = 0; + forward_indexes[2 * i + 1] = static_cast(num_input_indexes - 1); + counts[i] = static_cast(num_input_indexes); + } + + // Add input data + net.AddRandomInput("Input", {batch, chunk, dim}); + + OpDefBuilder("ExtractPooling", "ExtractPoolingTest") + .Input("Input") + .AddIntArg("modulus", modulus) + .AddIntArg("include_variance", 1) + .AddIntArg("num_log_counts", 1) + .AddIntsArg("input_indexes", input_indexes) + .AddIntsArg("output_indexes", output_indexes) + .AddIntsArg("forward_indexes", forward_indexes) + .AddFloatsArg("counts", counts) + .AddIntsArg("input_time_range", {0, chunk - 1}) + .AddIntsArg("output_time_range", {0, chunk - 1}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.Run(); + } + net.Sync(); +} + +#define MACE_BM_EXTRACTPOOLING_MACRO(N, C, D, INP, M, TYPE, DEVICE) \ + static void MACE_BM_EXTRACTPOOLING_##N##_##C##_##D##_##INP##_##M##_##TYPE##\ +_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * D; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + ExtractPooling(iters, N, C, D, INP, M); \ + } \ + MACE_BENCHMARK(MACE_BM_EXTRACTPOOLING_##N##_##C##_##D##_##INP##_##M##_##TYPE\ +##_##DEVICE) + +#define MACE_BM_EXTRACTPOOLING(N, C, D, INP, M) \ + MACE_BM_EXTRACTPOOLING_MACRO(N, C, D, INP, M, float, CPU); + +MACE_BM_EXTRACTPOOLING(8, 40, 512, 2, 4); +MACE_BM_EXTRACTPOOLING(16, 80, 100, 3, 9); +MACE_BM_EXTRACTPOOLING(32, 60, 200, 6, 18); + + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/extract_pooling_test.cc b/mace/ops/extract_pooling_test.cc new file mode 100644 index 00000000..c36e38d2 --- /dev/null +++ b/mace/ops/extract_pooling_test.cc @@ -0,0 +1,187 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class ExtractPoolingTest : public OpsTestBase {}; + +namespace { +template +void TestExtractPooling(const std::vector &input_shape, + const std::vector &input_value, + const int modulus, + const int num_log_count, + const int include_variance, + const std::vector &input_time_range, + const std::vector &input_indexes, + const std::vector &forward_indexes, + const std::vector &counts, + const std::vector &output_indexes, + const std::vector &output_time_range, + const std::vector &output_shape, + const std::vector &output_value) { + // Construct graph + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input_value); + OpDefBuilder("ExtractPooling", "ExtractPoolingTest") + .Input("Input") + .AddIntArg("modulus", modulus) + .AddIntArg("include_variance", include_variance) + .AddIntArg("num_log_count", num_log_count) + .AddIntsArg("input_indexes", input_indexes) + .AddIntsArg("output_indexes", output_indexes) + .AddIntsArg("forward_indexes", forward_indexes) + .AddFloatsArg("counts", counts) + .AddIntsArg("input_time_range", input_time_range) + .AddIntsArg("output_time_range", output_time_range) + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(); + // Check + auto expected = net.CreateTensor(output_shape, output_value); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} +} // namespace + +TEST_F(ExtractPoolingTest, SimpleCPU) { + TestExtractPooling( + {3, 20, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}, + 9, 0, 0, + {-2, 17}, + {0, 3, 6, 9, 12, 15}, + {0, 6, 2, 6}, + {6, 4}, + {0, 9}, + {0, 17}, + {3, 18, 3}, + {29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, + 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, + 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, + 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, + 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, + 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, + 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, + 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, + 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, + 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, + 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, + 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, + 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, + 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, + 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, + 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, + 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, + 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5}); +} + +TEST_F(ExtractPoolingTest, SimpleCPUWithVariance) { +TestExtractPooling( + {3, 20, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}, + 9, 1, 1, + {-2, 17}, + {0, 3, 6, 9, 12, 15}, + {0, 6, 2, 6}, + {6, 4}, + {0, 9}, + {0, 17}, + {3, 18, 7}, + {1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, + 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/kaldi_batch_norm.cc b/mace/ops/kaldi_batch_norm.cc new file mode 100644 index 00000000..61c0340c --- /dev/null +++ b/mace/ops/kaldi_batch_norm.cc @@ -0,0 +1,176 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This Op is for Kaldi's BatchNormComponent +// More details about forward computation are here: +// http://kaldi-asr.org/doc/nnet-normalize-component_8cc_source.html#l00320 +#include +#include +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class KaldiBatchNormOp; + +template <> +class KaldiBatchNormOp : public Operation { + public: + explicit KaldiBatchNormOp(OpConstructContext *context) + : Operation(context), + epsilon_(Operation::GetOptionalArg("epsilon", + static_cast(1e-3))), + target_rms_(Operation::GetOptionalArg("target_rms", 1.0f)), + block_dim_(Operation::GetOptionalArg("block_dim", -1)), + test_mode_(static_cast( + Operation::GetOptionalArg("test_mode", 0))) {} + + void CalculateMeanVar(const float *input_data, + index_t length, + index_t stride, + float mean_scale, + float var_scale, + float *mean_data, + float *var_data) { + float mean_value = 0.f; + float var_value = 0.f; + for (index_t i = 0; i < length; ++i) { + float x = input_data[i * stride]; + mean_value += x; + var_value += x * x; + } + mean_value = mean_value * mean_scale; + var_value = var_value * mean_scale; + float mean_sqr = mean_value * mean_value; + var_value = (var_value > mean_sqr) ? + var_scale * (var_value - mean_sqr + epsilon_) : + var_scale * epsilon_; + var_data[0] = std::pow(var_value, -0.5f); + mean_data[0] = mean_value; + } + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(INPUT); + const std::vector &input_shape = input->shape(); + const index_t rank = input->dim_size(); + const index_t dim = input_shape[rank - 1]; + if (block_dim_ == -1) block_dim_ = static_cast(dim); + MACE_CHECK(target_rms_ > 0 && dim > 0 && dim % block_dim_ == 0); + MACE_CHECK(rank >= 2, "KaldiBatchNorm's input's rank must >= 2."); + index_t num_rows = + std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, + std::multiplies()); + + const index_t blocks = dim / block_dim_; + if (blocks > 1) num_rows *= blocks; + Tensor *output = this->Output(OUTPUT); + MACE_RETURN_IF_ERROR(output->ResizeLike(input)); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + const float *input_data = input->data(); + float *output_data = output->mutable_data(); + + utils::ThreadPool + &thread_pool = context->device()->cpu_runtime()->thread_pool(); + + if (test_mode_) { + MACE_CHECK(this->InputSize() == 3, "KaldiBatchNorm should have 3 inputs"); + const Tensor *scale = this->Input(SCALE); + const Tensor *offset = this->Input(OFFSET); + MACE_CHECK(scale->dim_size() == 1, "scale must be 1-dimensional. ", + scale->dim_size()); + MACE_CHECK(offset->dim_size() == 1, "offset must be 1-dimensional. ", + offset->dim_size()); + MACE_CHECK(scale->size() == offset->size() + && scale->size() == block_dim_); + Tensor::MappingGuard scale_guard(scale); + Tensor::MappingGuard offset_guard(offset); + const float *scale_data = scale->data(); + const float *offset_data = offset->data(); + + thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (index_t i = start0; i < end0; i += step0) { + for (index_t j = start1; j < end1; j += step1) { + index_t idx = i * block_dim_ + j; + output_data[idx] = input_data[idx] * scale_data[j] + offset_data[j]; + } + } + }, 0, num_rows, 1, 0, block_dim_, 1); + } else { + const index_t buf_size = + PadAlignSize(block_dim_ * sizeof(float)); + ScratchBuffer *scratch = context->device()->scratch_buffer(); + scratch->Rewind(); + scratch->GrowSize(2 * buf_size); + + Tensor mean(scratch->Scratch(buf_size), DT_FLOAT); + mean.Reshape({block_dim_}); + float *mean_data = mean.mutable_data(); + + Tensor var(scratch->Scratch(buf_size), DT_FLOAT); + var.Reshape({block_dim_}); + float *var_data = var.mutable_data(); + + float var_scale = 1.0f / (target_rms_ * target_rms_); + float mean_scale = 1.0f / num_rows; + + thread_pool.Compute1D([=](index_t start0, index_t end0, index_t step0) { + for (index_t i = start0; i < end0; i += step0) { + CalculateMeanVar(input_data + i, + num_rows, + block_dim_, + mean_scale, + var_scale, + mean_data + i, + var_data + i); + } + }, 0, block_dim_, 1); + thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (index_t i = start0; i < end0; i += step0) { + for (index_t j = start1; j < end1; j += step1) { + index_t idx = i * block_dim_ + j; + output_data[idx] = (input_data[idx] - mean_data[j]) * var_data[j]; + } + } + }, 0, num_rows, 1, 0, block_dim_, 1); + } + + return MaceStatus::MACE_SUCCESS; + } + + private: + const float epsilon_; + const float target_rms_; + int block_dim_; + const bool test_mode_; + + protected: + MACE_OP_INPUT_TAGS(INPUT, SCALE, OFFSET); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +void RegisterKaldiBatchNorm(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "KaldiBatchNorm", KaldiBatchNormOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/kaldi_batch_norm_benchmark.cc b/mace/ops/kaldi_batch_norm_benchmark.cc new file mode 100644 index 00000000..ac5e117b --- /dev/null +++ b/mace/ops/kaldi_batch_norm_benchmark.cc @@ -0,0 +1,94 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void KaldiBatchNorm( + int iters, int batch, int chunk, int dim, int block_dim) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + if (D == DeviceType::CPU) { + net.AddRandomInput("Input", {batch, chunk, dim}); + } else { + MACE_NOT_IMPLEMENTED; + } + net.AddRandomInput("Scale", {block_dim}, true); + net.AddRandomInput("Offset", {block_dim}, true); + + OpDefBuilder("KaldiBatchNorm", "KaldiBatchNormBM") + .Input("Input") + .Input("Scale") + .Input("Offset") + .AddIntArg("block_dim", block_dim) + .AddIntArg("test_mode", 1) + .Output("Output") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define MACE_BM_KALDI_BATCH_NORM_MACRO(N, C, D, BD, TYPE, DEVICE) \ + static void MACE_BM_KALDI_BATCH_NORM_##N##_##C##_##D##_##BD##_##TYPE\ +##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * D; \ + mace::testing::MacsProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + KaldiBatchNorm(iters, N, C, D, BD); \ + } \ + MACE_BENCHMARK(MACE_BM_KALDI_BATCH_NORM_##N##_##C##_##D##_##BD##_##TYPE\ +##_##DEVICE) + +#define MACE_BM_KALDI_BATCH_NORM(N, C, D, BD) \ + MACE_BM_KALDI_BATCH_NORM_MACRO(N, C, D, BD, float, CPU); + +MACE_BM_KALDI_BATCH_NORM(1, 1, 512, 512); +MACE_BM_KALDI_BATCH_NORM(1, 3, 128, 128); +MACE_BM_KALDI_BATCH_NORM(1, 3, 512, 128); +MACE_BM_KALDI_BATCH_NORM(1, 32, 112, 112); +MACE_BM_KALDI_BATCH_NORM(1, 64, 256, 256); +MACE_BM_KALDI_BATCH_NORM(1, 64, 512, 256); +MACE_BM_KALDI_BATCH_NORM(1, 128, 56, 56); +MACE_BM_KALDI_BATCH_NORM(1, 128, 256, 256); +MACE_BM_KALDI_BATCH_NORM(1, 256, 14, 14); +MACE_BM_KALDI_BATCH_NORM(1, 512, 14, 14); +MACE_BM_KALDI_BATCH_NORM(1, 1024, 7, 7); +MACE_BM_KALDI_BATCH_NORM(32, 1, 256, 128); +MACE_BM_KALDI_BATCH_NORM(32, 3, 256, 256); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/kaldi_batch_norm_test.cc b/mace/ops/kaldi_batch_norm_test.cc new file mode 100644 index 00000000..711d1b09 --- /dev/null +++ b/mace/ops/kaldi_batch_norm_test.cc @@ -0,0 +1,136 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class KaldiBatchNormOpTest : public OpsTestBase {}; + +namespace { +template +void Simple(const std::vector &input_shape, + const std::vector &input_value, + const int block_dim, + const int dim, + const std::vector &scale, + const std::vector &offset, + const std::vector &output_shape, + const std::vector &output_value) { + OpsTestNet net; + int scale_dim = block_dim; + if (scale_dim == -1) scale_dim = dim; + // Add input data + net.AddInputFromArray("Input", input_shape, + input_value); + net.AddInputFromArray("Scale", {scale_dim}, scale, true); + net.AddInputFromArray("Offset", {scale_dim}, offset, true); + + if (D == DeviceType::CPU) { + OpDefBuilder("KaldiBatchNorm", "KaldiBatchNormOpTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .AddIntArg("block_dim", block_dim) + .AddIntArg("test_mode", 1) + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + + net.RunOp(D); + } else if (D == DeviceType::GPU) { + MACE_NOT_IMPLEMENTED; + } + + // Check + auto expected = net.CreateTensor(output_shape, output_value); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-4); +} + +template +void SimpleNotTestMode(const std::vector &input_shape, + const std::vector &input_value, + const int block_dim, + const std::vector &output_shape, + const std::vector &output_value) { + OpsTestNet net; + // Add input data + net.AddInputFromArray("Input", input_shape, + input_value); + + if (D == DeviceType::CPU) { + OpDefBuilder("KaldiBatchNorm", "KaldiBatchNormOpTest") + .Input("Input") + .AddIntArg("block_dim", block_dim) + .AddIntArg("test_mode", 0) + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + + net.RunOp(D); + } else if (D == DeviceType::GPU) { + MACE_NOT_IMPLEMENTED; + } + + // Check + auto expected = net.CreateTensor(output_shape, output_value); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-4); +} +} // namespace + +TEST_F(KaldiBatchNormOpTest, SimpleTestModeCPUOneBlock) { + Simple( + {1, 6, 2}, + {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}, + -1, 2, + {4.f, 6.f}, + {2.f, 1.f}, + {1, 6, 2}, + {22, 31, 30, 43, 38, 55, 46, 67, 54, 79, 62, 91}); } + +TEST_F(KaldiBatchNormOpTest, SimpleTestModeCPUTwoBlock) { + Simple( + {1, 6, 4}, + {5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9, + 11, 11, 11, 11, 13, 13, 13, 13, 15, 15, 15, 15}, + 2, 4, + {4.f, 6.f}, + {2.f, 1.f}, + {1, 6, 4}, + {22, 31, 22, 31, 30, 43, 30, 43, 38, 55, 38, 55, + 46, 67, 46, 67, 54, 79, 54, 79, 62, 91, 62, 91}); +} + +TEST_F(KaldiBatchNormOpTest, SimpleNotTestModeCPUTwoBlock) { + SimpleNotTestMode( + {1, 6, 4}, + {5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9, + 11, 11, 11, 11, 13, 13, 13, 13, 15, 15, 15, 15}, + 2, + {1, 6, 4}, + {-1.46379, -1.46379, -1.46379, -1.46379, + -0.8783, -0.8783, -0.8783, -0.8783, + -0.29276, -0.29276, -0.29276, -0.29276, + 0.29276, 0.29276, 0.29276, 0.29276, + 0.8783, 0.8783, 0.8783, 0.8783, + 1.46379, 1.46379, 1.46379, 1.46379}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/ops_registry.cc b/mace/ops/ops_registry.cc index 26bf0463..e975b265 100644 --- a/mace/ops/ops_registry.cc +++ b/mace/ops/ops_registry.cc @@ -37,11 +37,14 @@ extern void RegisterDepthwiseDeconv2d(OpRegistryBase *op_registry); extern void RegisterDynamicLSTM(OpRegistryBase *op_registry); extern void RegisterEltwise(OpRegistryBase *op_registry); extern void RegisterExpandDims(OpRegistryBase *op_registry); +extern void RegisterExtractPooling(OpRegistryBase *op_registry); extern void RegisterFill(OpRegistryBase *op_registry); extern void RegisterFullyConnected(OpRegistryBase *op_registry); extern void RegisterGather(OpRegistryBase *op_registry); extern void RegisterIdentity(OpRegistryBase *op_registry); +extern void RegisterDelay(OpRegistryBase *op_registry); extern void RegisterInferConv2dShape(OpRegistryBase *op_registry); +extern void RegisterKaldiBatchNorm(OpRegistryBase *op_registry); extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry); extern void RegisterLSTMNonlinear(OpRegistryBase *op_registry); extern void RegisterMatMul(OpRegistryBase *op_registry); @@ -107,11 +110,14 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ops::RegisterDynamicLSTM(this); ops::RegisterEltwise(this); ops::RegisterExpandDims(this); + ops::RegisterExtractPooling(this); ops::RegisterFill(this); ops::RegisterFullyConnected(this); ops::RegisterGather(this); ops::RegisterIdentity(this); + ops::RegisterDelay(this); ops::RegisterInferConv2dShape(this); + ops::RegisterKaldiBatchNorm(this); ops::RegisterLocalResponseNorm(this); ops::RegisterLSTMNonlinear(this); ops::RegisterMatMul(this); diff --git a/mace/ops/pad_context.cc b/mace/ops/pad_context.cc index 8370f9f5..25117df2 100644 --- a/mace/ops/pad_context.cc +++ b/mace/ops/pad_context.cc @@ -42,7 +42,7 @@ class PadContextOp : public Operation { index_t rank = input->dim_size(); MACE_CHECK(rank >= 2, "input's rank should >= 2."); - MACE_CHECK(left_context_ > 0 && right_context_ > 0, + MACE_CHECK(left_context_ >= 0 && right_context_ >= 0, "left context and right context should be greater than zero"); const std::vector &input_shape = input->shape(); const index_t batch = diff --git a/mace/ops/reduce_test.cc b/mace/ops/reduce_test.cc index fc284084..ccf38fea 100644 --- a/mace/ops/reduce_test.cc +++ b/mace/ops/reduce_test.cc @@ -106,19 +106,6 @@ void SimpleMean12Test() { 10, 11, 12, 13}, ReduceType::MEAN); } -// template -// void SimpleSum12Test() { -// Simple({2, 2, 3, 4}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, -// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, -// {1, 2}, -// {2, 1, 1, 4}, -// {60, 66, 72, 78, -// 60, 66, 72, 78}, ReduceType::SUM); -//} - template void SimpleMin12Test() { Simple({2, 2, 3, 4}, @@ -145,20 +132,6 @@ void SimpleMax12Test() { 20, 21, 22, 23}, ReduceType::MAX); } -// template -// void SimpleSumSqr12Test() { -// Simple({2, 2, 3, 4}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, -// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, -// {1, 2}, -// {2, 1, 1, 4}, -// {880, 1006, 1144, 1294, -// 880, 1006, 1144, 1294}, ReduceType::SUM_SQR); -//} - - template void SimpleMean1Axis() { Simple({2, 2, 3, 4}, @@ -170,102 +143,8 @@ void SimpleMean1Axis() { {2, 1, 3, 4}, {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, -// {-3}, -// {1, 1, 3, 4}, -// {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, -// {2}, -// {1, 2, 1, 4}, -// {4, 5, 6, 7, 16, 17, 18, 19}, ReduceType::MEAN); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, -// {-1}, -// {1, 2, 3, 1}, -// {1.5, 5.5, 9.5, 13.5, 17.5, 21.5}, ReduceType::MEAN); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {1}, -// {1, 1, 3, 3}, -// {9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {-2}, -// {1, 3, 1, 3}, -// {3, 4, 5, 12, 13, 14, 21, 22, 23}, ReduceType::MEAN); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {3}, -// {1, 3, 3, 1}, -// {1, 4, 7, 10, 13, 16, 19, 22, 25}, ReduceType::MEAN); } -// template -// void SimpleSum1Axis() { -// Simple({2, 2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23, -// 0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {1}, -// {2, 1, 3, 4}, -// {12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, -// 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34}); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {2}, -// {1, 2, 1, 4}, -// {12, 15, 18, 21, 48, 51, 54, 57}, ReduceType::SUM); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {-1}, -// {1, 2, 3, 1}, -// {6, 22, 38, 54, 70, 86}, ReduceType::SUM); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {1}, -// {1, 1, 3, 3}, -// {27, 30, 33, 36, 39, 42, 45, 48, 51}, ReduceType::SUM); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {3}, -// {1, 3, 3, 1}, -// {3, 12, 21, 30, 39, 48, 57, 66, 75}, ReduceType::SUM); -//} - template void SimpleMin1Axis() { Simple({2, 2, 3, 4}, @@ -285,33 +164,6 @@ void SimpleMin1Axis() { {2, 1, 3, 4}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, ReduceType::MIN); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {2}, -// {1, 2, 1, 4}, -// {0, 1, 2, 3, 12, 13, 14, 15}, ReduceType::MIN); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {-1}, -// {1, 2, 3, 1}, -// {0, 4, 8, 12, 16, 20}, ReduceType::MIN); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {1}, -// {1, 1, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8}, ReduceType::MIN); } template @@ -337,53 +189,8 @@ void SimpleMax1Axis() { 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, ReduceType::MAX); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {2}, -// {1, 2, 1, 4}, -// {8, 9, 10, 11, 20, 21, 22, 23}, ReduceType::MAX); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {-1}, -// {1, 2, 3, 1}, -// {3, 7, 11, 15, 19, 23}, ReduceType::MAX); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {1}, -// {1, 1, 3, 3}, -// {18, 19, 20, 21, 22, 23, 24, 25, 26}, ReduceType::MAX); } -// template -// void SimpleSumSqr1Axis() { -// Simple({2, 2, 3, 4}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, -// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, -// {1}, -// {2, 1, 3, 4}, -// {144, 170, 200, 234, -// 272, 314, 360, 410, -// 464, 522, 584, 650, -// 144, 170, 200, 234, -// 272, 314, 360, 410, -// 464, 522, 584, 650}, ReduceType::SUM_SQR); -//} - - template void Simple2Axis() { Simple({1, 2, 3, 4}, @@ -396,16 +203,6 @@ void Simple2Axis() { {0, 1}, {1, 1, 3, 4}, {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); -// Simple3D({2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {0, 1}, -// {1, 1, 4}, -// {10, 11, 12, 13}, ReduceType::MEAN); Simple3D({2, 3, 4}, {0, 1, 2, 3, 4, 5, 6, 7, @@ -426,37 +223,6 @@ void Simple2Axis() { {0, 2}, {1, 2, 1, 4}, {4, 5, 6, 7, 16, 17, 18, 19}, ReduceType::MEAN); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {1, 3}, -// {1, 1, 3, 1}, -// {7.5, 11.5, 15.5}, ReduceType::MEAN); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {1, 2}, -// {1, 1, 1, 3}, -// {12, 13, 14}, ReduceType::MEAN); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {0, 1}, -// {1, 1, 3, 3}, -// {9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {2, 3}, -// {1, 3, 1, 1}, -// {4, 13, 22}, ReduceType::MEAN); } template @@ -471,64 +237,6 @@ void Simple3Axis() { {1, 2, 3}, {1, 1, 1, 1}, {11.5}, ReduceType::MEAN); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {0, 2, 3}, -// {1, 2, 1, 1}, -// {5.5, 17.5}, ReduceType::MEAN); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {0, 1, 3}, -// {1, 1, 3, 1}, -// {7.5, 11.5, 15.5}, ReduceType::MEAN); -// Simple({1, 2, 3, 4}, -// {0, 1, 2, 3, -// 4, 5, 6, 7, -// 8, 9, 10, 11, -// 12, 13, 14, 15, -// 16, 17, 18, 19, -// 20, 21, 22, 23}, -// {0, 1, 2}, -// {1, 1, 1, 4}, -// {10, 11, 12, 13}, ReduceType::MEAN); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {1, 2, 3}, -// {1, 1, 1, 1}, -// {13}, ReduceType::MEAN); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {0, 2, 3}, -// {1, 3, 1, 1}, -// {4, 13, 22}, ReduceType::MEAN); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {0, 1, 3}, -// {1, 1, 3, 1}, -// {10, 13, 16}, ReduceType::MEAN); -// Simple({1, 3, 3, 3}, -// {0, 1, 2, 3, 4, 5, 6, 7, 8, -// 9, 10, 11, 12, 13, 14, 15, 16, 17, -// 18, 19, 20, 21, 22, 23, 24, 25, 26}, -// {0, 1, 2}, -// {1, 1, 1, 3}, -// {12, 13, 14}, ReduceType::MEAN); } } // namespace @@ -622,25 +330,17 @@ void RandomTest(const std::vector &input_shape, TEST_F(ReduceOpTest, GPURandomFloat) { RandomTest({4, 64, 64, 3}, {1, 2}); -// RandomTest({2, 64, 64, 4}, {1, 2}); RandomTest({8, 128, 128, 64}, {1, 2}); -// RandomTest({1, 640, 480, 64}, {1, 2}); RandomTest({1, 480, 640, 32}, {1, 2}); -// RandomTest({1, 512, 512, 16}, {1, 2}); RandomTest({8, 117, 87, 33}, {1, 2}); -// RandomTest({1, 619, 450, 61}, {1, 2}); RandomTest({1, 511, 561, 11}, {1, 2}); } TEST_F(ReduceOpTest, GPURandomHalf) { RandomTest({4, 64, 64, 3}, {1, 2}); -// RandomTest({2, 64, 64, 4}, {1, 2}); RandomTest({8, 128, 128, 64}, {1, 2}); -// RandomTest({1, 640, 480, 64}, {1, 2}); RandomTest({1, 480, 640, 32}, {1, 2}); -// RandomTest({1, 512, 512, 16}, {1, 2}); RandomTest({8, 117, 87, 33}, {1, 2}); -// RandomTest({1, 619, 450, 61}, {1, 2}); RandomTest({1, 511, 561, 11}, {1, 2}); } diff --git a/mace/ops/splice.cc b/mace/ops/splice.cc index 6d477329..f63e9e5e 100644 --- a/mace/ops/splice.cc +++ b/mace/ops/splice.cc @@ -71,7 +71,8 @@ class SpliceOp : public Operation { const index_t out_chunk = chunk - (right_context - left_context); MACE_CHECK(input_dim > const_dim_, - "input dim should be greater than const dim."); + "input dim:", input_dim, + "should be greater than const dim:", const_dim_); const index_t output_dim = dim * num_splice + const_dim_; const index_t output_stride = out_chunk * output_dim; @@ -103,7 +104,7 @@ class SpliceOp : public Operation { const index_t input_offset = dim; for (int b = 0; b < batch; ++b) { for (index_t i = 0; i < out_chunk; ++i) { - T *output_base = output_data + + b * output_stride + i * output_dim; + T *output_base = output_data + b * output_stride + i * output_dim; const T *input_base = input_data + b * input_stride + i * input_dim; memcpy(output_base + output_offset, input_base + input_offset, diff --git a/mace/ops/sum_group.cc b/mace/ops/sum_group.cc index 0efdfe2a..1b62af7e 100644 --- a/mace/ops/sum_group.cc +++ b/mace/ops/sum_group.cc @@ -80,18 +80,22 @@ class SumGroupOp : public Operation { MACE_CHECK(cur_index <= input_dim) << "size value over-ranged:" << cur_index << "<=" << input_dim; } - - for (index_t i = 0; i < bh; ++i) { - for (index_t j = 0; j < output_dim; ++j) { - int start_col = sum_indexes[j].first; - int end_col = sum_indexes[j].second; - T sum = 0; - for (int src_col = start_col; src_col < end_col; ++src_col) { - sum += input_data[i * input_dim + src_col]; + utils::ThreadPool + &thread_pool = context->device()->cpu_runtime()->thread_pool(); + thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (index_t i = start0; i < end0; i += step0) { + for (index_t j = start1; j < end1; j += step1) { + int start_col = sum_indexes[j].first; + int end_col = sum_indexes[j].second; + T sum = 0; + for (int src_col = start_col; src_col < end_col; ++src_col) { + sum += input_data[i * input_dim + src_col]; + } + output_data[i * output_dim + j] = sum; } - output_data[i * output_dim + j] = sum; } - } + }, 0, bh, 1, 0, output_dim, 1); return MaceStatus::MACE_SUCCESS; } diff --git a/mace/ops/target_rms_norm.cc b/mace/ops/target_rms_norm.cc index 80d42a1d..6caf1ce5 100644 --- a/mace/ops/target_rms_norm.cc +++ b/mace/ops/target_rms_norm.cc @@ -35,7 +35,11 @@ class TargetRMSNormOp : public Operation { public: explicit TargetRMSNormOp(OpConstructContext *context) : Operation(context), - target_rms_(Operation::GetOptionalArg("target_rms", 1.0)) {} + target_rms_(Operation::GetOptionalArg("target_rms", 1.0)), + add_log_stddev_( + static_cast( + Operation::GetOptionalArg("add_log_stddev", 0))), + block_dim_(Operation::GetOptionalArg("block_dim", 0)) {} // Calculate the square sum of an array float SquareSum(const float *data, const index_t data_len) { @@ -67,6 +71,24 @@ class TargetRMSNormOp : public Operation { return result; } + + void NormalizePerRow(const float *data, + const index_t data_len, + float d_scale, + bool add_log_stddev, + float *out_data) { + float scale = SquareSum(data, data_len); + scale = scale / d_scale; + scale = scale < 1.0e-6f ? 1.0e-6f : scale; + scale = static_cast(1.0 / std::sqrt(scale)); + for (index_t j = 0; j < data_len; ++j) { + out_data[j] = data[j] * scale; + } + if (add_log_stddev) { + out_data[data_len] = std::log(target_rms_) - std::log(scale); + } + } + MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); const Tensor *input = this->Input(0); @@ -75,15 +97,18 @@ class TargetRMSNormOp : public Operation { const index_t dim_size = input->dim_size(); MACE_CHECK(dim_size >= 1, "TargetRMSNorm's input dim size should be >= 1."); - const index_t dim = input_shape[dim_size -1]; - MACE_CHECK(dim > 0 && target_rms_ > 0, + const index_t input_dim = input_shape[dim_size -1]; + MACE_CHECK(input_dim > 0 && target_rms_ > 0, "Both input dim and target rms should be greater than zero."); const index_t bh = std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, std::multiplies()); - const float d_scale = dim * target_rms_ * target_rms_; - - MACE_RETURN_IF_ERROR(output->ResizeLike(input)); + if (block_dim_ == 0) block_dim_ = static_cast(input_dim); + const index_t output_dim = add_log_stddev_ ? + input_dim + (input_dim / block_dim_) : input_dim; + std::vector output_shape = input->shape(); + output_shape[dim_size - 1] = output_dim; + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); Tensor::MappingGuard guard_input(input); Tensor::MappingGuard guard_output(output); @@ -91,19 +116,37 @@ class TargetRMSNormOp : public Operation { const float *input_data = input->data(); float *output_data = output->mutable_data(); - for (index_t i = 0; i < bh; ++i) { - float scale = SquareSum(input_data + i * dim, dim); - scale = static_cast(1.0 / std::sqrt(scale / d_scale)); - for (index_t j = 0; j < dim; ++j) { - output_data[i * dim + j] = input_data[i * dim + j] * scale; - } + index_t num_rows = bh; + index_t output_block_dim = add_log_stddev_ ? block_dim_ + 1 : block_dim_; + + if (block_dim_ != input_dim) { + index_t num_blocks = input_dim / block_dim_; + num_rows *= num_blocks; } + const float d_scale = block_dim_ * target_rms_ * target_rms_; + + utils::ThreadPool + &thread_pool = context->device()->cpu_runtime()->thread_pool(); + thread_pool.Compute1D([=](index_t start0, index_t end0, index_t step0) { + for (index_t i = start0; i < end0; i += step0) { + const float *input_ptr = input_data + i * block_dim_; + float *out_ptr = output_data + i * output_block_dim; + NormalizePerRow(input_ptr, + block_dim_, + d_scale, + add_log_stddev_, + out_ptr); + } + }, 0, num_rows, 1); + return MaceStatus::MACE_SUCCESS; } private: float target_rms_; + bool add_log_stddev_; + int block_dim_; }; void RegisterTargetRMSNorm(OpRegistryBase *op_registry) { diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 9e2c8ae6..80da9b1d 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -112,17 +112,20 @@ MaceSupportedOps = [ 'Conv2D', 'Crop', 'Deconv2D', + 'Delay', 'DepthToSpace', 'DepthwiseConv2d', 'DepthwiseDeconv2d', 'Dequantize', 'Eltwise', 'ExpandDims', + 'ExtractPooling', 'Fill', 'FullyConnected', 'Gather', 'Identity', 'InferConv2dShape', + 'KaldiBatchNorm', 'LocalResponseNorm', 'LSTMCell', 'LstmNonlinear', diff --git a/mace/python/tools/converter_tool/onnx_converter.py b/mace/python/tools/converter_tool/onnx_converter.py index 99ae2a79..54d53db0 100644 --- a/mace/python/tools/converter_tool/onnx_converter.py +++ b/mace/python/tools/converter_tool/onnx_converter.py @@ -41,6 +41,15 @@ from numbers import Number IS_PYTHON3 = sys.version_info > (3,) + +class AttributeType(Enum): + INT = 100 + FLOAT = 101 + INTS = 102 + FLOATS = 103 + BOOL = 104 + + OnnxSupportedOps = [ 'Abs', # 'Acos', @@ -57,6 +66,7 @@ OnnxSupportedOps = [ # 'Atanh', 'AveragePool', 'BatchNormalization', + 'BatchNorm', 'Cast', # 'Ceil', # 'Clip', @@ -72,11 +82,12 @@ OnnxSupportedOps = [ 'DimRange', 'Div', 'Dropout', - 'DynamicLstmCell', + 'DynamicLSTM', 'Elu', 'Equal', # 'Exp', # 'Expand', + 'ExtractPooling', # 'EyeLike', # 'Flatten', # 'Floor', @@ -91,7 +102,7 @@ OnnxSupportedOps = [ # 'Hardmax', 'Identity', # 'If', - # 'IfDefined', + 'IfDefined', 'ImageScaler', # 'InstanceNormalization', # 'LRN', @@ -318,6 +329,7 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.ArgMin.name: self.convert_argmax, OnnxOpType.AveragePool.name: self.convert_pooling, OnnxOpType.BatchNormalization.name: self.convert_fused_batchnorm, + OnnxOpType.BatchNorm.name: self.convert_fused_batchnorm, OnnxOpType.Cast.name: self.convert_cast, OnnxOpType.Concat.name: self.convert_concat, OnnxOpType.Conv.name: self.convert_conv2d, @@ -327,16 +339,18 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.DimRange.name: self.convert_dim_range, OnnxOpType.Div.name: self.convert_eltwise, OnnxOpType.Equal.name: self.convert_eltwise, + OnnxOpType.ExtractPooling.name: self.convert_extract_pooling, OnnxOpType.Gather.name: self.convert_gather, OnnxOpType.Gemm.name: self.convert_gemm, OnnxOpType.GlobalAveragePool.name: self.convert_reduce, OnnxOpType.GlobalMaxPool.name: self.convert_reduce, OnnxOpType.Identity.name: self.convert_identity, + OnnxOpType.IfDefined.name: self.convert_ifdefined, OnnxOpType.ImageScaler.name: self.convert_imagescaler, OnnxOpType.LeakyRelu.name: self.convert_activation, OnnxOpType.LogSoftmax.name: self.convert_softmax, OnnxOpType.LstmNonlinear.name: self.convert_lstm_nonlinear, - OnnxOpType.DynamicLstmCell.name: self.convert_dynamic_lstm, + OnnxOpType.DynamicLSTM.name: self.convert_dynamic_lstm, OnnxOpType.Max.name: self.convert_eltwise, OnnxOpType.MaxPool.name: self.convert_pooling, OnnxOpType.MatMul.name: self.convert_matmul, @@ -378,6 +392,8 @@ class OnnxConverter(base_converter.ConverterInterface): ir_version = onnx_model.ir_version opset_imp = onnx_model.opset_import + self._isKaldi = False + polish_available = True print("onnx model IR version: ", ir_version) for imp in opset_imp: @@ -387,6 +403,7 @@ class OnnxConverter(base_converter.ConverterInterface): if 'kaldi2onnx' in domain: polish_available = False self._data_format = DataFormat.DF_NONE + self._isKaldi = True if polish_available: onnx_model = onnx.utils.polish_model(onnx_model) @@ -643,10 +660,10 @@ class OnnxConverter(base_converter.ConverterInterface): mace_check(axis_value == 1 or axis_value == -3, "only support concat at channel dimension") elif node.op_type == OnnxOpType.Append.name: - axis_value = 1 + axis_value = -1 axis_arg = op.arg.add() axis_arg.name = MaceKeyword.mace_axis_str - axis_arg.i = 4 + axis_value if axis_value < 0 else axis_value + axis_arg.i = axis_value def convert_conv2d(self, node): op = self.convert_general_op(node) @@ -772,15 +789,15 @@ class OnnxConverter(base_converter.ConverterInterface): op = self.convert_general_op(node) op.type = MaceOp.DynamicLSTM.name - if 'delay_a' in node.attrs: - prev_out_delay = node.attrs['delay_a'] + if 'prev_out_delay' in node.attrs: + prev_out_delay = node.attrs['prev_out_delay'] mace_check(prev_out_delay < 0, "dynamic's prev_out_delay should <= 0.") prev_out_delay_arg = op.arg.add() prev_out_delay_arg.name = 'prev_out_delay' prev_out_delay_arg.i = prev_out_delay - if 'delay_b' in node.attrs: - prev_cell_delay = node.attrs['delay_b'] + if 'prev_cell_delay' in node.attrs: + prev_cell_delay = node.attrs['prev_cell_delay'] mace_check(prev_cell_delay < 0, "dynamic's prev_cell_delay should < 0.") prev_cell_delay_arg = op.arg.add() @@ -788,20 +805,20 @@ class OnnxConverter(base_converter.ConverterInterface): prev_cell_delay_arg.i = prev_cell_delay if 'prev_out_offset' in node.attrs: prev_out_offset = node.attrs['prev_out_offset'] - mace_check(pre_out_offset >= 0, + mace_check(prev_out_offset >= 0, "dynamic's prev_out_offset should >= 0.") prev_out_offset_arg = op.arg.add() prev_out_offset_arg.name = 'prev_out_offset' prev_out_offset_arg.i = prev_out_offset - if 'prev_a_dim' in node.attrs: - prev_out_dim = node.attrs['prev_a_dim'] + if 'prev_out_dim' in node.attrs: + prev_out_dim = node.attrs['prev_out_dim'] mace_check(prev_out_dim > 0, "dynamic's prev_out_dim should > 0.") prev_out_dim_arg = op.arg.add() prev_out_dim_arg.name = 'prev_out_dim' prev_out_dim_arg.i = prev_out_dim - if 'prev_b_dim' in node.attrs: - prev_cell_dim = node.attrs['prev_b_dim'] + if 'prev_cell_dim' in node.attrs: + prev_cell_dim = node.attrs['prev_cell_dim'] mace_check(prev_cell_dim > 0, "dynamic's prev_cell_dim should > 0.") prev_cell_dim_arg = op.arg.add() @@ -844,11 +861,154 @@ class OnnxConverter(base_converter.ConverterInterface): value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.f = value + @staticmethod + def copy_node_attr(op, node, attr_name, dtype=AttributeType.INT, + default=None): + if attr_name in node.attrs or default is not None: + if attr_name in node.attrs: + value = node.attrs[attr_name] + else: + value = default + new_arg = op.arg.add() + new_arg.name = attr_name + if dtype == AttributeType.INT: + new_arg.i = int(value) + elif dtype == AttributeType.FLOAT: + new_arg.f = float(value) + elif dtype == AttributeType.INTS: + new_arg.ints.extend(value) + elif dtype == AttributeType.FLOATS: + new_arg.floats.extend(value) + return value + else: + return default + + def convert_extract_pooling(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.ExtractPooling.name + + self.copy_node_attr(op, node, 'include_variance', AttributeType.INT) + self.copy_node_attr(op, node, 'num_log_count', AttributeType.INT) + self.copy_node_attr(op, node, 'variance_floor', AttributeType.FLOAT) + self.copy_node_attr(op, node, 'input_time_range', AttributeType.INTS) + self.copy_node_attr(op, node, 'input_indexes', AttributeType.INTS) + + if 'output_time_range' in node.attrs: + output_time_range = node.attrs['output_time_range'] + mace_check(len(output_time_range) == 2, + "output time range should have two values.") + out_start_index = output_time_range[0] + out_end_index = output_time_range[1] + else: + mace_check('start_index' in node.attrs and + 'end_index' in node.attrs, + "'start_index' and 'end_index'" + " are required in ExtractPooling.") + out_start_index = node.attrs['start_index'] + out_end_index = node.attrs['end_index'] + output_time_range = [out_start_index, out_end_index] + + output_time_range_arg = op.arg.add() + output_time_range_arg.name = 'output_time_range' + output_time_range_arg.ints.extend(output_time_range) + + mace_check('modulus' in node.attrs, + "'modulus' is required in ExtractPooling.") + mace_check('output_indexes' in node.attrs, + "'output_indexes' is required in ExtractPooling.") + mace_check('counts' in node.attrs, + "'counts' is required in ExtractPooling.") + mace_check('forward_indexes' in node.attrs, + "'forward_indexes' is required in ExtractPooling.") + modulus = node.attrs['modulus'] + output_indexes = node.attrs['output_indexes'] + counts = node.attrs['counts'] + forward_indexes = node.attrs['forward_indexes'] + + mace_check(len(counts) == len(output_indexes) and + len(forward_indexes) == 2 * len(output_indexes), + "output_indexes length:%s " + "counts length:%s " + "forward_indexes length:%s" + % (len(output_indexes), len(counts), len(forward_indexes))) + + new_output_indexes = [] + new_forward_indexes = [] + new_counts = [] + for i in range(len(output_indexes)): + if output_indexes[i] + modulus > out_start_index and\ + output_indexes[i] <= out_end_index: + new_output_indexes.append(output_indexes[i]) + new_counts.append(counts[i]) + new_forward_indexes.append(forward_indexes[2 * i]) + new_forward_indexes.append(forward_indexes[2 * i + 1]) + modulus_arg = op.arg.add() + modulus_arg.name = 'modulus' + modulus_arg.i = modulus + + counts_arg = op.arg.add() + counts_arg.name = 'counts' + counts_arg.floats.extend(new_counts) + + forward_indexes_arg = op.arg.add() + forward_indexes_arg.name = 'forward_indexes' + forward_indexes_arg.ints.extend(new_forward_indexes) + + output_indexes_arg = op.arg.add() + output_indexes_arg.name = 'output_indexes' + output_indexes_arg.ints.extend(new_output_indexes) + def convert_flatten(self, node): op = self.convert_general_op(node) op.type = MaceOp.Reshape.name + def convert_kaldi_batchnorm(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.KaldiBatchNorm.name + dim = self.copy_node_attr(op, node, + 'dim', AttributeType.INT, -1) + block_dim = self.copy_node_attr(op, node, + 'block_dim', + AttributeType.INT, -1) + epsilon = self.copy_node_attr(op, node, + 'epsilon', + AttributeType.FLOAT, 1e-3) + target_rms = self.copy_node_attr(op, node, + 'target_rms', + AttributeType.FLOAT, 1.0) + test_mode = self.copy_node_attr(op, node, + 'test_mode', + AttributeType.INT, 0) + mace_check(block_dim > 0 and + dim % block_dim == 0 and + epsilon > 0 and + target_rms > 0, "attributes invalid.") + + if test_mode > 0: + mace_check(len(node.inputs) == 3, + "Kaldi's BatchNorm should have 3 inputs.") + stats_mean = np.array(self._consts[node.inputs[1]].float_data) + stats_var = np.array(self._consts[node.inputs[2]].float_data) + offset_value = -1.0 * stats_mean + scale_value = stats_var + scale_value[scale_value < 0] = 0 + scale_value = np.power(scale_value + epsilon, -0.5) * target_rms + offset_value = offset_value * scale_value + scale_name = node.name + '_scale' + offset_name = node.name + '_offset' + self.add_tensor(scale_name, scale_value.shape, + mace_pb2.DT_FLOAT, scale_value) + self.add_tensor(offset_name, offset_value.shape, + mace_pb2.DT_FLOAT, offset_value) + del op.input[1:] + op.input.extend([scale_name, offset_name]) + del op.output[1:] + del op.output_shape[1:] + def convert_fused_batchnorm(self, node): + if self._isKaldi: + self.convert_kaldi_batchnorm(node) + return op = self.convert_general_op(node) op.type = MaceOp.BatchNorm.name @@ -946,6 +1106,21 @@ class OnnxConverter(base_converter.ConverterInterface): op = self.convert_general_op(node) op.type = MaceOp.Identity.name + def convert_ifdefined(self, node): + op = self.convert_general_op(node) + if 'offset' in node.attrs: + offset = node.attrs['offset'] + else: + offset = 0 + mace_check(offset <= 0, "IfDefined's offset should be <= 0.") + if offset == 0: + op.type = MaceOp.Identity.name + else: + op.type = MaceOp.Delay.name + offset_arg = op.arg.add() + offset_arg.name = 'offset' + offset_arg.i = node.attrs['offset'] + def convert_imagescaler(self, node): op = self.convert_general_op(node) op.type = MaceOp.BatchNorm.name @@ -1100,7 +1275,6 @@ class OnnxConverter(base_converter.ConverterInterface): def convert_softmax(self, node): op = self.convert_general_op(node) op.type = MaceOp.Softmax.name - # TODO: add logsoftmax in softmax op if node.op_type == OnnxOpType.LogSoftmax.name: use_log_arg = op.arg.add() use_log_arg.name = 'use_log' @@ -1164,11 +1338,11 @@ class OnnxConverter(base_converter.ConverterInterface): op = self.convert_general_op(node) op.type = MaceOp.TargetRMSNorm.name - if 'target_rms' in node.attrs: - value = node.attrs['target_rms'] - target_rms_arg = op.arg.add() - target_rms_arg.name = 'target_rms' - target_rms_arg.f = value + self.copy_node_attr(op, node, 'target_rms', AttributeType.FLOAT) + self.copy_node_attr(op, node, 'add_log_stddev', AttributeType.INT, + default=0) + self.copy_node_attr(op, node, 'block_dim', AttributeType.INT, + default=0) def convert_transpose(self, node): op = self.convert_general_op(node) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index e654a8fb..faf33034 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -1130,7 +1130,7 @@ class Transformer(base_converter.ConverterInterface): rhs = op.input[1] if rhs in self._consts and len(self._consts[rhs].dims) == 2: arg = ConverterUtil.get_arg(op, MaceKeyword.mace_transpose_b_str) # noqa - six.print_("Transpose matmul weight %s" % rhs) + # six.print_("Transpose matmul weight %s" % rhs) if arg is None: arg = op.arg.add() arg.name = MaceKeyword.mace_transpose_b_str @@ -1143,7 +1143,8 @@ class Transformer(base_converter.ConverterInterface): filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape arg.i = 1 - six.print_('transpose matmul weight') + six.print_('Transpose matmul weight to shape:', + filter.dims) def transpose_filters(self): net = self._model -- GitLab