提交 a7afda0e 编写于 作者: L liutuo

add kaldi op IfDefined

add extract pooling and kaldi batchnorm

add time index info in extract pooling

add benchmar& test for kaldi components

add batch support for kaldi

change ifdefined to delay
上级 77df54f2
...@@ -86,7 +86,8 @@ class ConcatOp<DeviceType::CPU, T> : public ConcatOpBase { ...@@ -86,7 +86,8 @@ class ConcatOp<DeviceType::CPU, T> : public ConcatOpBase {
continue; continue;
} }
MACE_CHECK(input->dim(j) == input0->dim(j), MACE_CHECK(input->dim(j) == input0->dim(j),
"Dimensions of inputs should equal except axis."); "Dimensions of inputs should equal except axis: ",
input->dim(j), "!=", input0->dim(j));
} }
outer_sizes[i] = input->size() / inner_size; outer_sizes[i] = input->size() / inner_size;
output_shape[axis] += input->dim(axis); output_shape[axis] += input->dim(axis);
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for IfDefined descriptor in Kaldi.
// It defines time offset.
// If time index <= offset, using zeros as output.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class DelayOp;
template <typename T>
class DelayOp<DeviceType::CPU, T> : public Operation {
public:
explicit DelayOp(OpConstructContext *context)
: Operation(context),
offset_(Operation::GetOptionalArg<int>("offset", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
MACE_CHECK(offset_ < 0, "offset param should be negative.");
index_t rank = input->dim_size();
MACE_CHECK(rank >= 2, "input's rank should >= 2.");
const std::vector<index_t> &input_shape = input->shape();
const index_t batch =
std::accumulate(input_shape.begin(), input_shape.end() - 2, 1,
std::multiplies<index_t>());
const index_t chunk = input_shape[rank - 2];
const index_t dim = input_shape[rank - 1];
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
output->Clear();
if (chunk <= -offset_)
return MaceStatus::MACE_SUCCESS;
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
memcpy(output_data + (i * chunk + j - offset_) * dim,
input_data + (i * chunk + j) * dim,
dim * sizeof(T));
}
}
}, 0, batch, 1, 0, chunk + offset_, 1);
return MaceStatus::MACE_SUCCESS;
}
private:
int offset_;
};
void RegisterDelay(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Delay", DelayOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D, typename T>
static void Delay(int iters,
int batch,
int chunk,
int dim,
int offset) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, chunk, dim});
OpDefBuilder("Delay", "DelayTest")
.Input("Input")
.Output("Output")
.AddIntArg("offset", -offset)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.Run();
}
net.Sync();
}
#define MACE_BM_DELAY_MACRO(N, C, D, OFFSET, TYPE, DEVICE) \
static void MACE_BM_DELAY_##N##_##C##_##D##_##OFFSET##_##TYPE##_##DEVICE(\
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * D; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Delay<DEVICE, TYPE>(iters, N, C, D, OFFSET); \
} \
MACE_BENCHMARK(MACE_BM_DELAY_##N##_##C##_##D##_##OFFSET##_##TYPE\
##_##DEVICE)
#define MACE_BM_DELAY(N, C, D, OFFSET) \
MACE_BM_DELAY_MACRO(N, C, D, OFFSET, float, CPU);
MACE_BM_DELAY(8, 40, 512, 2);
MACE_BM_DELAY(16, 80, 100, 3);
MACE_BM_DELAY(32, 60, 200, 5);
} // namespace test
} // namespace ops
} // namespace mace
...@@ -214,7 +214,7 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation { ...@@ -214,7 +214,7 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
std::vector<index_t> output_shape = input->shape(); std::vector<index_t> output_shape = input->shape();
output_shape[1] = output_dim; output_shape[input_rank - 1] = output_dim;
MACE_RETURN_IF_ERROR(output->Resize(output_shape)); MACE_RETURN_IF_ERROR(output->Resize(output_shape));
...@@ -235,8 +235,10 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation { ...@@ -235,8 +235,10 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
affine_b_in.Clear(); affine_b_in.Clear();
affine_b_out.Clear(); affine_b_out.Clear();
for (int i = 0; i < chunk; ++i) { for (int i = 0; i < chunk; ++i) {
const float *input_ptr = input_data + (b * chunk + i) * input_dim;
float *output_ptr = output_data + (b * chunk + i) * output_dim;
// Append // Append
memcpy(affine_a_in_data, input_data, input_dim * sizeof(float)); memcpy(affine_a_in_data, input_ptr, input_dim * sizeof(float));
if (prev_out_idx >= 0) { if (prev_out_idx >= 0) {
memcpy(affine_a_in_data + input_dim, memcpy(affine_a_in_data + input_dim,
prev_out_data + prev_out_idx % out_buf_chunk * prev_out_dim_, prev_out_data + prev_out_idx % out_buf_chunk * prev_out_dim_,
...@@ -283,7 +285,7 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation { ...@@ -283,7 +285,7 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
false, false,
&affine_b_out); &affine_b_out);
// Output // Output
memcpy(output_data, memcpy(output_ptr,
affine_b_out_data, affine_b_out_data,
output_dim * sizeof(float)); output_dim * sizeof(float));
// Update // Update
...@@ -292,8 +294,6 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation { ...@@ -292,8 +294,6 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
prev_out_dim_, prev_out_dim_,
scale_, scale_,
curr_out_ptr); curr_out_ptr);
input_data += input_dim;
output_data += output_dim;
prev_out_idx++; prev_out_idx++;
prev_cell_idx++; prev_cell_idx++;
} }
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for fused StatisticsExtraction, StatisticsPooling and
// Round Components in Kaldi.
// This op is used to extract moving-average mean and standard-deviation
// statistics of input data.
// 'input_indexes' indicates which frames will be used for extract statistics.
// 'output_indexes' indicates which frames of outputs will be used to
// save statistics results.
// 'modulus' will be used for extent results to all frames.
// 'start_index' and 'end_index' indicate time indexes of output frames.
// 'forward_indexes' and 'count' were from precomputed index in kaldi.
// Reference to
// http://kaldi-asr.org/doc/nnet-general-component_8h_source.html#l00158
#include <functional>
#include <memory>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class ExtractPoolingOp;
template <typename T>
class ExtractPoolingOp<DeviceType::CPU, T> : public Operation {
public:
explicit ExtractPoolingOp(OpConstructContext *context)
: Operation(context),
modulus_(Operation::GetOptionalArg<int>("modulus", 1)),
include_variance_(
static_cast<bool>(
Operation::GetOptionalArg<int>("include_variance", 0))),
num_log_count_(
Operation::GetOptionalArg<int>("num_log_count", 0)),
variance_floor_(
Operation::GetOptionalArg<float>("variance_floor", 1.0e-10)),
input_indexes_(Operation::GetRepeatedArgs<int>("input_indexes")),
output_indexes_(Operation::GetRepeatedArgs<int>("output_indexes")),
forward_indexes_(Operation::GetRepeatedArgs<int>("forward_indexes")),
counts_(Operation::GetRepeatedArgs<float>("counts")),
input_time_range_(Operation::GetRepeatedArgs<int>("input_time_range")),
output_time_range_(
Operation::GetRepeatedArgs<int>("output_time_range")) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
const std::vector<index_t> &input_shape = input->shape();
const index_t dim_size = input_shape.size();
MACE_CHECK(dim_size >= 2,
"ExtractPooling only supports input dim size >= 2");
MACE_CHECK(modulus_ >= 1,
"ExtractPooling's pooling size should be greater than zero.");
MACE_CHECK(input_time_range_.size() == 2 && output_time_range_.size() == 2
&& counts_.size() * 2 == forward_indexes_.size()
&& counts_.size() == output_indexes_.size());
int in_start_index = input_time_range_[0];
int out_start_index = output_time_range_[0];
int out_end_index = output_time_range_[1];
MACE_CHECK(out_end_index >= out_start_index
&& input_time_range_[1] >= input_time_range_[0],
"end index should be greater than start index.");
const index_t output_chunk = out_end_index - out_start_index + 1;
const index_t input_dim = input_shape[dim_size - 1];
const index_t chunk = input_shape[dim_size - 2];
MACE_CHECK(chunk == input_time_range_[1] - input_time_range_[0] + 1,
"input chunk should be equal to end - start + 1.");
const index_t batch =
std::accumulate(input_shape.begin(), input_shape.end() - 2, 1,
std::multiplies<index_t>());
index_t output_dim = include_variance_ ? 2 * input_dim : input_dim;
output_dim += num_log_count_;
std::vector<index_t> output_shape(input_shape);
output_shape[dim_size - 1] = output_dim;
output_shape[dim_size - 2] = output_chunk;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
const index_t num_input_indexes = input_indexes_.size();
const index_t num_output_indexes = output_indexes_.size();
MACE_CHECK(num_input_indexes > 0 && num_output_indexes > 0,
"ExtractPooling's input_indexes or output_indexes is empty.");
const index_t extract_out_size = PadAlignSize(output_dim * sizeof(float));
ScratchBuffer *scratch = context->device()->scratch_buffer();
scratch->Rewind();
scratch->GrowSize(extract_out_size);
Tensor extract_out(scratch->Scratch(extract_out_size), DT_FLOAT);
extract_out.Reshape({1, output_dim});
extract_out.Clear();
float *extract_out_data = extract_out.mutable_data<float>();
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
for (index_t b = 0; b < batch; ++b) {
for (index_t i = 0; i < num_output_indexes; ++i) {
int start = forward_indexes_[2 * i];
int end = forward_indexes_[2 * i + 1];
float count = counts_[i];
float mean_scale = 1.f / count;
float log_count = std::log(count);
thread_pool.Compute1D([=](index_t start0,
index_t end0,
index_t step0) {
for (index_t n = start0; n < end0; n += step0) {
extract_out_data[n] = log_count;
}
}, 0, num_log_count_, 1);
if (include_variance_) {
thread_pool.Compute1D([=](index_t start0,
index_t end0,
index_t step0) {
for (index_t d = start0; d < end0; d += step0) {
float mean = 0.f;
float variance = 0.f;
for (int t = start; t < end; ++t) {
index_t input_index =
(b * chunk + input_indexes_[t] - in_start_index)
* input_dim;
float x = input_data[input_index + d];
mean += x;
variance += x * x;
}
mean *= mean_scale;
variance *= mean_scale;
extract_out_data[d + num_log_count_] = mean;
variance = variance - mean * mean;
extract_out_data[d + input_dim + num_log_count_] =
variance < variance_floor_ ?
std::sqrt(variance_floor_) :
std::sqrt(variance);
}
}, 0, input_dim, 1);
} else {
thread_pool.Compute1D([=](index_t start0,
index_t end0,
index_t step0) {
for (index_t d = start0; d < end0; d += step0) {
float mean = 0.f;
for (int t = start; t < end; ++t) {
index_t input_index =
(b * chunk + input_indexes_[t] - in_start_index)
* input_dim;
mean += input_data[input_index + d];
}
extract_out_data[d + num_log_count_] = mean * mean_scale;
}
}, 0, input_dim, 1);
}
int output_start = output_indexes_[i] < out_start_index ?
out_start_index : output_indexes_[i];
int output_end = output_indexes_[i] + modulus_;
output_end = output_end > out_end_index ?
out_end_index + 1 :
output_end;
thread_pool.Compute1D([=](index_t start0,
index_t end0,
index_t step0) {
for (index_t idx = start0; idx < end0; idx += step0) {
memcpy(output_data + (b * output_chunk + idx - out_start_index)
* output_dim,
extract_out_data, output_dim * sizeof(float));
}
}, output_start, output_end, 1);
}
}
return MaceStatus::MACE_SUCCESS;
}
private:
int modulus_;
bool include_variance_;
int num_log_count_;
float variance_floor_;
std::vector<int> input_indexes_;
std::vector<int> output_indexes_;
std::vector<int> forward_indexes_;
std::vector<float> counts_;
std::vector<int> input_time_range_;
std::vector<int> output_time_range_;
};
void RegisterExtractPooling(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "ExtractPooling", ExtractPoolingOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D, typename T>
static void ExtractPooling(int iters,
int batch,
int chunk,
int dim,
int input_period,
int modulus) {
mace::testing::StopTiming();
OpsTestNet net;
size_t num_input_indexes = static_cast<size_t>(chunk / input_period);
std::vector<int> input_indexes(num_input_indexes, 0);
for (size_t i = 0; i < num_input_indexes; ++i) {
input_indexes[i] = static_cast<int>(i * input_period);
}
size_t num_output_indexes = static_cast<size_t>(chunk / modulus);
std::vector<int> output_indexes(num_output_indexes, 0);
std::vector<int> forward_indexes(num_output_indexes * 2, 0);
std::vector<float> counts(num_output_indexes, 0.f);
for (size_t i = 0; i < num_output_indexes; ++i) {
output_indexes[i] = static_cast<int>(i * modulus);
forward_indexes[2 * i] = 0;
forward_indexes[2 * i + 1] = static_cast<int>(num_input_indexes - 1);
counts[i] = static_cast<float>(num_input_indexes);
}
// Add input data
net.AddRandomInput<D, T>("Input", {batch, chunk, dim});
OpDefBuilder("ExtractPooling", "ExtractPoolingTest")
.Input("Input")
.AddIntArg("modulus", modulus)
.AddIntArg("include_variance", 1)
.AddIntArg("num_log_counts", 1)
.AddIntsArg("input_indexes", input_indexes)
.AddIntsArg("output_indexes", output_indexes)
.AddIntsArg("forward_indexes", forward_indexes)
.AddFloatsArg("counts", counts)
.AddIntsArg("input_time_range", {0, chunk - 1})
.AddIntsArg("output_time_range", {0, chunk - 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.Run();
}
net.Sync();
}
#define MACE_BM_EXTRACTPOOLING_MACRO(N, C, D, INP, M, TYPE, DEVICE) \
static void MACE_BM_EXTRACTPOOLING_##N##_##C##_##D##_##INP##_##M##_##TYPE##\
_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * D; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
ExtractPooling<DEVICE, TYPE>(iters, N, C, D, INP, M); \
} \
MACE_BENCHMARK(MACE_BM_EXTRACTPOOLING_##N##_##C##_##D##_##INP##_##M##_##TYPE\
##_##DEVICE)
#define MACE_BM_EXTRACTPOOLING(N, C, D, INP, M) \
MACE_BM_EXTRACTPOOLING_MACRO(N, C, D, INP, M, float, CPU);
MACE_BM_EXTRACTPOOLING(8, 40, 512, 2, 4);
MACE_BM_EXTRACTPOOLING(16, 80, 100, 3, 9);
MACE_BM_EXTRACTPOOLING(32, 60, 200, 6, 18);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ExtractPoolingTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestExtractPooling(const std::vector<index_t> &input_shape,
const std::vector<float> &input_value,
const int modulus,
const int num_log_count,
const int include_variance,
const std::vector<int> &input_time_range,
const std::vector<int> &input_indexes,
const std::vector<int> &forward_indexes,
const std::vector<float> &counts,
const std::vector<int> &output_indexes,
const std::vector<int> &output_time_range,
const std::vector<index_t> &output_shape,
const std::vector<float> &output_value) {
// Construct graph
OpsTestNet net;
net.AddInputFromArray<D, float>("Input", input_shape, input_value);
OpDefBuilder("ExtractPooling", "ExtractPoolingTest")
.Input("Input")
.AddIntArg("modulus", modulus)
.AddIntArg("include_variance", include_variance)
.AddIntArg("num_log_count", num_log_count)
.AddIntsArg("input_indexes", input_indexes)
.AddIntsArg("output_indexes", output_indexes)
.AddIntsArg("forward_indexes", forward_indexes)
.AddFloatsArg("counts", counts)
.AddIntsArg("input_time_range", input_time_range)
.AddIntsArg("output_time_range", output_time_range)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
// Check
auto expected = net.CreateTensor<float>(output_shape, output_value);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
} // namespace
TEST_F(ExtractPoolingTest, SimpleCPU) {
TestExtractPooling<DeviceType::CPU, float>(
{3, 20, 3},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60},
9, 0, 0,
{-2, 17},
{0, 3, 6, 9, 12, 15},
{0, 6, 2, 6},
{6, 4},
{0, 9},
{0, 17},
{3, 18, 3},
{29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5,
38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5});
}
TEST_F(ExtractPoolingTest, SimpleCPUWithVariance) {
TestExtractPooling<DeviceType::CPU, float>(
{3, 20, 3},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60},
9, 1, 1,
{-2, 17},
{0, 3, 6, 9, 12, 15},
{0, 6, 2, 6},
{6, 4},
{0, 9},
{0, 17},
{3, 18, 7},
{1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623,
1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for Kaldi's BatchNormComponent
// More details about forward computation are here:
// http://kaldi-asr.org/doc/nnet-normalize-component_8cc_source.html#l00320
#include <memory>
#include <string>
#include <vector>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class KaldiBatchNormOp;
template <>
class KaldiBatchNormOp<DeviceType::CPU, float> : public Operation {
public:
explicit KaldiBatchNormOp(OpConstructContext *context)
: Operation(context),
epsilon_(Operation::GetOptionalArg<float>("epsilon",
static_cast<float>(1e-3))),
target_rms_(Operation::GetOptionalArg<float>("target_rms", 1.0f)),
block_dim_(Operation::GetOptionalArg<int>("block_dim", -1)),
test_mode_(static_cast<bool>(
Operation::GetOptionalArg<int>("test_mode", 0))) {}
void CalculateMeanVar(const float *input_data,
index_t length,
index_t stride,
float mean_scale,
float var_scale,
float *mean_data,
float *var_data) {
float mean_value = 0.f;
float var_value = 0.f;
for (index_t i = 0; i < length; ++i) {
float x = input_data[i * stride];
mean_value += x;
var_value += x * x;
}
mean_value = mean_value * mean_scale;
var_value = var_value * mean_scale;
float mean_sqr = mean_value * mean_value;
var_value = (var_value > mean_sqr) ?
var_scale * (var_value - mean_sqr + epsilon_) :
var_scale * epsilon_;
var_data[0] = std::pow(var_value, -0.5f);
mean_data[0] = mean_value;
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(INPUT);
const std::vector<index_t> &input_shape = input->shape();
const index_t rank = input->dim_size();
const index_t dim = input_shape[rank - 1];
if (block_dim_ == -1) block_dim_ = static_cast<int>(dim);
MACE_CHECK(target_rms_ > 0 && dim > 0 && dim % block_dim_ == 0);
MACE_CHECK(rank >= 2, "KaldiBatchNorm's input's rank must >= 2.");
index_t num_rows =
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1,
std::multiplies<index_t>());
const index_t blocks = dim / block_dim_;
if (blocks > 1) num_rows *= blocks;
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
if (test_mode_) {
MACE_CHECK(this->InputSize() == 3, "KaldiBatchNorm should have 3 inputs");
const Tensor *scale = this->Input(SCALE);
const Tensor *offset = this->Input(OFFSET);
MACE_CHECK(scale->dim_size() == 1, "scale must be 1-dimensional. ",
scale->dim_size());
MACE_CHECK(offset->dim_size() == 1, "offset must be 1-dimensional. ",
offset->dim_size());
MACE_CHECK(scale->size() == offset->size()
&& scale->size() == block_dim_);
Tensor::MappingGuard scale_guard(scale);
Tensor::MappingGuard offset_guard(offset);
const float *scale_data = scale->data<float>();
const float *offset_data = offset->data<float>();
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
index_t idx = i * block_dim_ + j;
output_data[idx] = input_data[idx] * scale_data[j] + offset_data[j];
}
}
}, 0, num_rows, 1, 0, block_dim_, 1);
} else {
const index_t buf_size =
PadAlignSize(block_dim_ * sizeof(float));
ScratchBuffer *scratch = context->device()->scratch_buffer();
scratch->Rewind();
scratch->GrowSize(2 * buf_size);
Tensor mean(scratch->Scratch(buf_size), DT_FLOAT);
mean.Reshape({block_dim_});
float *mean_data = mean.mutable_data<float>();
Tensor var(scratch->Scratch(buf_size), DT_FLOAT);
var.Reshape({block_dim_});
float *var_data = var.mutable_data<float>();
float var_scale = 1.0f / (target_rms_ * target_rms_);
float mean_scale = 1.0f / num_rows;
thread_pool.Compute1D([=](index_t start0, index_t end0, index_t step0) {
for (index_t i = start0; i < end0; i += step0) {
CalculateMeanVar(input_data + i,
num_rows,
block_dim_,
mean_scale,
var_scale,
mean_data + i,
var_data + i);
}
}, 0, block_dim_, 1);
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
index_t idx = i * block_dim_ + j;
output_data[idx] = (input_data[idx] - mean_data[j]) * var_data[j];
}
}
}, 0, num_rows, 1, 0, block_dim_, 1);
}
return MaceStatus::MACE_SUCCESS;
}
private:
const float epsilon_;
const float target_rms_;
int block_dim_;
const bool test_mode_;
protected:
MACE_OP_INPUT_TAGS(INPUT, SCALE, OFFSET);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
void RegisterKaldiBatchNorm(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "KaldiBatchNorm", KaldiBatchNormOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void KaldiBatchNorm(
int iters, int batch, int chunk, int dim, int block_dim) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, T>("Input", {batch, chunk, dim});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, T>("Scale", {block_dim}, true);
net.AddRandomInput<D, T>("Offset", {block_dim}, true);
OpDefBuilder("KaldiBatchNorm", "KaldiBatchNormBM")
.Input("Input")
.Input("Scale")
.Input("Offset")
.AddIntArg("block_dim", block_dim)
.AddIntArg("test_mode", 1)
.Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_KALDI_BATCH_NORM_MACRO(N, C, D, BD, TYPE, DEVICE) \
static void MACE_BM_KALDI_BATCH_NORM_##N##_##C##_##D##_##BD##_##TYPE\
##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * D; \
mace::testing::MacsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
KaldiBatchNorm<DEVICE, TYPE>(iters, N, C, D, BD); \
} \
MACE_BENCHMARK(MACE_BM_KALDI_BATCH_NORM_##N##_##C##_##D##_##BD##_##TYPE\
##_##DEVICE)
#define MACE_BM_KALDI_BATCH_NORM(N, C, D, BD) \
MACE_BM_KALDI_BATCH_NORM_MACRO(N, C, D, BD, float, CPU);
MACE_BM_KALDI_BATCH_NORM(1, 1, 512, 512);
MACE_BM_KALDI_BATCH_NORM(1, 3, 128, 128);
MACE_BM_KALDI_BATCH_NORM(1, 3, 512, 128);
MACE_BM_KALDI_BATCH_NORM(1, 32, 112, 112);
MACE_BM_KALDI_BATCH_NORM(1, 64, 256, 256);
MACE_BM_KALDI_BATCH_NORM(1, 64, 512, 256);
MACE_BM_KALDI_BATCH_NORM(1, 128, 56, 56);
MACE_BM_KALDI_BATCH_NORM(1, 128, 256, 256);
MACE_BM_KALDI_BATCH_NORM(1, 256, 14, 14);
MACE_BM_KALDI_BATCH_NORM(1, 512, 14, 14);
MACE_BM_KALDI_BATCH_NORM(1, 1024, 7, 7);
MACE_BM_KALDI_BATCH_NORM(32, 1, 256, 128);
MACE_BM_KALDI_BATCH_NORM(32, 3, 256, 256);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class KaldiBatchNormOpTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void Simple(const std::vector<index_t> &input_shape,
const std::vector<float> &input_value,
const int block_dim,
const int dim,
const std::vector<float> &scale,
const std::vector<float> &offset,
const std::vector<index_t> &output_shape,
const std::vector<float> &output_value) {
OpsTestNet net;
int scale_dim = block_dim;
if (scale_dim == -1) scale_dim = dim;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape,
input_value);
net.AddInputFromArray<D, float>("Scale", {scale_dim}, scale, true);
net.AddInputFromArray<D, float>("Offset", {scale_dim}, offset, true);
if (D == DeviceType::CPU) {
OpDefBuilder("KaldiBatchNorm", "KaldiBatchNormOpTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.AddIntArg("block_dim", block_dim)
.AddIntArg("test_mode", 1)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else if (D == DeviceType::GPU) {
MACE_NOT_IMPLEMENTED;
}
// Check
auto expected = net.CreateTensor<float>(output_shape, output_value);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-4);
}
template <DeviceType D>
void SimpleNotTestMode(const std::vector<index_t> &input_shape,
const std::vector<float> &input_value,
const int block_dim,
const std::vector<index_t> &output_shape,
const std::vector<float> &output_value) {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape,
input_value);
if (D == DeviceType::CPU) {
OpDefBuilder("KaldiBatchNorm", "KaldiBatchNormOpTest")
.Input("Input")
.AddIntArg("block_dim", block_dim)
.AddIntArg("test_mode", 0)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else if (D == DeviceType::GPU) {
MACE_NOT_IMPLEMENTED;
}
// Check
auto expected = net.CreateTensor<float>(output_shape, output_value);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-4);
}
} // namespace
TEST_F(KaldiBatchNormOpTest, SimpleTestModeCPUOneBlock) {
Simple<DeviceType::CPU>(
{1, 6, 2},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15},
-1, 2,
{4.f, 6.f},
{2.f, 1.f},
{1, 6, 2},
{22, 31, 30, 43, 38, 55, 46, 67, 54, 79, 62, 91}); }
TEST_F(KaldiBatchNormOpTest, SimpleTestModeCPUTwoBlock) {
Simple<DeviceType::CPU>(
{1, 6, 4},
{5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9,
11, 11, 11, 11, 13, 13, 13, 13, 15, 15, 15, 15},
2, 4,
{4.f, 6.f},
{2.f, 1.f},
{1, 6, 4},
{22, 31, 22, 31, 30, 43, 30, 43, 38, 55, 38, 55,
46, 67, 46, 67, 54, 79, 54, 79, 62, 91, 62, 91});
}
TEST_F(KaldiBatchNormOpTest, SimpleNotTestModeCPUTwoBlock) {
SimpleNotTestMode<DeviceType::CPU>(
{1, 6, 4},
{5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9,
11, 11, 11, 11, 13, 13, 13, 13, 15, 15, 15, 15},
2,
{1, 6, 4},
{-1.46379, -1.46379, -1.46379, -1.46379,
-0.8783, -0.8783, -0.8783, -0.8783,
-0.29276, -0.29276, -0.29276, -0.29276,
0.29276, 0.29276, 0.29276, 0.29276,
0.8783, 0.8783, 0.8783, 0.8783,
1.46379, 1.46379, 1.46379, 1.46379});
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -37,11 +37,14 @@ extern void RegisterDepthwiseDeconv2d(OpRegistryBase *op_registry); ...@@ -37,11 +37,14 @@ extern void RegisterDepthwiseDeconv2d(OpRegistryBase *op_registry);
extern void RegisterDynamicLSTM(OpRegistryBase *op_registry); extern void RegisterDynamicLSTM(OpRegistryBase *op_registry);
extern void RegisterEltwise(OpRegistryBase *op_registry); extern void RegisterEltwise(OpRegistryBase *op_registry);
extern void RegisterExpandDims(OpRegistryBase *op_registry); extern void RegisterExpandDims(OpRegistryBase *op_registry);
extern void RegisterExtractPooling(OpRegistryBase *op_registry);
extern void RegisterFill(OpRegistryBase *op_registry); extern void RegisterFill(OpRegistryBase *op_registry);
extern void RegisterFullyConnected(OpRegistryBase *op_registry); extern void RegisterFullyConnected(OpRegistryBase *op_registry);
extern void RegisterGather(OpRegistryBase *op_registry); extern void RegisterGather(OpRegistryBase *op_registry);
extern void RegisterIdentity(OpRegistryBase *op_registry); extern void RegisterIdentity(OpRegistryBase *op_registry);
extern void RegisterDelay(OpRegistryBase *op_registry);
extern void RegisterInferConv2dShape(OpRegistryBase *op_registry); extern void RegisterInferConv2dShape(OpRegistryBase *op_registry);
extern void RegisterKaldiBatchNorm(OpRegistryBase *op_registry);
extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry); extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry);
extern void RegisterLSTMNonlinear(OpRegistryBase *op_registry); extern void RegisterLSTMNonlinear(OpRegistryBase *op_registry);
extern void RegisterMatMul(OpRegistryBase *op_registry); extern void RegisterMatMul(OpRegistryBase *op_registry);
...@@ -107,11 +110,14 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ...@@ -107,11 +110,14 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterDynamicLSTM(this); ops::RegisterDynamicLSTM(this);
ops::RegisterEltwise(this); ops::RegisterEltwise(this);
ops::RegisterExpandDims(this); ops::RegisterExpandDims(this);
ops::RegisterExtractPooling(this);
ops::RegisterFill(this); ops::RegisterFill(this);
ops::RegisterFullyConnected(this); ops::RegisterFullyConnected(this);
ops::RegisterGather(this); ops::RegisterGather(this);
ops::RegisterIdentity(this); ops::RegisterIdentity(this);
ops::RegisterDelay(this);
ops::RegisterInferConv2dShape(this); ops::RegisterInferConv2dShape(this);
ops::RegisterKaldiBatchNorm(this);
ops::RegisterLocalResponseNorm(this); ops::RegisterLocalResponseNorm(this);
ops::RegisterLSTMNonlinear(this); ops::RegisterLSTMNonlinear(this);
ops::RegisterMatMul(this); ops::RegisterMatMul(this);
......
...@@ -42,7 +42,7 @@ class PadContextOp<DeviceType::CPU, T> : public Operation { ...@@ -42,7 +42,7 @@ class PadContextOp<DeviceType::CPU, T> : public Operation {
index_t rank = input->dim_size(); index_t rank = input->dim_size();
MACE_CHECK(rank >= 2, "input's rank should >= 2."); MACE_CHECK(rank >= 2, "input's rank should >= 2.");
MACE_CHECK(left_context_ > 0 && right_context_ > 0, MACE_CHECK(left_context_ >= 0 && right_context_ >= 0,
"left context and right context should be greater than zero"); "left context and right context should be greater than zero");
const std::vector<index_t> &input_shape = input->shape(); const std::vector<index_t> &input_shape = input->shape();
const index_t batch = const index_t batch =
......
...@@ -106,19 +106,6 @@ void SimpleMean12Test() { ...@@ -106,19 +106,6 @@ void SimpleMean12Test() {
10, 11, 12, 13}, ReduceType::MEAN); 10, 11, 12, 13}, ReduceType::MEAN);
} }
// template <DeviceType D>
// void SimpleSum12Test() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {1, 2},
// {2, 1, 1, 4},
// {60, 66, 72, 78,
// 60, 66, 72, 78}, ReduceType::SUM);
//}
template <DeviceType D> template <DeviceType D>
void SimpleMin12Test() { void SimpleMin12Test() {
Simple<D>({2, 2, 3, 4}, Simple<D>({2, 2, 3, 4},
...@@ -145,20 +132,6 @@ void SimpleMax12Test() { ...@@ -145,20 +132,6 @@ void SimpleMax12Test() {
20, 21, 22, 23}, ReduceType::MAX); 20, 21, 22, 23}, ReduceType::MAX);
} }
// template <DeviceType D>
// void SimpleSumSqr12Test() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {1, 2},
// {2, 1, 1, 4},
// {880, 1006, 1144, 1294,
// 880, 1006, 1144, 1294}, ReduceType::SUM_SQR);
//}
template <DeviceType D> template <DeviceType D>
void SimpleMean1Axis() { void SimpleMean1Axis() {
Simple<D>({2, 2, 3, 4}, Simple<D>({2, 2, 3, 4},
...@@ -170,102 +143,8 @@ void SimpleMean1Axis() { ...@@ -170,102 +143,8 @@ void SimpleMean1Axis() {
{2, 1, 3, 4}, {2, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {-3},
// {1, 1, 3, 4},
// {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {4, 5, 6, 7, 16, 17, 18, 19}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {1.5, 5.5, 9.5, 13.5, 17.5, 21.5}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {-2},
// {1, 3, 1, 3},
// {3, 4, 5, 12, 13, 14, 21, 22, 23}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {3},
// {1, 3, 3, 1},
// {1, 4, 7, 10, 13, 16, 19, 22, 25}, ReduceType::MEAN);
} }
// template <DeviceType D>
// void SimpleSum1Axis() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23,
// 0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {1},
// {2, 1, 3, 4},
// {12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,
// 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34});
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {12, 15, 18, 21, 48, 51, 54, 57}, ReduceType::SUM);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {6, 22, 38, 54, 70, 86}, ReduceType::SUM);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {27, 30, 33, 36, 39, 42, 45, 48, 51}, ReduceType::SUM);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {3},
// {1, 3, 3, 1},
// {3, 12, 21, 30, 39, 48, 57, 66, 75}, ReduceType::SUM);
//}
template <DeviceType D> template <DeviceType D>
void SimpleMin1Axis() { void SimpleMin1Axis() {
Simple<D>({2, 2, 3, 4}, Simple<D>({2, 2, 3, 4},
...@@ -285,33 +164,6 @@ void SimpleMin1Axis() { ...@@ -285,33 +164,6 @@ void SimpleMin1Axis() {
{2, 1, 3, 4}, {2, 1, 3, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, ReduceType::MIN); 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, ReduceType::MIN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {0, 1, 2, 3, 12, 13, 14, 15}, ReduceType::MIN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {0, 4, 8, 12, 16, 20}, ReduceType::MIN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8}, ReduceType::MIN);
} }
template <DeviceType D> template <DeviceType D>
...@@ -337,53 +189,8 @@ void SimpleMax1Axis() { ...@@ -337,53 +189,8 @@ void SimpleMax1Axis() {
12, 13, 14, 15, 12, 13, 14, 15,
16, 17, 18, 19, 16, 17, 18, 19,
20, 21, 22, 23}, ReduceType::MAX); 20, 21, 22, 23}, ReduceType::MAX);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {8, 9, 10, 11, 20, 21, 22, 23}, ReduceType::MAX);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {3, 7, 11, 15, 19, 23}, ReduceType::MAX);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {18, 19, 20, 21, 22, 23, 24, 25, 26}, ReduceType::MAX);
} }
// template <DeviceType D>
// void SimpleSumSqr1Axis() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {1},
// {2, 1, 3, 4},
// {144, 170, 200, 234,
// 272, 314, 360, 410,
// 464, 522, 584, 650,
// 144, 170, 200, 234,
// 272, 314, 360, 410,
// 464, 522, 584, 650}, ReduceType::SUM_SQR);
//}
template <DeviceType D> template <DeviceType D>
void Simple2Axis() { void Simple2Axis() {
Simple<D>({1, 2, 3, 4}, Simple<D>({1, 2, 3, 4},
...@@ -396,16 +203,6 @@ void Simple2Axis() { ...@@ -396,16 +203,6 @@ void Simple2Axis() {
{0, 1}, {0, 1},
{1, 1, 3, 4}, {1, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple3D<D>({2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 1},
// {1, 1, 4},
// {10, 11, 12, 13}, ReduceType::MEAN);
Simple3D<D>({2, 3, 4}, Simple3D<D>({2, 3, 4},
{0, 1, 2, 3, {0, 1, 2, 3,
4, 5, 6, 7, 4, 5, 6, 7,
...@@ -426,37 +223,6 @@ void Simple2Axis() { ...@@ -426,37 +223,6 @@ void Simple2Axis() {
{0, 2}, {0, 2},
{1, 2, 1, 4}, {1, 2, 1, 4},
{4, 5, 6, 7, 16, 17, 18, 19}, ReduceType::MEAN); {4, 5, 6, 7, 16, 17, 18, 19}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {1, 3},
// {1, 1, 3, 1},
// {7.5, 11.5, 15.5}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1, 2},
// {1, 1, 1, 3},
// {12, 13, 14}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 1},
// {1, 1, 3, 3},
// {9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {2, 3},
// {1, 3, 1, 1},
// {4, 13, 22}, ReduceType::MEAN);
} }
template <DeviceType D> template <DeviceType D>
...@@ -471,64 +237,6 @@ void Simple3Axis() { ...@@ -471,64 +237,6 @@ void Simple3Axis() {
{1, 2, 3}, {1, 2, 3},
{1, 1, 1, 1}, {1, 1, 1, 1},
{11.5}, ReduceType::MEAN); {11.5}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 2, 3},
// {1, 2, 1, 1},
// {5.5, 17.5}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 1, 3},
// {1, 1, 3, 1},
// {7.5, 11.5, 15.5}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 1, 2},
// {1, 1, 1, 4},
// {10, 11, 12, 13}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1, 2, 3},
// {1, 1, 1, 1},
// {13}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 2, 3},
// {1, 3, 1, 1},
// {4, 13, 22}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 1, 3},
// {1, 1, 3, 1},
// {10, 13, 16}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 1, 2},
// {1, 1, 1, 3},
// {12, 13, 14}, ReduceType::MEAN);
} }
} // namespace } // namespace
...@@ -622,25 +330,17 @@ void RandomTest(const std::vector<index_t> &input_shape, ...@@ -622,25 +330,17 @@ void RandomTest(const std::vector<index_t> &input_shape,
TEST_F(ReduceOpTest, GPURandomFloat) { TEST_F(ReduceOpTest, GPURandomFloat) {
RandomTest<DeviceType::GPU, float>({4, 64, 64, 3}, {1, 2}); RandomTest<DeviceType::GPU, float>({4, 64, 64, 3}, {1, 2});
// RandomTest<DeviceType::GPU, float>({2, 64, 64, 4}, {1, 2});
RandomTest<DeviceType::GPU, float>({8, 128, 128, 64}, {1, 2}); RandomTest<DeviceType::GPU, float>({8, 128, 128, 64}, {1, 2});
// RandomTest<DeviceType::GPU, float>({1, 640, 480, 64}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 480, 640, 32}, {1, 2}); RandomTest<DeviceType::GPU, float>({1, 480, 640, 32}, {1, 2});
// RandomTest<DeviceType::GPU, float>({1, 512, 512, 16}, {1, 2});
RandomTest<DeviceType::GPU, float>({8, 117, 87, 33}, {1, 2}); RandomTest<DeviceType::GPU, float>({8, 117, 87, 33}, {1, 2});
// RandomTest<DeviceType::GPU, float>({1, 619, 450, 61}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 511, 561, 11}, {1, 2}); RandomTest<DeviceType::GPU, float>({1, 511, 561, 11}, {1, 2});
} }
TEST_F(ReduceOpTest, GPURandomHalf) { TEST_F(ReduceOpTest, GPURandomHalf) {
RandomTest<DeviceType::GPU, half>({4, 64, 64, 3}, {1, 2}); RandomTest<DeviceType::GPU, half>({4, 64, 64, 3}, {1, 2});
// RandomTest<DeviceType::GPU, half>({2, 64, 64, 4}, {1, 2});
RandomTest<DeviceType::GPU, half>({8, 128, 128, 64}, {1, 2}); RandomTest<DeviceType::GPU, half>({8, 128, 128, 64}, {1, 2});
// RandomTest<DeviceType::GPU, half>({1, 640, 480, 64}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 480, 640, 32}, {1, 2}); RandomTest<DeviceType::GPU, half>({1, 480, 640, 32}, {1, 2});
// RandomTest<DeviceType::GPU, half>({1, 512, 512, 16}, {1, 2});
RandomTest<DeviceType::GPU, half>({8, 117, 87, 33}, {1, 2}); RandomTest<DeviceType::GPU, half>({8, 117, 87, 33}, {1, 2});
// RandomTest<DeviceType::GPU, half>({1, 619, 450, 61}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 511, 561, 11}, {1, 2}); RandomTest<DeviceType::GPU, half>({1, 511, 561, 11}, {1, 2});
} }
......
...@@ -71,7 +71,8 @@ class SpliceOp<DeviceType::CPU, T> : public Operation { ...@@ -71,7 +71,8 @@ class SpliceOp<DeviceType::CPU, T> : public Operation {
const index_t out_chunk = chunk - (right_context - left_context); const index_t out_chunk = chunk - (right_context - left_context);
MACE_CHECK(input_dim > const_dim_, MACE_CHECK(input_dim > const_dim_,
"input dim should be greater than const dim."); "input dim:", input_dim,
"should be greater than const dim:", const_dim_);
const index_t output_dim = dim * num_splice + const_dim_; const index_t output_dim = dim * num_splice + const_dim_;
const index_t output_stride = out_chunk * output_dim; const index_t output_stride = out_chunk * output_dim;
...@@ -103,7 +104,7 @@ class SpliceOp<DeviceType::CPU, T> : public Operation { ...@@ -103,7 +104,7 @@ class SpliceOp<DeviceType::CPU, T> : public Operation {
const index_t input_offset = dim; const index_t input_offset = dim;
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
for (index_t i = 0; i < out_chunk; ++i) { for (index_t i = 0; i < out_chunk; ++i) {
T *output_base = output_data + + b * output_stride + i * output_dim; T *output_base = output_data + b * output_stride + i * output_dim;
const T *input_base = input_data + b * input_stride + i * input_dim; const T *input_base = input_data + b * input_stride + i * input_dim;
memcpy(output_base + output_offset, memcpy(output_base + output_offset,
input_base + input_offset, input_base + input_offset,
......
...@@ -80,18 +80,22 @@ class SumGroupOp<DeviceType::CPU, T> : public Operation { ...@@ -80,18 +80,22 @@ class SumGroupOp<DeviceType::CPU, T> : public Operation {
MACE_CHECK(cur_index <= input_dim) MACE_CHECK(cur_index <= input_dim)
<< "size value over-ranged:" << cur_index << "<=" << input_dim; << "size value over-ranged:" << cur_index << "<=" << input_dim;
} }
utils::ThreadPool
for (index_t i = 0; i < bh; ++i) { &thread_pool = context->device()->cpu_runtime()->thread_pool();
for (index_t j = 0; j < output_dim; ++j) { thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
int start_col = sum_indexes[j].first; index_t start1, index_t end1, index_t step1) {
int end_col = sum_indexes[j].second; for (index_t i = start0; i < end0; i += step0) {
T sum = 0; for (index_t j = start1; j < end1; j += step1) {
for (int src_col = start_col; src_col < end_col; ++src_col) { int start_col = sum_indexes[j].first;
sum += input_data[i * input_dim + src_col]; int end_col = sum_indexes[j].second;
T sum = 0;
for (int src_col = start_col; src_col < end_col; ++src_col) {
sum += input_data[i * input_dim + src_col];
}
output_data[i * output_dim + j] = sum;
} }
output_data[i * output_dim + j] = sum;
} }
} }, 0, bh, 1, 0, output_dim, 1);
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
......
...@@ -35,7 +35,11 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation { ...@@ -35,7 +35,11 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
public: public:
explicit TargetRMSNormOp(OpConstructContext *context) explicit TargetRMSNormOp(OpConstructContext *context)
: Operation(context), : Operation(context),
target_rms_(Operation::GetOptionalArg<float>("target_rms", 1.0)) {} target_rms_(Operation::GetOptionalArg<float>("target_rms", 1.0)),
add_log_stddev_(
static_cast<bool>(
Operation::GetOptionalArg<int>("add_log_stddev", 0))),
block_dim_(Operation::GetOptionalArg<int>("block_dim", 0)) {}
// Calculate the square sum of an array // Calculate the square sum of an array
float SquareSum(const float *data, const index_t data_len) { float SquareSum(const float *data, const index_t data_len) {
...@@ -67,6 +71,24 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation { ...@@ -67,6 +71,24 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
return result; return result;
} }
void NormalizePerRow(const float *data,
const index_t data_len,
float d_scale,
bool add_log_stddev,
float *out_data) {
float scale = SquareSum(data, data_len);
scale = scale / d_scale;
scale = scale < 1.0e-6f ? 1.0e-6f : scale;
scale = static_cast<float>(1.0 / std::sqrt(scale));
for (index_t j = 0; j < data_len; ++j) {
out_data[j] = data[j] * scale;
}
if (add_log_stddev) {
out_data[data_len] = std::log(target_rms_) - std::log(scale);
}
}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
const Tensor *input = this->Input(0); const Tensor *input = this->Input(0);
...@@ -75,15 +97,18 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation { ...@@ -75,15 +97,18 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
const index_t dim_size = input->dim_size(); const index_t dim_size = input->dim_size();
MACE_CHECK(dim_size >= 1, MACE_CHECK(dim_size >= 1,
"TargetRMSNorm's input dim size should be >= 1."); "TargetRMSNorm's input dim size should be >= 1.");
const index_t dim = input_shape[dim_size -1]; const index_t input_dim = input_shape[dim_size -1];
MACE_CHECK(dim > 0 && target_rms_ > 0, MACE_CHECK(input_dim > 0 && target_rms_ > 0,
"Both input dim and target rms should be greater than zero."); "Both input dim and target rms should be greater than zero.");
const index_t bh = const index_t bh =
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, std::accumulate(input_shape.begin(), input_shape.end() - 1, 1,
std::multiplies<index_t>()); std::multiplies<index_t>());
const float d_scale = dim * target_rms_ * target_rms_; if (block_dim_ == 0) block_dim_ = static_cast<int>(input_dim);
const index_t output_dim = add_log_stddev_ ?
MACE_RETURN_IF_ERROR(output->ResizeLike(input)); input_dim + (input_dim / block_dim_) : input_dim;
std::vector<index_t> output_shape = input->shape();
output_shape[dim_size - 1] = output_dim;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard guard_input(input); Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output); Tensor::MappingGuard guard_output(output);
...@@ -91,19 +116,37 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation { ...@@ -91,19 +116,37 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>(); float *output_data = output->mutable_data<float>();
for (index_t i = 0; i < bh; ++i) { index_t num_rows = bh;
float scale = SquareSum(input_data + i * dim, dim); index_t output_block_dim = add_log_stddev_ ? block_dim_ + 1 : block_dim_;
scale = static_cast<float>(1.0 / std::sqrt(scale / d_scale));
for (index_t j = 0; j < dim; ++j) { if (block_dim_ != input_dim) {
output_data[i * dim + j] = input_data[i * dim + j] * scale; index_t num_blocks = input_dim / block_dim_;
} num_rows *= num_blocks;
} }
const float d_scale = block_dim_ * target_rms_ * target_rms_;
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute1D([=](index_t start0, index_t end0, index_t step0) {
for (index_t i = start0; i < end0; i += step0) {
const float *input_ptr = input_data + i * block_dim_;
float *out_ptr = output_data + i * output_block_dim;
NormalizePerRow(input_ptr,
block_dim_,
d_scale,
add_log_stddev_,
out_ptr);
}
}, 0, num_rows, 1);
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
private: private:
float target_rms_; float target_rms_;
bool add_log_stddev_;
int block_dim_;
}; };
void RegisterTargetRMSNorm(OpRegistryBase *op_registry) { void RegisterTargetRMSNorm(OpRegistryBase *op_registry) {
......
...@@ -112,17 +112,20 @@ MaceSupportedOps = [ ...@@ -112,17 +112,20 @@ MaceSupportedOps = [
'Conv2D', 'Conv2D',
'Crop', 'Crop',
'Deconv2D', 'Deconv2D',
'Delay',
'DepthToSpace', 'DepthToSpace',
'DepthwiseConv2d', 'DepthwiseConv2d',
'DepthwiseDeconv2d', 'DepthwiseDeconv2d',
'Dequantize', 'Dequantize',
'Eltwise', 'Eltwise',
'ExpandDims', 'ExpandDims',
'ExtractPooling',
'Fill', 'Fill',
'FullyConnected', 'FullyConnected',
'Gather', 'Gather',
'Identity', 'Identity',
'InferConv2dShape', 'InferConv2dShape',
'KaldiBatchNorm',
'LocalResponseNorm', 'LocalResponseNorm',
'LSTMCell', 'LSTMCell',
'LstmNonlinear', 'LstmNonlinear',
......
...@@ -41,6 +41,15 @@ from numbers import Number ...@@ -41,6 +41,15 @@ from numbers import Number
IS_PYTHON3 = sys.version_info > (3,) IS_PYTHON3 = sys.version_info > (3,)
class AttributeType(Enum):
INT = 100
FLOAT = 101
INTS = 102
FLOATS = 103
BOOL = 104
OnnxSupportedOps = [ OnnxSupportedOps = [
'Abs', 'Abs',
# 'Acos', # 'Acos',
...@@ -57,6 +66,7 @@ OnnxSupportedOps = [ ...@@ -57,6 +66,7 @@ OnnxSupportedOps = [
# 'Atanh', # 'Atanh',
'AveragePool', 'AveragePool',
'BatchNormalization', 'BatchNormalization',
'BatchNorm',
'Cast', 'Cast',
# 'Ceil', # 'Ceil',
# 'Clip', # 'Clip',
...@@ -72,11 +82,12 @@ OnnxSupportedOps = [ ...@@ -72,11 +82,12 @@ OnnxSupportedOps = [
'DimRange', 'DimRange',
'Div', 'Div',
'Dropout', 'Dropout',
'DynamicLstmCell', 'DynamicLSTM',
'Elu', 'Elu',
'Equal', 'Equal',
# 'Exp', # 'Exp',
# 'Expand', # 'Expand',
'ExtractPooling',
# 'EyeLike', # 'EyeLike',
# 'Flatten', # 'Flatten',
# 'Floor', # 'Floor',
...@@ -91,7 +102,7 @@ OnnxSupportedOps = [ ...@@ -91,7 +102,7 @@ OnnxSupportedOps = [
# 'Hardmax', # 'Hardmax',
'Identity', 'Identity',
# 'If', # 'If',
# 'IfDefined', 'IfDefined',
'ImageScaler', 'ImageScaler',
# 'InstanceNormalization', # 'InstanceNormalization',
# 'LRN', # 'LRN',
...@@ -318,6 +329,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -318,6 +329,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.ArgMin.name: self.convert_argmax, OnnxOpType.ArgMin.name: self.convert_argmax,
OnnxOpType.AveragePool.name: self.convert_pooling, OnnxOpType.AveragePool.name: self.convert_pooling,
OnnxOpType.BatchNormalization.name: self.convert_fused_batchnorm, OnnxOpType.BatchNormalization.name: self.convert_fused_batchnorm,
OnnxOpType.BatchNorm.name: self.convert_fused_batchnorm,
OnnxOpType.Cast.name: self.convert_cast, OnnxOpType.Cast.name: self.convert_cast,
OnnxOpType.Concat.name: self.convert_concat, OnnxOpType.Concat.name: self.convert_concat,
OnnxOpType.Conv.name: self.convert_conv2d, OnnxOpType.Conv.name: self.convert_conv2d,
...@@ -327,16 +339,18 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -327,16 +339,18 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.DimRange.name: self.convert_dim_range, OnnxOpType.DimRange.name: self.convert_dim_range,
OnnxOpType.Div.name: self.convert_eltwise, OnnxOpType.Div.name: self.convert_eltwise,
OnnxOpType.Equal.name: self.convert_eltwise, OnnxOpType.Equal.name: self.convert_eltwise,
OnnxOpType.ExtractPooling.name: self.convert_extract_pooling,
OnnxOpType.Gather.name: self.convert_gather, OnnxOpType.Gather.name: self.convert_gather,
OnnxOpType.Gemm.name: self.convert_gemm, OnnxOpType.Gemm.name: self.convert_gemm,
OnnxOpType.GlobalAveragePool.name: self.convert_reduce, OnnxOpType.GlobalAveragePool.name: self.convert_reduce,
OnnxOpType.GlobalMaxPool.name: self.convert_reduce, OnnxOpType.GlobalMaxPool.name: self.convert_reduce,
OnnxOpType.Identity.name: self.convert_identity, OnnxOpType.Identity.name: self.convert_identity,
OnnxOpType.IfDefined.name: self.convert_ifdefined,
OnnxOpType.ImageScaler.name: self.convert_imagescaler, OnnxOpType.ImageScaler.name: self.convert_imagescaler,
OnnxOpType.LeakyRelu.name: self.convert_activation, OnnxOpType.LeakyRelu.name: self.convert_activation,
OnnxOpType.LogSoftmax.name: self.convert_softmax, OnnxOpType.LogSoftmax.name: self.convert_softmax,
OnnxOpType.LstmNonlinear.name: self.convert_lstm_nonlinear, OnnxOpType.LstmNonlinear.name: self.convert_lstm_nonlinear,
OnnxOpType.DynamicLstmCell.name: self.convert_dynamic_lstm, OnnxOpType.DynamicLSTM.name: self.convert_dynamic_lstm,
OnnxOpType.Max.name: self.convert_eltwise, OnnxOpType.Max.name: self.convert_eltwise,
OnnxOpType.MaxPool.name: self.convert_pooling, OnnxOpType.MaxPool.name: self.convert_pooling,
OnnxOpType.MatMul.name: self.convert_matmul, OnnxOpType.MatMul.name: self.convert_matmul,
...@@ -378,6 +392,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -378,6 +392,8 @@ class OnnxConverter(base_converter.ConverterInterface):
ir_version = onnx_model.ir_version ir_version = onnx_model.ir_version
opset_imp = onnx_model.opset_import opset_imp = onnx_model.opset_import
self._isKaldi = False
polish_available = True polish_available = True
print("onnx model IR version: ", ir_version) print("onnx model IR version: ", ir_version)
for imp in opset_imp: for imp in opset_imp:
...@@ -387,6 +403,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -387,6 +403,7 @@ class OnnxConverter(base_converter.ConverterInterface):
if 'kaldi2onnx' in domain: if 'kaldi2onnx' in domain:
polish_available = False polish_available = False
self._data_format = DataFormat.DF_NONE self._data_format = DataFormat.DF_NONE
self._isKaldi = True
if polish_available: if polish_available:
onnx_model = onnx.utils.polish_model(onnx_model) onnx_model = onnx.utils.polish_model(onnx_model)
...@@ -643,10 +660,10 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -643,10 +660,10 @@ class OnnxConverter(base_converter.ConverterInterface):
mace_check(axis_value == 1 or axis_value == -3, mace_check(axis_value == 1 or axis_value == -3,
"only support concat at channel dimension") "only support concat at channel dimension")
elif node.op_type == OnnxOpType.Append.name: elif node.op_type == OnnxOpType.Append.name:
axis_value = 1 axis_value = -1
axis_arg = op.arg.add() axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 4 + axis_value if axis_value < 0 else axis_value axis_arg.i = axis_value
def convert_conv2d(self, node): def convert_conv2d(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
...@@ -772,15 +789,15 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -772,15 +789,15 @@ class OnnxConverter(base_converter.ConverterInterface):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.DynamicLSTM.name op.type = MaceOp.DynamicLSTM.name
if 'delay_a' in node.attrs: if 'prev_out_delay' in node.attrs:
prev_out_delay = node.attrs['delay_a'] prev_out_delay = node.attrs['prev_out_delay']
mace_check(prev_out_delay < 0, mace_check(prev_out_delay < 0,
"dynamic's prev_out_delay should <= 0.") "dynamic's prev_out_delay should <= 0.")
prev_out_delay_arg = op.arg.add() prev_out_delay_arg = op.arg.add()
prev_out_delay_arg.name = 'prev_out_delay' prev_out_delay_arg.name = 'prev_out_delay'
prev_out_delay_arg.i = prev_out_delay prev_out_delay_arg.i = prev_out_delay
if 'delay_b' in node.attrs: if 'prev_cell_delay' in node.attrs:
prev_cell_delay = node.attrs['delay_b'] prev_cell_delay = node.attrs['prev_cell_delay']
mace_check(prev_cell_delay < 0, mace_check(prev_cell_delay < 0,
"dynamic's prev_cell_delay should < 0.") "dynamic's prev_cell_delay should < 0.")
prev_cell_delay_arg = op.arg.add() prev_cell_delay_arg = op.arg.add()
...@@ -788,20 +805,20 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -788,20 +805,20 @@ class OnnxConverter(base_converter.ConverterInterface):
prev_cell_delay_arg.i = prev_cell_delay prev_cell_delay_arg.i = prev_cell_delay
if 'prev_out_offset' in node.attrs: if 'prev_out_offset' in node.attrs:
prev_out_offset = node.attrs['prev_out_offset'] prev_out_offset = node.attrs['prev_out_offset']
mace_check(pre_out_offset >= 0, mace_check(prev_out_offset >= 0,
"dynamic's prev_out_offset should >= 0.") "dynamic's prev_out_offset should >= 0.")
prev_out_offset_arg = op.arg.add() prev_out_offset_arg = op.arg.add()
prev_out_offset_arg.name = 'prev_out_offset' prev_out_offset_arg.name = 'prev_out_offset'
prev_out_offset_arg.i = prev_out_offset prev_out_offset_arg.i = prev_out_offset
if 'prev_a_dim' in node.attrs: if 'prev_out_dim' in node.attrs:
prev_out_dim = node.attrs['prev_a_dim'] prev_out_dim = node.attrs['prev_out_dim']
mace_check(prev_out_dim > 0, mace_check(prev_out_dim > 0,
"dynamic's prev_out_dim should > 0.") "dynamic's prev_out_dim should > 0.")
prev_out_dim_arg = op.arg.add() prev_out_dim_arg = op.arg.add()
prev_out_dim_arg.name = 'prev_out_dim' prev_out_dim_arg.name = 'prev_out_dim'
prev_out_dim_arg.i = prev_out_dim prev_out_dim_arg.i = prev_out_dim
if 'prev_b_dim' in node.attrs: if 'prev_cell_dim' in node.attrs:
prev_cell_dim = node.attrs['prev_b_dim'] prev_cell_dim = node.attrs['prev_cell_dim']
mace_check(prev_cell_dim > 0, mace_check(prev_cell_dim > 0,
"dynamic's prev_cell_dim should > 0.") "dynamic's prev_cell_dim should > 0.")
prev_cell_dim_arg = op.arg.add() prev_cell_dim_arg = op.arg.add()
...@@ -844,11 +861,154 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -844,11 +861,154 @@ class OnnxConverter(base_converter.ConverterInterface):
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = value value_arg.f = value
@staticmethod
def copy_node_attr(op, node, attr_name, dtype=AttributeType.INT,
default=None):
if attr_name in node.attrs or default is not None:
if attr_name in node.attrs:
value = node.attrs[attr_name]
else:
value = default
new_arg = op.arg.add()
new_arg.name = attr_name
if dtype == AttributeType.INT:
new_arg.i = int(value)
elif dtype == AttributeType.FLOAT:
new_arg.f = float(value)
elif dtype == AttributeType.INTS:
new_arg.ints.extend(value)
elif dtype == AttributeType.FLOATS:
new_arg.floats.extend(value)
return value
else:
return default
def convert_extract_pooling(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.ExtractPooling.name
self.copy_node_attr(op, node, 'include_variance', AttributeType.INT)
self.copy_node_attr(op, node, 'num_log_count', AttributeType.INT)
self.copy_node_attr(op, node, 'variance_floor', AttributeType.FLOAT)
self.copy_node_attr(op, node, 'input_time_range', AttributeType.INTS)
self.copy_node_attr(op, node, 'input_indexes', AttributeType.INTS)
if 'output_time_range' in node.attrs:
output_time_range = node.attrs['output_time_range']
mace_check(len(output_time_range) == 2,
"output time range should have two values.")
out_start_index = output_time_range[0]
out_end_index = output_time_range[1]
else:
mace_check('start_index' in node.attrs and
'end_index' in node.attrs,
"'start_index' and 'end_index'"
" are required in ExtractPooling.")
out_start_index = node.attrs['start_index']
out_end_index = node.attrs['end_index']
output_time_range = [out_start_index, out_end_index]
output_time_range_arg = op.arg.add()
output_time_range_arg.name = 'output_time_range'
output_time_range_arg.ints.extend(output_time_range)
mace_check('modulus' in node.attrs,
"'modulus' is required in ExtractPooling.")
mace_check('output_indexes' in node.attrs,
"'output_indexes' is required in ExtractPooling.")
mace_check('counts' in node.attrs,
"'counts' is required in ExtractPooling.")
mace_check('forward_indexes' in node.attrs,
"'forward_indexes' is required in ExtractPooling.")
modulus = node.attrs['modulus']
output_indexes = node.attrs['output_indexes']
counts = node.attrs['counts']
forward_indexes = node.attrs['forward_indexes']
mace_check(len(counts) == len(output_indexes) and
len(forward_indexes) == 2 * len(output_indexes),
"output_indexes length:%s "
"counts length:%s "
"forward_indexes length:%s"
% (len(output_indexes), len(counts), len(forward_indexes)))
new_output_indexes = []
new_forward_indexes = []
new_counts = []
for i in range(len(output_indexes)):
if output_indexes[i] + modulus > out_start_index and\
output_indexes[i] <= out_end_index:
new_output_indexes.append(output_indexes[i])
new_counts.append(counts[i])
new_forward_indexes.append(forward_indexes[2 * i])
new_forward_indexes.append(forward_indexes[2 * i + 1])
modulus_arg = op.arg.add()
modulus_arg.name = 'modulus'
modulus_arg.i = modulus
counts_arg = op.arg.add()
counts_arg.name = 'counts'
counts_arg.floats.extend(new_counts)
forward_indexes_arg = op.arg.add()
forward_indexes_arg.name = 'forward_indexes'
forward_indexes_arg.ints.extend(new_forward_indexes)
output_indexes_arg = op.arg.add()
output_indexes_arg.name = 'output_indexes'
output_indexes_arg.ints.extend(new_output_indexes)
def convert_flatten(self, node): def convert_flatten(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.Reshape.name op.type = MaceOp.Reshape.name
def convert_kaldi_batchnorm(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.KaldiBatchNorm.name
dim = self.copy_node_attr(op, node,
'dim', AttributeType.INT, -1)
block_dim = self.copy_node_attr(op, node,
'block_dim',
AttributeType.INT, -1)
epsilon = self.copy_node_attr(op, node,
'epsilon',
AttributeType.FLOAT, 1e-3)
target_rms = self.copy_node_attr(op, node,
'target_rms',
AttributeType.FLOAT, 1.0)
test_mode = self.copy_node_attr(op, node,
'test_mode',
AttributeType.INT, 0)
mace_check(block_dim > 0 and
dim % block_dim == 0 and
epsilon > 0 and
target_rms > 0, "attributes invalid.")
if test_mode > 0:
mace_check(len(node.inputs) == 3,
"Kaldi's BatchNorm should have 3 inputs.")
stats_mean = np.array(self._consts[node.inputs[1]].float_data)
stats_var = np.array(self._consts[node.inputs[2]].float_data)
offset_value = -1.0 * stats_mean
scale_value = stats_var
scale_value[scale_value < 0] = 0
scale_value = np.power(scale_value + epsilon, -0.5) * target_rms
offset_value = offset_value * scale_value
scale_name = node.name + '_scale'
offset_name = node.name + '_offset'
self.add_tensor(scale_name, scale_value.shape,
mace_pb2.DT_FLOAT, scale_value)
self.add_tensor(offset_name, offset_value.shape,
mace_pb2.DT_FLOAT, offset_value)
del op.input[1:]
op.input.extend([scale_name, offset_name])
del op.output[1:]
del op.output_shape[1:]
def convert_fused_batchnorm(self, node): def convert_fused_batchnorm(self, node):
if self._isKaldi:
self.convert_kaldi_batchnorm(node)
return
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name op.type = MaceOp.BatchNorm.name
...@@ -946,6 +1106,21 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -946,6 +1106,21 @@ class OnnxConverter(base_converter.ConverterInterface):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.Identity.name op.type = MaceOp.Identity.name
def convert_ifdefined(self, node):
op = self.convert_general_op(node)
if 'offset' in node.attrs:
offset = node.attrs['offset']
else:
offset = 0
mace_check(offset <= 0, "IfDefined's offset should be <= 0.")
if offset == 0:
op.type = MaceOp.Identity.name
else:
op.type = MaceOp.Delay.name
offset_arg = op.arg.add()
offset_arg.name = 'offset'
offset_arg.i = node.attrs['offset']
def convert_imagescaler(self, node): def convert_imagescaler(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name op.type = MaceOp.BatchNorm.name
...@@ -1100,7 +1275,6 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -1100,7 +1275,6 @@ class OnnxConverter(base_converter.ConverterInterface):
def convert_softmax(self, node): def convert_softmax(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.Softmax.name op.type = MaceOp.Softmax.name
# TODO: add logsoftmax in softmax op
if node.op_type == OnnxOpType.LogSoftmax.name: if node.op_type == OnnxOpType.LogSoftmax.name:
use_log_arg = op.arg.add() use_log_arg = op.arg.add()
use_log_arg.name = 'use_log' use_log_arg.name = 'use_log'
...@@ -1164,11 +1338,11 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -1164,11 +1338,11 @@ class OnnxConverter(base_converter.ConverterInterface):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.TargetRMSNorm.name op.type = MaceOp.TargetRMSNorm.name
if 'target_rms' in node.attrs: self.copy_node_attr(op, node, 'target_rms', AttributeType.FLOAT)
value = node.attrs['target_rms'] self.copy_node_attr(op, node, 'add_log_stddev', AttributeType.INT,
target_rms_arg = op.arg.add() default=0)
target_rms_arg.name = 'target_rms' self.copy_node_attr(op, node, 'block_dim', AttributeType.INT,
target_rms_arg.f = value default=0)
def convert_transpose(self, node): def convert_transpose(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
......
...@@ -1130,7 +1130,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1130,7 +1130,7 @@ class Transformer(base_converter.ConverterInterface):
rhs = op.input[1] rhs = op.input[1]
if rhs in self._consts and len(self._consts[rhs].dims) == 2: if rhs in self._consts and len(self._consts[rhs].dims) == 2:
arg = ConverterUtil.get_arg(op, MaceKeyword.mace_transpose_b_str) # noqa arg = ConverterUtil.get_arg(op, MaceKeyword.mace_transpose_b_str) # noqa
six.print_("Transpose matmul weight %s" % rhs) # six.print_("Transpose matmul weight %s" % rhs)
if arg is None: if arg is None:
arg = op.arg.add() arg = op.arg.add()
arg.name = MaceKeyword.mace_transpose_b_str arg.name = MaceKeyword.mace_transpose_b_str
...@@ -1143,7 +1143,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1143,7 +1143,8 @@ class Transformer(base_converter.ConverterInterface):
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
arg.i = 1 arg.i = 1
six.print_('transpose matmul weight') six.print_('Transpose matmul weight to shape:',
filter.dims)
def transpose_filters(self): def transpose_filters(self):
net = self._model net = self._model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册