提交 a7afda0e 编写于 作者: L liutuo

add kaldi op IfDefined

add extract pooling and kaldi batchnorm

add time index info in extract pooling

add benchmar& test for kaldi components

add batch support for kaldi

change ifdefined to delay
上级 77df54f2
......@@ -86,7 +86,8 @@ class ConcatOp<DeviceType::CPU, T> : public ConcatOpBase {
continue;
}
MACE_CHECK(input->dim(j) == input0->dim(j),
"Dimensions of inputs should equal except axis.");
"Dimensions of inputs should equal except axis: ",
input->dim(j), "!=", input0->dim(j));
}
outer_sizes[i] = input->size() / inner_size;
output_shape[axis] += input->dim(axis);
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for IfDefined descriptor in Kaldi.
// It defines time offset.
// If time index <= offset, using zeros as output.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class DelayOp;
template <typename T>
class DelayOp<DeviceType::CPU, T> : public Operation {
public:
explicit DelayOp(OpConstructContext *context)
: Operation(context),
offset_(Operation::GetOptionalArg<int>("offset", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
MACE_CHECK(offset_ < 0, "offset param should be negative.");
index_t rank = input->dim_size();
MACE_CHECK(rank >= 2, "input's rank should >= 2.");
const std::vector<index_t> &input_shape = input->shape();
const index_t batch =
std::accumulate(input_shape.begin(), input_shape.end() - 2, 1,
std::multiplies<index_t>());
const index_t chunk = input_shape[rank - 2];
const index_t dim = input_shape[rank - 1];
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
output->Clear();
if (chunk <= -offset_)
return MaceStatus::MACE_SUCCESS;
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
memcpy(output_data + (i * chunk + j - offset_) * dim,
input_data + (i * chunk + j) * dim,
dim * sizeof(T));
}
}
}, 0, batch, 1, 0, chunk + offset_, 1);
return MaceStatus::MACE_SUCCESS;
}
private:
int offset_;
};
void RegisterDelay(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Delay", DelayOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D, typename T>
static void Delay(int iters,
int batch,
int chunk,
int dim,
int offset) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, chunk, dim});
OpDefBuilder("Delay", "DelayTest")
.Input("Input")
.Output("Output")
.AddIntArg("offset", -offset)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.Run();
}
net.Sync();
}
#define MACE_BM_DELAY_MACRO(N, C, D, OFFSET, TYPE, DEVICE) \
static void MACE_BM_DELAY_##N##_##C##_##D##_##OFFSET##_##TYPE##_##DEVICE(\
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * D; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Delay<DEVICE, TYPE>(iters, N, C, D, OFFSET); \
} \
MACE_BENCHMARK(MACE_BM_DELAY_##N##_##C##_##D##_##OFFSET##_##TYPE\
##_##DEVICE)
#define MACE_BM_DELAY(N, C, D, OFFSET) \
MACE_BM_DELAY_MACRO(N, C, D, OFFSET, float, CPU);
MACE_BM_DELAY(8, 40, 512, 2);
MACE_BM_DELAY(16, 80, 100, 3);
MACE_BM_DELAY(32, 60, 200, 5);
} // namespace test
} // namespace ops
} // namespace mace
......@@ -214,7 +214,7 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
Tensor *output = this->Output(OUTPUT);
std::vector<index_t> output_shape = input->shape();
output_shape[1] = output_dim;
output_shape[input_rank - 1] = output_dim;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
......@@ -235,8 +235,10 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
affine_b_in.Clear();
affine_b_out.Clear();
for (int i = 0; i < chunk; ++i) {
const float *input_ptr = input_data + (b * chunk + i) * input_dim;
float *output_ptr = output_data + (b * chunk + i) * output_dim;
// Append
memcpy(affine_a_in_data, input_data, input_dim * sizeof(float));
memcpy(affine_a_in_data, input_ptr, input_dim * sizeof(float));
if (prev_out_idx >= 0) {
memcpy(affine_a_in_data + input_dim,
prev_out_data + prev_out_idx % out_buf_chunk * prev_out_dim_,
......@@ -283,7 +285,7 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
false,
&affine_b_out);
// Output
memcpy(output_data,
memcpy(output_ptr,
affine_b_out_data,
output_dim * sizeof(float));
// Update
......@@ -292,8 +294,6 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
prev_out_dim_,
scale_,
curr_out_ptr);
input_data += input_dim;
output_data += output_dim;
prev_out_idx++;
prev_cell_idx++;
}
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for fused StatisticsExtraction, StatisticsPooling and
// Round Components in Kaldi.
// This op is used to extract moving-average mean and standard-deviation
// statistics of input data.
// 'input_indexes' indicates which frames will be used for extract statistics.
// 'output_indexes' indicates which frames of outputs will be used to
// save statistics results.
// 'modulus' will be used for extent results to all frames.
// 'start_index' and 'end_index' indicate time indexes of output frames.
// 'forward_indexes' and 'count' were from precomputed index in kaldi.
// Reference to
// http://kaldi-asr.org/doc/nnet-general-component_8h_source.html#l00158
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class ExtractPoolingOp;
template <typename T>
class ExtractPoolingOp<DeviceType::CPU, T> : public Operation {
public:
explicit ExtractPoolingOp(OpConstructContext *context)
: Operation(context),
modulus_(Operation::GetOptionalArg<int>("modulus", 1)),
include_variance_(
static_cast<bool>(
Operation::GetOptionalArg<int>("include_variance", 0))),
num_log_count_(
Operation::GetOptionalArg<int>("num_log_count", 0)),
variance_floor_(
Operation::GetOptionalArg<float>("variance_floor", 1.0e-10)),
input_indexes_(Operation::GetRepeatedArgs<int>("input_indexes")),
output_indexes_(Operation::GetRepeatedArgs<int>("output_indexes")),
forward_indexes_(Operation::GetRepeatedArgs<int>("forward_indexes")),
counts_(Operation::GetRepeatedArgs<float>("counts")),
input_time_range_(Operation::GetRepeatedArgs<int>("input_time_range")),
output_time_range_(
Operation::GetRepeatedArgs<int>("output_time_range")) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
const std::vector<index_t> &input_shape = input->shape();
const index_t dim_size = input_shape.size();
MACE_CHECK(dim_size >= 2,
"ExtractPooling only supports input dim size >= 2");
MACE_CHECK(modulus_ >= 1,
"ExtractPooling's pooling size should be greater than zero.");
MACE_CHECK(input_time_range_.size() == 2 && output_time_range_.size() == 2
&& counts_.size() * 2 == forward_indexes_.size()
&& counts_.size() == output_indexes_.size());
int in_start_index = input_time_range_[0];
int out_start_index = output_time_range_[0];
int out_end_index = output_time_range_[1];
MACE_CHECK(out_end_index >= out_start_index
&& input_time_range_[1] >= input_time_range_[0],
"end index should be greater than start index.");
const index_t output_chunk = out_end_index - out_start_index + 1;
const index_t input_dim = input_shape[dim_size - 1];
const index_t chunk = input_shape[dim_size - 2];
MACE_CHECK(chunk == input_time_range_[1] - input_time_range_[0] + 1,
"input chunk should be equal to end - start + 1.");
const index_t batch =
std::accumulate(input_shape.begin(), input_shape.end() - 2, 1,
std::multiplies<index_t>());
index_t output_dim = include_variance_ ? 2 * input_dim : input_dim;
output_dim += num_log_count_;
std::vector<index_t> output_shape(input_shape);
output_shape[dim_size - 1] = output_dim;
output_shape[dim_size - 2] = output_chunk;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
const index_t num_input_indexes = input_indexes_.size();
const index_t num_output_indexes = output_indexes_.size();
MACE_CHECK(num_input_indexes > 0 && num_output_indexes > 0,
"ExtractPooling's input_indexes or output_indexes is empty.");
const index_t extract_out_size = PadAlignSize(output_dim * sizeof(float));
ScratchBuffer *scratch = context->device()->scratch_buffer();
scratch->Rewind();
scratch->GrowSize(extract_out_size);
Tensor extract_out(scratch->Scratch(extract_out_size), DT_FLOAT);
extract_out.Reshape({1, output_dim});
extract_out.Clear();
float *extract_out_data = extract_out.mutable_data<float>();
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
for (index_t b = 0; b < batch; ++b) {
for (index_t i = 0; i < num_output_indexes; ++i) {
int start = forward_indexes_[2 * i];
int end = forward_indexes_[2 * i + 1];
float count = counts_[i];
float mean_scale = 1.f / count;
float log_count = std::log(count);
thread_pool.Compute1D([=](index_t start0,
index_t end0,
index_t step0) {
for (index_t n = start0; n < end0; n += step0) {
extract_out_data[n] = log_count;
}
}, 0, num_log_count_, 1);
if (include_variance_) {
thread_pool.Compute1D([=](index_t start0,
index_t end0,
index_t step0) {
for (index_t d = start0; d < end0; d += step0) {
float mean = 0.f;
float variance = 0.f;
for (int t = start; t < end; ++t) {
index_t input_index =
(b * chunk + input_indexes_[t] - in_start_index)
* input_dim;
float x = input_data[input_index + d];
mean += x;
variance += x * x;
}
mean *= mean_scale;
variance *= mean_scale;
extract_out_data[d + num_log_count_] = mean;
variance = variance - mean * mean;
extract_out_data[d + input_dim + num_log_count_] =
variance < variance_floor_ ?
std::sqrt(variance_floor_) :
std::sqrt(variance);
}
}, 0, input_dim, 1);
} else {
thread_pool.Compute1D([=](index_t start0,
index_t end0,
index_t step0) {
for (index_t d = start0; d < end0; d += step0) {
float mean = 0.f;
for (int t = start; t < end; ++t) {
index_t input_index =
(b * chunk + input_indexes_[t] - in_start_index)
* input_dim;
mean += input_data[input_index + d];
}
extract_out_data[d + num_log_count_] = mean * mean_scale;
}
}, 0, input_dim, 1);
}
int output_start = output_indexes_[i] < out_start_index ?
out_start_index : output_indexes_[i];
int output_end = output_indexes_[i] + modulus_;
output_end = output_end > out_end_index ?
out_end_index + 1 :
output_end;
thread_pool.Compute1D([=](index_t start0,
index_t end0,
index_t step0) {
for (index_t idx = start0; idx < end0; idx += step0) {
memcpy(output_data + (b * output_chunk + idx - out_start_index)
* output_dim,
extract_out_data, output_dim * sizeof(float));
}
}, output_start, output_end, 1);
}
}
return MaceStatus::MACE_SUCCESS;
}
private:
int modulus_;
bool include_variance_;
int num_log_count_;
float variance_floor_;
std::vector<int> input_indexes_;
std::vector<int> output_indexes_;
std::vector<int> forward_indexes_;
std::vector<float> counts_;
std::vector<int> input_time_range_;
std::vector<int> output_time_range_;
};
void RegisterExtractPooling(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "ExtractPooling", ExtractPoolingOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D, typename T>
static void ExtractPooling(int iters,
int batch,
int chunk,
int dim,
int input_period,
int modulus) {
mace::testing::StopTiming();
OpsTestNet net;
size_t num_input_indexes = static_cast<size_t>(chunk / input_period);
std::vector<int> input_indexes(num_input_indexes, 0);
for (size_t i = 0; i < num_input_indexes; ++i) {
input_indexes[i] = static_cast<int>(i * input_period);
}
size_t num_output_indexes = static_cast<size_t>(chunk / modulus);
std::vector<int> output_indexes(num_output_indexes, 0);
std::vector<int> forward_indexes(num_output_indexes * 2, 0);
std::vector<float> counts(num_output_indexes, 0.f);
for (size_t i = 0; i < num_output_indexes; ++i) {
output_indexes[i] = static_cast<int>(i * modulus);
forward_indexes[2 * i] = 0;
forward_indexes[2 * i + 1] = static_cast<int>(num_input_indexes - 1);
counts[i] = static_cast<float>(num_input_indexes);
}
// Add input data
net.AddRandomInput<D, T>("Input", {batch, chunk, dim});
OpDefBuilder("ExtractPooling", "ExtractPoolingTest")
.Input("Input")
.AddIntArg("modulus", modulus)
.AddIntArg("include_variance", 1)
.AddIntArg("num_log_counts", 1)
.AddIntsArg("input_indexes", input_indexes)
.AddIntsArg("output_indexes", output_indexes)
.AddIntsArg("forward_indexes", forward_indexes)
.AddFloatsArg("counts", counts)
.AddIntsArg("input_time_range", {0, chunk - 1})
.AddIntsArg("output_time_range", {0, chunk - 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.Run();
}
net.Sync();
}
#define MACE_BM_EXTRACTPOOLING_MACRO(N, C, D, INP, M, TYPE, DEVICE) \
static void MACE_BM_EXTRACTPOOLING_##N##_##C##_##D##_##INP##_##M##_##TYPE##\
_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * D; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
ExtractPooling<DEVICE, TYPE>(iters, N, C, D, INP, M); \
} \
MACE_BENCHMARK(MACE_BM_EXTRACTPOOLING_##N##_##C##_##D##_##INP##_##M##_##TYPE\
##_##DEVICE)
#define MACE_BM_EXTRACTPOOLING(N, C, D, INP, M) \
MACE_BM_EXTRACTPOOLING_MACRO(N, C, D, INP, M, float, CPU);
MACE_BM_EXTRACTPOOLING(8, 40, 512, 2, 4);
MACE_BM_EXTRACTPOOLING(16, 80, 100, 3, 9);
MACE_BM_EXTRACTPOOLING(32, 60, 200, 6, 18);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ExtractPoolingTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestExtractPooling(const std::vector<index_t> &input_shape,
const std::vector<float> &input_value,
const int modulus,
const int num_log_count,
const int include_variance,
const std::vector<int> &input_time_range,
const std::vector<int> &input_indexes,
const std::vector<int> &forward_indexes,
const std::vector<float> &counts,
const std::vector<int> &output_indexes,
const std::vector<int> &output_time_range,
const std::vector<index_t> &output_shape,
const std::vector<float> &output_value) {
// Construct graph
OpsTestNet net;
net.AddInputFromArray<D, float>("Input", input_shape, input_value);
OpDefBuilder("ExtractPooling", "ExtractPoolingTest")
.Input("Input")
.AddIntArg("modulus", modulus)
.AddIntArg("include_variance", include_variance)
.AddIntArg("num_log_count", num_log_count)
.AddIntsArg("input_indexes", input_indexes)
.AddIntsArg("output_indexes", output_indexes)
.AddIntsArg("forward_indexes", forward_indexes)
.AddFloatsArg("counts", counts)
.AddIntsArg("input_time_range", input_time_range)
.AddIntsArg("output_time_range", output_time_range)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
// Check
auto expected = net.CreateTensor<float>(output_shape, output_value);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
} // namespace
TEST_F(ExtractPoolingTest, SimpleCPU) {
TestExtractPooling<DeviceType::CPU, float>(
{3, 20, 3},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60},
9, 0, 0,
{-2, 17},
{0, 3, 6, 9, 12, 15},
{0, 6, 2, 6},
{6, 4},
{0, 9},
{0, 17},
{3, 18, 3},
{29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5});
}
TEST_F(ExtractPoolingTest, SimpleCPUWithVariance) {
TestExtractPooling<DeviceType::CPU, float>(
{3, 20, 3},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60},
9, 1, 1,
{-2, 17},
{0, 3, 6, 9, 12, 15},
{0, 6, 2, 6},
{6, 4},
{0, 9},
{0, 17},
{3, 18, 7},
{1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for Kaldi's BatchNormComponent
// More details about forward computation are here:
// http://kaldi-asr.org/doc/nnet-normalize-component_8cc_source.html#l00320
#include <memory>
#include <string>
#include <vector>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class KaldiBatchNormOp;
template <>
class KaldiBatchNormOp<DeviceType::CPU, float> : public Operation {
public:
explicit KaldiBatchNormOp(OpConstructContext *context)
: Operation(context),
epsilon_(Operation::GetOptionalArg<float>("epsilon",
static_cast<float>(1e-3))),
target_rms_(Operation::GetOptionalArg<float>("target_rms", 1.0f)),
block_dim_(Operation::GetOptionalArg<int>("block_dim", -1)),
test_mode_(static_cast<bool>(
Operation::GetOptionalArg<int>("test_mode", 0))) {}
void CalculateMeanVar(const float *input_data,
index_t length,
index_t stride,
float mean_scale,
float var_scale,
float *mean_data,
float *var_data) {
float mean_value = 0.f;
float var_value = 0.f;
for (index_t i = 0; i < length; ++i) {
float x = input_data[i * stride];
mean_value += x;
var_value += x * x;
}
mean_value = mean_value * mean_scale;
var_value = var_value * mean_scale;
float mean_sqr = mean_value * mean_value;
var_value = (var_value > mean_sqr) ?
var_scale * (var_value - mean_sqr + epsilon_) :
var_scale * epsilon_;
var_data[0] = std::pow(var_value, -0.5f);
mean_data[0] = mean_value;
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(INPUT);
const std::vector<index_t> &input_shape = input->shape();
const index_t rank = input->dim_size();
const index_t dim = input_shape[rank - 1];
if (block_dim_ == -1) block_dim_ = static_cast<int>(dim);
MACE_CHECK(target_rms_ > 0 && dim > 0 && dim % block_dim_ == 0);
MACE_CHECK(rank >= 2, "KaldiBatchNorm's input's rank must >= 2.");
index_t num_rows =
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1,
std::multiplies<index_t>());
const index_t blocks = dim / block_dim_;
if (blocks > 1) num_rows *= blocks;
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
if (test_mode_) {
MACE_CHECK(this->InputSize() == 3, "KaldiBatchNorm should have 3 inputs");
const Tensor *scale = this->Input(SCALE);
const Tensor *offset = this->Input(OFFSET);
MACE_CHECK(scale->dim_size() == 1, "scale must be 1-dimensional. ",
scale->dim_size());
MACE_CHECK(offset->dim_size() == 1, "offset must be 1-dimensional. ",
offset->dim_size());
MACE_CHECK(scale->size() == offset->size()
&& scale->size() == block_dim_);
Tensor::MappingGuard scale_guard(scale);
Tensor::MappingGuard offset_guard(offset);
const float *scale_data = scale->data<float>();
const float *offset_data = offset->data<float>();
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
index_t idx = i * block_dim_ + j;
output_data[idx] = input_data[idx] * scale_data[j] + offset_data[j];
}
}
}, 0, num_rows, 1, 0, block_dim_, 1);
} else {
const index_t buf_size =
PadAlignSize(block_dim_ * sizeof(float));
ScratchBuffer *scratch = context->device()->scratch_buffer();
scratch->Rewind();
scratch->GrowSize(2 * buf_size);
Tensor mean(scratch->Scratch(buf_size), DT_FLOAT);
mean.Reshape({block_dim_});
float *mean_data = mean.mutable_data<float>();
Tensor var(scratch->Scratch(buf_size), DT_FLOAT);
var.Reshape({block_dim_});
float *var_data = var.mutable_data<float>();
float var_scale = 1.0f / (target_rms_ * target_rms_);
float mean_scale = 1.0f / num_rows;
thread_pool.Compute1D([=](index_t start0, index_t end0, index_t step0) {
for (index_t i = start0; i < end0; i += step0) {
CalculateMeanVar(input_data + i,
num_rows,
block_dim_,
mean_scale,
var_scale,
mean_data + i,
var_data + i);
}
}, 0, block_dim_, 1);
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
index_t idx = i * block_dim_ + j;
output_data[idx] = (input_data[idx] - mean_data[j]) * var_data[j];
}
}
}, 0, num_rows, 1, 0, block_dim_, 1);
}
return MaceStatus::MACE_SUCCESS;
}
private:
const float epsilon_;
const float target_rms_;
int block_dim_;
const bool test_mode_;
protected:
MACE_OP_INPUT_TAGS(INPUT, SCALE, OFFSET);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
void RegisterKaldiBatchNorm(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "KaldiBatchNorm", KaldiBatchNormOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void KaldiBatchNorm(
int iters, int batch, int chunk, int dim, int block_dim) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, T>("Input", {batch, chunk, dim});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, T>("Scale", {block_dim}, true);
net.AddRandomInput<D, T>("Offset", {block_dim}, true);
OpDefBuilder("KaldiBatchNorm", "KaldiBatchNormBM")
.Input("Input")
.Input("Scale")
.Input("Offset")
.AddIntArg("block_dim", block_dim)
.AddIntArg("test_mode", 1)
.Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_KALDI_BATCH_NORM_MACRO(N, C, D, BD, TYPE, DEVICE) \
static void MACE_BM_KALDI_BATCH_NORM_##N##_##C##_##D##_##BD##_##TYPE\
##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * D; \
mace::testing::MacsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
KaldiBatchNorm<DEVICE, TYPE>(iters, N, C, D, BD); \
} \
MACE_BENCHMARK(MACE_BM_KALDI_BATCH_NORM_##N##_##C##_##D##_##BD##_##TYPE\
##_##DEVICE)
#define MACE_BM_KALDI_BATCH_NORM(N, C, D, BD) \
MACE_BM_KALDI_BATCH_NORM_MACRO(N, C, D, BD, float, CPU);
MACE_BM_KALDI_BATCH_NORM(1, 1, 512, 512);
MACE_BM_KALDI_BATCH_NORM(1, 3, 128, 128);
MACE_BM_KALDI_BATCH_NORM(1, 3, 512, 128);
MACE_BM_KALDI_BATCH_NORM(1, 32, 112, 112);
MACE_BM_KALDI_BATCH_NORM(1, 64, 256, 256);
MACE_BM_KALDI_BATCH_NORM(1, 64, 512, 256);
MACE_BM_KALDI_BATCH_NORM(1, 128, 56, 56);
MACE_BM_KALDI_BATCH_NORM(1, 128, 256, 256);
MACE_BM_KALDI_BATCH_NORM(1, 256, 14, 14);
MACE_BM_KALDI_BATCH_NORM(1, 512, 14, 14);
MACE_BM_KALDI_BATCH_NORM(1, 1024, 7, 7);
MACE_BM_KALDI_BATCH_NORM(32, 1, 256, 128);
MACE_BM_KALDI_BATCH_NORM(32, 3, 256, 256);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class KaldiBatchNormOpTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void Simple(const std::vector<index_t> &input_shape,
const std::vector<float> &input_value,
const int block_dim,
const int dim,
const std::vector<float> &scale,
const std::vector<float> &offset,
const std::vector<index_t> &output_shape,
const std::vector<float> &output_value) {
OpsTestNet net;
int scale_dim = block_dim;
if (scale_dim == -1) scale_dim = dim;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape,
input_value);
net.AddInputFromArray<D, float>("Scale", {scale_dim}, scale, true);
net.AddInputFromArray<D, float>("Offset", {scale_dim}, offset, true);
if (D == DeviceType::CPU) {
OpDefBuilder("KaldiBatchNorm", "KaldiBatchNormOpTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.AddIntArg("block_dim", block_dim)
.AddIntArg("test_mode", 1)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else if (D == DeviceType::GPU) {
MACE_NOT_IMPLEMENTED;
}
// Check
auto expected = net.CreateTensor<float>(output_shape, output_value);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-4);
}
template <DeviceType D>
void SimpleNotTestMode(const std::vector<index_t> &input_shape,
const std::vector<float> &input_value,
const int block_dim,
const std::vector<index_t> &output_shape,
const std::vector<float> &output_value) {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape,
input_value);
if (D == DeviceType::CPU) {
OpDefBuilder("KaldiBatchNorm", "KaldiBatchNormOpTest")
.Input("Input")
.AddIntArg("block_dim", block_dim)
.AddIntArg("test_mode", 0)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else if (D == DeviceType::GPU) {
MACE_NOT_IMPLEMENTED;
}
// Check
auto expected = net.CreateTensor<float>(output_shape, output_value);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-4);
}
} // namespace
TEST_F(KaldiBatchNormOpTest, SimpleTestModeCPUOneBlock) {
Simple<DeviceType::CPU>(
{1, 6, 2},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15},
-1, 2,
{4.f, 6.f},
{2.f, 1.f},
{1, 6, 2},
{22, 31, 30, 43, 38, 55, 46, 67, 54, 79, 62, 91}); }
TEST_F(KaldiBatchNormOpTest, SimpleTestModeCPUTwoBlock) {
Simple<DeviceType::CPU>(
{1, 6, 4},
{5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9,
11, 11, 11, 11, 13, 13, 13, 13, 15, 15, 15, 15},
2, 4,
{4.f, 6.f},
{2.f, 1.f},
{1, 6, 4},
{22, 31, 22, 31, 30, 43, 30, 43, 38, 55, 38, 55,
46, 67, 46, 67, 54, 79, 54, 79, 62, 91, 62, 91});
}
TEST_F(KaldiBatchNormOpTest, SimpleNotTestModeCPUTwoBlock) {
SimpleNotTestMode<DeviceType::CPU>(
{1, 6, 4},
{5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9,
11, 11, 11, 11, 13, 13, 13, 13, 15, 15, 15, 15},
2,
{1, 6, 4},
{-1.46379, -1.46379, -1.46379, -1.46379,
-0.8783, -0.8783, -0.8783, -0.8783,
-0.29276, -0.29276, -0.29276, -0.29276,
0.29276, 0.29276, 0.29276, 0.29276,
0.8783, 0.8783, 0.8783, 0.8783,
1.46379, 1.46379, 1.46379, 1.46379});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -37,11 +37,14 @@ extern void RegisterDepthwiseDeconv2d(OpRegistryBase *op_registry);
extern void RegisterDynamicLSTM(OpRegistryBase *op_registry);
extern void RegisterEltwise(OpRegistryBase *op_registry);
extern void RegisterExpandDims(OpRegistryBase *op_registry);
extern void RegisterExtractPooling(OpRegistryBase *op_registry);
extern void RegisterFill(OpRegistryBase *op_registry);
extern void RegisterFullyConnected(OpRegistryBase *op_registry);
extern void RegisterGather(OpRegistryBase *op_registry);
extern void RegisterIdentity(OpRegistryBase *op_registry);
extern void RegisterDelay(OpRegistryBase *op_registry);
extern void RegisterInferConv2dShape(OpRegistryBase *op_registry);
extern void RegisterKaldiBatchNorm(OpRegistryBase *op_registry);
extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry);
extern void RegisterLSTMNonlinear(OpRegistryBase *op_registry);
extern void RegisterMatMul(OpRegistryBase *op_registry);
......@@ -107,11 +110,14 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterDynamicLSTM(this);
ops::RegisterEltwise(this);
ops::RegisterExpandDims(this);
ops::RegisterExtractPooling(this);
ops::RegisterFill(this);
ops::RegisterFullyConnected(this);
ops::RegisterGather(this);
ops::RegisterIdentity(this);
ops::RegisterDelay(this);
ops::RegisterInferConv2dShape(this);
ops::RegisterKaldiBatchNorm(this);
ops::RegisterLocalResponseNorm(this);
ops::RegisterLSTMNonlinear(this);
ops::RegisterMatMul(this);
......
......@@ -42,7 +42,7 @@ class PadContextOp<DeviceType::CPU, T> : public Operation {
index_t rank = input->dim_size();
MACE_CHECK(rank >= 2, "input's rank should >= 2.");
MACE_CHECK(left_context_ > 0 && right_context_ > 0,
MACE_CHECK(left_context_ >= 0 && right_context_ >= 0,
"left context and right context should be greater than zero");
const std::vector<index_t> &input_shape = input->shape();
const index_t batch =
......
......@@ -106,19 +106,6 @@ void SimpleMean12Test() {
10, 11, 12, 13}, ReduceType::MEAN);
}
// template <DeviceType D>
// void SimpleSum12Test() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {1, 2},
// {2, 1, 1, 4},
// {60, 66, 72, 78,
// 60, 66, 72, 78}, ReduceType::SUM);
//}
template <DeviceType D>
void SimpleMin12Test() {
Simple<D>({2, 2, 3, 4},
......@@ -145,20 +132,6 @@ void SimpleMax12Test() {
20, 21, 22, 23}, ReduceType::MAX);
}
// template <DeviceType D>
// void SimpleSumSqr12Test() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {1, 2},
// {2, 1, 1, 4},
// {880, 1006, 1144, 1294,
// 880, 1006, 1144, 1294}, ReduceType::SUM_SQR);
//}
template <DeviceType D>
void SimpleMean1Axis() {
Simple<D>({2, 2, 3, 4},
......@@ -170,102 +143,8 @@ void SimpleMean1Axis() {
{2, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {-3},
// {1, 1, 3, 4},
// {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {4, 5, 6, 7, 16, 17, 18, 19}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {1.5, 5.5, 9.5, 13.5, 17.5, 21.5}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {-2},
// {1, 3, 1, 3},
// {3, 4, 5, 12, 13, 14, 21, 22, 23}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {3},
// {1, 3, 3, 1},
// {1, 4, 7, 10, 13, 16, 19, 22, 25}, ReduceType::MEAN);
}
// template <DeviceType D>
// void SimpleSum1Axis() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23,
// 0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {1},
// {2, 1, 3, 4},
// {12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,
// 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34});
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {12, 15, 18, 21, 48, 51, 54, 57}, ReduceType::SUM);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {6, 22, 38, 54, 70, 86}, ReduceType::SUM);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {27, 30, 33, 36, 39, 42, 45, 48, 51}, ReduceType::SUM);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {3},
// {1, 3, 3, 1},
// {3, 12, 21, 30, 39, 48, 57, 66, 75}, ReduceType::SUM);
//}
template <DeviceType D>
void SimpleMin1Axis() {
Simple<D>({2, 2, 3, 4},
......@@ -285,33 +164,6 @@ void SimpleMin1Axis() {
{2, 1, 3, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, ReduceType::MIN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {0, 1, 2, 3, 12, 13, 14, 15}, ReduceType::MIN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {0, 4, 8, 12, 16, 20}, ReduceType::MIN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8}, ReduceType::MIN);
}
template <DeviceType D>
......@@ -337,53 +189,8 @@ void SimpleMax1Axis() {
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23}, ReduceType::MAX);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {8, 9, 10, 11, 20, 21, 22, 23}, ReduceType::MAX);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {3, 7, 11, 15, 19, 23}, ReduceType::MAX);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {18, 19, 20, 21, 22, 23, 24, 25, 26}, ReduceType::MAX);
}
// template <DeviceType D>
// void SimpleSumSqr1Axis() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {1},
// {2, 1, 3, 4},
// {144, 170, 200, 234,
// 272, 314, 360, 410,
// 464, 522, 584, 650,
// 144, 170, 200, 234,
// 272, 314, 360, 410,
// 464, 522, 584, 650}, ReduceType::SUM_SQR);
//}
template <DeviceType D>
void Simple2Axis() {
Simple<D>({1, 2, 3, 4},
......@@ -396,16 +203,6 @@ void Simple2Axis() {
{0, 1},
{1, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple3D<D>({2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 1},
// {1, 1, 4},
// {10, 11, 12, 13}, ReduceType::MEAN);
Simple3D<D>({2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
......@@ -426,37 +223,6 @@ void Simple2Axis() {
{0, 2},
{1, 2, 1, 4},
{4, 5, 6, 7, 16, 17, 18, 19}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {1, 3},
// {1, 1, 3, 1},
// {7.5, 11.5, 15.5}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1, 2},
// {1, 1, 1, 3},
// {12, 13, 14}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 1},
// {1, 1, 3, 3},
// {9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {2, 3},
// {1, 3, 1, 1},
// {4, 13, 22}, ReduceType::MEAN);
}
template <DeviceType D>
......@@ -471,64 +237,6 @@ void Simple3Axis() {
{1, 2, 3},
{1, 1, 1, 1},
{11.5}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 2, 3},
// {1, 2, 1, 1},
// {5.5, 17.5}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 1, 3},
// {1, 1, 3, 1},
// {7.5, 11.5, 15.5}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 1, 2},
// {1, 1, 1, 4},
// {10, 11, 12, 13}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1, 2, 3},
// {1, 1, 1, 1},
// {13}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 2, 3},
// {1, 3, 1, 1},
// {4, 13, 22}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 1, 3},
// {1, 1, 3, 1},
// {10, 13, 16}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 1, 2},
// {1, 1, 1, 3},
// {12, 13, 14}, ReduceType::MEAN);
}
} // namespace
......@@ -622,25 +330,17 @@ void RandomTest(const std::vector<index_t> &input_shape,
TEST_F(ReduceOpTest, GPURandomFloat) {
RandomTest<DeviceType::GPU, float>({4, 64, 64, 3}, {1, 2});
// RandomTest<DeviceType::GPU, float>({2, 64, 64, 4}, {1, 2});
RandomTest<DeviceType::GPU, float>({8, 128, 128, 64}, {1, 2});
// RandomTest<DeviceType::GPU, float>({1, 640, 480, 64}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 480, 640, 32}, {1, 2});
// RandomTest<DeviceType::GPU, float>({1, 512, 512, 16}, {1, 2});
RandomTest<DeviceType::GPU, float>({8, 117, 87, 33}, {1, 2});
// RandomTest<DeviceType::GPU, float>({1, 619, 450, 61}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 511, 561, 11}, {1, 2});
}
TEST_F(ReduceOpTest, GPURandomHalf) {
RandomTest<DeviceType::GPU, half>({4, 64, 64, 3}, {1, 2});
// RandomTest<DeviceType::GPU, half>({2, 64, 64, 4}, {1, 2});
RandomTest<DeviceType::GPU, half>({8, 128, 128, 64}, {1, 2});
// RandomTest<DeviceType::GPU, half>({1, 640, 480, 64}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 480, 640, 32}, {1, 2});
// RandomTest<DeviceType::GPU, half>({1, 512, 512, 16}, {1, 2});
RandomTest<DeviceType::GPU, half>({8, 117, 87, 33}, {1, 2});
// RandomTest<DeviceType::GPU, half>({1, 619, 450, 61}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 511, 561, 11}, {1, 2});
}
......
......@@ -71,7 +71,8 @@ class SpliceOp<DeviceType::CPU, T> : public Operation {
const index_t out_chunk = chunk - (right_context - left_context);
MACE_CHECK(input_dim > const_dim_,
"input dim should be greater than const dim.");
"input dim:", input_dim,
"should be greater than const dim:", const_dim_);
const index_t output_dim = dim * num_splice + const_dim_;
const index_t output_stride = out_chunk * output_dim;
......@@ -103,7 +104,7 @@ class SpliceOp<DeviceType::CPU, T> : public Operation {
const index_t input_offset = dim;
for (int b = 0; b < batch; ++b) {
for (index_t i = 0; i < out_chunk; ++i) {
T *output_base = output_data + + b * output_stride + i * output_dim;
T *output_base = output_data + b * output_stride + i * output_dim;
const T *input_base = input_data + b * input_stride + i * input_dim;
memcpy(output_base + output_offset,
input_base + input_offset,
......
......@@ -80,18 +80,22 @@ class SumGroupOp<DeviceType::CPU, T> : public Operation {
MACE_CHECK(cur_index <= input_dim)
<< "size value over-ranged:" << cur_index << "<=" << input_dim;
}
for (index_t i = 0; i < bh; ++i) {
for (index_t j = 0; j < output_dim; ++j) {
int start_col = sum_indexes[j].first;
int end_col = sum_indexes[j].second;
T sum = 0;
for (int src_col = start_col; src_col < end_col; ++src_col) {
sum += input_data[i * input_dim + src_col];
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
int start_col = sum_indexes[j].first;
int end_col = sum_indexes[j].second;
T sum = 0;
for (int src_col = start_col; src_col < end_col; ++src_col) {
sum += input_data[i * input_dim + src_col];
}
output_data[i * output_dim + j] = sum;
}
output_data[i * output_dim + j] = sum;
}
}
}, 0, bh, 1, 0, output_dim, 1);
return MaceStatus::MACE_SUCCESS;
}
......
......@@ -35,7 +35,11 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
public:
explicit TargetRMSNormOp(OpConstructContext *context)
: Operation(context),
target_rms_(Operation::GetOptionalArg<float>("target_rms", 1.0)) {}
target_rms_(Operation::GetOptionalArg<float>("target_rms", 1.0)),
add_log_stddev_(
static_cast<bool>(
Operation::GetOptionalArg<int>("add_log_stddev", 0))),
block_dim_(Operation::GetOptionalArg<int>("block_dim", 0)) {}
// Calculate the square sum of an array
float SquareSum(const float *data, const index_t data_len) {
......@@ -67,6 +71,24 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
return result;
}
void NormalizePerRow(const float *data,
const index_t data_len,
float d_scale,
bool add_log_stddev,
float *out_data) {
float scale = SquareSum(data, data_len);
scale = scale / d_scale;
scale = scale < 1.0e-6f ? 1.0e-6f : scale;
scale = static_cast<float>(1.0 / std::sqrt(scale));
for (index_t j = 0; j < data_len; ++j) {
out_data[j] = data[j] * scale;
}
if (add_log_stddev) {
out_data[data_len] = std::log(target_rms_) - std::log(scale);
}
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
......@@ -75,15 +97,18 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
const index_t dim_size = input->dim_size();
MACE_CHECK(dim_size >= 1,
"TargetRMSNorm's input dim size should be >= 1.");
const index_t dim = input_shape[dim_size -1];
MACE_CHECK(dim > 0 && target_rms_ > 0,
const index_t input_dim = input_shape[dim_size -1];
MACE_CHECK(input_dim > 0 && target_rms_ > 0,
"Both input dim and target rms should be greater than zero.");
const index_t bh =
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1,
std::multiplies<index_t>());
const float d_scale = dim * target_rms_ * target_rms_;
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
if (block_dim_ == 0) block_dim_ = static_cast<int>(input_dim);
const index_t output_dim = add_log_stddev_ ?
input_dim + (input_dim / block_dim_) : input_dim;
std::vector<index_t> output_shape = input->shape();
output_shape[dim_size - 1] = output_dim;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
......@@ -91,19 +116,37 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
for (index_t i = 0; i < bh; ++i) {
float scale = SquareSum(input_data + i * dim, dim);
scale = static_cast<float>(1.0 / std::sqrt(scale / d_scale));
for (index_t j = 0; j < dim; ++j) {
output_data[i * dim + j] = input_data[i * dim + j] * scale;
}
index_t num_rows = bh;
index_t output_block_dim = add_log_stddev_ ? block_dim_ + 1 : block_dim_;
if (block_dim_ != input_dim) {
index_t num_blocks = input_dim / block_dim_;
num_rows *= num_blocks;
}
const float d_scale = block_dim_ * target_rms_ * target_rms_;
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute1D([=](index_t start0, index_t end0, index_t step0) {
for (index_t i = start0; i < end0; i += step0) {
const float *input_ptr = input_data + i * block_dim_;
float *out_ptr = output_data + i * output_block_dim;
NormalizePerRow(input_ptr,
block_dim_,
d_scale,
add_log_stddev_,
out_ptr);
}
}, 0, num_rows, 1);
return MaceStatus::MACE_SUCCESS;
}
private:
float target_rms_;
bool add_log_stddev_;
int block_dim_;
};
void RegisterTargetRMSNorm(OpRegistryBase *op_registry) {
......
......@@ -112,17 +112,20 @@ MaceSupportedOps = [
'Conv2D',
'Crop',
'Deconv2D',
'Delay',
'DepthToSpace',
'DepthwiseConv2d',
'DepthwiseDeconv2d',
'Dequantize',
'Eltwise',
'ExpandDims',
'ExtractPooling',
'Fill',
'FullyConnected',
'Gather',
'Identity',
'InferConv2dShape',
'KaldiBatchNorm',
'LocalResponseNorm',
'LSTMCell',
'LstmNonlinear',
......
......@@ -41,6 +41,15 @@ from numbers import Number
IS_PYTHON3 = sys.version_info > (3,)
class AttributeType(Enum):
INT = 100
FLOAT = 101
INTS = 102
FLOATS = 103
BOOL = 104
OnnxSupportedOps = [
'Abs',
# 'Acos',
......@@ -57,6 +66,7 @@ OnnxSupportedOps = [
# 'Atanh',
'AveragePool',
'BatchNormalization',
'BatchNorm',
'Cast',
# 'Ceil',
# 'Clip',
......@@ -72,11 +82,12 @@ OnnxSupportedOps = [
'DimRange',
'Div',
'Dropout',
'DynamicLstmCell',
'DynamicLSTM',
'Elu',
'Equal',
# 'Exp',
# 'Expand',
'ExtractPooling',
# 'EyeLike',
# 'Flatten',
# 'Floor',
......@@ -91,7 +102,7 @@ OnnxSupportedOps = [
# 'Hardmax',
'Identity',
# 'If',
# 'IfDefined',
'IfDefined',
'ImageScaler',
# 'InstanceNormalization',
# 'LRN',
......@@ -318,6 +329,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.ArgMin.name: self.convert_argmax,
OnnxOpType.AveragePool.name: self.convert_pooling,
OnnxOpType.BatchNormalization.name: self.convert_fused_batchnorm,
OnnxOpType.BatchNorm.name: self.convert_fused_batchnorm,
OnnxOpType.Cast.name: self.convert_cast,
OnnxOpType.Concat.name: self.convert_concat,
OnnxOpType.Conv.name: self.convert_conv2d,
......@@ -327,16 +339,18 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.DimRange.name: self.convert_dim_range,
OnnxOpType.Div.name: self.convert_eltwise,
OnnxOpType.Equal.name: self.convert_eltwise,
OnnxOpType.ExtractPooling.name: self.convert_extract_pooling,
OnnxOpType.Gather.name: self.convert_gather,
OnnxOpType.Gemm.name: self.convert_gemm,
OnnxOpType.GlobalAveragePool.name: self.convert_reduce,
OnnxOpType.GlobalMaxPool.name: self.convert_reduce,
OnnxOpType.Identity.name: self.convert_identity,
OnnxOpType.IfDefined.name: self.convert_ifdefined,
OnnxOpType.ImageScaler.name: self.convert_imagescaler,
OnnxOpType.LeakyRelu.name: self.convert_activation,
OnnxOpType.LogSoftmax.name: self.convert_softmax,
OnnxOpType.LstmNonlinear.name: self.convert_lstm_nonlinear,
OnnxOpType.DynamicLstmCell.name: self.convert_dynamic_lstm,
OnnxOpType.DynamicLSTM.name: self.convert_dynamic_lstm,
OnnxOpType.Max.name: self.convert_eltwise,
OnnxOpType.MaxPool.name: self.convert_pooling,
OnnxOpType.MatMul.name: self.convert_matmul,
......@@ -378,6 +392,8 @@ class OnnxConverter(base_converter.ConverterInterface):
ir_version = onnx_model.ir_version
opset_imp = onnx_model.opset_import
self._isKaldi = False
polish_available = True
print("onnx model IR version: ", ir_version)
for imp in opset_imp:
......@@ -387,6 +403,7 @@ class OnnxConverter(base_converter.ConverterInterface):
if 'kaldi2onnx' in domain:
polish_available = False
self._data_format = DataFormat.DF_NONE
self._isKaldi = True
if polish_available:
onnx_model = onnx.utils.polish_model(onnx_model)
......@@ -643,10 +660,10 @@ class OnnxConverter(base_converter.ConverterInterface):
mace_check(axis_value == 1 or axis_value == -3,
"only support concat at channel dimension")
elif node.op_type == OnnxOpType.Append.name:
axis_value = 1
axis_value = -1
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 4 + axis_value if axis_value < 0 else axis_value
axis_arg.i = axis_value
def convert_conv2d(self, node):
op = self.convert_general_op(node)
......@@ -772,15 +789,15 @@ class OnnxConverter(base_converter.ConverterInterface):
op = self.convert_general_op(node)
op.type = MaceOp.DynamicLSTM.name
if 'delay_a' in node.attrs:
prev_out_delay = node.attrs['delay_a']
if 'prev_out_delay' in node.attrs:
prev_out_delay = node.attrs['prev_out_delay']
mace_check(prev_out_delay < 0,
"dynamic's prev_out_delay should <= 0.")
prev_out_delay_arg = op.arg.add()
prev_out_delay_arg.name = 'prev_out_delay'
prev_out_delay_arg.i = prev_out_delay
if 'delay_b' in node.attrs:
prev_cell_delay = node.attrs['delay_b']
if 'prev_cell_delay' in node.attrs:
prev_cell_delay = node.attrs['prev_cell_delay']
mace_check(prev_cell_delay < 0,
"dynamic's prev_cell_delay should < 0.")
prev_cell_delay_arg = op.arg.add()
......@@ -788,20 +805,20 @@ class OnnxConverter(base_converter.ConverterInterface):
prev_cell_delay_arg.i = prev_cell_delay
if 'prev_out_offset' in node.attrs:
prev_out_offset = node.attrs['prev_out_offset']
mace_check(pre_out_offset >= 0,
mace_check(prev_out_offset >= 0,
"dynamic's prev_out_offset should >= 0.")
prev_out_offset_arg = op.arg.add()
prev_out_offset_arg.name = 'prev_out_offset'
prev_out_offset_arg.i = prev_out_offset
if 'prev_a_dim' in node.attrs:
prev_out_dim = node.attrs['prev_a_dim']
if 'prev_out_dim' in node.attrs:
prev_out_dim = node.attrs['prev_out_dim']
mace_check(prev_out_dim > 0,
"dynamic's prev_out_dim should > 0.")
prev_out_dim_arg = op.arg.add()
prev_out_dim_arg.name = 'prev_out_dim'
prev_out_dim_arg.i = prev_out_dim
if 'prev_b_dim' in node.attrs:
prev_cell_dim = node.attrs['prev_b_dim']
if 'prev_cell_dim' in node.attrs:
prev_cell_dim = node.attrs['prev_cell_dim']
mace_check(prev_cell_dim > 0,
"dynamic's prev_cell_dim should > 0.")
prev_cell_dim_arg = op.arg.add()
......@@ -844,11 +861,154 @@ class OnnxConverter(base_converter.ConverterInterface):
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = value
@staticmethod
def copy_node_attr(op, node, attr_name, dtype=AttributeType.INT,
default=None):
if attr_name in node.attrs or default is not None:
if attr_name in node.attrs:
value = node.attrs[attr_name]
else:
value = default
new_arg = op.arg.add()
new_arg.name = attr_name
if dtype == AttributeType.INT:
new_arg.i = int(value)
elif dtype == AttributeType.FLOAT:
new_arg.f = float(value)
elif dtype == AttributeType.INTS:
new_arg.ints.extend(value)
elif dtype == AttributeType.FLOATS:
new_arg.floats.extend(value)
return value
else:
return default
def convert_extract_pooling(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.ExtractPooling.name
self.copy_node_attr(op, node, 'include_variance', AttributeType.INT)
self.copy_node_attr(op, node, 'num_log_count', AttributeType.INT)
self.copy_node_attr(op, node, 'variance_floor', AttributeType.FLOAT)
self.copy_node_attr(op, node, 'input_time_range', AttributeType.INTS)
self.copy_node_attr(op, node, 'input_indexes', AttributeType.INTS)
if 'output_time_range' in node.attrs:
output_time_range = node.attrs['output_time_range']
mace_check(len(output_time_range) == 2,
"output time range should have two values.")
out_start_index = output_time_range[0]
out_end_index = output_time_range[1]
else:
mace_check('start_index' in node.attrs and
'end_index' in node.attrs,
"'start_index' and 'end_index'"
" are required in ExtractPooling.")
out_start_index = node.attrs['start_index']
out_end_index = node.attrs['end_index']
output_time_range = [out_start_index, out_end_index]
output_time_range_arg = op.arg.add()
output_time_range_arg.name = 'output_time_range'
output_time_range_arg.ints.extend(output_time_range)
mace_check('modulus' in node.attrs,
"'modulus' is required in ExtractPooling.")
mace_check('output_indexes' in node.attrs,
"'output_indexes' is required in ExtractPooling.")
mace_check('counts' in node.attrs,
"'counts' is required in ExtractPooling.")
mace_check('forward_indexes' in node.attrs,
"'forward_indexes' is required in ExtractPooling.")
modulus = node.attrs['modulus']
output_indexes = node.attrs['output_indexes']
counts = node.attrs['counts']
forward_indexes = node.attrs['forward_indexes']
mace_check(len(counts) == len(output_indexes) and
len(forward_indexes) == 2 * len(output_indexes),
"output_indexes length:%s "
"counts length:%s "
"forward_indexes length:%s"
% (len(output_indexes), len(counts), len(forward_indexes)))
new_output_indexes = []
new_forward_indexes = []
new_counts = []
for i in range(len(output_indexes)):
if output_indexes[i] + modulus > out_start_index and\
output_indexes[i] <= out_end_index:
new_output_indexes.append(output_indexes[i])
new_counts.append(counts[i])
new_forward_indexes.append(forward_indexes[2 * i])
new_forward_indexes.append(forward_indexes[2 * i + 1])
modulus_arg = op.arg.add()
modulus_arg.name = 'modulus'
modulus_arg.i = modulus
counts_arg = op.arg.add()
counts_arg.name = 'counts'
counts_arg.floats.extend(new_counts)
forward_indexes_arg = op.arg.add()
forward_indexes_arg.name = 'forward_indexes'
forward_indexes_arg.ints.extend(new_forward_indexes)
output_indexes_arg = op.arg.add()
output_indexes_arg.name = 'output_indexes'
output_indexes_arg.ints.extend(new_output_indexes)
def convert_flatten(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Reshape.name
def convert_kaldi_batchnorm(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.KaldiBatchNorm.name
dim = self.copy_node_attr(op, node,
'dim', AttributeType.INT, -1)
block_dim = self.copy_node_attr(op, node,
'block_dim',
AttributeType.INT, -1)
epsilon = self.copy_node_attr(op, node,
'epsilon',
AttributeType.FLOAT, 1e-3)
target_rms = self.copy_node_attr(op, node,
'target_rms',
AttributeType.FLOAT, 1.0)
test_mode = self.copy_node_attr(op, node,
'test_mode',
AttributeType.INT, 0)
mace_check(block_dim > 0 and
dim % block_dim == 0 and
epsilon > 0 and
target_rms > 0, "attributes invalid.")
if test_mode > 0:
mace_check(len(node.inputs) == 3,
"Kaldi's BatchNorm should have 3 inputs.")
stats_mean = np.array(self._consts[node.inputs[1]].float_data)
stats_var = np.array(self._consts[node.inputs[2]].float_data)
offset_value = -1.0 * stats_mean
scale_value = stats_var
scale_value[scale_value < 0] = 0
scale_value = np.power(scale_value + epsilon, -0.5) * target_rms
offset_value = offset_value * scale_value
scale_name = node.name + '_scale'
offset_name = node.name + '_offset'
self.add_tensor(scale_name, scale_value.shape,
mace_pb2.DT_FLOAT, scale_value)
self.add_tensor(offset_name, offset_value.shape,
mace_pb2.DT_FLOAT, offset_value)
del op.input[1:]
op.input.extend([scale_name, offset_name])
del op.output[1:]
del op.output_shape[1:]
def convert_fused_batchnorm(self, node):
if self._isKaldi:
self.convert_kaldi_batchnorm(node)
return
op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name
......@@ -946,6 +1106,21 @@ class OnnxConverter(base_converter.ConverterInterface):
op = self.convert_general_op(node)
op.type = MaceOp.Identity.name
def convert_ifdefined(self, node):
op = self.convert_general_op(node)
if 'offset' in node.attrs:
offset = node.attrs['offset']
else:
offset = 0
mace_check(offset <= 0, "IfDefined's offset should be <= 0.")
if offset == 0:
op.type = MaceOp.Identity.name
else:
op.type = MaceOp.Delay.name
offset_arg = op.arg.add()
offset_arg.name = 'offset'
offset_arg.i = node.attrs['offset']
def convert_imagescaler(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name
......@@ -1100,7 +1275,6 @@ class OnnxConverter(base_converter.ConverterInterface):
def convert_softmax(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Softmax.name
# TODO: add logsoftmax in softmax op
if node.op_type == OnnxOpType.LogSoftmax.name:
use_log_arg = op.arg.add()
use_log_arg.name = 'use_log'
......@@ -1164,11 +1338,11 @@ class OnnxConverter(base_converter.ConverterInterface):
op = self.convert_general_op(node)
op.type = MaceOp.TargetRMSNorm.name
if 'target_rms' in node.attrs:
value = node.attrs['target_rms']
target_rms_arg = op.arg.add()
target_rms_arg.name = 'target_rms'
target_rms_arg.f = value
self.copy_node_attr(op, node, 'target_rms', AttributeType.FLOAT)
self.copy_node_attr(op, node, 'add_log_stddev', AttributeType.INT,
default=0)
self.copy_node_attr(op, node, 'block_dim', AttributeType.INT,
default=0)
def convert_transpose(self, node):
op = self.convert_general_op(node)
......
......@@ -1130,7 +1130,7 @@ class Transformer(base_converter.ConverterInterface):
rhs = op.input[1]
if rhs in self._consts and len(self._consts[rhs].dims) == 2:
arg = ConverterUtil.get_arg(op, MaceKeyword.mace_transpose_b_str) # noqa
six.print_("Transpose matmul weight %s" % rhs)
# six.print_("Transpose matmul weight %s" % rhs)
if arg is None:
arg = op.arg.add()
arg.name = MaceKeyword.mace_transpose_b_str
......@@ -1143,7 +1143,8 @@ class Transformer(base_converter.ConverterInterface):
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
arg.i = 1
six.print_('transpose matmul weight')
six.print_('Transpose matmul weight to shape:',
filter.dims)
def transpose_filters(self):
net = self._model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册